Search in sources :

Example 6 with Prediction

use of org.kie.kogito.explainability.model.Prediction in project kogito-apps by kiegroup.

the class ExplainabilityMetrics method getLocalSaliencyPrecision.

/**
 * Evaluate the precision of a local saliency explainer on a given model.
 * Get the predictions having outputs with the lowest score for the given decision and pair them with predictions
 * whose outputs have the highest score for the same decision.
 * Get the bottom k (less important) features (according to the saliency) for the less important outputs and
 * "paste" them on each paired input corresponding to an output with high score (for the target decision).
 * Perform prediction on the "masked" input, if the output changes that's considered a false negative, otherwise
 * it's a true positive.
 * see Section 3.2.1 of https://openreview.net/attachment?id=B1xBAA4FwH&name=original_pdf
 *
 * @param outputName decision to evaluate recall for
 * @param predictionProvider the prediction provider to test
 * @param localExplainer the explainer to evaluate
 * @param dataDistribution the data distribution used to obtain inputs for evaluation
 * @param k the no. of features to extract
 * @param chunkSize the size of the chunk of predictions to use for evaluation
 * @return the saliency precision
 */
public static double getLocalSaliencyPrecision(String outputName, PredictionProvider predictionProvider, LocalExplainer<Map<String, Saliency>> localExplainer, DataDistribution dataDistribution, int k, int chunkSize) throws InterruptedException, ExecutionException, TimeoutException {
    List<Prediction> sorted = DataUtils.getScoreSortedPredictions(outputName, predictionProvider, dataDistribution);
    // get the top and bottom 'chunkSize' predictions
    List<Prediction> topChunk = new ArrayList<>(sorted.subList(0, chunkSize));
    List<Prediction> bottomChunk = new ArrayList<>(sorted.subList(sorted.size() - chunkSize, sorted.size()));
    double truePositives = 0;
    double falsePositives = 0;
    int currentChunk = 0;
    for (Prediction prediction : bottomChunk) {
        Map<String, Saliency> stringSaliencyMap = localExplainer.explainAsync(prediction, predictionProvider).get(Config.DEFAULT_ASYNC_TIMEOUT, Config.DEFAULT_ASYNC_TIMEUNIT);
        if (stringSaliencyMap.containsKey(outputName)) {
            Saliency saliency = stringSaliencyMap.get(outputName);
            List<FeatureImportance> topFeatures = saliency.getPerFeatureImportance().stream().sorted(Comparator.comparingDouble(FeatureImportance::getScore)).limit(k).collect(Collectors.toList());
            Prediction topPrediction = topChunk.get(currentChunk);
            PredictionInput input = topPrediction.getInput();
            PredictionInput maskedInput = maskInput(topFeatures, input);
            List<PredictionOutput> predictionOutputList = predictionProvider.predictAsync(List.of(maskedInput)).get(Config.DEFAULT_ASYNC_TIMEOUT, Config.DEFAULT_ASYNC_TIMEUNIT);
            if (!predictionOutputList.isEmpty()) {
                PredictionOutput predictionOutput = predictionOutputList.get(0);
                Optional<Output> newOptionalOutput = predictionOutput.getByName(outputName);
                if (newOptionalOutput.isPresent()) {
                    Output newOutput = newOptionalOutput.get();
                    Optional<Output> optionalOutput = topPrediction.getOutput().getByName(outputName);
                    if (optionalOutput.isPresent()) {
                        Output output = optionalOutput.get();
                        if (output.getValue().equals(newOutput.getValue())) {
                            truePositives++;
                        } else {
                            falsePositives++;
                        }
                    }
                }
            }
            currentChunk++;
        }
    }
    if ((truePositives + falsePositives) > 0) {
        return truePositives / (truePositives + falsePositives);
    } else {
        // if bottomChunk is empty or the target output (by name) is not an output of the model.
        return Double.NaN;
    }
}
Also used : PredictionInput(org.kie.kogito.explainability.model.PredictionInput) Prediction(org.kie.kogito.explainability.model.Prediction) ArrayList(java.util.ArrayList) Saliency(org.kie.kogito.explainability.model.Saliency) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output)

Example 7 with Prediction

use of org.kie.kogito.explainability.model.Prediction in project kogito-apps by kiegroup.

the class CounterfactualExplainer method explainAsync.

