Search in sources :

Example 76 with PredictionProvider

use of org.kie.kogito.explainability.model.PredictionProvider in project kogito-apps by kiegroup.

the class LimeExplainerTest method testSparseBalance.

@ParameterizedTest
@ValueSource(longs = { 0, 1, 2, 3, 4 })
void testSparseBalance(long seed) throws InterruptedException, ExecutionException, TimeoutException {
    for (int nf = 1; nf < 4; nf++) {
        Random random = new Random();
        int noOfSamples = 100;
        LimeConfig limeConfigNoPenalty = new LimeConfig().withPerturbationContext(new PerturbationContext(seed, random, DEFAULT_NO_OF_PERTURBATIONS)).withSamples(noOfSamples).withPenalizeBalanceSparse(false);
        LimeExplainer limeExplainerNoPenalty = new LimeExplainer(limeConfigNoPenalty);
        List<Feature> features = new ArrayList<>();
        for (int i = 0; i < nf; i++) {
            features.add(TestUtils.getMockedNumericFeature(i));
        }
        PredictionInput input = new PredictionInput(features);
        PredictionProvider model = TestUtils.getSumSkipModel(0);
        PredictionOutput output = model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit()).get(0);
        Prediction prediction = new SimplePrediction(input, output);
        Map<String, Saliency> saliencyMapNoPenalty = limeExplainerNoPenalty.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        assertThat(saliencyMapNoPenalty).isNotNull();
        String decisionName = "sum-but0";
        Saliency saliencyNoPenalty = saliencyMapNoPenalty.get(decisionName);
        LimeConfig limeConfig = new LimeConfig().withPerturbationContext(new PerturbationContext(seed, random, DEFAULT_NO_OF_PERTURBATIONS)).withSamples(noOfSamples).withPenalizeBalanceSparse(true);
        LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
        Map<String, Saliency> saliencyMap = limeExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        assertThat(saliencyMap).isNotNull();
        Saliency saliency = saliencyMap.get(decisionName);
        for (int i = 0; i < features.size(); i++) {
            double score = saliency.getPerFeatureImportance().get(i).getScore();
            double scoreNoPenalty = saliencyNoPenalty.getPerFeatureImportance().get(i).getScore();
            assertThat(Math.abs(score)).isLessThanOrEqualTo(Math.abs(scoreNoPenalty));
        }
    }
}
Also used : SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) Prediction(org.kie.kogito.explainability.model.Prediction) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) ArrayList(java.util.ArrayList) Saliency(org.kie.kogito.explainability.model.Saliency) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Feature(org.kie.kogito.explainability.model.Feature) Random(java.util.Random) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) ValueSource(org.junit.jupiter.params.provider.ValueSource) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Example 77 with PredictionProvider

use of org.kie.kogito.explainability.model.PredictionProvider in project kogito-apps by kiegroup.

the class LimeExplainerTest method testEmptyPrediction.

@ParameterizedTest
@ValueSource(longs = { 0, 1, 2, 3, 4 })
void testEmptyPrediction(long seed) throws ExecutionException, InterruptedException, TimeoutException {
    Random random = new Random();
    LimeConfig limeConfig = new LimeConfig().withPerturbationContext(new PerturbationContext(seed, random, DEFAULT_NO_OF_PERTURBATIONS)).withSamples(10);
    LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
    PredictionInput input = new PredictionInput(Collections.emptyList());
    PredictionProvider model = TestUtils.getSumSkipModel(0);
    PredictionOutput output = model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit()).get(0);
    Prediction prediction = new SimplePrediction(input, output);
    assertThrows(LocalExplanationException.class, () -> limeExplainer.explainAsync(prediction, model));
}
Also used : SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) Random(java.util.Random) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Prediction(org.kie.kogito.explainability.model.Prediction) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) ValueSource(org.junit.jupiter.params.provider.ValueSource) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Example 78 with PredictionProvider

use of org.kie.kogito.explainability.model.PredictionProvider in project kogito-apps by kiegroup.

the class LimeExplainerTest method testNonEmptyInput.

@ParameterizedTest
@ValueSource(longs = { 0, 1, 2, 3, 4 })
void testNonEmptyInput(long seed) throws ExecutionException, InterruptedException, TimeoutException {
    Random random = new Random();
    LimeConfig limeConfig = new LimeConfig().withPerturbationContext(new PerturbationContext(seed, random, DEFAULT_NO_OF_PERTURBATIONS)).withSamples(10);
    LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
    List<Feature> features = new ArrayList<>();
    for (int i = 0; i < 4; i++) {
        features.add(TestUtils.getMockedNumericFeature(i));
    }
    PredictionInput input = new PredictionInput(features);
    PredictionProvider model = TestUtils.getSumSkipModel(0);
    PredictionOutput output = model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit()).get(0);
    Prediction prediction = new SimplePrediction(input, output);
    Map<String, Saliency> saliencyMap = limeExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    assertNotNull(saliencyMap);
}
Also used : SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) Prediction(org.kie.kogito.explainability.model.Prediction) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) ArrayList(java.util.ArrayList) Saliency(org.kie.kogito.explainability.model.Saliency) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Feature(org.kie.kogito.explainability.model.Feature) Random(java.util.Random) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) ValueSource(org.junit.jupiter.params.provider.ValueSource) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Example 79 with PredictionProvider

