Search in sources :

Example 81 with Prediction

use of org.kie.kogito.explainability.model.Prediction in project kogito-apps by kiegroup.

the class DummyModelsLimeExplainerTest method testMapOneFeatureToOutputRegression.

@ParameterizedTest
@ValueSource(longs = { 0 })
void testMapOneFeatureToOutputRegression(long seed) throws Exception {
    Random random = new Random();
    int idx = 1;
    List<Feature> features = new LinkedList<>();
    features.add(TestUtils.getMockedNumericFeature(100));
    features.add(TestUtils.getMockedNumericFeature(20));
    features.add(TestUtils.getMockedNumericFeature(0.1));
    PredictionInput input = new PredictionInput(features);
    PredictionProvider model = TestUtils.getFeaturePassModel(idx);
    List<PredictionOutput> outputs = model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    Prediction prediction = new SimplePrediction(input, outputs.get(0));
    LimeConfig limeConfig = new LimeConfig().withSamples(100).withPerturbationContext(new PerturbationContext(seed, random, 1));
    LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
    Map<String, Saliency> saliencyMap = limeExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    for (Saliency saliency : saliencyMap.values()) {
        assertNotNull(saliency);
        List<FeatureImportance> topFeatures = saliency.getTopFeatures(3);
        assertEquals(3, topFeatures.size());
        assertEquals(1d, ExplainabilityMetrics.impactScore(model, prediction, topFeatures));
    }
    int topK = 1;
    double minimumPositiveStabilityRate = 0.5;
    double minimumNegativeStabilityRate = 0.5;
    TestUtils.assertLimeStability(model, prediction, limeExplainer, topK, minimumPositiveStabilityRate, minimumNegativeStabilityRate);
    List<PredictionInput> inputs = new ArrayList<>();
    for (int i = 0; i < 100; i++) {
        List<Feature> fs = new LinkedList<>();
        fs.add(TestUtils.getMockedNumericFeature());
        fs.add(TestUtils.getMockedNumericFeature());
        fs.add(TestUtils.getMockedNumericFeature());
        inputs.add(new PredictionInput(fs));
    }
    DataDistribution distribution = new PredictionInputsDataDistribution(inputs);
    int k = 2;
    int chunkSize = 10;
    String decision = "feature-" + idx;
    double precision = ExplainabilityMetrics.getLocalSaliencyPrecision(decision, model, limeExplainer, distribution, k, chunkSize);
    assertThat(precision).isZero();
    double recall = ExplainabilityMetrics.getLocalSaliencyRecall(decision, model, limeExplainer, distribution, k, chunkSize);
    assertThat(recall).isEqualTo(1);
    double f1 = ExplainabilityMetrics.getLocalSaliencyF1(decision, model, limeExplainer, distribution, k, chunkSize);
    assertThat(f1).isZero();
}
Also used : SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) Prediction(org.kie.kogito.explainability.model.Prediction) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) ArrayList(java.util.ArrayList) Saliency(org.kie.kogito.explainability.model.Saliency) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Feature(org.kie.kogito.explainability.model.Feature) LinkedList(java.util.LinkedList) Random(java.util.Random) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) PredictionInputsDataDistribution(org.kie.kogito.explainability.model.PredictionInputsDataDistribution) DataDistribution(org.kie.kogito.explainability.model.DataDistribution) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) PredictionInputsDataDistribution(org.kie.kogito.explainability.model.PredictionInputsDataDistribution) ValueSource(org.junit.jupiter.params.provider.ValueSource) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Example 82 with Prediction

use of org.kie.kogito.explainability.model.Prediction in project kogito-apps by kiegroup.

the class ExplainabilityMetricsTest method testBrokenPredict.

