Search in sources :

Example 1 with Saliency

use of org.kie.kogito.explainability.model.Saliency in project kogito-apps by kiegroup.

the class AggregatedLimeExplainerTest method testExplainWithMetadata.

@ParameterizedTest
@ValueSource(ints = { 0, 1, 2, 3, 4 })
void testExplainWithMetadata(int seed) throws ExecutionException, InterruptedException {
    Random random = new Random();
    random.setSeed(seed);
    PredictionProvider sumSkipModel = TestUtils.getSumSkipModel(1);
    PredictionProviderMetadata metadata = new PredictionProviderMetadata() {

        @Override
        public DataDistribution getDataDistribution() {
            return DataUtils.generateRandomDataDistribution(3, 100, random);
        }

        @Override
        public PredictionInput getInputShape() {
            List<Feature> features = new LinkedList<>();
            features.add(FeatureFactory.newNumericalFeature("f0", 0));
            features.add(FeatureFactory.newNumericalFeature("f1", 0));
            features.add(FeatureFactory.newNumericalFeature("f2", 0));
            return new PredictionInput(features);
        }

        @Override
        public PredictionOutput getOutputShape() {
            List<Output> outputs = new LinkedList<>();
            outputs.add(new Output("sum-but1", Type.BOOLEAN, new Value(false), 0d));
            return new PredictionOutput(outputs);
        }
    };
    AggregatedLimeExplainer aggregatedLimeExplainer = new AggregatedLimeExplainer();
    Map<String, Saliency> explain = aggregatedLimeExplainer.explainFromMetadata(sumSkipModel, metadata).get();
    assertNotNull(explain);
    assertEquals(1, explain.size());
    assertTrue(explain.containsKey("sum-but1"));
    Saliency saliency = explain.get("sum-but1");
    assertNotNull(saliency);
    List<String> collect = saliency.getPositiveFeatures(2).stream().map(FeatureImportance::getFeature).map(Feature::getName).collect(Collectors.toList());
    // skipped feature should not appear in top two positive features
    assertFalse(collect.contains("f1"));
}
Also used : PredictionInput(org.kie.kogito.explainability.model.PredictionInput) PredictionProviderMetadata(org.kie.kogito.explainability.model.PredictionProviderMetadata) Saliency(org.kie.kogito.explainability.model.Saliency) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Feature(org.kie.kogito.explainability.model.Feature) LinkedList(java.util.LinkedList) Random(java.util.Random) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) Value(org.kie.kogito.explainability.model.Value) ValueSource(org.junit.jupiter.params.provider.ValueSource) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Example 2 with Saliency

use of org.kie.kogito.explainability.model.Saliency in project kogito-apps by kiegroup.

the class ExplainabilityMetrics method getLocalSaliencyRecall.

/**
 * Evaluate the recall of a local saliency explainer on a given model.
 * Get the predictions having outputs with the highest score for the given decision and pair them with predictions
 * whose outputs have the lowest score for the same decision.
 * Get the top k (most important) features (according to the saliency) for the most important outputs and
 * "paste" them on each paired input corresponding to an output with low score (for the target decision).
 * Perform prediction on the "masked" input, if the output on the masked input is equals to the output for the
 * input the mask features were take from, that's considered a true positive, otherwise it's a false positive.
 * see Section 3.2.1 of https://openreview.net/attachment?id=B1xBAA4FwH&name=original_pdf
 *
 * @param outputName decision to evaluate recall for
 * @param predictionProvider the prediction provider to test
 * @param localExplainer the explainer to evaluate
 * @param dataDistribution the data distribution used to obtain inputs for evaluation
 * @param k the no. of features to extract
 * @param chunkSize the size of the chunk of predictions to use for evaluation
 * @return the saliency recall
 */
