Search in sources :

Example 41 with Saliency

use of org.kie.kogito.explainability.model.Saliency in project kogito-apps by kiegroup.

the class LimeExplainerTest method testZeroSampleSize.

@Test
void testZeroSampleSize() throws ExecutionException, InterruptedException, TimeoutException {
    LimeConfig limeConfig = new LimeConfig().withSamples(0);
    LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
    List<Feature> features = new ArrayList<>();
    for (int i = 0; i < 4; i++) {
        features.add(TestUtils.getMockedNumericFeature(i));
    }
    PredictionInput input = new PredictionInput(features);
    PredictionProvider model = TestUtils.getSumSkipModel(0);
    PredictionOutput output = model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit()).get(0);
    Prediction prediction = new SimplePrediction(input, output);
    Map<String, Saliency> saliencyMap = limeExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    assertNotNull(saliencyMap);
}
Also used : SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) Prediction(org.kie.kogito.explainability.model.Prediction) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) ArrayList(java.util.ArrayList) Saliency(org.kie.kogito.explainability.model.Saliency) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Feature(org.kie.kogito.explainability.model.Feature) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Test(org.junit.jupiter.api.Test) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Example 42 with Saliency

use of org.kie.kogito.explainability.model.Saliency in project kogito-apps by kiegroup.

the class LimeExplainerTest method testDeterministic.

@ParameterizedTest
@ValueSource(longs = { 0, 1, 2, 3, 4 })
void testDeterministic(long seed) throws ExecutionException, InterruptedException, TimeoutException {
    List<Saliency> saliencies = new ArrayList<>();
    for (int j = 0; j < 2; j++) {
        Random random = new Random();
        LimeConfig limeConfig = new LimeConfig().withPerturbationContext(new PerturbationContext(seed, random, DEFAULT_NO_OF_PERTURBATIONS)).withSamples(10);
        LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
        List<Feature> features = new ArrayList<>();
        for (int i = 0; i < 4; i++) {
            features.add(TestUtils.getMockedNumericFeature(i));
        }
        PredictionInput input = new PredictionInput(features);
        PredictionProvider model = TestUtils.getSumSkipModel(0);
        PredictionOutput output = model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit()).get(0);
        Prediction prediction = new SimplePrediction(input, output);
        Map<String, Saliency> saliencyMap = limeExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        saliencies.add(saliencyMap.get("sum-but0"));
    }
    assertThat(saliencies.get(0).getPerFeatureImportance().stream().map(FeatureImportance::getScore).collect(Collectors.toList())).isEqualTo(saliencies.get(1).getPerFeatureImportance().stream().map(FeatureImportance::getScore).collect(Collectors.toList()));
}
Also used : SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) Prediction(org.kie.kogito.explainability.model.Prediction) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) ArrayList(java.util.ArrayList) Saliency(org.kie.kogito.explainability.model.Saliency) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Feature(org.kie.kogito.explainability.model.Feature) Random(java.util.Random) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) ValueSource(org.junit.jupiter.params.provider.ValueSource) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Example 43 with Saliency

use of org.kie.kogito.explainability.model.Saliency in project kogito-apps by kiegroup.

the class LimeExplainerTest method testNormalizedWeights.

@Test
void testNormalizedWeights() throws InterruptedException, ExecutionException, TimeoutException {
    Random random = new Random();
    LimeConfig limeConfig = new LimeConfig().withNormalizeWeights(true).withPerturbationContext(new PerturbationContext(4L, random, 2)).withSamples(10);
    LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
    int nf = 4;
    List<Feature> features = new ArrayList<>();
    for (int i = 0; i < nf; i++) {
        features.add(TestUtils.getMockedNumericFeature(i));
    }
    PredictionInput input = new PredictionInput(features);
    PredictionProvider model = TestUtils.getSumSkipModel(0);
    PredictionOutput output = model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit()).get(0);
    Prediction prediction = new SimplePrediction(input, output);
    Map<String, Saliency> saliencyMap = limeExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    assertThat(saliencyMap).isNotNull();
    String decisionName = "sum-but0";
    Saliency saliency = saliencyMap.get(decisionName);
    List<FeatureImportance> perFeatureImportance = saliency.getPerFeatureImportance();
    for (FeatureImportance featureImportance : perFeatureImportance) {
        assertThat(featureImportance.getScore()).isBetween(0d, 1d);
    }
}
Also used : SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) Prediction(org.kie.kogito.explainability.model.Prediction) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) ArrayList(java.util.ArrayList) Saliency(org.kie.kogito.explainability.model.Saliency) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Feature(org.kie.kogito.explainability.model.Feature) Random(java.util.Random) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Test(org.junit.jupiter.api.Test) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Example 44 with Saliency

