Search in sources :

Example 1 with PerturbationContext

use of org.kie.kogito.explainability.model.PerturbationContext in project kogito-apps by kiegroup.

the class DataUtils method boostrapFeatureDistributions.

/**
 * Generate feature distributions from an existing (evantually small) {@link DataDistribution} for each {@link Feature}.
 * Each feature intervals (min, max) and density information (mean, stdDev) are generated using bootstrap, then
 * data points are sampled from a normal distribution (see {@link #generateData(double, double, int, Random)}).
 *
 * @param dataDistribution data distribution to take feature values from
 * @param perturbationContext perturbation context
 * @param featureDistributionSize desired size of generated feature distributions
 * @param draws number of times sampling from feature values is performed
 * @param sampleSize size of each sample draw
 * @param numericFeatureZonesMap high feature score zones
 * @return a map feature name -> generated feature distribution
 */
public static Map<String, FeatureDistribution> boostrapFeatureDistributions(DataDistribution dataDistribution, PerturbationContext perturbationContext, int featureDistributionSize, int draws, int sampleSize, Map<String, HighScoreNumericFeatureZones> numericFeatureZonesMap) {
    Map<String, FeatureDistribution> featureDistributions = new HashMap<>();
    for (FeatureDistribution featureDistribution : dataDistribution.asFeatureDistributions()) {
        Feature feature = featureDistribution.getFeature();
        if (Type.NUMBER.equals(feature.getType())) {
            List<Value> values = featureDistribution.getAllSamples();
            double[] means = new double[draws];
            double[] stdDevs = new double[draws];
            double[] mins = new double[draws];
            double[] maxs = new double[draws];
            for (int i = 0; i < draws; i++) {
                List<Value> sampledValues = DataUtils.sampleWithReplacement(values, sampleSize, perturbationContext.getRandom());
                double[] data = sampledValues.stream().mapToDouble(Value::asNumber).toArray();
                double mean = DataUtils.getMean(data);
                double stdDev = Math.pow(DataUtils.getStdDev(data, mean), 2);
                double min = Arrays.stream(data).min().orElse(Double.MIN_VALUE);
                double max = Arrays.stream(data).max().orElse(Double.MAX_VALUE);
                means[i] = mean;
                stdDevs[i] = stdDev;
                mins[i] = min;
                maxs[i] = max;
            }
            double finalMean = DataUtils.getMean(means);
            double finalStdDev = Math.sqrt(DataUtils.getMean(stdDevs));
            double finalMin = DataUtils.getMean(mins);
            double finalMax = DataUtils.getMean(maxs);
            double[] doubles = DataUtils.generateData(finalMean, finalStdDev, featureDistributionSize, perturbationContext.getRandom());
            double[] boundedData = Arrays.stream(doubles).map(d -> Math.min(Math.max(d, finalMin), finalMax)).toArray();
            HighScoreNumericFeatureZones highScoreNumericFeatureZones = numericFeatureZonesMap.get(feature.getName());
            double[] finaldata;
            if (highScoreNumericFeatureZones != null) {
                double[] filteredData = DoubleStream.of(boundedData).filter(highScoreNumericFeatureZones::test).toArray();
                // only use the filtered data if it's not discarding more than 50% of the points
                if (filteredData.length > featureDistributionSize / 2) {
                    finaldata = filteredData;
                } else {
                    finaldata = boundedData;
                }
            } else {
                finaldata = boundedData;
            }
            NumericFeatureDistribution numericFeatureDistribution = new NumericFeatureDistribution(feature, finaldata);
            featureDistributions.put(feature.getName(), numericFeatureDistribution);
        }
    }
    return featureDistributions;
}
Also used : IntStream(java.util.stream.IntStream) FeatureFactory(org.kie.kogito.explainability.model.FeatureFactory) Arrays(java.util.Arrays) MalformedInputException(java.nio.charset.MalformedInputException) PredictionInputsDataDistribution(org.kie.kogito.explainability.model.PredictionInputsDataDistribution) PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) Feature(org.kie.kogito.explainability.model.Feature) Prediction(org.kie.kogito.explainability.model.Prediction) CSVRecord(org.apache.commons.csv.CSVRecord) TimeoutException(java.util.concurrent.TimeoutException) HashMap(java.util.HashMap) Random(java.util.Random) Value(org.kie.kogito.explainability.model.Value) DataDistribution(org.kie.kogito.explainability.model.DataDistribution) ArrayList(java.util.ArrayList) CSVFormat(org.apache.commons.csv.CSVFormat) NumericFeatureDistribution(org.kie.kogito.explainability.model.NumericFeatureDistribution) PartialDependenceGraph(org.kie.kogito.explainability.model.PartialDependenceGraph) Map(java.util.Map) FeatureDistribution(org.kie.kogito.explainability.model.FeatureDistribution) LinkedList(java.util.LinkedList) Path(java.nio.file.Path) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) IndependentFeaturesDataDistribution(org.kie.kogito.explainability.model.IndependentFeaturesDataDistribution) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) Files(java.nio.file.Files) IOException(java.io.IOException) Collectors(java.util.stream.Collectors) Type(org.kie.kogito.explainability.model.Type) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) DoubleStream(java.util.stream.DoubleStream) ExecutionException(java.util.concurrent.ExecutionException) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) List(java.util.List) Output(org.kie.kogito.explainability.model.Output) Writer(java.io.Writer) Optional(java.util.Optional) HighScoreNumericFeatureZones(org.kie.kogito.explainability.local.lime.HighScoreNumericFeatureZones) BufferedReader(java.io.BufferedReader) Config(org.kie.kogito.explainability.Config) Collections(java.util.Collections) CSVPrinter(org.apache.commons.csv.CSVPrinter) HashMap(java.util.HashMap) Feature(org.kie.kogito.explainability.model.Feature) HighScoreNumericFeatureZones(org.kie.kogito.explainability.local.lime.HighScoreNumericFeatureZones) NumericFeatureDistribution(org.kie.kogito.explainability.model.NumericFeatureDistribution) FeatureDistribution(org.kie.kogito.explainability.model.FeatureDistribution) Value(org.kie.kogito.explainability.model.Value) NumericFeatureDistribution(org.kie.kogito.explainability.model.NumericFeatureDistribution)

