use of org.kie.kogito.explainability.model.Prediction in project kogito-apps by kiegroup.
the class ExplainabilityMetrics method getLocalSaliencyPrecision.
/**
* Evaluate the precision of a local saliency explainer on a given model.
* Get the predictions having outputs with the lowest score for the given decision and pair them with predictions
* whose outputs have the highest score for the same decision.
* Get the bottom k (less important) features (according to the saliency) for the less important outputs and
* "paste" them on each paired input corresponding to an output with high score (for the target decision).
* Perform prediction on the "masked" input, if the output changes that's considered a false negative, otherwise
* it's a true positive.
* see Section 3.2.1 of https://openreview.net/attachment?id=B1xBAA4FwH&name=original_pdf
*
* @param outputName decision to evaluate recall for
* @param predictionProvider the prediction provider to test
* @param localExplainer the explainer to evaluate
* @param dataDistribution the data distribution used to obtain inputs for evaluation
* @param k the no. of features to extract
* @param chunkSize the size of the chunk of predictions to use for evaluation
* @return the saliency precision
*/
public static double getLocalSaliencyPrecision(String outputName, PredictionProvider predictionProvider, LocalExplainer<Map<String, Saliency>> localExplainer, DataDistribution dataDistribution, int k, int chunkSize) throws InterruptedException, ExecutionException, TimeoutException {
List<Prediction> sorted = DataUtils.getScoreSortedPredictions(outputName, predictionProvider, dataDistribution);
// get the top and bottom 'chunkSize' predictions
List<Prediction> topChunk = new ArrayList<>(sorted.subList(0, chunkSize));
List<Prediction> bottomChunk = new ArrayList<>(sorted.subList(sorted.size() - chunkSize, sorted.size()));
double truePositives = 0;
double falsePositives = 0;
int currentChunk = 0;
for (Prediction prediction : bottomChunk) {
Map<String, Saliency> stringSaliencyMap = localExplainer.explainAsync(prediction, predictionProvider).get(Config.DEFAULT_ASYNC_TIMEOUT, Config.DEFAULT_ASYNC_TIMEUNIT);
if (stringSaliencyMap.containsKey(outputName)) {
Saliency saliency = stringSaliencyMap.get(outputName);
List<FeatureImportance> topFeatures = saliency.getPerFeatureImportance().stream().sorted(Comparator.comparingDouble(FeatureImportance::getScore)).limit(k).collect(Collectors.toList());
Prediction topPrediction = topChunk.get(currentChunk);
PredictionInput input = topPrediction.getInput();
PredictionInput maskedInput = maskInput(topFeatures, input);
List<PredictionOutput> predictionOutputList = predictionProvider.predictAsync(List.of(maskedInput)).get(Config.DEFAULT_ASYNC_TIMEOUT, Config.DEFAULT_ASYNC_TIMEUNIT);
if (!predictionOutputList.isEmpty()) {
PredictionOutput predictionOutput = predictionOutputList.get(0);
Optional<Output> newOptionalOutput = predictionOutput.getByName(outputName);
if (newOptionalOutput.isPresent()) {
Output newOutput = newOptionalOutput.get();
Optional<Output> optionalOutput = topPrediction.getOutput().getByName(outputName);
if (optionalOutput.isPresent()) {
Output output = optionalOutput.get();
if (output.getValue().equals(newOutput.getValue())) {
truePositives++;
} else {
falsePositives++;
}
}
}
}
currentChunk++;
}
}
if ((truePositives + falsePositives) > 0) {
return truePositives / (truePositives + falsePositives);
} else {
// if bottomChunk is empty or the target output (by name) is not an output of the model.
return Double.NaN;
}
}
use of org.kie.kogito.explainability.model.Prediction in project kogito-apps by kiegroup.
the class CounterfactualExplainer method explainAsync.
@Override
public CompletableFuture<CounterfactualResult> explainAsync(Prediction prediction, PredictionProvider model, Consumer<CounterfactualResult> intermediateResultsConsumer) {
final AtomicLong sequenceId = new AtomicLong(0);
final CounterfactualPrediction cfPrediction = (CounterfactualPrediction) prediction;
final UUID executionId = cfPrediction.getExecutionId();
final Long maxRunningTimeSeconds = cfPrediction.getMaxRunningTimeSeconds();
final List<CounterfactualEntity> entities = CounterfactualEntityFactory.createEntities(prediction.getInput());
final List<Output> goal = prediction.getOutput().getOutputs();
// Original features kept as structural reference to re-assemble composite features
final List<Feature> originalFeatures = prediction.getInput().getFeatures();
Function<UUID, CounterfactualSolution> initial = uuid -> new CounterfactualSolution(entities, originalFeatures, model, goal, UUID.randomUUID(), executionId, this.counterfactualConfig.getGoalThreshold());
final CompletableFuture<CounterfactualSolution> cfSolution = CompletableFuture.supplyAsync(() -> {
SolverConfig solverConfig = this.counterfactualConfig.getSolverConfig();
if (Objects.nonNull(maxRunningTimeSeconds)) {
solverConfig.withTerminationSpentLimit(Duration.ofSeconds(maxRunningTimeSeconds));
}
try (SolverManager<CounterfactualSolution, UUID> solverManager = this.counterfactualConfig.getSolverManagerFactory().apply(solverConfig)) {
SolverJob<CounterfactualSolution, UUID> solverJob = solverManager.solveAndListen(executionId, initial, assignSolutionId.andThen(createSolutionConsumer(intermediateResultsConsumer, sequenceId)), null);
try {
// Wait until the solving ends
return solverJob.getFinalBestSolution();
} catch (ExecutionException e) {
logger.error("Solving failed: {}", e.getMessage());
throw new IllegalStateException("Prediction returned an error", e);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new IllegalStateException("Solving failed (Thread interrupted)", e);
}
}
}, this.counterfactualConfig.getExecutor());
final CompletableFuture<List<PredictionOutput>> cfOutputs = cfSolution.thenCompose(s -> model.predictAsync(buildInput(s.getEntities())));
return CompletableFuture.allOf(cfOutputs, cfSolution).thenApply(v -> {
CounterfactualSolution solution = cfSolution.join();
return new CounterfactualResult(solution.getEntities(), solution.getOriginalFeatures(), cfOutputs.join(), solution.getScore().isFeasible(), UUID.randomUUID(), solution.getExecutionId(), sequenceId.incrementAndGet());
});
}
use of org.kie.kogito.explainability.model.Prediction in project kogito-apps by kiegroup.
the class ShapKernelExplainer method explain.
/**
* Compute the shap values for a specific prediction
*
* @param prediction: The ShapPrediction to be explained.
* @param model: The PredictionProvider we are explaining.
*
* @return the shap values for this prediction, of shape [n_model_outputs x n_features]
*/
private CompletableFuture<ShapResults> explain(Prediction prediction, PredictionProvider model) {
ShapDataCarrier sdc = this.initialize(model);
sdc.setSamplesAdded(new ArrayList<>());
PredictionInput pi = prediction.getInput();
PredictionOutput po = prediction.getOutput();
if (pi.getFeatures().size() != sdc.getCols()) {
throw new IllegalArgumentException(String.format("Prediction input feature count (%d) does not match background data feature count (%d)", pi.getFeatures().size(), sdc.getCols()));
}
int cols = sdc.getCols();
CompletableFuture<RealMatrix> output = sdc.getOutputSize().thenApply(os -> {
if (po.getOutputs().size() != os) {
throw new IllegalArgumentException(String.format("Prediction output size (%d) does not match background data output size (%d)", po.getOutputs().size(), os));
}
return MatrixUtils.createRealMatrix(new double[os][cols]);
});
RealVector poVector = MatrixUtilsExtensions.vectorFromPredictionOutput(po);
// first find varying features
this.setVaryingFeatureGroups(pi, sdc);
// if no features vary, then the features do not effect output, and all shap values are zero.
if (sdc.getNumVarying() == 0) {
return output.thenApply(o -> saliencyFromMatrix(o, pi, po)).thenCombine(sdc.getFnull(), ShapResults::new);
} else if (sdc.getNumVarying() == 1) // if 1 feature varies, this feature has all the effect
{
CompletableFuture<RealVector> diff = sdc.getLinkNull().thenApply(poVector::subtract);
return output.thenCompose(o -> diff.thenCombine(sdc.getOutputSize(), (df, os) -> {
RealMatrix out = MatrixUtils.createRealMatrix(new double[os][cols]);
for (int i = 0; i < os; i++) {
out.setEntry(i, sdc.getVaryingFeatureGroups(0), df.getEntry(i));
}
return saliencyFromMatrix(out, pi, po);
})).thenCombine(sdc.getFnull(), ShapResults::new);
} else // if more than 1 feature varies, we need to perform WLR
{
// establish sizes of feature permutations (called subsets)
ShapStatistics shapStats = this.computeSubsetStatistics(sdc);
// weight each subset by number of features
this.initializeWeights(shapStats, sdc);
// add all fully enumerated subsets
this.addCompleteSubsets(shapStats, pi, sdc);
// renormalize weights after full subsets have been added
this.renormalizeWeights(shapStats);
// sample non-fully enumerated subsets
this.addNonCompleteSubsets(shapStats, pi, sdc);
// run the synthetic data generated through the model
CompletableFuture<RealMatrix> expectations = this.runSyntheticData(sdc);
// run the wlr model over the synthetic data results
return output.thenCompose(o -> this.solveSystem(expectations, poVector, sdc).thenApply(wo -> saliencyFromMatrix(wo[0], wo[1], pi, po))).thenCombine(sdc.getFnull(), ShapResults::new);
}
}
use of org.kie.kogito.explainability.model.Prediction in project kogito-apps by kiegroup.
the class ShapKernelExplainer method solveSystem.
/**
* Run WLRRs in parallel, with each parallel thread computing the shap values for a particular output of the model
*
* @param expectations: The expectations of each sample
* @param poVector: The predictionOutputs for this explanation's prediction
*
* @return the shap values as found by the WLR
*/
private CompletableFuture<RealMatrix[]> solveSystem(CompletableFuture<RealMatrix> expectations, RealVector poVector, ShapDataCarrier sdc) {
return expectations.thenCompose(exps -> sdc.getFnull().thenCompose(fn -> sdc.getOutputSize().thenCompose(os -> {
HashMap<Integer, CompletableFuture<RealVector[]>> shapSlices = new HashMap<>();
for (int output = 0; output < os; output++) {
int finalOutput = output;
shapSlices.put(output, CompletableFuture.supplyAsync(() -> solve(exps, finalOutput, poVector, fn, sdc), this.config.getExecutor()));
}
// reduce parallel operations into single array
RealMatrix outputMatrix = MatrixUtils.createRealMatrix(new double[os][sdc.getCols()]);
final CompletableFuture<RealMatrix[]>[] shapVals = new CompletableFuture[] { CompletableFuture.supplyAsync(() -> new RealMatrix[] { outputMatrix.copy(), outputMatrix.copy() }, this.config.getExecutor()) };
shapSlices.forEach((idx, slice) -> shapVals[0] = shapVals[0].thenCompose(e -> slice.thenApply(s -> {
// store shap values in first matrix
e[0].setRowVector(idx, s[0]);
// shap value confidences go in the second
e[1].setRowVector(idx, s[1]);
return e;
})));
return shapVals[0];
})));
}
use of org.kie.kogito.explainability.model.Prediction in project kogito-apps by kiegroup.
the class LimeImpactScoreCalculator method getImpactScore.
private BigDecimal getImpactScore(LimeConfigSolution solution, LimeConfig config, List<Prediction> predictions) {
double succeededEvaluations = 0;
BigDecimal impactScore = BigDecimal.ZERO;
LimeExplainer limeExplainer = new LimeExplainer(config);
for (Prediction prediction : predictions) {
try {
Map<String, Saliency> saliencyMap = limeExplainer.explainAsync(prediction, solution.getModel()).get(Config.DEFAULT_ASYNC_TIMEOUT, Config.DEFAULT_ASYNC_TIMEUNIT);
for (Map.Entry<String, Saliency> entry : saliencyMap.entrySet()) {
List<FeatureImportance> topFeatures = entry.getValue().getTopFeatures(TOP_FEATURES);
if (!topFeatures.isEmpty()) {
double v = ExplainabilityMetrics.impactScore(solution.getModel(), prediction, topFeatures);
impactScore = impactScore.add(BigDecimal.valueOf(v));
succeededEvaluations++;
}
}
} catch (ExecutionException e) {
LOGGER.error("Saliency impact-score calculation returned an error {}", e.getMessage());
} catch (InterruptedException e) {
LOGGER.error("Interrupted while waiting for saliency impact-score calculation {}", e.getMessage());
Thread.currentThread().interrupt();
} catch (TimeoutException e) {
LOGGER.error("Timed out while waiting for saliency impact-score calculation", e);
}
}
if (succeededEvaluations > 0) {
impactScore = impactScore.divide(BigDecimal.valueOf(succeededEvaluations), RoundingMode.CEILING);
}
return impactScore;
}
Aggregations