Search in sources :

Example 16 with PredictionOutput

use of org.kie.kogito.explainability.model.PredictionOutput in project kogito-apps by kiegroup.

the class FairnessMetrics method individualConsistency.

/**
 * Calculate individual fairness in terms of consistency of predictions across similar inputs.
 *
 * @param proximityFunction a function that finds the top k similar inputs, given a reference input and a list of inputs
 * @param samples a list of inputs to be tested for consistency
 * @param predictionProvider the model under inspection
 * @return the consistency measure
 * @throws ExecutionException if any error occurs during model prediction
 * @throws InterruptedException if timeout or other interruption issues occur during model prediction
 */
public static double individualConsistency(BiFunction<PredictionInput, List<PredictionInput>, List<PredictionInput>> proximityFunction, List<PredictionInput> samples, PredictionProvider predictionProvider) throws ExecutionException, InterruptedException {
    double consistency = 1;
    for (PredictionInput input : samples) {
        List<PredictionOutput> predictionOutputs = predictionProvider.predictAsync(List.of(input)).get();
        PredictionOutput predictionOutput = predictionOutputs.get(0);
        List<PredictionInput> neighbors = proximityFunction.apply(input, samples);
        List<PredictionOutput> neighborsOutputs = predictionProvider.predictAsync(neighbors).get();
        for (Output output : predictionOutput.getOutputs()) {
            Value originalValue = output.getValue();
            for (PredictionOutput neighborOutput : neighborsOutputs) {
                Output currentOutput = neighborOutput.getByName(output.getName()).orElse(null);
                if (currentOutput != null && !originalValue.equals(currentOutput.getValue())) {
                    consistency -= 1f / (neighbors.size() * predictionOutput.getOutputs().size() * samples.size());
                }
            }
        }
    }
    return consistency;
}
Also used : PredictionInput(org.kie.kogito.explainability.model.PredictionInput) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Value(org.kie.kogito.explainability.model.Value)

Example 17 with PredictionOutput

use of org.kie.kogito.explainability.model.PredictionOutput in project kogito-apps by kiegroup.

the class CounterFactualScoreCalculator method calculateScore.

/**
 * Calculates the counterfactual score for each proposed solution.
 * This method assumes that each model used as {@link org.kie.kogito.explainability.model.PredictionProvider} is
 * consistent, in the sense that for repeated operations, the size of the returned collection of
 * {@link PredictionOutput} is the same, if the size of {@link PredictionInput} doesn't change.
 *
 * @param solution Proposed solution
 * @return A {@link BendableBigDecimalScore} with three "hard" levels and one "soft" level
 */
@Override
public BendableBigDecimalScore calculateScore(CounterfactualSolution solution) {
    BendableBigDecimalScore currentScore = calculateInputScore(solution);
    final List<Feature> flattenedFeatures = solution.getEntities().stream().map(CounterfactualEntity::asFeature).collect(Collectors.toList());
    final List<Feature> input = CompositeFeatureUtils.unflattenFeatures(flattenedFeatures, solution.getOriginalFeatures());
    final List<PredictionInput> inputs = List.of(new PredictionInput(input));
    final CompletableFuture<List<PredictionOutput>> predictionAsync = solution.getModel().predictAsync(inputs);
    try {
        List<PredictionOutput> predictions = predictionAsync.get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        solution.setPredictionOutputs(predictions);
        final BendableBigDecimalScore outputScore = calculateOutputScore(solution);
        currentScore = currentScore.add(outputScore);
    } catch (ExecutionException e) {
        logger.error("Prediction returned an error {}", e.getMessage());
    } catch (InterruptedException e) {
        logger.error("Interrupted while waiting for prediction {}", e.getMessage());
        Thread.currentThread().interrupt();
    } catch (TimeoutException e) {
        logger.error("Timed out while waiting for prediction");
    }
    return currentScore;
}
Also used : PredictionInput(org.kie.kogito.explainability.model.PredictionInput) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) BendableBigDecimalScore(org.optaplanner.core.api.score.buildin.bendablebigdecimal.BendableBigDecimalScore) List(java.util.List) ExecutionException(java.util.concurrent.ExecutionException) Feature(org.kie.kogito.explainability.model.Feature) TimeoutException(java.util.concurrent.TimeoutException)