public static double getLocalSaliencyRecall(String outputName, PredictionProvider predictionProvider, LocalExplainer<Map<String, Saliency>> localExplainer, DataDistribution dataDistribution, int k, int chunkSize) throws InterruptedException, ExecutionException, TimeoutException {
    // get all samples from the data distribution
    List<Prediction> sorted = DataUtils.getScoreSortedPredictions(outputName, predictionProvider, dataDistribution);
    // get the top and bottom 'chunkSize' predictions
    List<Prediction> topChunk = new ArrayList<>(sorted.subList(0, chunkSize));
    List<Prediction> bottomChunk = new ArrayList<>(sorted.subList(sorted.size() - chunkSize, sorted.size()));
    double truePositives = 0;
    double falseNegatives = 0;
    int currentChunk = 0;
    // input, then feed the model with this masked input and check the output is equals to the top scored one.
    for (Prediction prediction : topChunk) {
        Optional<Output> optionalOutput = prediction.getOutput().getByName(outputName);
        if (optionalOutput.isPresent()) {
            Output output = optionalOutput.get();
            Map<String, Saliency> stringSaliencyMap = localExplainer.explainAsync(prediction, predictionProvider).get(Config.DEFAULT_ASYNC_TIMEOUT, Config.DEFAULT_ASYNC_TIMEUNIT);
            if (stringSaliencyMap.containsKey(outputName)) {
                Saliency saliency = stringSaliencyMap.get(outputName);
                List<FeatureImportance> topFeatures = saliency.getPerFeatureImportance().stream().sorted((f1, f2) -> Double.compare(f2.getScore(), f1.getScore())).limit(k).collect(Collectors.toList());
                PredictionInput input = bottomChunk.get(currentChunk).getInput();
                PredictionInput maskedInput = maskInput(topFeatures, input);
                List<PredictionOutput> predictionOutputList = predictionProvider.predictAsync(List.of(maskedInput)).get(Config.DEFAULT_ASYNC_TIMEOUT, Config.DEFAULT_ASYNC_TIMEUNIT);
                if (!predictionOutputList.isEmpty()) {
                    PredictionOutput predictionOutput = predictionOutputList.get(0);
                    Optional<Output> optionalNewOutput = predictionOutput.getByName(outputName);
                    if (optionalNewOutput.isPresent()) {
                        Output newOutput = optionalOutput.get();
                        if (output.getValue().equals(newOutput.getValue())) {
                            truePositives++;
                        } else {
                            falseNegatives++;
                        }
                    }
                }
                currentChunk++;
            }
        }
    }
    if ((truePositives + falseNegatives) > 0) {
        return truePositives / (truePositives + falseNegatives);
    } else {
        // if topChunk is empty or the target output (by name) is not an output of the model.
        return Double.NaN;
    }
}
Also used : PredictionInput(org.kie.kogito.explainability.model.PredictionInput) Prediction(org.kie.kogito.explainability.model.Prediction) ArrayList(java.util.ArrayList) Saliency(org.kie.kogito.explainability.model.Saliency) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output)

Example 3 with Saliency

use of org.kie.kogito.explainability.model.Saliency in project kogito-apps by kiegroup.

the class ExplainabilityMetrics method getLocalSaliencyPrecision.

/**
 * Evaluate the precision of a local saliency explainer on a given model.
 * Get the predictions having outputs with the lowest score for the given decision and pair them with predictions
 * whose outputs have the highest score for the same decision.
 * Get the bottom k (less important) features (according to the saliency) for the less important outputs and
 * "paste" them on each paired input corresponding to an output with high score (for the target decision).
 * Perform prediction on the "masked" input, if the output changes that's considered a false negative, otherwise
 * it's a true positive.
 * see Section 3.2.1 of https://openreview.net/attachment?id=B1xBAA4FwH&name=original_pdf
 *
 * @param outputName decision to evaluate recall for
 * @param predictionProvider the prediction provider to test
 * @param localExplainer the explainer to evaluate
 * @param dataDistribution the data distribution used to obtain inputs for evaluation
 * @param k the no. of features to extract
 * @param chunkSize the size of the chunk of predictions to use for evaluation
 * @return the saliency precision
 */
