Search in sources :

Example 26 with Saliency

use of org.kie.kogito.explainability.model.Saliency in project kogito-apps by kiegroup.

the class ShapResultsTest method buildShapResults.

ShapResults buildShapResults(int nOutputs, int nFeatures, int scalar1, int scalar2) {
    Saliency[] saliencies = new Saliency[nOutputs];
    for (int i = 0; i < nOutputs; i++) {
        List<FeatureImportance> fis = new ArrayList<>();
        for (int j = 0; j < nFeatures; j++) {
            fis.add(new FeatureImportance(new Feature("f" + String.valueOf(j), Type.NUMBER, new Value(j)), i * j * scalar1));
        }
        saliencies[i] = new Saliency(new Output("o" + String.valueOf(i), Type.NUMBER, new Value(i), 1.0), fis);
    }
    RealVector fnull = MatrixUtils.createRealVector(new double[nOutputs]);
    fnull.mapAddToSelf(scalar2);
    return new ShapResults(saliencies, fnull);
}
Also used : FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) RealVector(org.apache.commons.math3.linear.RealVector) Output(org.kie.kogito.explainability.model.Output) ArrayList(java.util.ArrayList) Value(org.kie.kogito.explainability.model.Value) Saliency(org.kie.kogito.explainability.model.Saliency) Feature(org.kie.kogito.explainability.model.Feature)

Example 27 with Saliency

use of org.kie.kogito.explainability.model.Saliency in project kogito-apps by kiegroup.

the class ExplainabilityMetricsTest method testFidelityWithTextClassifier.

@Test
void testFidelityWithTextClassifier() throws ExecutionException, InterruptedException, TimeoutException {
    List<Pair<Saliency, Prediction>> pairs = new LinkedList<>();
    LimeConfig limeConfig = new LimeConfig().withSamples(10);
    LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
    PredictionProvider model = TestUtils.getDummyTextClassifier();
    List<Feature> features = new LinkedList<>();
    features.add(FeatureFactory.newFulltextFeature("f-0", "brown fox", s -> Arrays.asList(s.split(" "))));
    features.add(FeatureFactory.newTextFeature("f-1", "money"));
    PredictionInput input = new PredictionInput(features);
    Prediction prediction = new SimplePrediction(input, model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit()).get(0));
    Map<String, Saliency> saliencyMap = limeExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    for (Saliency saliency : saliencyMap.values()) {
        pairs.add(Pair.of(saliency, prediction));
    }
    Assertions.assertDoesNotThrow(() -> {
        ExplainabilityMetrics.classificationFidelity(pairs);
    });
}
Also used : FeatureFactory(org.kie.kogito.explainability.model.FeatureFactory) Arrays(java.util.Arrays) Feature(org.kie.kogito.explainability.model.Feature) Prediction(org.kie.kogito.explainability.model.Prediction) Assertions.assertThat(org.assertj.core.api.Assertions.assertThat) TimeoutException(java.util.concurrent.TimeoutException) Saliency(org.kie.kogito.explainability.model.Saliency) Pair(org.apache.commons.lang3.tuple.Pair) Assertions.assertFalse(org.junit.jupiter.api.Assertions.assertFalse) Map(java.util.Map) CompletableFuture.supplyAsync(java.util.concurrent.CompletableFuture.supplyAsync) LimeConfig(org.kie.kogito.explainability.local.lime.LimeConfig) Assertions.assertEquals(org.junit.jupiter.api.Assertions.assertEquals) LinkedList(java.util.LinkedList) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) Awaitility.await(org.awaitility.Awaitility.await) LimeExplainer(org.kie.kogito.explainability.local.lime.LimeExplainer) Collections.emptyList(java.util.Collections.emptyList) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) ExecutionException(java.util.concurrent.ExecutionException) TimeUnit(java.util.concurrent.TimeUnit) Test(org.junit.jupiter.api.Test) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) List(java.util.List) TestUtils(org.kie.kogito.explainability.TestUtils) Assertions(org.junit.jupiter.api.Assertions) Config(org.kie.kogito.explainability.Config) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) LimeExplainer(org.kie.kogito.explainability.local.lime.LimeExplainer) Prediction(org.kie.kogito.explainability.model.Prediction) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) Saliency(org.kie.kogito.explainability.model.Saliency) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Feature(org.kie.kogito.explainability.model.Feature) LinkedList(java.util.LinkedList) LimeConfig(org.kie.kogito.explainability.local.lime.LimeConfig) Pair(org.apache.commons.lang3.tuple.Pair) Test(org.junit.jupiter.api.Test)

Example 28 with Saliency