Example 18 with PredictionOutput

use of org.kie.kogito.explainability.model.PredictionOutput in project kogito-apps by kiegroup.

the class CounterfactualExplainer method explainAsync.

@Override
public CompletableFuture<CounterfactualResult> explainAsync(Prediction prediction, PredictionProvider model, Consumer<CounterfactualResult> intermediateResultsConsumer) {
    final AtomicLong sequenceId = new AtomicLong(0);
    final CounterfactualPrediction cfPrediction = (CounterfactualPrediction) prediction;
    final UUID executionId = cfPrediction.getExecutionId();
    final Long maxRunningTimeSeconds = cfPrediction.getMaxRunningTimeSeconds();
    final List<CounterfactualEntity> entities = CounterfactualEntityFactory.createEntities(prediction.getInput());
    final List<Output> goal = prediction.getOutput().getOutputs();
    // Original features kept as structural reference to re-assemble composite features
    final List<Feature> originalFeatures = prediction.getInput().getFeatures();
    Function<UUID, CounterfactualSolution> initial = uuid -> new CounterfactualSolution(entities, originalFeatures, model, goal, UUID.randomUUID(), executionId, this.counterfactualConfig.getGoalThreshold());
    final CompletableFuture<CounterfactualSolution> cfSolution = CompletableFuture.supplyAsync(() -> {
        SolverConfig solverConfig = this.counterfactualConfig.getSolverConfig();
        if (Objects.nonNull(maxRunningTimeSeconds)) {
            solverConfig.withTerminationSpentLimit(Duration.ofSeconds(maxRunningTimeSeconds));
        }
        try (SolverManager<CounterfactualSolution, UUID> solverManager = this.counterfactualConfig.getSolverManagerFactory().apply(solverConfig)) {
            SolverJob<CounterfactualSolution, UUID> solverJob = solverManager.solveAndListen(executionId, initial, assignSolutionId.andThen(createSolutionConsumer(intermediateResultsConsumer, sequenceId)), null);
            try {
                // Wait until the solving ends
                return solverJob.getFinalBestSolution();
            } catch (ExecutionException e) {
                logger.error("Solving failed: {}", e.getMessage());
                throw new IllegalStateException("Prediction returned an error", e);
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                throw new IllegalStateException("Solving failed (Thread interrupted)", e);
            }
        }
    }, this.counterfactualConfig.getExecutor());
    final CompletableFuture<List<PredictionOutput>> cfOutputs = cfSolution.thenCompose(s -> model.predictAsync(buildInput(s.getEntities())));
    return CompletableFuture.allOf(cfOutputs, cfSolution).thenApply(v -> {
        CounterfactualSolution solution = cfSolution.join();
        return new CounterfactualResult(solution.getEntities(), solution.getOriginalFeatures(), cfOutputs.join(), solution.getScore().isFeasible(), UUID.randomUUID(), solution.getExecutionId(), sequenceId.incrementAndGet());
    });
}
Also used : SolverConfig(org.optaplanner.core.config.solver.SolverConfig) Feature(org.kie.kogito.explainability.model.Feature) Prediction(org.kie.kogito.explainability.model.Prediction) LoggerFactory(org.slf4j.LoggerFactory) CompletableFuture(java.util.concurrent.CompletableFuture) CounterfactualEntity(org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity) SolverManager(org.optaplanner.core.api.solver.SolverManager) Function(java.util.function.Function) CompositeFeatureUtils(org.kie.kogito.explainability.utils.CompositeFeatureUtils) Duration(java.time.Duration) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Logger(org.slf4j.Logger) Executor(java.util.concurrent.Executor) LocalExplainer(org.kie.kogito.explainability.local.LocalExplainer) UUID(java.util.UUID) Collectors(java.util.stream.Collectors) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Objects(java.util.Objects) ExecutionException(java.util.concurrent.ExecutionException) Consumer(java.util.function.Consumer) AtomicLong(java.util.concurrent.atomic.AtomicLong) CounterfactualEntityFactory(org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntityFactory) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) List(java.util.List) Output(org.kie.kogito.explainability.model.Output) SolverJob(org.optaplanner.core.api.solver.SolverJob) Feature(org.kie.kogito.explainability.model.Feature) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) CounterfactualEntity(org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity) AtomicLong(java.util.concurrent.atomic.AtomicLong) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) AtomicLong(java.util.concurrent.atomic.AtomicLong) List(java.util.List) UUID(java.util.UUID) ExecutionException(java.util.concurrent.ExecutionException) SolverConfig(org.optaplanner.core.config.solver.SolverConfig)

