Search in sources :

Example 11 with FeatureImportance

use of org.kie.kogito.explainability.model.FeatureImportance in project kogito-apps by kiegroup.

the class DummyModelsLimeExplainerTest method testTextSpamClassification.

@ParameterizedTest
@ValueSource(longs = { 0 })
void testTextSpamClassification(long seed) throws Exception {
    Random random = new Random();
    List<Feature> features = new LinkedList<>();
    Function<String, List<String>> tokenizer = s -> Arrays.asList(s.split(" ").clone());
    features.add(FeatureFactory.newFulltextFeature("f1", "we go here and there", tokenizer));
    features.add(FeatureFactory.newFulltextFeature("f2", "please give me some money", tokenizer));
    features.add(FeatureFactory.newFulltextFeature("f3", "dear friend, please reply", tokenizer));
    PredictionInput input = new PredictionInput(features);
    PredictionProvider model = TestUtils.getDummyTextClassifier();
    List<PredictionOutput> outputs = model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    Prediction prediction = new SimplePrediction(input, outputs.get(0));
    LimeConfig limeConfig = new LimeConfig().withSamples(100).withPerturbationContext(new PerturbationContext(seed, random, 1));
    LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
    Map<String, Saliency> saliencyMap = limeExplainer.explainAsync(prediction, model).toCompletableFuture().get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    for (Saliency saliency : saliencyMap.values()) {
        assertNotNull(saliency);
        List<FeatureImportance> topFeatures = saliency.getPositiveFeatures(1);
        assertEquals(1, topFeatures.size());
        assertEquals(1d, ExplainabilityMetrics.impactScore(model, prediction, topFeatures));
    }
    int topK = 1;
    double minimumPositiveStabilityRate = 0.5;
    double minimumNegativeStabilityRate = 0.2;
    TestUtils.assertLimeStability(model, prediction, limeExplainer, topK, minimumPositiveStabilityRate, minimumNegativeStabilityRate);
    List<PredictionInput> inputs = new ArrayList<>();
    for (int i = 0; i < 100; i++) {
        List<Feature> fs = new LinkedList<>();
        fs.add(TestUtils.getMockedNumericFeature());
        fs.add(TestUtils.getMockedNumericFeature());
        fs.add(TestUtils.getMockedNumericFeature());
        inputs.add(new PredictionInput(fs));
    }
    DataDistribution distribution = new PredictionInputsDataDistribution(inputs);
    int k = 2;
    int chunkSize = 10;
    String decision = "spam";
    double precision = ExplainabilityMetrics.getLocalSaliencyPrecision(decision, model, limeExplainer, distribution, k, chunkSize);
    assertThat(precision).isEqualTo(1);
    double recall = ExplainabilityMetrics.getLocalSaliencyRecall(decision, model, limeExplainer, distribution, k, chunkSize);
    assertThat(recall).isEqualTo(1);
    double f1 = ExplainabilityMetrics.getLocalSaliencyF1(decision, model, limeExplainer, distribution, k, chunkSize);
    assertThat(f1).isEqualTo(1);
}
Also used : FeatureFactory(org.kie.kogito.explainability.model.FeatureFactory) Assertions.assertNotNull(org.junit.jupiter.api.Assertions.assertNotNull) Arrays(java.util.Arrays) PredictionInputsDataDistribution(org.kie.kogito.explainability.model.PredictionInputsDataDistribution) PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) Feature(org.kie.kogito.explainability.model.Feature) Prediction(org.kie.kogito.explainability.model.Prediction) Random(java.util.Random) Function(java.util.function.Function) DataDistribution(org.kie.kogito.explainability.model.DataDistribution) Saliency(org.kie.kogito.explainability.model.Saliency) ArrayList(java.util.ArrayList) Map(java.util.Map) Assertions.assertEquals(org.junit.jupiter.api.Assertions.assertEquals) LinkedList(java.util.LinkedList) AssertionsForClassTypes.assertThat(org.assertj.core.api.AssertionsForClassTypes.assertThat) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) ValueSource(org.junit.jupiter.params.provider.ValueSource) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) List(java.util.List) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest) TestUtils(org.kie.kogito.explainability.TestUtils) ExplainabilityMetrics(org.kie.kogito.explainability.utils.ExplainabilityMetrics) Config(org.kie.kogito.explainability.Config) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) ArrayList(java.util.ArrayList) Saliency(org.kie.kogito.explainability.model.Saliency) Feature(org.kie.kogito.explainability.model.Feature) Random(java.util.Random) ArrayList(java.util.ArrayList) LinkedList(java.util.LinkedList) List(java.util.List) PredictionInputsDataDistribution(org.kie.kogito.explainability.model.PredictionInputsDataDistribution) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) Prediction(org.kie.kogito.explainability.model.Prediction) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) LinkedList(java.util.LinkedList) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) PredictionInputsDataDistribution(org.kie.kogito.explainability.model.PredictionInputsDataDistribution) DataDistribution(org.kie.kogito.explainability.model.DataDistribution) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) ValueSource(org.junit.jupiter.params.provider.ValueSource) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Example 12 with FeatureImportance

