Search in sources :

Example 71 with Prediction

use of org.kie.kogito.explainability.model.Prediction in project kogito-apps by kiegroup.

the class LimeExplainerTest method testEmptyInput.

@Test
void testEmptyInput() {
    LimeExplainer recordingLimeExplainer = new LimeExplainer();
    PredictionProvider model = mock(PredictionProvider.class);
    Prediction prediction = mock(Prediction.class);
    assertThatCode(() -> recordingLimeExplainer.explainAsync(prediction, model)).hasMessage("cannot explain a prediction whose input is empty");
}
Also used : Prediction(org.kie.kogito.explainability.model.Prediction) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Test(org.junit.jupiter.api.Test) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Example 72 with Prediction

use of org.kie.kogito.explainability.model.Prediction in project kogito-apps by kiegroup.

the class LimeStabilityTest method testStabilityDeterministic.

@ParameterizedTest
@ValueSource(longs = { 0, 1, 2, 3, 4 })
void testStabilityDeterministic(long seed) throws Exception {
    List<LocalSaliencyStability> stabilities = new ArrayList<>();
    for (int j = 0; j < 2; j++) {
        Random random = new Random();
        PredictionProvider model = TestUtils.getSumSkipModel(0);
        List<Feature> featureList = new LinkedList<>();
        for (int i = 0; i < 5; i++) {
            featureList.add(TestUtils.getMockedNumericFeature(i));
        }
        PredictionInput input = new PredictionInput(featureList);
        List<PredictionOutput> predictionOutputs = model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        Prediction prediction = new SimplePrediction(input, predictionOutputs.get(0));
        LimeConfig limeConfig = new LimeConfig().withSamples(10).withPerturbationContext(new PerturbationContext(seed, random, 1));
        LimeExplainer explainer = new LimeExplainer(limeConfig);
        LocalSaliencyStability stability = ExplainabilityMetrics.getLocalSaliencyStability(model, prediction, explainer, 2, 10);
        stabilities.add(stability);
    }
    LocalSaliencyStability first = stabilities.get(0);
    LocalSaliencyStability second = stabilities.get(1);
    String decisionName = "sum-but0";
    assertThat(first.getNegativeStabilityScore(decisionName, 1)).isEqualTo(second.getNegativeStabilityScore(decisionName, 1));
    assertThat(first.getPositiveStabilityScore(decisionName, 1)).isEqualTo(second.getPositiveStabilityScore(decisionName, 1));
    assertThat(first.getNegativeStabilityScore(decisionName, 2)).isEqualTo(second.getNegativeStabilityScore(decisionName, 2));
    assertThat(first.getPositiveStabilityScore(decisionName, 2)).isEqualTo(second.getPositiveStabilityScore(decisionName, 2));
}
Also used : SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) Prediction(org.kie.kogito.explainability.model.Prediction) ArrayList(java.util.ArrayList) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Feature(org.kie.kogito.explainability.model.Feature) LinkedList(java.util.LinkedList) Random(java.util.Random) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) LocalSaliencyStability(org.kie.kogito.explainability.utils.LocalSaliencyStability) ValueSource(org.junit.jupiter.params.provider.ValueSource) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Example 73 with Prediction

use of org.kie.kogito.explainability.model.Prediction in project kogito-apps by kiegroup.

the class CountingOptimizationStrategyTest method testMaybeOptimize.

@Test
void testMaybeOptimize() {
    LimeOptimizationService optimizationService = mock(LimeOptimizationService.class);
    CountingOptimizationStrategy strategy = new CountingOptimizationStrategy(10, optimizationService);
    List<Prediction> recordedPredictions = Collections.emptyList();
    PredictionProvider model = mock(PredictionProvider.class);
    LimeExplainer explaier = new LimeExplainer();
    LimeConfig config = new LimeConfig();
    assertThatCode(() -> strategy.maybeOptimize(recordedPredictions, model, explaier, config)).doesNotThrowAnyException();
}
Also used : LimeExplainer(org.kie.kogito.explainability.local.lime.LimeExplainer) Prediction(org.kie.kogito.explainability.model.Prediction) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) LimeConfig(org.kie.kogito.explainability.local.lime.LimeConfig) Test(org.junit.jupiter.api.Test)

Example 74 with Prediction

use of org.kie.kogito.explainability.model.Prediction in project kogito-apps by kiegroup.

the class AggregatedLimeExplainerTest method testExplainWithPredictions.