Example 19 with PredictionOutput

use of org.kie.kogito.explainability.model.PredictionOutput in project kogito-apps by kiegroup.

the class ShapKernelExplainer method explain.

/**
 * Compute the shap values for a specific prediction
 *
 * @param prediction: The ShapPrediction to be explained.
 * @param model: The PredictionProvider we are explaining.
 *
 * @return the shap values for this prediction, of shape [n_model_outputs x n_features]
 */
private CompletableFuture<ShapResults> explain(Prediction prediction, PredictionProvider model) {
    ShapDataCarrier sdc = this.initialize(model);
    sdc.setSamplesAdded(new ArrayList<>());
    PredictionInput pi = prediction.getInput();
    PredictionOutput po = prediction.getOutput();
    if (pi.getFeatures().size() != sdc.getCols()) {
        throw new IllegalArgumentException(String.format("Prediction input feature count (%d) does not match background data feature count (%d)", pi.getFeatures().size(), sdc.getCols()));
    }
    int cols = sdc.getCols();
    CompletableFuture<RealMatrix> output = sdc.getOutputSize().thenApply(os -> {
        if (po.getOutputs().size() != os) {
            throw new IllegalArgumentException(String.format("Prediction output size (%d) does not match background data output size (%d)", po.getOutputs().size(), os));
        }
        return MatrixUtils.createRealMatrix(new double[os][cols]);
    });
    RealVector poVector = MatrixUtilsExtensions.vectorFromPredictionOutput(po);
    // first find varying features
    this.setVaryingFeatureGroups(pi, sdc);
    // if no features vary, then the features do not effect output, and all shap values are zero.
    if (sdc.getNumVarying() == 0) {
        return output.thenApply(o -> saliencyFromMatrix(o, pi, po)).thenCombine(sdc.getFnull(), ShapResults::new);
    } else if (sdc.getNumVarying() == 1) // if 1 feature varies, this feature has all the effect
    {
        CompletableFuture<RealVector> diff = sdc.getLinkNull().thenApply(poVector::subtract);
        return output.thenCompose(o -> diff.thenCombine(sdc.getOutputSize(), (df, os) -> {
            RealMatrix out = MatrixUtils.createRealMatrix(new double[os][cols]);
            for (int i = 0; i < os; i++) {
                out.setEntry(i, sdc.getVaryingFeatureGroups(0), df.getEntry(i));
            }
            return saliencyFromMatrix(out, pi, po);
        })).thenCombine(sdc.getFnull(), ShapResults::new);
    } else // if more than 1 feature varies, we need to perform WLR
    {
        // establish sizes of feature permutations (called subsets)
        ShapStatistics shapStats = this.computeSubsetStatistics(sdc);
        // weight each subset by number of features
        this.initializeWeights(shapStats, sdc);
        // add all fully enumerated subsets
        this.addCompleteSubsets(shapStats, pi, sdc);
        // renormalize weights after full subsets have been added
        this.renormalizeWeights(shapStats);
        // sample non-fully enumerated subsets
        this.addNonCompleteSubsets(shapStats, pi, sdc);
        // run the synthetic data generated through the model
        CompletableFuture<RealMatrix> expectations = this.runSyntheticData(sdc);
        // run the wlr model over the synthetic data results
        return output.thenCompose(o -> this.solveSystem(expectations, poVector, sdc).thenApply(wo -> saliencyFromMatrix(wo[0], wo[1], pi, po))).thenCombine(sdc.getFnull(), ShapResults::new);
    }
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) LarsPath(org.kie.kogito.explainability.utils.LarsPath) Prediction(org.kie.kogito.explainability.model.Prediction) LoggerFactory(org.slf4j.LoggerFactory) HashMap(java.util.HashMap) CompletableFuture(java.util.concurrent.CompletableFuture) RealVector(org.apache.commons.math3.linear.RealVector) WeightedLinearRegression(org.kie.kogito.explainability.utils.WeightedLinearRegression) Saliency(org.kie.kogito.explainability.model.Saliency) ArrayList(java.util.ArrayList) MathArithmeticException(org.apache.commons.math3.exception.MathArithmeticException) MatrixUtils(org.apache.commons.math3.linear.MatrixUtils) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) LassoLarsIC(org.kie.kogito.explainability.utils.LassoLarsIC) CombinatoricsUtils(org.apache.commons.math3.util.CombinatoricsUtils) Logger(org.slf4j.Logger) Iterator(java.util.Iterator) LocalExplainer(org.kie.kogito.explainability.local.LocalExplainer) AnyMatrix(org.apache.commons.math3.linear.AnyMatrix) WeightedLinearRegressionResults(org.kie.kogito.explainability.utils.WeightedLinearRegressionResults) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) Collectors(java.util.stream.Collectors) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Consumer(java.util.function.Consumer) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) List(java.util.List) MatrixUtilsExtensions(org.kie.kogito.explainability.utils.MatrixUtilsExtensions) RandomChoice(org.kie.kogito.explainability.utils.RandomChoice) RealMatrix(org.apache.commons.math3.linear.RealMatrix) Collections(java.util.Collections) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) CompletableFuture(java.util.concurrent.CompletableFuture) RealMatrix(org.apache.commons.math3.linear.RealMatrix) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) RealVector(org.apache.commons.math3.linear.RealVector)

