Search in sources :

Example 56 with Feature

use of org.kie.kogito.explainability.model.Feature in project kogito-apps by kiegroup.

the class DatasetEncoder method getColumnData.

private List<List<double[]>> getColumnData(List<PredictionInput> perturbedInputs, EncodingParams params) {
    List<List<double[]>> columnData = new LinkedList<>();
    for (int t = 0; t < targetInputFeatures.size(); t++) {
        Feature targetFeature = targetInputFeatures.get(t);
        int finalT = t;
        // encode all inputs with respect to the target, based on their type
        List<double[]> encode = targetFeature.getType().encode(params, targetFeature.getValue(), perturbedInputs.stream().map(predictionInput -> predictionInput.getFeatures().get(finalT).getValue()).toArray(Value[]::new));
        columnData.add(encode);
    }
    return columnData;
}
Also used : Value(org.kie.kogito.explainability.model.Value) List(java.util.List) LinkedList(java.util.LinkedList) Feature(org.kie.kogito.explainability.model.Feature) LinkedList(java.util.LinkedList)

Example 57 with Feature

use of org.kie.kogito.explainability.model.Feature in project kogito-apps by kiegroup.

the class HighScoreNumericFeatureZonesProvider method getHighScoreFeatureZones.

/**
 * Get a map of feature-name -> high score feature zones. Predictions in data distribution are sorted by (descending)
 * score, then the (aggregated) mean score is calculated and all the data points that are associated with a prediction
 * having a score between the mean and the maximum are selected (feature-wise), with an associated tolerance
 * (the stdDev of the high score feature points).
 *
 * @param dataDistribution a data distribution
 * @param predictionProvider the model used to score the inputs
 * @param features the list of features to associate high score points with
 * @param maxNoOfSamples max no. of inputs used for discovering high score zones
 * @return a map feature name -> high score numeric feature zones
 */