Example 2 with PerturbationContext

use of org.kie.kogito.explainability.model.PerturbationContext in project kogito-apps by kiegroup.

the class LimeExplainer method getNewPerturbationContext.

private PerturbationContext getNewPerturbationContext(List<Feature> linearizedTargetInputFeatures, int noOfRetries, PerturbationContext perturbationContext) {
    PerturbationContext newPerturbationContext;
    int nextPerturbationSize = Math.max(perturbationContext.getNoOfPerturbations() + 1, linearizedTargetInputFeatures.size() / noOfRetries);
    // make sure to stay within the max no. of features boundaries
    nextPerturbationSize = Math.min(linearizedTargetInputFeatures.size() - 1, nextPerturbationSize);
    Optional<Long> optionalSeed = perturbationContext.getSeed();
    if (optionalSeed.isPresent()) {
        Long seed = optionalSeed.get();
        newPerturbationContext = new PerturbationContext(seed, perturbationContext.getRandom(), nextPerturbationSize);
    } else {
        newPerturbationContext = new PerturbationContext(perturbationContext.getRandom(), nextPerturbationSize);
    }
    return newPerturbationContext;
}
Also used : PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext)

Example 3 with PerturbationContext

use of org.kie.kogito.explainability.model.PerturbationContext in project kogito-apps by kiegroup.

the class LimeExplainer method adjustAndRetry.

private CompletableFuture<Map<String, Saliency>> adjustAndRetry(PredictionProvider model, PredictionInput originalInput, List<Feature> linearizedTargetInputFeatures, List<Output> actualOutputs, LimeConfig executionConfig) {
    if (limeConfig.isAdaptDatasetVariance()) {
        PerturbationContext newPerturbationContext = getNewPerturbationContext(linearizedTargetInputFeatures, executionConfig.getNoOfRetries(), executionConfig.getPerturbationContext());
        int newNoOfSamples = executionConfig.getNoOfSamples() + executionConfig.getNoOfSamples() / limeConfig.getNoOfRetries();
        executionConfig = executionConfig.withSamples(newNoOfSamples).withPerturbationContext(newPerturbationContext);
    }
    return explainRetryCycle(model, originalInput, linearizedTargetInputFeatures, actualOutputs, executionConfig.withRetries(executionConfig.getNoOfRetries() - 1));
}
Also used : PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext)

Example 4 with PerturbationContext

use of org.kie.kogito.explainability.model.PerturbationContext in project kogito-apps by kiegroup.

the class JITDMNServiceImpl method evaluateModelAndExplain.