@Override
public CompletableFuture<CounterfactualResult> explainAsync(Prediction prediction, PredictionProvider model, Consumer<CounterfactualResult> intermediateResultsConsumer) {
    final AtomicLong sequenceId = new AtomicLong(0);
    final CounterfactualPrediction cfPrediction = (CounterfactualPrediction) prediction;
    final UUID executionId = cfPrediction.getExecutionId();
    final Long maxRunningTimeSeconds = cfPrediction.getMaxRunningTimeSeconds();
    final List<CounterfactualEntity> entities = CounterfactualEntityFactory.createEntities(prediction.getInput());
    final List<Output> goal = prediction.getOutput().getOutputs();
    // Original features kept as structural reference to re-assemble composite features
    final List<Feature> originalFeatures = prediction.getInput().getFeatures();
    Function<UUID, CounterfactualSolution> initial = uuid -> new CounterfactualSolution(entities, originalFeatures, model, goal, UUID.randomUUID(), executionId, this.counterfactualConfig.getGoalThreshold());
    final CompletableFuture<CounterfactualSolution> cfSolution = CompletableFuture.supplyAsync(() -> {
        SolverConfig solverConfig = this.counterfactualConfig.getSolverConfig();
        if (Objects.nonNull(maxRunningTimeSeconds)) {
            solverConfig.withTerminationSpentLimit(Duration.ofSeconds(maxRunningTimeSeconds));
        }
        try (SolverManager<CounterfactualSolution, UUID> solverManager = this.counterfactualConfig.getSolverManagerFactory().apply(solverConfig)) {
            SolverJob<CounterfactualSolution, UUID> solverJob = solverManager.solveAndListen(executionId, initial, assignSolutionId.andThen(createSolutionConsumer(intermediateResultsConsumer, sequenceId)), null);
            try {
                // Wait until the solving ends
                return solverJob.getFinalBestSolution();
            } catch (ExecutionException e) {
                logger.error("Solving failed: {}", e.getMessage());
                throw new IllegalStateException("Prediction returned an error", e);
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                throw new IllegalStateException("Solving failed (Thread interrupted)", e);
            }
        }
    }, this.counterfactualConfig.getExecutor());
    final CompletableFuture<List<PredictionOutput>> cfOutputs = cfSolution.thenCompose(s -> model.predictAsync(buildInput(s.getEntities())));
    return CompletableFuture.allOf(cfOutputs, cfSolution).thenApply(v -> {
        CounterfactualSolution solution = cfSolution.join();
        return new CounterfactualResult(solution.getEntities(), solution.getOriginalFeatures(), cfOutputs.join(), solution.getScore().isFeasible(), UUID.randomUUID(), solution.getExecutionId(), sequenceId.incrementAndGet());
    });
}
Also used : SolverConfig(org.optaplanner.core.config.solver.SolverConfig) Feature(org.kie.kogito.explainability.model.Feature) Prediction(org.kie.kogito.explainability.model.Prediction) LoggerFactory(org.slf4j.LoggerFactory) CompletableFuture(java.util.concurrent.CompletableFuture) CounterfactualEntity(org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity) SolverManager(org.optaplanner.core.api.solver.SolverManager) Function(java.util.function.Function) CompositeFeatureUtils(org.kie.kogito.explainability.utils.CompositeFeatureUtils) Duration(java.time.Duration) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Logger(org.slf4j.Logger) Executor(java.util.concurrent.Executor) LocalExplainer(org.kie.kogito.explainability.local.LocalExplainer) UUID(java.util.UUID) Collectors(java.util.stream.Collectors) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Objects(java.util.Objects) ExecutionException(java.util.concurrent.ExecutionException) Consumer(java.util.function.Consumer) AtomicLong(java.util.concurrent.atomic.AtomicLong) CounterfactualEntityFactory(org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntityFactory) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) List(java.util.List) Output(org.kie.kogito.explainability.model.Output) SolverJob(org.optaplanner.core.api.solver.SolverJob) Feature(org.kie.kogito.explainability.model.Feature) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) CounterfactualEntity(org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity) AtomicLong(java.util.concurrent.atomic.AtomicLong) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) AtomicLong(java.util.concurrent.atomic.AtomicLong) List(java.util.List) UUID(java.util.UUID) ExecutionException(java.util.concurrent.ExecutionException) SolverConfig(org.optaplanner.core.config.solver.SolverConfig)