use of org.kie.kogito.explainability.model.PredictionProvider in project kogito-apps by kiegroup.

the class LimeExplainerTest method testEmptyInput.

@Test
void testEmptyInput() {
    LimeExplainer recordingLimeExplainer = new LimeExplainer();
    PredictionProvider model = mock(PredictionProvider.class);
    Prediction prediction = mock(Prediction.class);
    assertThatCode(() -> recordingLimeExplainer.explainAsync(prediction, model)).hasMessage("cannot explain a prediction whose input is empty");
}
Also used : Prediction(org.kie.kogito.explainability.model.Prediction) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Test(org.junit.jupiter.api.Test) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Example 80 with PredictionProvider

use of org.kie.kogito.explainability.model.PredictionProvider in project kogito-apps by kiegroup.

the class LimeStabilityTest method testStabilityDeterministic.

@ParameterizedTest
@ValueSource(longs = { 0, 1, 2, 3, 4 })
void testStabilityDeterministic(long seed) throws Exception {
    List<LocalSaliencyStability> stabilities = new ArrayList<>();
    for (int j = 0; j < 2; j++) {
        Random random = new Random();
        PredictionProvider model = TestUtils.getSumSkipModel(0);
        List<Feature> featureList = new LinkedList<>();
        for (int i = 0; i < 5; i++) {
            featureList.add(TestUtils.getMockedNumericFeature(i));
        }
        PredictionInput input = new PredictionInput(featureList);
        List<PredictionOutput> predictionOutputs = model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        Prediction prediction = new SimplePrediction(input, predictionOutputs.get(0));
        LimeConfig limeConfig = new LimeConfig().withSamples(10).withPerturbationContext(new PerturbationContext(seed, random, 1));
        LimeExplainer explainer = new LimeExplainer(limeConfig);
        LocalSaliencyStability stability = ExplainabilityMetrics.getLocalSaliencyStability(model, prediction, explainer, 2, 10);
        stabilities.add(stability);
    }
    LocalSaliencyStability first = stabilities.get(0);
    LocalSaliencyStability second = stabilities.get(1);
    String decisionName = "sum-but0";
    assertThat(first.getNegativeStabilityScore(decisionName, 1)).isEqualTo(second.getNegativeStabilityScore(decisionName, 1));
    assertThat(first.getPositiveStabilityScore(decisionName, 1)).isEqualTo(second.getPositiveStabilityScore(decisionName, 1));
    assertThat(first.getNegativeStabilityScore(decisionName, 2)).isEqualTo(second.getNegativeStabilityScore(decisionName, 2));
    assertThat(first.getPositiveStabilityScore(decisionName, 2)).isEqualTo(second.getPositiveStabilityScore(decisionName, 2));
}
Also used : SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) Prediction(org.kie.kogito.explainability.model.Prediction) ArrayList(java.util.ArrayList) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Feature(org.kie.kogito.explainability.model.Feature) LinkedList(java.util.LinkedList) Random(java.util.Random) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) LocalSaliencyStability(org.kie.kogito.explainability.utils.LocalSaliencyStability) ValueSource(org.junit.jupiter.params.provider.ValueSource) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Aggregations

PredictionProvider (org.kie.kogito.explainability.model.PredictionProvider)158 Prediction (org.kie.kogito.explainability.model.Prediction)134 PredictionInput (org.kie.kogito.explainability.model.PredictionInput)134 PredictionOutput (org.kie.kogito.explainability.model.PredictionOutput)126 Test (org.junit.jupiter.api.Test)109 SimplePrediction (org.kie.kogito.explainability.model.SimplePrediction)99 Random (java.util.Random)91 Feature (org.kie.kogito.explainability.model.Feature)76 ArrayList (java.util.ArrayList)73 PerturbationContext (org.kie.kogito.explainability.model.PerturbationContext)69 ParameterizedTest (org.junit.jupiter.params.ParameterizedTest)64 LimeConfig (org.kie.kogito.explainability.local.lime.LimeConfig)59 LimeExplainer (org.kie.kogito.explainability.local.lime.LimeExplainer)54 Output (org.kie.kogito.explainability.model.Output)45 Saliency (org.kie.kogito.explainability.model.Saliency)45 LinkedList (java.util.LinkedList)41 Value (org.kie.kogito.explainability.model.Value)41 List (java.util.List)37 LimeConfigOptimizer (org.kie.kogito.explainability.local.lime.optim.LimeConfigOptimizer)33 ValueSource (org.junit.jupiter.params.provider.ValueSource)32