use of org.kie.kogito.explainability.model.Saliency in project kogito-apps by kiegroup.

the class PmmlScorecardCategoricalLimeExplainerTest method testPMMLScorecardCategorical.

@Test
void testPMMLScorecardCategorical() throws Exception {
    PredictionInput input = getTestInput();
    Random random = new Random();
    LimeConfig limeConfig = new LimeConfig().withSamples(10).withPerturbationContext(new PerturbationContext(0L, random, 1));
    LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
    PredictionProvider model = getModel();
    List<PredictionOutput> predictionOutputs = model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    assertThat(predictionOutputs).isNotNull().isNotEmpty();
    PredictionOutput output = predictionOutputs.get(0);
    assertThat(output).isNotNull();
    Prediction prediction = new SimplePrediction(input, output);
    Map<String, Saliency> saliencyMap = limeExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    for (Saliency saliency : saliencyMap.values()) {
        assertThat(saliency).isNotNull();
        double v = ExplainabilityMetrics.impactScore(model, prediction, saliency.getTopFeatures(2));
        assertThat(v).isGreaterThan(0d);
    }
    assertDoesNotThrow(() -> ValidationUtils.validateLocalSaliencyStability(model, prediction, limeExplainer, 1, 0.4, 0.4));
    List<PredictionInput> inputs = getSamples();
    DataDistribution distribution = new PredictionInputsDataDistribution(inputs);
    String decision = "score";
    int k = 1;
    int chunkSize = 2;
    double f1 = ExplainabilityMetrics.getLocalSaliencyF1(decision, model, limeExplainer, distribution, k, chunkSize);
    AssertionsForClassTypes.assertThat(f1).isBetween(0d, 1d);
}
Also used : SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) LimeExplainer(org.kie.kogito.explainability.local.lime.LimeExplainer) Prediction(org.kie.kogito.explainability.model.Prediction) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) Saliency(org.kie.kogito.explainability.model.Saliency) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) LimeConfig(org.kie.kogito.explainability.local.lime.LimeConfig) Random(java.util.Random) PredictionInputsDataDistribution(org.kie.kogito.explainability.model.PredictionInputsDataDistribution) DataDistribution(org.kie.kogito.explainability.model.DataDistribution) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) PredictionInputsDataDistribution(org.kie.kogito.explainability.model.PredictionInputsDataDistribution) Test(org.junit.jupiter.api.Test)

Example 29 with Saliency

use of org.kie.kogito.explainability.model.Saliency in project kogito-apps by kiegroup.

the class LimeExplainerTest method testWithDataDistribution.

@Test
void testWithDataDistribution() throws InterruptedException, ExecutionException, TimeoutException {
    Random random = new Random();
    PerturbationContext perturbationContext = new PerturbationContext(4L, random, 1);
    List<FeatureDistribution> featureDistributions = new ArrayList<>();
    int nf = 4;
    List<Feature> features = new ArrayList<>();
    for (int i = 0; i < nf; i++) {
        Feature numericalFeature = FeatureFactory.newNumericalFeature("f-" + i, Double.NaN);
        features.add(numericalFeature);
        List<Value> values = new ArrayList<>();
        for (int r = 0; r < 4; r++) {
            values.add(Type.NUMBER.randomValue(perturbationContext));
        }
        featureDistributions.add(new GenericFeatureDistribution(numericalFeature, values));
    }
    DataDistribution dataDistribution = new IndependentFeaturesDataDistribution(featureDistributions);
    LimeConfig limeConfig = new LimeConfig().withDataDistribution(dataDistribution).withPerturbationContext(perturbationContext).withSamples(10);
    LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
    PredictionInput input = new PredictionInput(features);
    PredictionProvider model = TestUtils.getSumThresholdModel(random.nextDouble(), random.nextDouble());
    PredictionOutput output = model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit()).get(0);
    Prediction prediction = new SimplePrediction(input, output);
    Map<String, Saliency> saliencyMap = limeExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    assertThat(saliencyMap).isNotNull();
    String decisionName = "inside";
    Saliency saliency = saliencyMap.get(decisionName);
    assertThat(saliency).isNotNull();
}
Also used : SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) Prediction(org.kie.kogito.explainability.model.Prediction) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) ArrayList(java.util.ArrayList) Saliency(org.kie.kogito.explainability.model.Saliency) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Feature(org.kie.kogito.explainability.model.Feature) GenericFeatureDistribution(org.kie.kogito.explainability.model.GenericFeatureDistribution) FeatureDistribution(org.kie.kogito.explainability.model.FeatureDistribution) Random(java.util.Random) DataDistribution(org.kie.kogito.explainability.model.DataDistribution) IndependentFeaturesDataDistribution(org.kie.kogito.explainability.model.IndependentFeaturesDataDistribution) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Value(org.kie.kogito.explainability.model.Value) IndependentFeaturesDataDistribution(org.kie.kogito.explainability.model.IndependentFeaturesDataDistribution) GenericFeatureDistribution(org.kie.kogito.explainability.model.GenericFeatureDistribution) Test(org.junit.jupiter.api.Test) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Example 30 with Saliency