Example 20 with PredictionOutput

use of org.kie.kogito.explainability.model.PredictionOutput in project kogito-apps by kiegroup.

the class LimeExplainer method prepareInputs.

/**
 * Check the perturbed inputs so that the dataset of perturbed input / outputs contains more than just one output
 * class, otherwise it would be impossible to linearly separate it, and hence learn meaningful weights to be used as
 * feature importance scores.
 * The check can be {@code strict} or not, if so it will throw a {@code DatasetNotSeparableException} when the dataset
 * for a given output is not separable.
 */
private LimeInputs prepareInputs(List<PredictionInput> perturbedInputs, List<PredictionOutput> perturbedOutputs, List<Feature> linearizedTargetInputFeatures, int o, Output currentOutput, boolean strict) {
    if (currentOutput.getValue() != null && currentOutput.getValue().getUnderlyingObject() != null) {
        Map<Double, Long> rawClassesBalance;
        // calculate the no. of samples belonging to each output class
        Value fv = currentOutput.getValue();
        rawClassesBalance = getClassBalance(perturbedOutputs, fv, o);
        Long max = rawClassesBalance.values().stream().max(Long::compareTo).orElse(1L);
        double separationRatio = (double) max / (double) perturbedInputs.size();
        List<Output> outputs = perturbedOutputs.stream().map(po -> po.getOutputs().get(o)).collect(Collectors.toList());
        boolean classification = rawClassesBalance.size() == 2;
        if (strict) {
            // check if the dataset is separable and also if the linear model should fit a regressor or a classifier
            if (rawClassesBalance.size() > 1 && separationRatio < limeConfig.getSeparableDatasetRatio()) {
                // if dataset creation process succeeds use it to train the linear model
                return new LimeInputs(classification, linearizedTargetInputFeatures, currentOutput, perturbedInputs, outputs);
            } else {
                throw new DatasetNotSeparableException(currentOutput, rawClassesBalance);
            }
        } else {
            LOGGER.warn("Using an hardly separable dataset for output '{}' of type '{}' with value '{}' ({})", currentOutput.getName(), currentOutput.getType(), currentOutput.getValue(), rawClassesBalance);
            return new LimeInputs(classification, linearizedTargetInputFeatures, currentOutput, perturbedInputs, outputs);
        }
    } else {
        return new LimeInputs(false, linearizedTargetInputFeatures, currentOutput, emptyList(), emptyList());
    }
}
Also used : Arrays(java.util.Arrays) PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) Feature(org.kie.kogito.explainability.model.Feature) Prediction(org.kie.kogito.explainability.model.Prediction) CompletableFuture.completedFuture(java.util.concurrent.CompletableFuture.completedFuture) LoggerFactory(org.slf4j.LoggerFactory) HashMap(java.util.HashMap) CompletableFuture(java.util.concurrent.CompletableFuture) Value(org.kie.kogito.explainability.model.Value) DataDistribution(org.kie.kogito.explainability.model.DataDistribution) Saliency(org.kie.kogito.explainability.model.Saliency) ArrayList(java.util.ArrayList) LinearModel(org.kie.kogito.explainability.utils.LinearModel) Pair(org.apache.commons.lang3.tuple.Pair) Map(java.util.Map) FeatureDistribution(org.kie.kogito.explainability.model.FeatureDistribution) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) DataUtils(org.kie.kogito.explainability.utils.DataUtils) Logger(org.slf4j.Logger) LocalExplainer(org.kie.kogito.explainability.local.LocalExplainer) Collections.emptyList(java.util.Collections.emptyList) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) LocalExplanationException(org.kie.kogito.explainability.local.LocalExplanationException) Collectors(java.util.stream.Collectors) Type(org.kie.kogito.explainability.model.Type) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Objects(java.util.Objects) Consumer(java.util.function.Consumer) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) List(java.util.List) Output(org.kie.kogito.explainability.model.Output) Optional(java.util.Optional) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) Value(org.kie.kogito.explainability.model.Value)