use of org.kie.kogito.explainability.model.FeatureImportance in project kogito-apps by kiegroup.

the class ShapKernelExplainerTest method saliencyToMatrix.

private RealMatrix[] saliencyToMatrix(Saliency[] saliencies) {
    RealMatrix emptyMatrix = MatrixUtils.createRealMatrix(new double[saliencies.length][saliencies[0].getPerFeatureImportance().size()]);
    RealMatrix[] out = new RealMatrix[] { emptyMatrix.copy(), emptyMatrix.copy() };
    for (int i = 0; i < saliencies.length; i++) {
        List<FeatureImportance> fis = saliencies[i].getPerFeatureImportance();
        for (int j = 0; j < fis.size(); j++) {
            out[0].setEntry(i, j, fis.get(j).getScore());
            out[1].setEntry(i, j, fis.get(j).getConfidence());
        }
    }
    return out;
}
Also used : RealMatrix(org.apache.commons.math3.linear.RealMatrix) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance)

Example 13 with FeatureImportance

use of org.kie.kogito.explainability.model.FeatureImportance in project kogito-apps by kiegroup.

the class ShapResultsTest method buildShapResults.

ShapResults buildShapResults(int nOutputs, int nFeatures, int scalar1, int scalar2) {
    Saliency[] saliencies = new Saliency[nOutputs];
    for (int i = 0; i < nOutputs; i++) {
        List<FeatureImportance> fis = new ArrayList<>();
        for (int j = 0; j < nFeatures; j++) {
            fis.add(new FeatureImportance(new Feature("f" + String.valueOf(j), Type.NUMBER, new Value(j)), i * j * scalar1));
        }
        saliencies[i] = new Saliency(new Output("o" + String.valueOf(i), Type.NUMBER, new Value(i), 1.0), fis);
    }
    RealVector fnull = MatrixUtils.createRealVector(new double[nOutputs]);
    fnull.mapAddToSelf(scalar2);
    return new ShapResults(saliencies, fnull);
}
Also used : FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) RealVector(org.apache.commons.math3.linear.RealVector) Output(org.kie.kogito.explainability.model.Output) ArrayList(java.util.ArrayList) Value(org.kie.kogito.explainability.model.Value) Saliency(org.kie.kogito.explainability.model.Saliency) Feature(org.kie.kogito.explainability.model.Feature)

Example 14 with FeatureImportance

use of org.kie.kogito.explainability.model.FeatureImportance in project kogito-apps by kiegroup.

the class AggregatedLimeExplainerTest method testExplainWithPredictions.

@ParameterizedTest
@ValueSource(ints = { 0, 1, 2, 3, 4 })
void testExplainWithPredictions(int seed) throws ExecutionException, InterruptedException {
    Random random = new Random();
    random.setSeed(seed);
    PredictionProvider sumSkipModel = TestUtils.getSumSkipModel(1);
    DataDistribution dataDistribution = DataUtils.generateRandomDataDistribution(3, 100, random);
    List<PredictionInput> samples = dataDistribution.sample(10);
    List<PredictionOutput> predictionOutputs = sumSkipModel.predictAsync(samples).get();
    List<Prediction> predictions = DataUtils.getPredictions(samples, predictionOutputs);
    AggregatedLimeExplainer aggregatedLimeExplainer = new AggregatedLimeExplainer();
    Map<String, Saliency> explain = aggregatedLimeExplainer.explainFromPredictions(sumSkipModel, predictions).get();
    assertNotNull(explain);
    assertEquals(1, explain.size());
    assertTrue(explain.containsKey("sum-but1"));
    Saliency saliency = explain.get("sum-but1");
    assertNotNull(saliency);
    List<String> collect = saliency.getPositiveFeatures(2).stream().map(FeatureImportance::getFeature).map(Feature::getName).collect(Collectors.toList());
    // skipped feature should not appear in top two positive features
    assertFalse(collect.contains("f1"));
}
Also used : PredictionInput(org.kie.kogito.explainability.model.PredictionInput) Prediction(org.kie.kogito.explainability.model.Prediction) Saliency(org.kie.kogito.explainability.model.Saliency) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Random(java.util.Random) DataDistribution(org.kie.kogito.explainability.model.DataDistribution) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) ValueSource(org.junit.jupiter.params.provider.ValueSource) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Example 15 with FeatureImportance

use of org.kie.kogito.explainability.model.FeatureImportance in project kogito-apps by kiegroup.

the class DummyModelsLimeExplainerTest method testFixedOutput.