@Test
void testBrokenPredict() {
    Config.INSTANCE.setAsyncTimeout(1);
    Config.INSTANCE.setAsyncTimeUnit(TimeUnit.MILLISECONDS);
    Prediction emptyPrediction = new SimplePrediction(new PredictionInput(emptyList()), new PredictionOutput(emptyList()));
    PredictionProvider brokenProvider = inputs -> supplyAsync(() -> {
        await().atLeast(1, TimeUnit.SECONDS).until(() -> false);
        throw new RuntimeException("this should never happen");
    });
    List<FeatureImportance> emptyFeatures = emptyList();
    try {
        Assertions.assertThrows(IllegalStateException.class, () -> ExplainabilityMetrics.impactScore(brokenProvider, emptyPrediction, emptyFeatures));
    } finally {
        Config.INSTANCE.setAsyncTimeout(Config.DEFAULT_ASYNC_TIMEOUT);
        Config.INSTANCE.setAsyncTimeUnit(Config.DEFAULT_ASYNC_TIMEUNIT);
    }
}
Also used : FeatureFactory(org.kie.kogito.explainability.model.FeatureFactory) Arrays(java.util.Arrays) Feature(org.kie.kogito.explainability.model.Feature) Prediction(org.kie.kogito.explainability.model.Prediction) Assertions.assertThat(org.assertj.core.api.Assertions.assertThat) TimeoutException(java.util.concurrent.TimeoutException) Saliency(org.kie.kogito.explainability.model.Saliency) Pair(org.apache.commons.lang3.tuple.Pair) Assertions.assertFalse(org.junit.jupiter.api.Assertions.assertFalse) Map(java.util.Map) CompletableFuture.supplyAsync(java.util.concurrent.CompletableFuture.supplyAsync) LimeConfig(org.kie.kogito.explainability.local.lime.LimeConfig) Assertions.assertEquals(org.junit.jupiter.api.Assertions.assertEquals) LinkedList(java.util.LinkedList) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) Awaitility.await(org.awaitility.Awaitility.await) LimeExplainer(org.kie.kogito.explainability.local.lime.LimeExplainer) Collections.emptyList(java.util.Collections.emptyList) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) ExecutionException(java.util.concurrent.ExecutionException) TimeUnit(java.util.concurrent.TimeUnit) Test(org.junit.jupiter.api.Test) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) List(java.util.List) TestUtils(org.kie.kogito.explainability.TestUtils) Assertions(org.junit.jupiter.api.Assertions) Config(org.kie.kogito.explainability.Config) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Prediction(org.kie.kogito.explainability.model.Prediction) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Test(org.junit.jupiter.api.Test)

Example 83 with Prediction

use of org.kie.kogito.explainability.model.Prediction in project kogito-apps by kiegroup.

the class ExplainabilityMetricsTest method testFidelityWithEvenSumModel.

@Test
void testFidelityWithEvenSumModel() throws ExecutionException, InterruptedException, TimeoutException {
    List<Pair<Saliency, Prediction>> pairs = new LinkedList<>();
    LimeConfig limeConfig = new LimeConfig().withSamples(10);
    LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
    PredictionProvider model = TestUtils.getEvenSumModel(1);
    List<Feature> features = new LinkedList<>();
    features.add(FeatureFactory.newNumericalFeature("f-1", 1));
    features.add(FeatureFactory.newNumericalFeature("f-2", 2));
    features.add(FeatureFactory.newNumericalFeature("f-3", 3));
    PredictionInput input = new PredictionInput(features);
    Prediction prediction = new SimplePrediction(input, model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit()).get(0));
    Map<String, Saliency> saliencyMap = limeExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    for (Saliency saliency : saliencyMap.values()) {
        pairs.add(Pair.of(saliency, prediction));
    }
    Assertions.assertDoesNotThrow(() -> {
        ExplainabilityMetrics.classificationFidelity(pairs);
    });
}
Also used : SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) LimeExplainer(org.kie.kogito.explainability.local.lime.LimeExplainer) Prediction(org.kie.kogito.explainability.model.Prediction) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) Saliency(org.kie.kogito.explainability.model.Saliency) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Feature(org.kie.kogito.explainability.model.Feature) LinkedList(java.util.LinkedList) LimeConfig(org.kie.kogito.explainability.local.lime.LimeConfig) Pair(org.apache.commons.lang3.tuple.Pair) Test(org.junit.jupiter.api.Test)

Example 84 with Prediction

use of org.kie.kogito.explainability.model.Prediction in project kogito-apps by kiegroup.

the class LimeConfigOptimizerTest method testSameConfig.