public static Map<String, HighScoreNumericFeatureZones> getHighScoreFeatureZones(DataDistribution dataDistribution, PredictionProvider predictionProvider, List<Feature> features, int maxNoOfSamples) {
    Map<String, HighScoreNumericFeatureZones> numericFeatureZonesMap = new HashMap<>();
    List<Prediction> scoreSortedPredictions = new ArrayList<>();
    try {
        scoreSortedPredictions.addAll(DataUtils.getScoreSortedPredictions(predictionProvider, new PredictionInputsDataDistribution(dataDistribution.sample(maxNoOfSamples))));
    } catch (ExecutionException e) {
        LOGGER.error("Could not sort predictions by score {}", e.getMessage());
    } catch (InterruptedException e) {
        LOGGER.error("Interrupted while waiting for sorting predictions by score {}", e.getMessage());
        Thread.currentThread().interrupt();
    } catch (TimeoutException e) {
        LOGGER.error("Timed out while waiting for sorting predictions by score", e);
    }
    if (!scoreSortedPredictions.isEmpty()) {
        // calculate min, max and mean scores
        double max = scoreSortedPredictions.get(0).getOutput().getOutputs().stream().mapToDouble(Output::getScore).sum();
        double min = scoreSortedPredictions.get(scoreSortedPredictions.size() - 1).getOutput().getOutputs().stream().mapToDouble(Output::getScore).sum();
        if (max != min) {
            double threshold = scoreSortedPredictions.stream().map(p -> p.getOutput().getOutputs().stream().mapToDouble(Output::getScore).sum()).mapToDouble(d -> d).average().orElse((max + min) / 2);
            // filter out predictions whose score is in [min, threshold]
            scoreSortedPredictions = scoreSortedPredictions.stream().filter(p -> p.getOutput().getOutputs().stream().mapToDouble(Output::getScore).sum() > threshold).collect(Collectors.toList());
            for (int j = 0; j < features.size(); j++) {
                Feature feature = features.get(j);
                if (Type.NUMBER.equals(feature.getType())) {
                    int finalJ = j;
                    // get feature values associated with high score inputs
                    List<Double> topValues = scoreSortedPredictions.stream().map(prediction -> prediction.getInput().getFeatures().get(finalJ).getValue().asNumber()).distinct().collect(Collectors.toList());
                    // get high score points and tolerance
                    double[] highScoreFeaturePoints = topValues.stream().flatMapToDouble(DoubleStream::of).toArray();
                    double center = DataUtils.getMean(highScoreFeaturePoints);
                    double tolerance = DataUtils.getStdDev(highScoreFeaturePoints, center) / 2;
                    HighScoreNumericFeatureZones highScoreNumericFeatureZones = new HighScoreNumericFeatureZones(highScoreFeaturePoints, tolerance);
                    numericFeatureZonesMap.put(feature.getName(), highScoreNumericFeatureZones);
                }
            }
        }
    }
    return numericFeatureZonesMap;
}
Also used : DataUtils(org.kie.kogito.explainability.utils.DataUtils) Logger(org.slf4j.Logger) PredictionInputsDataDistribution(org.kie.kogito.explainability.model.PredictionInputsDataDistribution) PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) Feature(org.kie.kogito.explainability.model.Feature) Prediction(org.kie.kogito.explainability.model.Prediction) LoggerFactory(org.slf4j.LoggerFactory) TimeoutException(java.util.concurrent.TimeoutException) HashMap(java.util.HashMap) Collectors(java.util.stream.Collectors) DataDistribution(org.kie.kogito.explainability.model.DataDistribution) Type(org.kie.kogito.explainability.model.Type) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) ArrayList(java.util.ArrayList) DoubleStream(java.util.stream.DoubleStream) ExecutionException(java.util.concurrent.ExecutionException) List(java.util.List) Map(java.util.Map) Output(org.kie.kogito.explainability.model.Output) HashMap(java.util.HashMap) Prediction(org.kie.kogito.explainability.model.Prediction) ArrayList(java.util.ArrayList) Feature(org.kie.kogito.explainability.model.Feature) ExecutionException(java.util.concurrent.ExecutionException) PredictionInputsDataDistribution(org.kie.kogito.explainability.model.PredictionInputsDataDistribution) TimeoutException(java.util.concurrent.TimeoutException)

Example 58 with Feature

use of org.kie.kogito.explainability.model.Feature in project kogito-apps by kiegroup.

the class LimeExplainer method prepareInputs.

/**
 * Check the perturbed inputs so that the dataset of perturbed input / outputs contains more than just one output
 * class, otherwise it would be impossible to linearly separate it, and hence learn meaningful weights to be used as
 * feature importance scores.
 * The check can be {@code strict} or not, if so it will throw a {@code DatasetNotSeparableException} when the dataset
 * for a given output is not separable.
 */
