use of org.kie.kogito.explainability.model.Saliency in project kogito-apps by kiegroup.
the class AggregatedLimeExplainerTest method testExplainWithMetadata.
@ParameterizedTest
@ValueSource(ints = { 0, 1, 2, 3, 4 })
void testExplainWithMetadata(int seed) throws ExecutionException, InterruptedException {
Random random = new Random();
random.setSeed(seed);
PredictionProvider sumSkipModel = TestUtils.getSumSkipModel(1);
PredictionProviderMetadata metadata = new PredictionProviderMetadata() {
@Override
public DataDistribution getDataDistribution() {
return DataUtils.generateRandomDataDistribution(3, 100, random);
}
@Override
public PredictionInput getInputShape() {
List<Feature> features = new LinkedList<>();
features.add(FeatureFactory.newNumericalFeature("f0", 0));
features.add(FeatureFactory.newNumericalFeature("f1", 0));
features.add(FeatureFactory.newNumericalFeature("f2", 0));
return new PredictionInput(features);
}
@Override
public PredictionOutput getOutputShape() {
List<Output> outputs = new LinkedList<>();
outputs.add(new Output("sum-but1", Type.BOOLEAN, new Value(false), 0d));
return new PredictionOutput(outputs);
}
};
AggregatedLimeExplainer aggregatedLimeExplainer = new AggregatedLimeExplainer();
Map<String, Saliency> explain = aggregatedLimeExplainer.explainFromMetadata(sumSkipModel, metadata).get();
assertNotNull(explain);
assertEquals(1, explain.size());
assertTrue(explain.containsKey("sum-but1"));
Saliency saliency = explain.get("sum-but1");
assertNotNull(saliency);
List<String> collect = saliency.getPositiveFeatures(2).stream().map(FeatureImportance::getFeature).map(Feature::getName).collect(Collectors.toList());
// skipped feature should not appear in top two positive features
assertFalse(collect.contains("f1"));
}
use of org.kie.kogito.explainability.model.Saliency in project kogito-apps by kiegroup.
the class ExplainabilityMetrics method getLocalSaliencyRecall.
/**
* Evaluate the recall of a local saliency explainer on a given model.
* Get the predictions having outputs with the highest score for the given decision and pair them with predictions
* whose outputs have the lowest score for the same decision.
* Get the top k (most important) features (according to the saliency) for the most important outputs and
* "paste" them on each paired input corresponding to an output with low score (for the target decision).
* Perform prediction on the "masked" input, if the output on the masked input is equals to the output for the
* input the mask features were take from, that's considered a true positive, otherwise it's a false positive.
* see Section 3.2.1 of https://openreview.net/attachment?id=B1xBAA4FwH&name=original_pdf
*
* @param outputName decision to evaluate recall for
* @param predictionProvider the prediction provider to test
* @param localExplainer the explainer to evaluate
* @param dataDistribution the data distribution used to obtain inputs for evaluation
* @param k the no. of features to extract
* @param chunkSize the size of the chunk of predictions to use for evaluation
* @return the saliency recall
*/
public static double getLocalSaliencyRecall(String outputName, PredictionProvider predictionProvider, LocalExplainer<Map<String, Saliency>> localExplainer, DataDistribution dataDistribution, int k, int chunkSize) throws InterruptedException, ExecutionException, TimeoutException {
// get all samples from the data distribution
List<Prediction> sorted = DataUtils.getScoreSortedPredictions(outputName, predictionProvider, dataDistribution);
// get the top and bottom 'chunkSize' predictions
List<Prediction> topChunk = new ArrayList<>(sorted.subList(0, chunkSize));
List<Prediction> bottomChunk = new ArrayList<>(sorted.subList(sorted.size() - chunkSize, sorted.size()));
double truePositives = 0;
double falseNegatives = 0;
int currentChunk = 0;
// input, then feed the model with this masked input and check the output is equals to the top scored one.
for (Prediction prediction : topChunk) {
Optional<Output> optionalOutput = prediction.getOutput().getByName(outputName);
if (optionalOutput.isPresent()) {
Output output = optionalOutput.get();
Map<String, Saliency> stringSaliencyMap = localExplainer.explainAsync(prediction, predictionProvider).get(Config.DEFAULT_ASYNC_TIMEOUT, Config.DEFAULT_ASYNC_TIMEUNIT);
if (stringSaliencyMap.containsKey(outputName)) {
Saliency saliency = stringSaliencyMap.get(outputName);
List<FeatureImportance> topFeatures = saliency.getPerFeatureImportance().stream().sorted((f1, f2) -> Double.compare(f2.getScore(), f1.getScore())).limit(k).collect(Collectors.toList());
PredictionInput input = bottomChunk.get(currentChunk).getInput();
PredictionInput maskedInput = maskInput(topFeatures, input);
List<PredictionOutput> predictionOutputList = predictionProvider.predictAsync(List.of(maskedInput)).get(Config.DEFAULT_ASYNC_TIMEOUT, Config.DEFAULT_ASYNC_TIMEUNIT);
if (!predictionOutputList.isEmpty()) {
PredictionOutput predictionOutput = predictionOutputList.get(0);
Optional<Output> optionalNewOutput = predictionOutput.getByName(outputName);
if (optionalNewOutput.isPresent()) {
Output newOutput = optionalOutput.get();
if (output.getValue().equals(newOutput.getValue())) {
truePositives++;
} else {
falseNegatives++;
}
}
}
currentChunk++;
}
}
}
if ((truePositives + falseNegatives) > 0) {
return truePositives / (truePositives + falseNegatives);
} else {
// if topChunk is empty or the target output (by name) is not an output of the model.
return Double.NaN;
}
}
use of org.kie.kogito.explainability.model.Saliency in project kogito-apps by kiegroup.
the class ExplainabilityMetrics method getLocalSaliencyPrecision.
/**
* Evaluate the precision of a local saliency explainer on a given model.
* Get the predictions having outputs with the lowest score for the given decision and pair them with predictions
* whose outputs have the highest score for the same decision.
* Get the bottom k (less important) features (according to the saliency) for the less important outputs and
* "paste" them on each paired input corresponding to an output with high score (for the target decision).
* Perform prediction on the "masked" input, if the output changes that's considered a false negative, otherwise
* it's a true positive.
* see Section 3.2.1 of https://openreview.net/attachment?id=B1xBAA4FwH&name=original_pdf
*
* @param outputName decision to evaluate recall for
* @param predictionProvider the prediction provider to test
* @param localExplainer the explainer to evaluate
* @param dataDistribution the data distribution used to obtain inputs for evaluation
* @param k the no. of features to extract
* @param chunkSize the size of the chunk of predictions to use for evaluation
* @return the saliency precision
*/
public static double getLocalSaliencyPrecision(String outputName, PredictionProvider predictionProvider, LocalExplainer<Map<String, Saliency>> localExplainer, DataDistribution dataDistribution, int k, int chunkSize) throws InterruptedException, ExecutionException, TimeoutException {
List<Prediction> sorted = DataUtils.getScoreSortedPredictions(outputName, predictionProvider, dataDistribution);
// get the top and bottom 'chunkSize' predictions
List<Prediction> topChunk = new ArrayList<>(sorted.subList(0, chunkSize));
List<Prediction> bottomChunk = new ArrayList<>(sorted.subList(sorted.size() - chunkSize, sorted.size()));
double truePositives = 0;
double falsePositives = 0;
int currentChunk = 0;
for (Prediction prediction : bottomChunk) {
Map<String, Saliency> stringSaliencyMap = localExplainer.explainAsync(prediction, predictionProvider).get(Config.DEFAULT_ASYNC_TIMEOUT, Config.DEFAULT_ASYNC_TIMEUNIT);
if (stringSaliencyMap.containsKey(outputName)) {
Saliency saliency = stringSaliencyMap.get(outputName);
List<FeatureImportance> topFeatures = saliency.getPerFeatureImportance().stream().sorted(Comparator.comparingDouble(FeatureImportance::getScore)).limit(k).collect(Collectors.toList());
Prediction topPrediction = topChunk.get(currentChunk);
PredictionInput input = topPrediction.getInput();
PredictionInput maskedInput = maskInput(topFeatures, input);
List<PredictionOutput> predictionOutputList = predictionProvider.predictAsync(List.of(maskedInput)).get(Config.DEFAULT_ASYNC_TIMEOUT, Config.DEFAULT_ASYNC_TIMEUNIT);
if (!predictionOutputList.isEmpty()) {
PredictionOutput predictionOutput = predictionOutputList.get(0);
Optional<Output> newOptionalOutput = predictionOutput.getByName(outputName);
if (newOptionalOutput.isPresent()) {
Output newOutput = newOptionalOutput.get();
Optional<Output> optionalOutput = topPrediction.getOutput().getByName(outputName);
if (optionalOutput.isPresent()) {
Output output = optionalOutput.get();
if (output.getValue().equals(newOutput.getValue())) {
truePositives++;
} else {
falsePositives++;
}
}
}
}
currentChunk++;
}
}
if ((truePositives + falsePositives) > 0) {
return truePositives / (truePositives + falsePositives);
} else {
// if bottomChunk is empty or the target output (by name) is not an output of the model.
return Double.NaN;
}
}
use of org.kie.kogito.explainability.model.Saliency in project kogito-apps by kiegroup.
the class LimeImpactScoreCalculator method getImpactScore.
private BigDecimal getImpactScore(LimeConfigSolution solution, LimeConfig config, List<Prediction> predictions) {
double succeededEvaluations = 0;
BigDecimal impactScore = BigDecimal.ZERO;
LimeExplainer limeExplainer = new LimeExplainer(config);
for (Prediction prediction : predictions) {
try {
Map<String, Saliency> saliencyMap = limeExplainer.explainAsync(prediction, solution.getModel()).get(Config.DEFAULT_ASYNC_TIMEOUT, Config.DEFAULT_ASYNC_TIMEUNIT);
for (Map.Entry<String, Saliency> entry : saliencyMap.entrySet()) {
List<FeatureImportance> topFeatures = entry.getValue().getTopFeatures(TOP_FEATURES);
if (!topFeatures.isEmpty()) {
double v = ExplainabilityMetrics.impactScore(solution.getModel(), prediction, topFeatures);
impactScore = impactScore.add(BigDecimal.valueOf(v));
succeededEvaluations++;
}
}
} catch (ExecutionException e) {
LOGGER.error("Saliency impact-score calculation returned an error {}", e.getMessage());
} catch (InterruptedException e) {
LOGGER.error("Interrupted while waiting for saliency impact-score calculation {}", e.getMessage());
Thread.currentThread().interrupt();
} catch (TimeoutException e) {
LOGGER.error("Timed out while waiting for saliency impact-score calculation", e);
}
}
if (succeededEvaluations > 0) {
impactScore = impactScore.divide(BigDecimal.valueOf(succeededEvaluations), RoundingMode.CEILING);
}
return impactScore;
}
use of org.kie.kogito.explainability.model.Saliency in project kogito-apps by kiegroup.
the class LimeExplainer method getSaliency.
private void getSaliency(List<Feature> linearizedTargetInputFeatures, Map<String, Saliency> result, LimeInputs limeInputs, Output originalOutput, LimeConfig executionConfig) {
List<FeatureImportance> featureImportanceList = new ArrayList<>();
// encode the training data so that it can be fed into the linear model
DatasetEncoder datasetEncoder = new DatasetEncoder(limeInputs.getPerturbedInputs(), limeInputs.getPerturbedOutputs(), linearizedTargetInputFeatures, originalOutput, executionConfig.getEncodingParams());
List<Pair<double[], Double>> trainingSet = datasetEncoder.getEncodedTrainingSet();
// weight the training samples based on the proximity to the target input to explain
double kernelWidth = executionConfig.getProximityKernelWidth() * Math.sqrt(linearizedTargetInputFeatures.size());
double[] sampleWeights = SampleWeighter.getSampleWeights(linearizedTargetInputFeatures, trainingSet, kernelWidth);
int ts = linearizedTargetInputFeatures.size();
double[] featureWeights = new double[ts];
Arrays.fill(featureWeights, 1);
if (executionConfig.isPenalizeBalanceSparse()) {
IndependentSparseFeatureBalanceFilter sparseFeatureBalanceFilter = new IndependentSparseFeatureBalanceFilter();
sparseFeatureBalanceFilter.apply(featureWeights, linearizedTargetInputFeatures, trainingSet);
}
if (executionConfig.isProximityFilter()) {
ProximityFilter proximityFilter = new ProximityFilter(executionConfig.getProximityThreshold(), executionConfig.getProximityFilteredDatasetMinimum().doubleValue());
proximityFilter.apply(trainingSet, sampleWeights);
}
LinearModel linearModel = new LinearModel(linearizedTargetInputFeatures.size(), limeInputs.isClassification());
double loss = linearModel.fit(trainingSet, sampleWeights);
if (!Double.isNaN(loss)) {
// create the output saliency
double[] weights = linearModel.getWeights();
if (limeConfig.isNormalizeWeights() && weights.length > 0) {
normalizeWeights(weights);
}
int i = 0;
for (Feature linearizedFeature : linearizedTargetInputFeatures) {
FeatureImportance featureImportance = new FeatureImportance(linearizedFeature, weights[i] * featureWeights[i]);
featureImportanceList.add(featureImportance);
i++;
}
}
Saliency saliency = new Saliency(originalOutput, featureImportanceList);
result.put(originalOutput.getName(), saliency);
}
Aggregations