use of org.kie.kogito.explainability.model.Saliency in project kogito-apps by kiegroup.

the class LimeExplainerTest method testSparseBalance.

@ParameterizedTest
@ValueSource(longs = { 0, 1, 2, 3, 4 })
void testSparseBalance(long seed) throws InterruptedException, ExecutionException, TimeoutException {
    for (int nf = 1; nf < 4; nf++) {
        Random random = new Random();
        int noOfSamples = 100;
        LimeConfig limeConfigNoPenalty = new LimeConfig().withPerturbationContext(new PerturbationContext(seed, random, DEFAULT_NO_OF_PERTURBATIONS)).withSamples(noOfSamples).withPenalizeBalanceSparse(false);
        LimeExplainer limeExplainerNoPenalty = new LimeExplainer(limeConfigNoPenalty);
        List<Feature> features = new ArrayList<>();
        for (int i = 0; i < nf; i++) {
            features.add(TestUtils.getMockedNumericFeature(i));
        }
        PredictionInput input = new PredictionInput(features);
        PredictionProvider model = TestUtils.getSumSkipModel(0);
        PredictionOutput output = model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit()).get(0);
        Prediction prediction = new SimplePrediction(input, output);
        Map<String, Saliency> saliencyMapNoPenalty = limeExplainerNoPenalty.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        assertThat(saliencyMapNoPenalty).isNotNull();
        String decisionName = "sum-but0";
        Saliency saliencyNoPenalty = saliencyMapNoPenalty.get(decisionName);
        LimeConfig limeConfig = new LimeConfig().withPerturbationContext(new PerturbationContext(seed, random, DEFAULT_NO_OF_PERTURBATIONS)).withSamples(noOfSamples).withPenalizeBalanceSparse(true);
        LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
        Map<String, Saliency> saliencyMap = limeExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        assertThat(saliencyMap).isNotNull();
        Saliency saliency = saliencyMap.get(decisionName);
        for (int i = 0; i < features.size(); i++) {
            double score = saliency.getPerFeatureImportance().get(i).getScore();
            double scoreNoPenalty = saliencyNoPenalty.getPerFeatureImportance().get(i).getScore();
            assertThat(Math.abs(score)).isLessThanOrEqualTo(Math.abs(scoreNoPenalty));
        }
    }
}
Also used : SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) Prediction(org.kie.kogito.explainability.model.Prediction) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) ArrayList(java.util.ArrayList) Saliency(org.kie.kogito.explainability.model.Saliency) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Feature(org.kie.kogito.explainability.model.Feature) Random(java.util.Random) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) ValueSource(org.junit.jupiter.params.provider.ValueSource) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Aggregations

Saliency (org.kie.kogito.explainability.model.Saliency)51 Prediction (org.kie.kogito.explainability.model.Prediction)44 PredictionInput (org.kie.kogito.explainability.model.PredictionInput)43 PredictionOutput (org.kie.kogito.explainability.model.PredictionOutput)43 PredictionProvider (org.kie.kogito.explainability.model.PredictionProvider)39 SimplePrediction (org.kie.kogito.explainability.model.SimplePrediction)39 ArrayList (java.util.ArrayList)34 Random (java.util.Random)28 Feature (org.kie.kogito.explainability.model.Feature)26 PerturbationContext (org.kie.kogito.explainability.model.PerturbationContext)26 ParameterizedTest (org.junit.jupiter.params.ParameterizedTest)25 Test (org.junit.jupiter.api.Test)23 FeatureImportance (org.kie.kogito.explainability.model.FeatureImportance)23 DataDistribution (org.kie.kogito.explainability.model.DataDistribution)21 PredictionInputsDataDistribution (org.kie.kogito.explainability.model.PredictionInputsDataDistribution)18 ValueSource (org.junit.jupiter.params.provider.ValueSource)16 LimeConfig (org.kie.kogito.explainability.local.lime.LimeConfig)16 LimeExplainer (org.kie.kogito.explainability.local.lime.LimeExplainer)16 LinkedList (java.util.LinkedList)13 RealMatrix (org.apache.commons.math3.linear.RealMatrix)9