public static double getLocalSaliencyPrecision(String outputName, PredictionProvider predictionProvider, LocalExplainer<Map<String, Saliency>> localExplainer, DataDistribution dataDistribution, int k, int chunkSize) throws InterruptedException, ExecutionException, TimeoutException {
    List<Prediction> sorted = DataUtils.getScoreSortedPredictions(outputName, predictionProvider, dataDistribution);
    // get the top and bottom 'chunkSize' predictions
    List<Prediction> topChunk = new ArrayList<>(sorted.subList(0, chunkSize));
    List<Prediction> bottomChunk = new ArrayList<>(sorted.subList(sorted.size() - chunkSize, sorted.size()));
    double truePositives = 0;
    double falsePositives = 0;
    int currentChunk = 0;
    for (Prediction prediction : bottomChunk) {
        Map<String, Saliency> stringSaliencyMap = localExplainer.explainAsync(prediction, predictionProvider).get(Config.DEFAULT_ASYNC_TIMEOUT, Config.DEFAULT_ASYNC_TIMEUNIT);
        if (stringSaliencyMap.containsKey(outputName)) {
            Saliency saliency = stringSaliencyMap.get(outputName);
            List<FeatureImportance> topFeatures = saliency.getPerFeatureImportance().stream().sorted(Comparator.comparingDouble(FeatureImportance::getScore)).limit(k).collect(Collectors.toList());
            Prediction topPrediction = topChunk.get(currentChunk);
            PredictionInput input = topPrediction.getInput();
            PredictionInput maskedInput = maskInput(topFeatures, input);
            List<PredictionOutput> predictionOutputList = predictionProvider.predictAsync(List.of(maskedInput)).get(Config.DEFAULT_ASYNC_TIMEOUT, Config.DEFAULT_ASYNC_TIMEUNIT);
            if (!predictionOutputList.isEmpty()) {
                PredictionOutput predictionOutput = predictionOutputList.get(0);
                Optional<Output> newOptionalOutput = predictionOutput.getByName(outputName);
                if (newOptionalOutput.isPresent()) {
                    Output newOutput = newOptionalOutput.get();
                    Optional<Output> optionalOutput = topPrediction.getOutput().getByName(outputName);
                    if (optionalOutput.isPresent()) {
                        Output output = optionalOutput.get();
                        if (output.getValue().equals(newOutput.getValue())) {
                            truePositives++;
                        } else {
                            falsePositives++;
                        }
                    }
                }
            }
            currentChunk++;
        }
    }
    if ((truePositives + falsePositives) > 0) {
        return truePositives / (truePositives + falsePositives);
    } else {
        // if bottomChunk is empty or the target output (by name) is not an output of the model.
        return Double.NaN;
    }
}
Also used : PredictionInput(org.kie.kogito.explainability.model.PredictionInput) Prediction(org.kie.kogito.explainability.model.Prediction) ArrayList(java.util.ArrayList) Saliency(org.kie.kogito.explainability.model.Saliency) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output)

Example 4 with Saliency

use of org.kie.kogito.explainability.model.Saliency in project kogito-apps by kiegroup.

the class LimeImpactScoreCalculator method getImpactScore.

private BigDecimal getImpactScore(LimeConfigSolution solution, LimeConfig config, List<Prediction> predictions) {
    double succeededEvaluations = 0;
    BigDecimal impactScore = BigDecimal.ZERO;
    LimeExplainer limeExplainer = new LimeExplainer(config);
    for (Prediction prediction : predictions) {
        try {
            Map<String, Saliency> saliencyMap = limeExplainer.explainAsync(prediction, solution.getModel()).get(Config.DEFAULT_ASYNC_TIMEOUT, Config.DEFAULT_ASYNC_TIMEUNIT);
            for (Map.Entry<String, Saliency> entry : saliencyMap.entrySet()) {
                List<FeatureImportance> topFeatures = entry.getValue().getTopFeatures(TOP_FEATURES);
                if (!topFeatures.isEmpty()) {
                    double v = ExplainabilityMetrics.impactScore(solution.getModel(), prediction, topFeatures);
                    impactScore = impactScore.add(BigDecimal.valueOf(v));
                    succeededEvaluations++;
                }
            }
        } catch (ExecutionException e) {
            LOGGER.error("Saliency impact-score calculation returned an error {}", e.getMessage());
        } catch (InterruptedException e) {
            LOGGER.error("Interrupted while waiting for saliency impact-score calculation {}", e.getMessage());
            Thread.currentThread().interrupt();
        } catch (TimeoutException e) {
            LOGGER.error("Timed out while waiting for saliency impact-score calculation", e);
        }
    }
    if (succeededEvaluations > 0) {
        impactScore = impactScore.divide(BigDecimal.valueOf(succeededEvaluations), RoundingMode.CEILING);
    }
    return impactScore;
}
Also used : LimeExplainer(org.kie.kogito.explainability.local.lime.LimeExplainer) Prediction(org.kie.kogito.explainability.model.Prediction) Saliency(org.kie.kogito.explainability.model.Saliency) BigDecimal(java.math.BigDecimal) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) ExecutionException(java.util.concurrent.ExecutionException) Map(java.util.Map) TimeoutException(java.util.concurrent.TimeoutException)

Example 5 with Saliency

use of org.kie.kogito.explainability.model.Saliency in project kogito-apps by kiegroup.