public DMNResultWithExplanation evaluateModelAndExplain(DMNEvaluator dmnEvaluator, Map<String, Object> context) {
    LocalDMNPredictionProvider localDMNPredictionProvider = new LocalDMNPredictionProvider(dmnEvaluator);
    DMNResult dmnResult = dmnEvaluator.evaluate(context);
    Prediction prediction = new SimplePrediction(LocalDMNPredictionProvider.toPredictionInput(context), LocalDMNPredictionProvider.toPredictionOutput(dmnResult));
    LimeConfig limeConfig = new LimeConfig().withSamples(explainabilityLimeSampleSize).withPerturbationContext(new PerturbationContext(new Random(), explainabilityLimeNoOfPerturbation));
    LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
    Map<String, Saliency> saliencyMap;
    try {
        saliencyMap = limeExplainer.explainAsync(prediction, localDMNPredictionProvider).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    } catch (TimeoutException | InterruptedException | ExecutionException e) {
        if (e instanceof InterruptedException) {
            LOGGER.error("Critical InterruptedException occurred", e);
            Thread.currentThread().interrupt();
        }
        return new DMNResultWithExplanation(new JITDMNResult(dmnEvaluator.getNamespace(), dmnEvaluator.getName(), dmnResult), new SalienciesResponse(EXPLAINABILITY_FAILED, EXPLAINABILITY_FAILED_MESSAGE, null));
    }
    List<SaliencyResponse> saliencyModelResponse = buildSalienciesResponse(dmnEvaluator.getDmnModel(), saliencyMap);
    return new DMNResultWithExplanation(new JITDMNResult(dmnEvaluator.getNamespace(), dmnEvaluator.getName(), dmnResult), new SalienciesResponse(EXPLAINABILITY_SUCCEEDED, null, saliencyModelResponse));
}
Also used : SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) SalienciesResponse(org.kie.kogito.trusty.service.common.responses.SalienciesResponse) DMNResult(org.kie.dmn.api.core.DMNResult) JITDMNResult(org.kie.kogito.jitexecutor.dmn.responses.JITDMNResult) PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) SaliencyResponse(org.kie.kogito.trusty.service.common.responses.SaliencyResponse) LimeExplainer(org.kie.kogito.explainability.local.lime.LimeExplainer) Prediction(org.kie.kogito.explainability.model.Prediction) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) DMNResultWithExplanation(org.kie.kogito.jitexecutor.dmn.responses.DMNResultWithExplanation) JITDMNResult(org.kie.kogito.jitexecutor.dmn.responses.JITDMNResult) Saliency(org.kie.kogito.explainability.model.Saliency) LimeConfig(org.kie.kogito.explainability.local.lime.LimeConfig) Random(java.util.Random) ExecutionException(java.util.concurrent.ExecutionException) TimeoutException(java.util.concurrent.TimeoutException)

Example 5 with PerturbationContext

use of org.kie.kogito.explainability.model.PerturbationContext in project kogito-apps by kiegroup.

the class DmnTestUtils method getPredictionInputs.

private static List<PredictionInput> getPredictionInputs(PredictionInput predictionInput) {
    List<PredictionInput> predictionInputs = new ArrayList<>();
    Random random = new Random();
    int noOfPerturbations = predictionInput.getFeatures().size();
    PerturbationContext perturbationContext = new PerturbationContext(4L, random, noOfPerturbations);
    for (int i = 0; i < 100; i++) {
        List<Feature> perturbFeatures = DataUtils.perturbFeatures(predictionInput.getFeatures(), perturbationContext);
        predictionInputs.add(new PredictionInput(perturbFeatures));
    }
    return predictionInputs;
}
Also used : PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) Random(java.util.Random) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) ArrayList(java.util.ArrayList) Feature(org.kie.kogito.explainability.model.Feature)

Aggregations

PerturbationContext (org.kie.kogito.explainability.model.PerturbationContext)73 Random (java.util.Random)64 PredictionProvider (org.kie.kogito.explainability.model.PredictionProvider)61 PredictionInput (org.kie.kogito.explainability.model.PredictionInput)59 Prediction (org.kie.kogito.explainability.model.Prediction)58 PredictionOutput (org.kie.kogito.explainability.model.PredictionOutput)58 SimplePrediction (org.kie.kogito.explainability.model.SimplePrediction)57 Test (org.junit.jupiter.api.Test)46 LimeConfig (org.kie.kogito.explainability.local.lime.LimeConfig)45 LimeExplainer (org.kie.kogito.explainability.local.lime.LimeExplainer)33 Feature (org.kie.kogito.explainability.model.Feature)30 LimeConfigOptimizer (org.kie.kogito.explainability.local.lime.optim.LimeConfigOptimizer)28 ArrayList (java.util.ArrayList)27 Saliency (org.kie.kogito.explainability.model.Saliency)25 ParameterizedTest (org.junit.jupiter.params.ParameterizedTest)24 DataDistribution (org.kie.kogito.explainability.model.DataDistribution)24 PredictionInputsDataDistribution (org.kie.kogito.explainability.model.PredictionInputsDataDistribution)20 ValueSource (org.junit.jupiter.params.provider.ValueSource)17 LinkedList (java.util.LinkedList)16 FeatureImportance (org.kie.kogito.explainability.model.FeatureImportance)12