private LimeInputs prepareInputs(List<PredictionInput> perturbedInputs, List<PredictionOutput> perturbedOutputs, List<Feature> linearizedTargetInputFeatures, int o, Output currentOutput, boolean strict) {
    if (currentOutput.getValue() != null && currentOutput.getValue().getUnderlyingObject() != null) {
        Map<Double, Long> rawClassesBalance;
        // calculate the no. of samples belonging to each output class
        Value fv = currentOutput.getValue();
        rawClassesBalance = getClassBalance(perturbedOutputs, fv, o);
        Long max = rawClassesBalance.values().stream().max(Long::compareTo).orElse(1L);
        double separationRatio = (double) max / (double) perturbedInputs.size();
        List<Output> outputs = perturbedOutputs.stream().map(po -> po.getOutputs().get(o)).collect(Collectors.toList());
        boolean classification = rawClassesBalance.size() == 2;
        if (strict) {
            // check if the dataset is separable and also if the linear model should fit a regressor or a classifier
            if (rawClassesBalance.size() > 1 && separationRatio < limeConfig.getSeparableDatasetRatio()) {
                // if dataset creation process succeeds use it to train the linear model
                return new LimeInputs(classification, linearizedTargetInputFeatures, currentOutput, perturbedInputs, outputs);
            } else {
                throw new DatasetNotSeparableException(currentOutput, rawClassesBalance);
            }
        } else {
            LOGGER.warn("Using an hardly separable dataset for output '{}' of type '{}' with value '{}' ({})", currentOutput.getName(), currentOutput.getType(), currentOutput.getValue(), rawClassesBalance);
            return new LimeInputs(classification, linearizedTargetInputFeatures, currentOutput, perturbedInputs, outputs);
        }
    } else {
        return new LimeInputs(false, linearizedTargetInputFeatures, currentOutput, emptyList(), emptyList());
    }
}
Also used : Arrays(java.util.Arrays) PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) Feature(org.kie.kogito.explainability.model.Feature) Prediction(org.kie.kogito.explainability.model.Prediction) CompletableFuture.completedFuture(java.util.concurrent.CompletableFuture.completedFuture) LoggerFactory(org.slf4j.LoggerFactory) HashMap(java.util.HashMap) CompletableFuture(java.util.concurrent.CompletableFuture) Value(org.kie.kogito.explainability.model.Value) DataDistribution(org.kie.kogito.explainability.model.DataDistribution) Saliency(org.kie.kogito.explainability.model.Saliency) ArrayList(java.util.ArrayList) LinearModel(org.kie.kogito.explainability.utils.LinearModel) Pair(org.apache.commons.lang3.tuple.Pair) Map(java.util.Map) FeatureDistribution(org.kie.kogito.explainability.model.FeatureDistribution) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) DataUtils(org.kie.kogito.explainability.utils.DataUtils) Logger(org.slf4j.Logger) LocalExplainer(org.kie.kogito.explainability.local.LocalExplainer) Collections.emptyList(java.util.Collections.emptyList) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) LocalExplanationException(org.kie.kogito.explainability.local.LocalExplanationException) Collectors(java.util.stream.Collectors) Type(org.kie.kogito.explainability.model.Type) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Objects(java.util.Objects) Consumer(java.util.function.Consumer) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) List(java.util.List) Output(org.kie.kogito.explainability.model.Output) Optional(java.util.Optional) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) Value(org.kie.kogito.explainability.model.Value)

Example 59 with Feature

use of org.kie.kogito.explainability.model.Feature in project kogito-apps by kiegroup.

the class LimeExplainer method getSaliency.

private void getSaliency(List<Feature> linearizedTargetInputFeatures, Map<String, Saliency> result, LimeInputs limeInputs, Output originalOutput, LimeConfig executionConfig) {
    List<FeatureImportance> featureImportanceList = new ArrayList<>();
    // encode the training data so that it can be fed into the linear model
    DatasetEncoder datasetEncoder = new DatasetEncoder(limeInputs.getPerturbedInputs(), limeInputs.getPerturbedOutputs(), linearizedTargetInputFeatures, originalOutput, executionConfig.getEncodingParams());
    List<Pair<double[], Double>> trainingSet = datasetEncoder.getEncodedTrainingSet();
    // weight the training samples based on the proximity to the target input to explain
    double kernelWidth = executionConfig.getProximityKernelWidth() * Math.sqrt(linearizedTargetInputFeatures.size());
    double[] sampleWeights = SampleWeighter.getSampleWeights(linearizedTargetInputFeatures, trainingSet, kernelWidth);
    int ts = linearizedTargetInputFeatures.size();
    double[] featureWeights = new double[ts];
    Arrays.fill(featureWeights, 1);
    if (executionConfig.isPenalizeBalanceSparse()) {
        IndependentSparseFeatureBalanceFilter sparseFeatureBalanceFilter = new IndependentSparseFeatureBalanceFilter();
        sparseFeatureBalanceFilter.apply(featureWeights, linearizedTargetInputFeatures, trainingSet);
    }
    if (executionConfig.isProximityFilter()) {
        ProximityFilter proximityFilter = new ProximityFilter(executionConfig.getProximityThreshold(), executionConfig.getProximityFilteredDatasetMinimum().doubleValue());
        proximityFilter.apply(trainingSet, sampleWeights);
    }
    LinearModel linearModel = new LinearModel(linearizedTargetInputFeatures.size(), limeInputs.isClassification());
    double loss = linearModel.fit(trainingSet, sampleWeights);
    if (!Double.isNaN(loss)) {
        // create the output saliency
        double[] weights = linearModel.getWeights();
        if (limeConfig.isNormalizeWeights() && weights.length > 0) {
            normalizeWeights(weights);
        }
        int i = 0;
        for (Feature linearizedFeature : linearizedTargetInputFeatures) {
            FeatureImportance featureImportance = new FeatureImportance(linearizedFeature, weights[i] * featureWeights[i]);
            featureImportanceList.add(featureImportance);
            i++;
        }
    }
    Saliency saliency = new Saliency(originalOutput, featureImportanceList);
    result.put(originalOutput.getName(), saliency);
}
Also used : ArrayList(java.util.ArrayList) Saliency(org.kie.kogito.explainability.model.Saliency) Feature(org.kie.kogito.explainability.model.Feature) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) LinearModel(org.kie.kogito.explainability.utils.LinearModel) Pair(org.apache.commons.lang3.tuple.Pair)