Example 8 with Prediction

use of org.kie.kogito.explainability.model.Prediction in project kogito-apps by kiegroup.

the class ShapKernelExplainer method explain.

/**
 * Compute the shap values for a specific prediction
 *
 * @param prediction: The ShapPrediction to be explained.
 * @param model: The PredictionProvider we are explaining.
 *
 * @return the shap values for this prediction, of shape [n_model_outputs x n_features]
 */
private CompletableFuture<ShapResults> explain(Prediction prediction, PredictionProvider model) {
    ShapDataCarrier sdc = this.initialize(model);
    sdc.setSamplesAdded(new ArrayList<>());
    PredictionInput pi = prediction.getInput();
    PredictionOutput po = prediction.getOutput();
    if (pi.getFeatures().size() != sdc.getCols()) {
        throw new IllegalArgumentException(String.format("Prediction input feature count (%d) does not match background data feature count (%d)", pi.getFeatures().size(), sdc.getCols()));
    }
    int cols = sdc.getCols();
    CompletableFuture<RealMatrix> output = sdc.getOutputSize().thenApply(os -> {
        if (po.getOutputs().size() != os) {
            throw new IllegalArgumentException(String.format("Prediction output size (%d) does not match background data output size (%d)", po.getOutputs().size(), os));
        }
        return MatrixUtils.createRealMatrix(new double[os][cols]);
    });
    RealVector poVector = MatrixUtilsExtensions.vectorFromPredictionOutput(po);
    // first find varying features
    this.setVaryingFeatureGroups(pi, sdc);
    // if no features vary, then the features do not effect output, and all shap values are zero.
    if (sdc.getNumVarying() == 0) {
        return output.thenApply(o -> saliencyFromMatrix(o, pi, po)).thenCombine(sdc.getFnull(), ShapResults::new);
    } else if (sdc.getNumVarying() == 1) // if 1 feature varies, this feature has all the effect
    {
        CompletableFuture<RealVector> diff = sdc.getLinkNull().thenApply(poVector::subtract);
        return output.thenCompose(o -> diff.thenCombine(sdc.getOutputSize(), (df, os) -> {
            RealMatrix out = MatrixUtils.createRealMatrix(new double[os][cols]);
            for (int i = 0; i < os; i++) {
                out.setEntry(i, sdc.getVaryingFeatureGroups(0), df.getEntry(i));
            }
            return saliencyFromMatrix(out, pi, po);
        })).thenCombine(sdc.getFnull(), ShapResults::new);
    } else // if more than 1 feature varies, we need to perform WLR
    {
        // establish sizes of feature permutations (called subsets)
        ShapStatistics shapStats = this.computeSubsetStatistics(sdc);
        // weight each subset by number of features
        this.initializeWeights(shapStats, sdc);
        // add all fully enumerated subsets
        this.addCompleteSubsets(shapStats, pi, sdc);
        // renormalize weights after full subsets have been added
        this.renormalizeWeights(shapStats);
        // sample non-fully enumerated subsets
        this.addNonCompleteSubsets(shapStats, pi, sdc);
        // run the synthetic data generated through the model
        CompletableFuture<RealMatrix> expectations = this.runSyntheticData(sdc);
        // run the wlr model over the synthetic data results
        return output.thenCompose(o -> this.solveSystem(expectations, poVector, sdc).thenApply(wo -> saliencyFromMatrix(wo[0], wo[1], pi, po))).thenCombine(sdc.getFnull(), ShapResults::new);
    }
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) LarsPath(org.kie.kogito.explainability.utils.LarsPath) Prediction(org.kie.kogito.explainability.model.Prediction) LoggerFactory(org.slf4j.LoggerFactory) HashMap(java.util.HashMap) CompletableFuture(java.util.concurrent.CompletableFuture) RealVector(org.apache.commons.math3.linear.RealVector) WeightedLinearRegression(org.kie.kogito.explainability.utils.WeightedLinearRegression) Saliency(org.kie.kogito.explainability.model.Saliency) ArrayList(java.util.ArrayList) MathArithmeticException(org.apache.commons.math3.exception.MathArithmeticException) MatrixUtils(org.apache.commons.math3.linear.MatrixUtils) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) LassoLarsIC(org.kie.kogito.explainability.utils.LassoLarsIC) CombinatoricsUtils(org.apache.commons.math3.util.CombinatoricsUtils) Logger(org.slf4j.Logger) Iterator(java.util.Iterator) LocalExplainer(org.kie.kogito.explainability.local.LocalExplainer) AnyMatrix(org.apache.commons.math3.linear.AnyMatrix) WeightedLinearRegressionResults(org.kie.kogito.explainability.utils.WeightedLinearRegressionResults) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) Collectors(java.util.stream.Collectors) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Consumer(java.util.function.Consumer) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) List(java.util.List) MatrixUtilsExtensions(org.kie.kogito.explainability.utils.MatrixUtilsExtensions) RandomChoice(org.kie.kogito.explainability.utils.RandomChoice) RealMatrix(org.apache.commons.math3.linear.RealMatrix) Collections(java.util.Collections) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) CompletableFuture(java.util.concurrent.CompletableFuture) RealMatrix(org.apache.commons.math3.linear.RealMatrix) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) RealVector(org.apache.commons.math3.linear.RealVector)

