Search in sources :

Example 51 with PerturbationContext

use of org.kie.kogito.explainability.model.PerturbationContext in project kogito-apps by kiegroup.

the class LimeExplainerTest method testNormalizedWeights.

@Test
void testNormalizedWeights() throws InterruptedException, ExecutionException, TimeoutException {
    Random random = new Random();
    LimeConfig limeConfig = new LimeConfig().withNormalizeWeights(true).withPerturbationContext(new PerturbationContext(4L, random, 2)).withSamples(10);
    LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
    int nf = 4;
    List<Feature> features = new ArrayList<>();
    for (int i = 0; i < nf; i++) {
        features.add(TestUtils.getMockedNumericFeature(i));
    }
    PredictionInput input = new PredictionInput(features);
    PredictionProvider model = TestUtils.getSumSkipModel(0);
    PredictionOutput output = model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit()).get(0);
    Prediction prediction = new SimplePrediction(input, output);
    Map<String, Saliency> saliencyMap = limeExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    assertThat(saliencyMap).isNotNull();
    String decisionName = "sum-but0";
    Saliency saliency = saliencyMap.get(decisionName);
    List<FeatureImportance> perFeatureImportance = saliency.getPerFeatureImportance();
    for (FeatureImportance featureImportance : perFeatureImportance) {
        assertThat(featureImportance.getScore()).isBetween(0d, 1d);
    }
}
Also used : SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) Prediction(org.kie.kogito.explainability.model.Prediction) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) ArrayList(java.util.ArrayList) Saliency(org.kie.kogito.explainability.model.Saliency) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Feature(org.kie.kogito.explainability.model.Feature) Random(java.util.Random) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Test(org.junit.jupiter.api.Test) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Example 52 with PerturbationContext

use of org.kie.kogito.explainability.model.PerturbationContext in project kogito-apps by kiegroup.

the class LimeStabilityTest method testStabilityWithTextData.

@ParameterizedTest
@ValueSource(longs = { 0 })
void testStabilityWithTextData(long seed) throws Exception {
    Random random = new Random();
    PredictionProvider sumSkipModel = TestUtils.getDummyTextClassifier();
    List<Feature> featureList = new LinkedList<>();
    for (int i = 0; i < 4; i++) {
        featureList.add(TestUtils.getMockedTextFeature("foo " + i));
    }
    featureList.add(TestUtils.getMockedTextFeature("money"));
    LimeConfig limeConfig = new LimeConfig().withSamples(10).withPerturbationContext(new PerturbationContext(seed, random, 1));
    LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
    assertStable(limeExplainer, sumSkipModel, featureList);
}
Also used : PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) Random(java.util.Random) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Feature(org.kie.kogito.explainability.model.Feature) LinkedList(java.util.LinkedList) ValueSource(org.junit.jupiter.params.provider.ValueSource) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Example 53 with PerturbationContext

use of org.kie.kogito.explainability.model.PerturbationContext in project kogito-apps by kiegroup.

the class LimeStabilityTest method testAdaptiveVariance.

@ParameterizedTest
@ValueSource(longs = { 0 })
void testAdaptiveVariance(long seed) throws Exception {
    Random random = new Random();
    PerturbationContext perturbationContext = new PerturbationContext(seed, random, 1);
    int samples = 1;
    int retries = 4;
    LimeConfig limeConfig = new LimeConfig().withSamples(samples).withPerturbationContext(perturbationContext).withRetries(retries).withAdaptiveVariance(true);
    LimeExplainer adaptiveVarianceLE = new LimeExplainer(limeConfig);
    List<Feature> features = new LinkedList<>();
    for (int i = 0; i < 4; i++) {
        features.add(FeatureFactory.newNumericalFeature("f-" + i, 2));
    }
    PredictionProvider model = TestUtils.getEvenSumModel(0);
    assertStable(adaptiveVarianceLE, model, features);
}
Also used : PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) Random(java.util.Random) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Feature(org.kie.kogito.explainability.model.Feature) LinkedList(java.util.LinkedList) ValueSource(org.junit.jupiter.params.provider.ValueSource) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Example 54 with PerturbationContext

use of org.kie.kogito.explainability.model.PerturbationContext in project kogito-apps by kiegroup.

the class LimeStabilityTest method testStabilityWithNumericData.