@Test
void testSameConfig() throws ExecutionException, InterruptedException {
    long seed = 0;
    List<LimeConfig> optimizedConfigs = new ArrayList<>();
    PredictionProvider model = TestUtils.getSumSkipModel(1);
    DataDistribution dataDistribution = DataUtils.generateRandomDataDistribution(5, 100, new Random());
    List<PredictionInput> samples = dataDistribution.sample(3);
    List<PredictionOutput> predictionOutputs = model.predictAsync(samples).get();
    List<Prediction> predictions = DataUtils.getPredictions(samples, predictionOutputs);
    for (int i = 0; i < 2; i++) {
        Random random = new Random();
        LimeConfig initialConfig = new LimeConfig().withSamples(10).withPerturbationContext(new PerturbationContext(seed, random, 1));
        LimeConfigOptimizer limeConfigOptimizer = new LimeConfigOptimizer().withDeterministicExecution(true).withStepCountLimit(10).withTimeLimit(10);
        LimeConfig optimizedConfig = limeConfigOptimizer.optimize(initialConfig, predictions, model);
        optimizedConfigs.add(optimizedConfig);
    }
    LimeConfig first = optimizedConfigs.get(0);
    LimeConfig second = optimizedConfigs.get(1);
    assertThat(first.getNoOfRetries()).isEqualTo(second.getNoOfRetries());
    assertThat(first.getNoOfSamples()).isEqualTo(second.getNoOfSamples());
    assertThat(first.getProximityFilteredDatasetMinimum()).isEqualTo(second.getProximityFilteredDatasetMinimum());
    assertThat(first.getProximityKernelWidth()).isEqualTo(second.getProximityKernelWidth());
    assertThat(first.getProximityThreshold()).isEqualTo(second.getProximityThreshold());
    assertThat(first.isProximityFilter()).isEqualTo(second.isProximityFilter());
    assertThat(first.isAdaptDatasetVariance()).isEqualTo(second.isAdaptDatasetVariance());
    assertThat(first.isPenalizeBalanceSparse()).isEqualTo(second.isPenalizeBalanceSparse());
    assertThat(first.getEncodingParams().getNumericTypeClusterGaussianFilterWidth()).isEqualTo(second.getEncodingParams().getNumericTypeClusterGaussianFilterWidth());
    assertThat(first.getEncodingParams().getNumericTypeClusterThreshold()).isEqualTo(second.getEncodingParams().getNumericTypeClusterThreshold());
    assertThat(first.getSeparableDatasetRatio()).isEqualTo(second.getSeparableDatasetRatio());
    assertThat(first.getPerturbationContext().getNoOfPerturbations()).isEqualTo(second.getPerturbationContext().getNoOfPerturbations());
}
Also used : PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) Prediction(org.kie.kogito.explainability.model.Prediction) ArrayList(java.util.ArrayList) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) LimeConfig(org.kie.kogito.explainability.local.lime.LimeConfig) DataDistribution(org.kie.kogito.explainability.model.DataDistribution) Random(java.util.Random) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Test(org.junit.jupiter.api.Test)

Example 85 with Prediction

use of org.kie.kogito.explainability.model.Prediction in project kogito-apps by kiegroup.

the class LimeImpactScoreCalculatorTest method testScoreWithEmptyPredictions.

@Test
void testScoreWithEmptyPredictions() {
    LimeImpactScoreCalculator scoreCalculator = new LimeImpactScoreCalculator();
    LimeConfig config = new LimeConfig();
    List<Prediction> predictions = Collections.emptyList();
    List<LimeConfigEntity> entities = Collections.emptyList();
    PredictionProvider model = TestUtils.getDummyTextClassifier();
    LimeConfigSolution solution = new LimeConfigSolution(config, predictions, entities, model);
    SimpleBigDecimalScore score = scoreCalculator.calculateScore(solution);
    assertThat(score).isNotNull();
    assertThat(score.getScore()).isNotNull().isEqualTo(BigDecimal.valueOf(0));
}
Also used : SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) Prediction(org.kie.kogito.explainability.model.Prediction) SimpleBigDecimalScore(org.optaplanner.core.api.score.buildin.simplebigdecimal.SimpleBigDecimalScore) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) LimeConfig(org.kie.kogito.explainability.local.lime.LimeConfig) Test(org.junit.jupiter.api.Test)

Aggregations

Prediction (org.kie.kogito.explainability.model.Prediction)134 PredictionProvider (org.kie.kogito.explainability.model.PredictionProvider)117 PredictionOutput (org.kie.kogito.explainability.model.PredictionOutput)107 PredictionInput (org.kie.kogito.explainability.model.PredictionInput)105 SimplePrediction (org.kie.kogito.explainability.model.SimplePrediction)96 Test (org.junit.jupiter.api.Test)95 Random (java.util.Random)65 PerturbationContext (org.kie.kogito.explainability.model.PerturbationContext)61 LimeConfig (org.kie.kogito.explainability.local.lime.LimeConfig)57 ArrayList (java.util.ArrayList)51 Feature (org.kie.kogito.explainability.model.Feature)48 Saliency (org.kie.kogito.explainability.model.Saliency)48 ParameterizedTest (org.junit.jupiter.params.ParameterizedTest)42 LimeExplainer (org.kie.kogito.explainability.local.lime.LimeExplainer)40 LimeConfigOptimizer (org.kie.kogito.explainability.local.lime.optim.LimeConfigOptimizer)28 DataDistribution (org.kie.kogito.explainability.model.DataDistribution)24 ValueSource (org.junit.jupiter.params.provider.ValueSource)22 FeatureImportance (org.kie.kogito.explainability.model.FeatureImportance)22 Output (org.kie.kogito.explainability.model.Output)22 LinkedList (java.util.LinkedList)21