use of org.kie.kogito.explainability.model.Saliency in project kogito-apps by kiegroup.

the class LimeStabilityTest method assertStable.

private void assertStable(LimeExplainer limeExplainer, PredictionProvider model, List<Feature> featureList) throws Exception {
    PredictionInput input = new PredictionInput(featureList);
    List<PredictionOutput> predictionOutputs = model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    for (PredictionOutput predictionOutput : predictionOutputs) {
        Prediction prediction = new SimplePrediction(input, predictionOutput);
        List<Saliency> saliencies = new LinkedList<>();
        for (int i = 0; i < 100; i++) {
            Map<String, Saliency> saliencyMap = limeExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
            saliencies.addAll(saliencyMap.values());
        }
        // check that the topmost important feature is stable
        List<String> names = new LinkedList<>();
        saliencies.stream().map(s -> s.getPositiveFeatures(1)).filter(f -> !f.isEmpty()).forEach(f -> names.add(f.get(0).getFeature().getName()));
        Map<String, Long> frequencyMap = names.stream().collect(Collectors.groupingBy(Function.identity(), Collectors.counting()));
        boolean topFeature = false;
        for (Map.Entry<String, Long> entry : frequencyMap.entrySet()) {
            if (entry.getValue() >= TOP_FEATURE_THRESHOLD) {
                topFeature = true;
                break;
            }
        }
        assertTrue(topFeature);
        // check that the impact is stable
        List<Double> impacts = new ArrayList<>(saliencies.size());
        for (Saliency saliency : saliencies) {
            double v = ExplainabilityMetrics.impactScore(model, prediction, saliency.getTopFeatures(2));
            impacts.add(v);
        }
        Map<Double, Long> impactMap = impacts.stream().collect(Collectors.groupingBy(Function.identity(), Collectors.counting()));
        boolean topImpact = false;
        for (Map.Entry<Double, Long> entry : impactMap.entrySet()) {
            if (entry.getValue() >= TOP_FEATURE_THRESHOLD) {
                topImpact = true;
                break;
            }
        }
        assertTrue(topImpact);
    }
}
Also used : ValueSource(org.junit.jupiter.params.provider.ValueSource) FeatureFactory(org.kie.kogito.explainability.model.FeatureFactory) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) Feature(org.kie.kogito.explainability.model.Feature) Prediction(org.kie.kogito.explainability.model.Prediction) Assertions.assertThat(org.assertj.core.api.Assertions.assertThat) Random(java.util.Random) Function(java.util.function.Function) Collectors(java.util.stream.Collectors) Saliency(org.kie.kogito.explainability.model.Saliency) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) ArrayList(java.util.ArrayList) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) List(java.util.List) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest) TestUtils(org.kie.kogito.explainability.TestUtils) ExplainabilityMetrics(org.kie.kogito.explainability.utils.ExplainabilityMetrics) Map(java.util.Map) Assertions.assertTrue(org.junit.jupiter.api.Assertions.assertTrue) LocalSaliencyStability(org.kie.kogito.explainability.utils.LocalSaliencyStability) LinkedList(java.util.LinkedList) Config(org.kie.kogito.explainability.Config) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) Prediction(org.kie.kogito.explainability.model.Prediction) ArrayList(java.util.ArrayList) Saliency(org.kie.kogito.explainability.model.Saliency) LinkedList(java.util.LinkedList) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Map(java.util.Map)

Example 45 with Saliency

use of org.kie.kogito.explainability.model.Saliency in project kogito-apps by kiegroup.