@ParameterizedTest
@ValueSource(ints = { 0, 1, 2, 3, 4 })
void testExplainWithPredictions(int seed) throws ExecutionException, InterruptedException {
    Random random = new Random();
    random.setSeed(seed);
    PredictionProvider sumSkipModel = TestUtils.getSumSkipModel(1);
    DataDistribution dataDistribution = DataUtils.generateRandomDataDistribution(3, 100, random);
    List<PredictionInput> samples = dataDistribution.sample(10);
    List<PredictionOutput> predictionOutputs = sumSkipModel.predictAsync(samples).get();
    List<Prediction> predictions = DataUtils.getPredictions(samples, predictionOutputs);
    AggregatedLimeExplainer aggregatedLimeExplainer = new AggregatedLimeExplainer();
    Map<String, Saliency> explain = aggregatedLimeExplainer.explainFromPredictions(sumSkipModel, predictions).get();
    assertNotNull(explain);
    assertEquals(1, explain.size());
    assertTrue(explain.containsKey("sum-but1"));
    Saliency saliency = explain.get("sum-but1");
    assertNotNull(saliency);
    List<String> collect = saliency.getPositiveFeatures(2).stream().map(FeatureImportance::getFeature).map(Feature::getName).collect(Collectors.toList());
    // skipped feature should not appear in top two positive features
    assertFalse(collect.contains("f1"));
}
Also used : PredictionInput(org.kie.kogito.explainability.model.PredictionInput) Prediction(org.kie.kogito.explainability.model.Prediction) Saliency(org.kie.kogito.explainability.model.Saliency) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Random(java.util.Random) DataDistribution(org.kie.kogito.explainability.model.DataDistribution) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) ValueSource(org.junit.jupiter.params.provider.ValueSource) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Example 75 with Prediction

use of org.kie.kogito.explainability.model.Prediction in project kogito-apps by kiegroup.

the class PartialDependencePlotExplainerTest method testTextClassifier.

@ParameterizedTest
@ValueSource(ints = { 0, 1, 2, 3, 4 })
void testTextClassifier(int seed) throws Exception {
    Random random = new Random();
    random.setSeed(seed);
    PartialDependencePlotExplainer partialDependencePlotExplainer = new PartialDependencePlotExplainer();
    PredictionProvider model = TestUtils.getDummyTextClassifier();
    Collection<Prediction> predictions = new ArrayList<>(3);
    List<String> texts = List.of("we want your money", "please reply quickly", "you are the lucky winner", "huge donation for you!", "bitcoin for you");
    for (String text : texts) {
        List<Feature> features = new ArrayList<>();
        features.add(FeatureFactory.newFulltextFeature("text", text));
        PredictionInput predictionInput = new PredictionInput(features);
        PredictionOutput predictionOutput = model.predictAsync(List.of(predictionInput)).get().get(0);
        predictions.add(new SimplePrediction(predictionInput, predictionOutput));
    }
    List<PartialDependenceGraph> pdps = partialDependencePlotExplainer.explainFromPredictions(model, predictions);
    assertThat(pdps).isNotEmpty();
}
Also used : SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) Prediction(org.kie.kogito.explainability.model.Prediction) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) ArrayList(java.util.ArrayList) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Feature(org.kie.kogito.explainability.model.Feature) Random(java.util.Random) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) PartialDependenceGraph(org.kie.kogito.explainability.model.PartialDependenceGraph) ValueSource(org.junit.jupiter.params.provider.ValueSource) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Aggregations

Prediction (org.kie.kogito.explainability.model.Prediction)134 PredictionProvider (org.kie.kogito.explainability.model.PredictionProvider)117 PredictionOutput (org.kie.kogito.explainability.model.PredictionOutput)107 PredictionInput (org.kie.kogito.explainability.model.PredictionInput)105 SimplePrediction (org.kie.kogito.explainability.model.SimplePrediction)96 Test (org.junit.jupiter.api.Test)95 Random (java.util.Random)65 PerturbationContext (org.kie.kogito.explainability.model.PerturbationContext)61 LimeConfig (org.kie.kogito.explainability.local.lime.LimeConfig)57 ArrayList (java.util.ArrayList)51 Feature (org.kie.kogito.explainability.model.Feature)48 Saliency (org.kie.kogito.explainability.model.Saliency)48 ParameterizedTest (org.junit.jupiter.params.ParameterizedTest)42 LimeExplainer (org.kie.kogito.explainability.local.lime.LimeExplainer)40 LimeConfigOptimizer (org.kie.kogito.explainability.local.lime.optim.LimeConfigOptimizer)28 DataDistribution (org.kie.kogito.explainability.model.DataDistribution)24 ValueSource (org.junit.jupiter.params.provider.ValueSource)22 FeatureImportance (org.kie.kogito.explainability.model.FeatureImportance)22 Output (org.kie.kogito.explainability.model.Output)22 LinkedList (java.util.LinkedList)21