Aggregations

PredictionOutput (org.kie.kogito.explainability.model.PredictionOutput)155 PredictionInput (org.kie.kogito.explainability.model.PredictionInput)137 PredictionProvider (org.kie.kogito.explainability.model.PredictionProvider)124 Prediction (org.kie.kogito.explainability.model.Prediction)122 Random (java.util.Random)90 Test (org.junit.jupiter.api.Test)90 SimplePrediction (org.kie.kogito.explainability.model.SimplePrediction)89 Feature (org.kie.kogito.explainability.model.Feature)80 ArrayList (java.util.ArrayList)74 Output (org.kie.kogito.explainability.model.Output)65 PerturbationContext (org.kie.kogito.explainability.model.PerturbationContext)65 ParameterizedTest (org.junit.jupiter.params.ParameterizedTest)55 LimeConfig (org.kie.kogito.explainability.local.lime.LimeConfig)52 LimeExplainer (org.kie.kogito.explainability.local.lime.LimeExplainer)50 Saliency (org.kie.kogito.explainability.model.Saliency)48 Value (org.kie.kogito.explainability.model.Value)47 LinkedList (java.util.LinkedList)37 List (java.util.List)36 LimeConfigOptimizer (org.kie.kogito.explainability.local.lime.optim.LimeConfigOptimizer)33 ValueSource (org.junit.jupiter.params.provider.ValueSource)32