the class ExplainabilityMetrics method classificationFidelity.

/**
 * Calculate fidelity (accuracy) of boolean classification outputs using saliency predictor function = sign(sum(saliency.scores))
 * See papers:
 * - Guidotti Riccardo, et al. "A survey of methods for explaining black box models." ACM computing surveys (2018).
 * - Bodria, Francesco, et al. "Explainability Methods for Natural Language Processing: Applications to Sentiment Analysis (Discussion Paper)."
 *
 * @param pairs pairs composed by the saliency and the related prediction
 * @return the fidelity accuracy
 */
public static double classificationFidelity(List<Pair<Saliency, Prediction>> pairs) {
    double acc = 0;
    double evals = 0;
    for (Pair<Saliency, Prediction> pair : pairs) {
        Saliency saliency = pair.getLeft();
        Prediction prediction = pair.getRight();
        for (Output output : prediction.getOutput().getOutputs()) {
            Type type = output.getType();
            if (Type.BOOLEAN.equals(type)) {
                double predictorOutput = saliency.getPerFeatureImportance().stream().map(FeatureImportance::getScore).mapToDouble(d -> d).sum();
                double v = output.getValue().asNumber();
                if ((v >= 0 && predictorOutput >= 0) || (v < 0 && predictorOutput < 0)) {
                    acc++;
                }
                evals++;
            }
        }
    }
    return evals == 0 ? 0 : acc / evals;
}
Also used : Feature(org.kie.kogito.explainability.model.Feature) Prediction(org.kie.kogito.explainability.model.Prediction) LoggerFactory(org.slf4j.LoggerFactory) TimeoutException(java.util.concurrent.TimeoutException) HashMap(java.util.HashMap) Function(java.util.function.Function) DataDistribution(org.kie.kogito.explainability.model.DataDistribution) Saliency(org.kie.kogito.explainability.model.Saliency) ArrayList(java.util.ArrayList) Pair(org.apache.commons.lang3.tuple.Pair) Map(java.util.Map) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Logger(org.slf4j.Logger) LocalExplainer(org.kie.kogito.explainability.local.LocalExplainer) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) Collectors(java.util.stream.Collectors) Type(org.kie.kogito.explainability.model.Type) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) ExecutionException(java.util.concurrent.ExecutionException) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) List(java.util.List) Output(org.kie.kogito.explainability.model.Output) Optional(java.util.Optional) Comparator(java.util.Comparator) Config(org.kie.kogito.explainability.Config) Collections(java.util.Collections) Type(org.kie.kogito.explainability.model.Type) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) Prediction(org.kie.kogito.explainability.model.Prediction) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) Saliency(org.kie.kogito.explainability.model.Saliency)

Aggregations

Saliency (org.kie.kogito.explainability.model.Saliency)51 Prediction (org.kie.kogito.explainability.model.Prediction)44 PredictionInput (org.kie.kogito.explainability.model.PredictionInput)43 PredictionOutput (org.kie.kogito.explainability.model.PredictionOutput)43 PredictionProvider (org.kie.kogito.explainability.model.PredictionProvider)39 SimplePrediction (org.kie.kogito.explainability.model.SimplePrediction)39 ArrayList (java.util.ArrayList)34 Random (java.util.Random)28 Feature (org.kie.kogito.explainability.model.Feature)26 PerturbationContext (org.kie.kogito.explainability.model.PerturbationContext)26 ParameterizedTest (org.junit.jupiter.params.ParameterizedTest)25 Test (org.junit.jupiter.api.Test)23 FeatureImportance (org.kie.kogito.explainability.model.FeatureImportance)23 DataDistribution (org.kie.kogito.explainability.model.DataDistribution)21 PredictionInputsDataDistribution (org.kie.kogito.explainability.model.PredictionInputsDataDistribution)18 ValueSource (org.junit.jupiter.params.provider.ValueSource)16 LimeConfig (org.kie.kogito.explainability.local.lime.LimeConfig)16 LimeExplainer (org.kie.kogito.explainability.local.lime.LimeExplainer)16 LinkedList (java.util.LinkedList)13 RealMatrix (org.apache.commons.math3.linear.RealMatrix)9