@ParameterizedTest
@ValueSource(longs = { 0 })
void testStabilityWithNumericData(long seed) throws Exception {
    Random random = new Random();
    PredictionProvider sumSkipModel = TestUtils.getSumSkipModel(0);
    List<Feature> featureList = new LinkedList<>();
    for (int i = 0; i < 5; i++) {
        featureList.add(TestUtils.getMockedNumericFeature(i));
    }
    LimeConfig limeConfig = new LimeConfig().withSamples(10).withPerturbationContext(new PerturbationContext(seed, random, 1));
    LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
    assertStable(limeExplainer, sumSkipModel, featureList);
}
Also used : PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) Random(java.util.Random) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Feature(org.kie.kogito.explainability.model.Feature) LinkedList(java.util.LinkedList) ValueSource(org.junit.jupiter.params.provider.ValueSource) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Example 55 with PerturbationContext

use of org.kie.kogito.explainability.model.PerturbationContext in project kogito-apps by kiegroup.

the class LimeExplainer method getPerturbedInputs.

private List<PredictionInput> getPerturbedInputs(List<Feature> features, LimeConfig executionConfig, PredictionProvider predictionProvider) {
    List<PredictionInput> perturbedInputs = new ArrayList<>();
    int size = executionConfig.getNoOfSamples();
    DataDistribution dataDistribution = executionConfig.getDataDistribution();
    Map<String, FeatureDistribution> featureDistributionsMap;
    PerturbationContext perturbationContext = executionConfig.getPerturbationContext();
    if (!dataDistribution.isEmpty()) {
        Map<String, HighScoreNumericFeatureZones> numericFeatureZonesMap;
        int max = executionConfig.getBoostrapInputs();
        if (executionConfig.isHighScoreFeatureZones()) {
            numericFeatureZonesMap = HighScoreNumericFeatureZonesProvider.getHighScoreFeatureZones(dataDistribution, predictionProvider, features, max);
        } else {
            numericFeatureZonesMap = new HashMap<>();
        }
        // generate feature distributions, if possible
        featureDistributionsMap = DataUtils.boostrapFeatureDistributions(dataDistribution, perturbationContext, 2 * size, 1, Math.min(size, max), numericFeatureZonesMap);
    } else {
        featureDistributionsMap = new HashMap<>();
    }
    for (int i = 0; i < size; i++) {
        List<Feature> newFeatures = DataUtils.perturbFeatures(features, perturbationContext, featureDistributionsMap);
        perturbedInputs.add(new PredictionInput(newFeatures));
    }
    return perturbedInputs;
}
Also used : PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) ArrayList(java.util.ArrayList) Feature(org.kie.kogito.explainability.model.Feature) FeatureDistribution(org.kie.kogito.explainability.model.FeatureDistribution) DataDistribution(org.kie.kogito.explainability.model.DataDistribution)

Aggregations

PerturbationContext (org.kie.kogito.explainability.model.PerturbationContext)73 Random (java.util.Random)64 PredictionProvider (org.kie.kogito.explainability.model.PredictionProvider)61 PredictionInput (org.kie.kogito.explainability.model.PredictionInput)59 Prediction (org.kie.kogito.explainability.model.Prediction)58 PredictionOutput (org.kie.kogito.explainability.model.PredictionOutput)58 SimplePrediction (org.kie.kogito.explainability.model.SimplePrediction)57 Test (org.junit.jupiter.api.Test)46 LimeConfig (org.kie.kogito.explainability.local.lime.LimeConfig)45 LimeExplainer (org.kie.kogito.explainability.local.lime.LimeExplainer)33 Feature (org.kie.kogito.explainability.model.Feature)30 LimeConfigOptimizer (org.kie.kogito.explainability.local.lime.optim.LimeConfigOptimizer)28 ArrayList (java.util.ArrayList)27 Saliency (org.kie.kogito.explainability.model.Saliency)25 ParameterizedTest (org.junit.jupiter.params.ParameterizedTest)24 DataDistribution (org.kie.kogito.explainability.model.DataDistribution)24 PredictionInputsDataDistribution (org.kie.kogito.explainability.model.PredictionInputsDataDistribution)20 ValueSource (org.junit.jupiter.params.provider.ValueSource)17 LinkedList (java.util.LinkedList)16 FeatureImportance (org.kie.kogito.explainability.model.FeatureImportance)12