Example 9 with Prediction

use of org.kie.kogito.explainability.model.Prediction in project kogito-apps by kiegroup.

the class ShapKernelExplainer method solveSystem.

/**
 * Run WLRRs in parallel, with each parallel thread computing the shap values for a particular output of the model
 *
 * @param expectations: The expectations of each sample
 * @param poVector: The predictionOutputs for this explanation's prediction
 *
 * @return the shap values as found by the WLR
 */
private CompletableFuture<RealMatrix[]> solveSystem(CompletableFuture<RealMatrix> expectations, RealVector poVector, ShapDataCarrier sdc) {
    return expectations.thenCompose(exps -> sdc.getFnull().thenCompose(fn -> sdc.getOutputSize().thenCompose(os -> {
        HashMap<Integer, CompletableFuture<RealVector[]>> shapSlices = new HashMap<>();
        for (int output = 0; output < os; output++) {
            int finalOutput = output;
            shapSlices.put(output, CompletableFuture.supplyAsync(() -> solve(exps, finalOutput, poVector, fn, sdc), this.config.getExecutor()));
        }
        // reduce parallel operations into single array
        RealMatrix outputMatrix = MatrixUtils.createRealMatrix(new double[os][sdc.getCols()]);
        final CompletableFuture<RealMatrix[]>[] shapVals = new CompletableFuture[] { CompletableFuture.supplyAsync(() -> new RealMatrix[] { outputMatrix.copy(), outputMatrix.copy() }, this.config.getExecutor()) };
        shapSlices.forEach((idx, slice) -> shapVals[0] = shapVals[0].thenCompose(e -> slice.thenApply(s -> {
            // store shap values in first matrix
            e[0].setRowVector(idx, s[0]);
            // shap value confidences go in the second
            e[1].setRowVector(idx, s[1]);
            return e;
        })));
        return shapVals[0];
    })));
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) LarsPath(org.kie.kogito.explainability.utils.LarsPath) Prediction(org.kie.kogito.explainability.model.Prediction) LoggerFactory(org.slf4j.LoggerFactory) HashMap(java.util.HashMap) CompletableFuture(java.util.concurrent.CompletableFuture) RealVector(org.apache.commons.math3.linear.RealVector) WeightedLinearRegression(org.kie.kogito.explainability.utils.WeightedLinearRegression) Saliency(org.kie.kogito.explainability.model.Saliency) ArrayList(java.util.ArrayList) MathArithmeticException(org.apache.commons.math3.exception.MathArithmeticException) MatrixUtils(org.apache.commons.math3.linear.MatrixUtils) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) LassoLarsIC(org.kie.kogito.explainability.utils.LassoLarsIC) CombinatoricsUtils(org.apache.commons.math3.util.CombinatoricsUtils) Logger(org.slf4j.Logger) Iterator(java.util.Iterator) LocalExplainer(org.kie.kogito.explainability.local.LocalExplainer) AnyMatrix(org.apache.commons.math3.linear.AnyMatrix) WeightedLinearRegressionResults(org.kie.kogito.explainability.utils.WeightedLinearRegressionResults) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) Collectors(java.util.stream.Collectors) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Consumer(java.util.function.Consumer) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) List(java.util.List) MatrixUtilsExtensions(org.kie.kogito.explainability.utils.MatrixUtilsExtensions) RandomChoice(org.kie.kogito.explainability.utils.RandomChoice) RealMatrix(org.apache.commons.math3.linear.RealMatrix) Collections(java.util.Collections) CompletableFuture(java.util.concurrent.CompletableFuture) RealMatrix(org.apache.commons.math3.linear.RealMatrix) HashMap(java.util.HashMap) RealVector(org.apache.commons.math3.linear.RealVector)

