Search in sources :

Example 91 with Prediction

use of org.kie.kogito.explainability.model.Prediction in project kogito-apps by kiegroup.

the class LimeExplainerTest method testZeroSampleSize.

@Test
void testZeroSampleSize() throws ExecutionException, InterruptedException, TimeoutException {
    LimeConfig limeConfig = new LimeConfig().withSamples(0);
    LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
    List<Feature> features = new ArrayList<>();
    for (int i = 0; i < 4; i++) {
        features.add(TestUtils.getMockedNumericFeature(i));
    }
    PredictionInput input = new PredictionInput(features);
    PredictionProvider model = TestUtils.getSumSkipModel(0);
    PredictionOutput output = model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit()).get(0);
    Prediction prediction = new SimplePrediction(input, output);
    Map<String, Saliency> saliencyMap = limeExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    assertNotNull(saliencyMap);
}
Also used : SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) Prediction(org.kie.kogito.explainability.model.Prediction) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) ArrayList(java.util.ArrayList) Saliency(org.kie.kogito.explainability.model.Saliency) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Feature(org.kie.kogito.explainability.model.Feature) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Test(org.junit.jupiter.api.Test) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Example 92 with Prediction

use of org.kie.kogito.explainability.model.Prediction in project kogito-apps by kiegroup.

the class LimeExplainerTest method testDeterministic.

@ParameterizedTest
@ValueSource(longs = { 0, 1, 2, 3, 4 })
void testDeterministic(long seed) throws ExecutionException, InterruptedException, TimeoutException {
    List<Saliency> saliencies = new ArrayList<>();
    for (int j = 0; j < 2; j++) {
        Random random = new Random();
        LimeConfig limeConfig = new LimeConfig().withPerturbationContext(new PerturbationContext(seed, random, DEFAULT_NO_OF_PERTURBATIONS)).withSamples(10);
        LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
        List<Feature> features = new ArrayList<>();
        for (int i = 0; i < 4; i++) {
            features.add(TestUtils.getMockedNumericFeature(i));
        }
        PredictionInput input = new PredictionInput(features);
        PredictionProvider model = TestUtils.getSumSkipModel(0);
        PredictionOutput output = model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit()).get(0);
        Prediction prediction = new SimplePrediction(input, output);
        Map<String, Saliency> saliencyMap = limeExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        saliencies.add(saliencyMap.get("sum-but0"));
    }
    assertThat(saliencies.get(0).getPerFeatureImportance().stream().map(FeatureImportance::getScore).collect(Collectors.toList())).isEqualTo(saliencies.get(1).getPerFeatureImportance().stream().map(FeatureImportance::getScore).collect(Collectors.toList()));
}
Also used : SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) Prediction(org.kie.kogito.explainability.model.Prediction) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) ArrayList(java.util.ArrayList) Saliency(org.kie.kogito.explainability.model.Saliency) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Feature(org.kie.kogito.explainability.model.Feature) Random(java.util.Random) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) ValueSource(org.junit.jupiter.params.provider.ValueSource) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Example 93 with Prediction

use of org.kie.kogito.explainability.model.Prediction in project kogito-apps by kiegroup.

the class LimeExplainerTest method testNormalizedWeights.

@Test
void testNormalizedWeights() throws InterruptedException, ExecutionException, TimeoutException {
    Random random = new Random();
    LimeConfig limeConfig = new LimeConfig().withNormalizeWeights(true).withPerturbationContext(new PerturbationContext(4L, random, 2)).withSamples(10);
    LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
    int nf = 4;
    List<Feature> features = new ArrayList<>();
    for (int i = 0; i < nf; i++) {
        features.add(TestUtils.getMockedNumericFeature(i));
    }
    PredictionInput input = new PredictionInput(features);
    PredictionProvider model = TestUtils.getSumSkipModel(0);
    PredictionOutput output = model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit()).get(0);
    Prediction prediction = new SimplePrediction(input, output);
    Map<String, Saliency> saliencyMap = limeExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    assertThat(saliencyMap).isNotNull();
    String decisionName = "sum-but0";
    Saliency saliency = saliencyMap.get(decisionName);
    List<FeatureImportance> perFeatureImportance = saliency.getPerFeatureImportance();
    for (FeatureImportance featureImportance : perFeatureImportance) {
        assertThat(featureImportance.getScore()).isBetween(0d, 1d);
    }
}
Also used : SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) Prediction(org.kie.kogito.explainability.model.Prediction) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) ArrayList(java.util.ArrayList) Saliency(org.kie.kogito.explainability.model.Saliency) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Feature(org.kie.kogito.explainability.model.Feature) Random(java.util.Random) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Test(org.junit.jupiter.api.Test) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Example 94 with Prediction

use of org.kie.kogito.explainability.model.Prediction in project kogito-apps by kiegroup.

the class LimeStabilityTest method assertStable.