Example 60 with Feature

use of org.kie.kogito.explainability.model.Feature in project kogito-apps by kiegroup.

the class LimeExplainer method explainAsync.

@Override
public CompletableFuture<Map<String, Saliency>> explainAsync(Prediction prediction, PredictionProvider model, Consumer<Map<String, Saliency>> intermediateResultsConsumer) {
    PredictionInput originalInput = prediction.getInput();
    if (originalInput == null || originalInput.getFeatures() == null || (originalInput.getFeatures() != null && originalInput.getFeatures().isEmpty())) {
        throw new LocalExplanationException("cannot explain a prediction whose input is empty");
    }
    List<PredictionInput> linearizedInputs = DataUtils.linearizeInputs(List.of(originalInput));
    PredictionInput targetInput = linearizedInputs.get(0);
    List<Feature> linearizedTargetInputFeatures = targetInput.getFeatures();
    if (linearizedTargetInputFeatures.isEmpty()) {
        throw new LocalExplanationException("input features linearization failed");
    }
    List<Output> actualOutputs = prediction.getOutput().getOutputs();
    LimeConfig executionConfig = limeConfig.copy();
    return explainWithExecutionConfig(model, originalInput, linearizedTargetInputFeatures, actualOutputs, executionConfig);
}
Also used : PredictionInput(org.kie.kogito.explainability.model.PredictionInput) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) LocalExplanationException(org.kie.kogito.explainability.local.LocalExplanationException) Feature(org.kie.kogito.explainability.model.Feature)

Aggregations

Feature (org.kie.kogito.explainability.model.Feature)233 PredictionOutput (org.kie.kogito.explainability.model.PredictionOutput)118 Test (org.junit.jupiter.api.Test)107 PredictionInput (org.kie.kogito.explainability.model.PredictionInput)107 ParameterizedTest (org.junit.jupiter.params.ParameterizedTest)104 Output (org.kie.kogito.explainability.model.Output)102 ArrayList (java.util.ArrayList)97 Random (java.util.Random)92 PredictionProvider (org.kie.kogito.explainability.model.PredictionProvider)78 Value (org.kie.kogito.explainability.model.Value)74 LinkedList (java.util.LinkedList)72 ValueSource (org.junit.jupiter.params.provider.ValueSource)71 Prediction (org.kie.kogito.explainability.model.Prediction)67 List (java.util.List)51 CounterfactualEntity (org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity)46 PerturbationContext (org.kie.kogito.explainability.model.PerturbationContext)42 Type (org.kie.kogito.explainability.model.Type)39 NumericalFeatureDomain (org.kie.kogito.explainability.model.domain.NumericalFeatureDomain)37 SimplePrediction (org.kie.kogito.explainability.model.SimplePrediction)35 FeatureDomain (org.kie.kogito.explainability.model.domain.FeatureDomain)33