@ParameterizedTest
@ValueSource(longs = { 0 })
void testFixedOutput(long seed) throws Exception {
    Random random = new Random();
    List<Feature> features = new LinkedList<>();
    features.add(FeatureFactory.newNumericalFeature("f1", 6));
    features.add(FeatureFactory.newNumericalFeature("f2", 3));
    features.add(FeatureFactory.newNumericalFeature("f3", 5));
    PredictionProvider model = TestUtils.getFixedOutputClassifier();
    PredictionInput input = new PredictionInput(features);
    List<PredictionOutput> outputs = model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    Prediction prediction = new SimplePrediction(input, outputs.get(0));
    LimeConfig limeConfig = new LimeConfig().withSamples(10).withPerturbationContext(new PerturbationContext(seed, random, 1));
    LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
    Map<String, Saliency> saliencyMap = limeExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    for (Saliency saliency : saliencyMap.values()) {
        assertNotNull(saliency);
        List<FeatureImportance> topFeatures = saliency.getTopFeatures(3);
        assertEquals(3, topFeatures.size());
        for (FeatureImportance featureImportance : topFeatures) {
            assertEquals(0, featureImportance.getScore());
        }
        assertEquals(0d, ExplainabilityMetrics.impactScore(model, prediction, topFeatures));
    }
    int topK = 1;
    double minimumPositiveStabilityRate = 0.5;
    double minimumNegativeStabilityRate = 0.5;
    TestUtils.assertLimeStability(model, prediction, limeExplainer, topK, minimumPositiveStabilityRate, minimumNegativeStabilityRate);
    List<PredictionInput> inputs = new ArrayList<>();
    for (int i = 0; i < 100; i++) {
        List<Feature> fs = new LinkedList<>();
        fs.add(TestUtils.getMockedNumericFeature());
        fs.add(TestUtils.getMockedNumericFeature());
        fs.add(TestUtils.getMockedNumericFeature());
        inputs.add(new PredictionInput(fs));
    }
    DataDistribution distribution = new PredictionInputsDataDistribution(inputs);
    int k = 2;
    int chunkSize = 10;
    String decision = "class";
    double precision = ExplainabilityMetrics.getLocalSaliencyPrecision(decision, model, limeExplainer, distribution, k, chunkSize);
    assertThat(precision).isEqualTo(1);
    double recall = ExplainabilityMetrics.getLocalSaliencyRecall(decision, model, limeExplainer, distribution, k, chunkSize);
    assertThat(recall).isEqualTo(1);
    double f1 = ExplainabilityMetrics.getLocalSaliencyF1(decision, model, limeExplainer, distribution, k, chunkSize);
    assertThat(f1).isEqualTo(1);
}
Also used : SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) Prediction(org.kie.kogito.explainability.model.Prediction) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) ArrayList(java.util.ArrayList) Saliency(org.kie.kogito.explainability.model.Saliency) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Feature(org.kie.kogito.explainability.model.Feature) LinkedList(java.util.LinkedList) Random(java.util.Random) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) PredictionInputsDataDistribution(org.kie.kogito.explainability.model.PredictionInputsDataDistribution) DataDistribution(org.kie.kogito.explainability.model.DataDistribution) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) PredictionInputsDataDistribution(org.kie.kogito.explainability.model.PredictionInputsDataDistribution) ValueSource(org.junit.jupiter.params.provider.ValueSource) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Aggregations

FeatureImportance (org.kie.kogito.explainability.model.FeatureImportance)25 Saliency (org.kie.kogito.explainability.model.Saliency)23 PredictionInput (org.kie.kogito.explainability.model.PredictionInput)19 PredictionOutput (org.kie.kogito.explainability.model.PredictionOutput)19 ArrayList (java.util.ArrayList)18 Prediction (org.kie.kogito.explainability.model.Prediction)18 Feature (org.kie.kogito.explainability.model.Feature)17 PredictionProvider (org.kie.kogito.explainability.model.PredictionProvider)16 Random (java.util.Random)14 SimplePrediction (org.kie.kogito.explainability.model.SimplePrediction)13 DataDistribution (org.kie.kogito.explainability.model.DataDistribution)12 PerturbationContext (org.kie.kogito.explainability.model.PerturbationContext)12 ParameterizedTest (org.junit.jupiter.params.ParameterizedTest)10 PredictionInputsDataDistribution (org.kie.kogito.explainability.model.PredictionInputsDataDistribution)10 ValueSource (org.junit.jupiter.params.provider.ValueSource)9 LinkedList (java.util.LinkedList)8 Test (org.junit.jupiter.api.Test)7 Output (org.kie.kogito.explainability.model.Output)7 LimeExplainer (org.kie.kogito.explainability.local.lime.LimeExplainer)6 LimeConfig (org.kie.kogito.explainability.local.lime.LimeConfig)5