private void assertStable(LimeExplainer limeExplainer, PredictionProvider model, List<Feature> featureList) throws Exception {
    PredictionInput input = new PredictionInput(featureList);
    List<PredictionOutput> predictionOutputs = model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    for (PredictionOutput predictionOutput : predictionOutputs) {
        Prediction prediction = new SimplePrediction(input, predictionOutput);
        List<Saliency> saliencies = new LinkedList<>();
        for (int i = 0; i < 100; i++) {
            Map<String, Saliency> saliencyMap = limeExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
            saliencies.addAll(saliencyMap.values());
        }
        // check that the topmost important feature is stable
        List<String> names = new LinkedList<>();
        saliencies.stream().map(s -> s.getPositiveFeatures(1)).filter(f -> !f.isEmpty()).forEach(f -> names.add(f.get(0).getFeature().getName()));
        Map<String, Long> frequencyMap = names.stream().collect(Collectors.groupingBy(Function.identity(), Collectors.counting()));
        boolean topFeature = false;
        for (Map.Entry<String, Long> entry : frequencyMap.entrySet()) {
            if (entry.getValue() >= TOP_FEATURE_THRESHOLD) {
                topFeature = true;
                break;
            }
        }
        assertTrue(topFeature);
        // check that the impact is stable
        List<Double> impacts = new ArrayList<>(saliencies.size());
        for (Saliency saliency : saliencies) {
            double v = ExplainabilityMetrics.impactScore(model, prediction, saliency.getTopFeatures(2));
            impacts.add(v);
        }
        Map<Double, Long> impactMap = impacts.stream().collect(Collectors.groupingBy(Function.identity(), Collectors.counting()));
        boolean topImpact = false;
        for (Map.Entry<Double, Long> entry : impactMap.entrySet()) {
            if (entry.getValue() >= TOP_FEATURE_THRESHOLD) {
                topImpact = true;
                break;
            }
        }
        assertTrue(topImpact);
    }
}
Also used : ValueSource(org.junit.jupiter.params.provider.ValueSource) FeatureFactory(org.kie.kogito.explainability.model.FeatureFactory) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) Feature(org.kie.kogito.explainability.model.Feature) Prediction(org.kie.kogito.explainability.model.Prediction) Assertions.assertThat(org.assertj.core.api.Assertions.assertThat) Random(java.util.Random) Function(java.util.function.Function) Collectors(java.util.stream.Collectors) Saliency(org.kie.kogito.explainability.model.Saliency) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) ArrayList(java.util.ArrayList) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) List(java.util.List) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest) TestUtils(org.kie.kogito.explainability.TestUtils) ExplainabilityMetrics(org.kie.kogito.explainability.utils.ExplainabilityMetrics) Map(java.util.Map) Assertions.assertTrue(org.junit.jupiter.api.Assertions.assertTrue) LocalSaliencyStability(org.kie.kogito.explainability.utils.LocalSaliencyStability) LinkedList(java.util.LinkedList) Config(org.kie.kogito.explainability.Config) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) Prediction(org.kie.kogito.explainability.model.Prediction) ArrayList(java.util.ArrayList) Saliency(org.kie.kogito.explainability.model.Saliency) LinkedList(java.util.LinkedList) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Map(java.util.Map)

Example 95 with Prediction

use of org.kie.kogito.explainability.model.Prediction in project kogito-apps by kiegroup.

the class LimeCombinedScoreCalculatorTest method testNonZeroScore.

@Test
void testNonZeroScore() throws ExecutionException, InterruptedException, TimeoutException {
    PredictionProvider model = TestUtils.getDummyTextClassifier();
    LimeCombinedScoreCalculator scoreCalculator = new LimeCombinedScoreCalculator();
    LimeConfig config = new LimeConfig();
    List<Feature> features = List.of(FeatureFactory.newFulltextFeature("text", "money so they say is the root of all evil today"));
    PredictionInput input = new PredictionInput(features);
    List<PredictionOutput> predictionOutputs = model.predictAsync(List.of(input)).get(Config.DEFAULT_ASYNC_TIMEOUT, Config.DEFAULT_ASYNC_TIMEUNIT);
    assertThat(predictionOutputs).isNotNull();
    assertThat(predictionOutputs.size()).isEqualTo(1);
    PredictionOutput output = predictionOutputs.get(0);
    Prediction prediction = new SimplePrediction(input, output);
    List<Prediction> predictions = List.of(prediction);
    List<LimeConfigEntity> entities = LimeConfigEntityFactory.createEncodingEntities(config);
    LimeConfigSolution solution = new LimeConfigSolution(config, predictions, entities, model);
    SimpleBigDecimalScore score = scoreCalculator.calculateScore(solution);
    assertThat(score).isNotNull();
    assertThat(score.getScore()).isNotNull().isNotEqualTo(BigDecimal.valueOf(0));
}
Also used : SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) Prediction(org.kie.kogito.explainability.model.Prediction) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Feature(org.kie.kogito.explainability.model.Feature) LimeConfig(org.kie.kogito.explainability.local.lime.LimeConfig) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) SimpleBigDecimalScore(org.optaplanner.core.api.score.buildin.simplebigdecimal.SimpleBigDecimalScore) Test(org.junit.jupiter.api.Test)

Aggregations

Prediction (org.kie.kogito.explainability.model.Prediction)134 PredictionProvider (org.kie.kogito.explainability.model.PredictionProvider)117 PredictionOutput (org.kie.kogito.explainability.model.PredictionOutput)107 PredictionInput (org.kie.kogito.explainability.model.PredictionInput)105 SimplePrediction (org.kie.kogito.explainability.model.SimplePrediction)96 Test (org.junit.jupiter.api.Test)95 Random (java.util.Random)65 PerturbationContext (org.kie.kogito.explainability.model.PerturbationContext)61 LimeConfig (org.kie.kogito.explainability.local.lime.LimeConfig)57 ArrayList (java.util.ArrayList)51 Feature (org.kie.kogito.explainability.model.Feature)48 Saliency (org.kie.kogito.explainability.model.Saliency)48 ParameterizedTest (org.junit.jupiter.params.ParameterizedTest)42 LimeExplainer (org.kie.kogito.explainability.local.lime.LimeExplainer)40 LimeConfigOptimizer (org.kie.kogito.explainability.local.lime.optim.LimeConfigOptimizer)28 DataDistribution (org.kie.kogito.explainability.model.DataDistribution)24 ValueSource (org.junit.jupiter.params.provider.ValueSource)22 FeatureImportance (org.kie.kogito.explainability.model.FeatureImportance)22 Output (org.kie.kogito.explainability.model.Output)22 LinkedList (java.util.LinkedList)21