Example 10 with Prediction

use of org.kie.kogito.explainability.model.Prediction in project kogito-apps by kiegroup.

the class LimeImpactScoreCalculator method getImpactScore.

private BigDecimal getImpactScore(LimeConfigSolution solution, LimeConfig config, List<Prediction> predictions) {
    double succeededEvaluations = 0;
    BigDecimal impactScore = BigDecimal.ZERO;
    LimeExplainer limeExplainer = new LimeExplainer(config);
    for (Prediction prediction : predictions) {
        try {
            Map<String, Saliency> saliencyMap = limeExplainer.explainAsync(prediction, solution.getModel()).get(Config.DEFAULT_ASYNC_TIMEOUT, Config.DEFAULT_ASYNC_TIMEUNIT);
            for (Map.Entry<String, Saliency> entry : saliencyMap.entrySet()) {
                List<FeatureImportance> topFeatures = entry.getValue().getTopFeatures(TOP_FEATURES);
                if (!topFeatures.isEmpty()) {
                    double v = ExplainabilityMetrics.impactScore(solution.getModel(), prediction, topFeatures);
                    impactScore = impactScore.add(BigDecimal.valueOf(v));
                    succeededEvaluations++;
                }
            }
        } catch (ExecutionException e) {
            LOGGER.error("Saliency impact-score calculation returned an error {}", e.getMessage());
        } catch (InterruptedException e) {
            LOGGER.error("Interrupted while waiting for saliency impact-score calculation {}", e.getMessage());
            Thread.currentThread().interrupt();
        } catch (TimeoutException e) {
            LOGGER.error("Timed out while waiting for saliency impact-score calculation", e);
        }
    }
    if (succeededEvaluations > 0) {
        impactScore = impactScore.divide(BigDecimal.valueOf(succeededEvaluations), RoundingMode.CEILING);
    }
    return impactScore;
}
Also used : LimeExplainer(org.kie.kogito.explainability.local.lime.LimeExplainer) Prediction(org.kie.kogito.explainability.model.Prediction) Saliency(org.kie.kogito.explainability.model.Saliency) BigDecimal(java.math.BigDecimal) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) ExecutionException(java.util.concurrent.ExecutionException) Map(java.util.Map) TimeoutException(java.util.concurrent.TimeoutException)

Aggregations

Prediction (org.kie.kogito.explainability.model.Prediction)134 PredictionProvider (org.kie.kogito.explainability.model.PredictionProvider)117 PredictionOutput (org.kie.kogito.explainability.model.PredictionOutput)107 PredictionInput (org.kie.kogito.explainability.model.PredictionInput)105 SimplePrediction (org.kie.kogito.explainability.model.SimplePrediction)96 Test (org.junit.jupiter.api.Test)95 Random (java.util.Random)65 PerturbationContext (org.kie.kogito.explainability.model.PerturbationContext)61 LimeConfig (org.kie.kogito.explainability.local.lime.LimeConfig)57 ArrayList (java.util.ArrayList)51 Feature (org.kie.kogito.explainability.model.Feature)48 Saliency (org.kie.kogito.explainability.model.Saliency)48 ParameterizedTest (org.junit.jupiter.params.ParameterizedTest)42 LimeExplainer (org.kie.kogito.explainability.local.lime.LimeExplainer)40 LimeConfigOptimizer (org.kie.kogito.explainability.local.lime.optim.LimeConfigOptimizer)28 DataDistribution (org.kie.kogito.explainability.model.DataDistribution)24 ValueSource (org.junit.jupiter.params.provider.ValueSource)22 FeatureImportance (org.kie.kogito.explainability.model.FeatureImportance)22 Output (org.kie.kogito.explainability.model.Output)22 LinkedList (java.util.LinkedList)21