the class LimeExplainer method getSaliency.

private void getSaliency(List<Feature> linearizedTargetInputFeatures, Map<String, Saliency> result, LimeInputs limeInputs, Output originalOutput, LimeConfig executionConfig) {
    List<FeatureImportance> featureImportanceList = new ArrayList<>();
    // encode the training data so that it can be fed into the linear model
    DatasetEncoder datasetEncoder = new DatasetEncoder(limeInputs.getPerturbedInputs(), limeInputs.getPerturbedOutputs(), linearizedTargetInputFeatures, originalOutput, executionConfig.getEncodingParams());
    List<Pair<double[], Double>> trainingSet = datasetEncoder.getEncodedTrainingSet();
    // weight the training samples based on the proximity to the target input to explain
    double kernelWidth = executionConfig.getProximityKernelWidth() * Math.sqrt(linearizedTargetInputFeatures.size());
    double[] sampleWeights = SampleWeighter.getSampleWeights(linearizedTargetInputFeatures, trainingSet, kernelWidth);
    int ts = linearizedTargetInputFeatures.size();
    double[] featureWeights = new double[ts];
    Arrays.fill(featureWeights, 1);
    if (executionConfig.isPenalizeBalanceSparse()) {
        IndependentSparseFeatureBalanceFilter sparseFeatureBalanceFilter = new IndependentSparseFeatureBalanceFilter();
        sparseFeatureBalanceFilter.apply(featureWeights, linearizedTargetInputFeatures, trainingSet);
    }
    if (executionConfig.isProximityFilter()) {
        ProximityFilter proximityFilter = new ProximityFilter(executionConfig.getProximityThreshold(), executionConfig.getProximityFilteredDatasetMinimum().doubleValue());
        proximityFilter.apply(trainingSet, sampleWeights);
    }
    LinearModel linearModel = new LinearModel(linearizedTargetInputFeatures.size(), limeInputs.isClassification());
    double loss = linearModel.fit(trainingSet, sampleWeights);
    if (!Double.isNaN(loss)) {
        // create the output saliency
        double[] weights = linearModel.getWeights();
        if (limeConfig.isNormalizeWeights() && weights.length > 0) {
            normalizeWeights(weights);
        }
        int i = 0;
        for (Feature linearizedFeature : linearizedTargetInputFeatures) {
            FeatureImportance featureImportance = new FeatureImportance(linearizedFeature, weights[i] * featureWeights[i]);
            featureImportanceList.add(featureImportance);
            i++;
        }
    }
    Saliency saliency = new Saliency(originalOutput, featureImportanceList);
    result.put(originalOutput.getName(), saliency);
}
Also used : ArrayList(java.util.ArrayList) Saliency(org.kie.kogito.explainability.model.Saliency) Feature(org.kie.kogito.explainability.model.Feature) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) LinearModel(org.kie.kogito.explainability.utils.LinearModel) Pair(org.apache.commons.lang3.tuple.Pair)

Aggregations

Saliency (org.kie.kogito.explainability.model.Saliency)51 Prediction (org.kie.kogito.explainability.model.Prediction)44 PredictionInput (org.kie.kogito.explainability.model.PredictionInput)43 PredictionOutput (org.kie.kogito.explainability.model.PredictionOutput)43 PredictionProvider (org.kie.kogito.explainability.model.PredictionProvider)39 SimplePrediction (org.kie.kogito.explainability.model.SimplePrediction)39 ArrayList (java.util.ArrayList)34 Random (java.util.Random)28 Feature (org.kie.kogito.explainability.model.Feature)26 PerturbationContext (org.kie.kogito.explainability.model.PerturbationContext)26 ParameterizedTest (org.junit.jupiter.params.ParameterizedTest)25 Test (org.junit.jupiter.api.Test)23 FeatureImportance (org.kie.kogito.explainability.model.FeatureImportance)23 DataDistribution (org.kie.kogito.explainability.model.DataDistribution)21 PredictionInputsDataDistribution (org.kie.kogito.explainability.model.PredictionInputsDataDistribution)18 ValueSource (org.junit.jupiter.params.provider.ValueSource)16 LimeConfig (org.kie.kogito.explainability.local.lime.LimeConfig)16 LimeExplainer (org.kie.kogito.explainability.local.lime.LimeExplainer)16 LinkedList (java.util.LinkedList)13 RealMatrix (org.apache.commons.math3.linear.RealMatrix)9