use of org.kie.kogito.explainability.model.PredictionOutput in project kogito-apps by kiegroup.
the class FairnessMetrics method individualConsistency.
/**
* Calculate individual fairness in terms of consistency of predictions across similar inputs.
*
* @param proximityFunction a function that finds the top k similar inputs, given a reference input and a list of inputs
* @param samples a list of inputs to be tested for consistency
* @param predictionProvider the model under inspection
* @return the consistency measure
* @throws ExecutionException if any error occurs during model prediction
* @throws InterruptedException if timeout or other interruption issues occur during model prediction
*/
public static double individualConsistency(BiFunction<PredictionInput, List<PredictionInput>, List<PredictionInput>> proximityFunction, List<PredictionInput> samples, PredictionProvider predictionProvider) throws ExecutionException, InterruptedException {
double consistency = 1;
for (PredictionInput input : samples) {
List<PredictionOutput> predictionOutputs = predictionProvider.predictAsync(List.of(input)).get();
PredictionOutput predictionOutput = predictionOutputs.get(0);
List<PredictionInput> neighbors = proximityFunction.apply(input, samples);
List<PredictionOutput> neighborsOutputs = predictionProvider.predictAsync(neighbors).get();
for (Output output : predictionOutput.getOutputs()) {
Value originalValue = output.getValue();
for (PredictionOutput neighborOutput : neighborsOutputs) {
Output currentOutput = neighborOutput.getByName(output.getName()).orElse(null);
if (currentOutput != null && !originalValue.equals(currentOutput.getValue())) {
consistency -= 1f / (neighbors.size() * predictionOutput.getOutputs().size() * samples.size());
}
}
}
}
return consistency;
}
use of org.kie.kogito.explainability.model.PredictionOutput in project kogito-apps by kiegroup.
the class CounterFactualScoreCalculator method calculateScore.
/**
* Calculates the counterfactual score for each proposed solution.
* This method assumes that each model used as {@link org.kie.kogito.explainability.model.PredictionProvider} is
* consistent, in the sense that for repeated operations, the size of the returned collection of
* {@link PredictionOutput} is the same, if the size of {@link PredictionInput} doesn't change.
*
* @param solution Proposed solution
* @return A {@link BendableBigDecimalScore} with three "hard" levels and one "soft" level
*/
@Override
public BendableBigDecimalScore calculateScore(CounterfactualSolution solution) {
BendableBigDecimalScore currentScore = calculateInputScore(solution);
final List<Feature> flattenedFeatures = solution.getEntities().stream().map(CounterfactualEntity::asFeature).collect(Collectors.toList());
final List<Feature> input = CompositeFeatureUtils.unflattenFeatures(flattenedFeatures, solution.getOriginalFeatures());
final List<PredictionInput> inputs = List.of(new PredictionInput(input));
final CompletableFuture<List<PredictionOutput>> predictionAsync = solution.getModel().predictAsync(inputs);
try {
List<PredictionOutput> predictions = predictionAsync.get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
solution.setPredictionOutputs(predictions);
final BendableBigDecimalScore outputScore = calculateOutputScore(solution);
currentScore = currentScore.add(outputScore);
} catch (ExecutionException e) {
logger.error("Prediction returned an error {}", e.getMessage());
} catch (InterruptedException e) {
logger.error("Interrupted while waiting for prediction {}", e.getMessage());
Thread.currentThread().interrupt();
} catch (TimeoutException e) {
logger.error("Timed out while waiting for prediction");
}
return currentScore;
}
use of org.kie.kogito.explainability.model.PredictionOutput in project kogito-apps by kiegroup.
the class CounterfactualExplainer method explainAsync.
@Override
public CompletableFuture<CounterfactualResult> explainAsync(Prediction prediction, PredictionProvider model, Consumer<CounterfactualResult> intermediateResultsConsumer) {
final AtomicLong sequenceId = new AtomicLong(0);
final CounterfactualPrediction cfPrediction = (CounterfactualPrediction) prediction;
final UUID executionId = cfPrediction.getExecutionId();
final Long maxRunningTimeSeconds = cfPrediction.getMaxRunningTimeSeconds();
final List<CounterfactualEntity> entities = CounterfactualEntityFactory.createEntities(prediction.getInput());
final List<Output> goal = prediction.getOutput().getOutputs();
// Original features kept as structural reference to re-assemble composite features
final List<Feature> originalFeatures = prediction.getInput().getFeatures();
Function<UUID, CounterfactualSolution> initial = uuid -> new CounterfactualSolution(entities, originalFeatures, model, goal, UUID.randomUUID(), executionId, this.counterfactualConfig.getGoalThreshold());
final CompletableFuture<CounterfactualSolution> cfSolution = CompletableFuture.supplyAsync(() -> {
SolverConfig solverConfig = this.counterfactualConfig.getSolverConfig();
if (Objects.nonNull(maxRunningTimeSeconds)) {
solverConfig.withTerminationSpentLimit(Duration.ofSeconds(maxRunningTimeSeconds));
}
try (SolverManager<CounterfactualSolution, UUID> solverManager = this.counterfactualConfig.getSolverManagerFactory().apply(solverConfig)) {
SolverJob<CounterfactualSolution, UUID> solverJob = solverManager.solveAndListen(executionId, initial, assignSolutionId.andThen(createSolutionConsumer(intermediateResultsConsumer, sequenceId)), null);
try {
// Wait until the solving ends
return solverJob.getFinalBestSolution();
} catch (ExecutionException e) {
logger.error("Solving failed: {}", e.getMessage());
throw new IllegalStateException("Prediction returned an error", e);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new IllegalStateException("Solving failed (Thread interrupted)", e);
}
}
}, this.counterfactualConfig.getExecutor());
final CompletableFuture<List<PredictionOutput>> cfOutputs = cfSolution.thenCompose(s -> model.predictAsync(buildInput(s.getEntities())));
return CompletableFuture.allOf(cfOutputs, cfSolution).thenApply(v -> {
CounterfactualSolution solution = cfSolution.join();
return new CounterfactualResult(solution.getEntities(), solution.getOriginalFeatures(), cfOutputs.join(), solution.getScore().isFeasible(), UUID.randomUUID(), solution.getExecutionId(), sequenceId.incrementAndGet());
});
}
use of org.kie.kogito.explainability.model.PredictionOutput in project kogito-apps by kiegroup.
the class ShapKernelExplainer method explain.
/**
* Compute the shap values for a specific prediction
*
* @param prediction: The ShapPrediction to be explained.
* @param model: The PredictionProvider we are explaining.
*
* @return the shap values for this prediction, of shape [n_model_outputs x n_features]
*/
private CompletableFuture<ShapResults> explain(Prediction prediction, PredictionProvider model) {
ShapDataCarrier sdc = this.initialize(model);
sdc.setSamplesAdded(new ArrayList<>());
PredictionInput pi = prediction.getInput();
PredictionOutput po = prediction.getOutput();
if (pi.getFeatures().size() != sdc.getCols()) {
throw new IllegalArgumentException(String.format("Prediction input feature count (%d) does not match background data feature count (%d)", pi.getFeatures().size(), sdc.getCols()));
}
int cols = sdc.getCols();
CompletableFuture<RealMatrix> output = sdc.getOutputSize().thenApply(os -> {
if (po.getOutputs().size() != os) {
throw new IllegalArgumentException(String.format("Prediction output size (%d) does not match background data output size (%d)", po.getOutputs().size(), os));
}
return MatrixUtils.createRealMatrix(new double[os][cols]);
});
RealVector poVector = MatrixUtilsExtensions.vectorFromPredictionOutput(po);
// first find varying features
this.setVaryingFeatureGroups(pi, sdc);
// if no features vary, then the features do not effect output, and all shap values are zero.
if (sdc.getNumVarying() == 0) {
return output.thenApply(o -> saliencyFromMatrix(o, pi, po)).thenCombine(sdc.getFnull(), ShapResults::new);
} else if (sdc.getNumVarying() == 1) // if 1 feature varies, this feature has all the effect
{
CompletableFuture<RealVector> diff = sdc.getLinkNull().thenApply(poVector::subtract);
return output.thenCompose(o -> diff.thenCombine(sdc.getOutputSize(), (df, os) -> {
RealMatrix out = MatrixUtils.createRealMatrix(new double[os][cols]);
for (int i = 0; i < os; i++) {
out.setEntry(i, sdc.getVaryingFeatureGroups(0), df.getEntry(i));
}
return saliencyFromMatrix(out, pi, po);
})).thenCombine(sdc.getFnull(), ShapResults::new);
} else // if more than 1 feature varies, we need to perform WLR
{
// establish sizes of feature permutations (called subsets)
ShapStatistics shapStats = this.computeSubsetStatistics(sdc);
// weight each subset by number of features
this.initializeWeights(shapStats, sdc);
// add all fully enumerated subsets
this.addCompleteSubsets(shapStats, pi, sdc);
// renormalize weights after full subsets have been added
this.renormalizeWeights(shapStats);
// sample non-fully enumerated subsets
this.addNonCompleteSubsets(shapStats, pi, sdc);
// run the synthetic data generated through the model
CompletableFuture<RealMatrix> expectations = this.runSyntheticData(sdc);
// run the wlr model over the synthetic data results
return output.thenCompose(o -> this.solveSystem(expectations, poVector, sdc).thenApply(wo -> saliencyFromMatrix(wo[0], wo[1], pi, po))).thenCombine(sdc.getFnull(), ShapResults::new);
}
}
use of org.kie.kogito.explainability.model.PredictionOutput in project kogito-apps by kiegroup.
the class LimeExplainer method prepareInputs.
/**
* Check the perturbed inputs so that the dataset of perturbed input / outputs contains more than just one output
* class, otherwise it would be impossible to linearly separate it, and hence learn meaningful weights to be used as
* feature importance scores.
* The check can be {@code strict} or not, if so it will throw a {@code DatasetNotSeparableException} when the dataset
* for a given output is not separable.
*/
private LimeInputs prepareInputs(List<PredictionInput> perturbedInputs, List<PredictionOutput> perturbedOutputs, List<Feature> linearizedTargetInputFeatures, int o, Output currentOutput, boolean strict) {
if (currentOutput.getValue() != null && currentOutput.getValue().getUnderlyingObject() != null) {
Map<Double, Long> rawClassesBalance;
// calculate the no. of samples belonging to each output class
Value fv = currentOutput.getValue();
rawClassesBalance = getClassBalance(perturbedOutputs, fv, o);
Long max = rawClassesBalance.values().stream().max(Long::compareTo).orElse(1L);
double separationRatio = (double) max / (double) perturbedInputs.size();
List<Output> outputs = perturbedOutputs.stream().map(po -> po.getOutputs().get(o)).collect(Collectors.toList());
boolean classification = rawClassesBalance.size() == 2;
if (strict) {
// check if the dataset is separable and also if the linear model should fit a regressor or a classifier
if (rawClassesBalance.size() > 1 && separationRatio < limeConfig.getSeparableDatasetRatio()) {
// if dataset creation process succeeds use it to train the linear model
return new LimeInputs(classification, linearizedTargetInputFeatures, currentOutput, perturbedInputs, outputs);
} else {
throw new DatasetNotSeparableException(currentOutput, rawClassesBalance);
}
} else {
LOGGER.warn("Using an hardly separable dataset for output '{}' of type '{}' with value '{}' ({})", currentOutput.getName(), currentOutput.getType(), currentOutput.getValue(), rawClassesBalance);
return new LimeInputs(classification, linearizedTargetInputFeatures, currentOutput, perturbedInputs, outputs);
}
} else {
return new LimeInputs(false, linearizedTargetInputFeatures, currentOutput, emptyList(), emptyList());
}
}
Aggregations