use of org.kie.kogito.explainability.model.PerturbationContext in project kogito-apps by kiegroup.
the class LimeExplainerTest method testNormalizedWeights.
@Test
void testNormalizedWeights() throws InterruptedException, ExecutionException, TimeoutException {
Random random = new Random();
LimeConfig limeConfig = new LimeConfig().withNormalizeWeights(true).withPerturbationContext(new PerturbationContext(4L, random, 2)).withSamples(10);
LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
int nf = 4;
List<Feature> features = new ArrayList<>();
for (int i = 0; i < nf; i++) {
features.add(TestUtils.getMockedNumericFeature(i));
}
PredictionInput input = new PredictionInput(features);
PredictionProvider model = TestUtils.getSumSkipModel(0);
PredictionOutput output = model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit()).get(0);
Prediction prediction = new SimplePrediction(input, output);
Map<String, Saliency> saliencyMap = limeExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
assertThat(saliencyMap).isNotNull();
String decisionName = "sum-but0";
Saliency saliency = saliencyMap.get(decisionName);
List<FeatureImportance> perFeatureImportance = saliency.getPerFeatureImportance();
for (FeatureImportance featureImportance : perFeatureImportance) {
assertThat(featureImportance.getScore()).isBetween(0d, 1d);
}
}
use of org.kie.kogito.explainability.model.PerturbationContext in project kogito-apps by kiegroup.
the class LimeStabilityTest method testStabilityWithTextData.
@ParameterizedTest
@ValueSource(longs = { 0 })
void testStabilityWithTextData(long seed) throws Exception {
Random random = new Random();
PredictionProvider sumSkipModel = TestUtils.getDummyTextClassifier();
List<Feature> featureList = new LinkedList<>();
for (int i = 0; i < 4; i++) {
featureList.add(TestUtils.getMockedTextFeature("foo " + i));
}
featureList.add(TestUtils.getMockedTextFeature("money"));
LimeConfig limeConfig = new LimeConfig().withSamples(10).withPerturbationContext(new PerturbationContext(seed, random, 1));
LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
assertStable(limeExplainer, sumSkipModel, featureList);
}
use of org.kie.kogito.explainability.model.PerturbationContext in project kogito-apps by kiegroup.
the class LimeStabilityTest method testAdaptiveVariance.
@ParameterizedTest
@ValueSource(longs = { 0 })
void testAdaptiveVariance(long seed) throws Exception {
Random random = new Random();
PerturbationContext perturbationContext = new PerturbationContext(seed, random, 1);
int samples = 1;
int retries = 4;
LimeConfig limeConfig = new LimeConfig().withSamples(samples).withPerturbationContext(perturbationContext).withRetries(retries).withAdaptiveVariance(true);
LimeExplainer adaptiveVarianceLE = new LimeExplainer(limeConfig);
List<Feature> features = new LinkedList<>();
for (int i = 0; i < 4; i++) {
features.add(FeatureFactory.newNumericalFeature("f-" + i, 2));
}
PredictionProvider model = TestUtils.getEvenSumModel(0);
assertStable(adaptiveVarianceLE, model, features);
}
use of org.kie.kogito.explainability.model.PerturbationContext in project kogito-apps by kiegroup.
the class LimeStabilityTest method testStabilityWithNumericData.
@ParameterizedTest
@ValueSource(longs = { 0 })
void testStabilityWithNumericData(long seed) throws Exception {
Random random = new Random();
PredictionProvider sumSkipModel = TestUtils.getSumSkipModel(0);
List<Feature> featureList = new LinkedList<>();
for (int i = 0; i < 5; i++) {
featureList.add(TestUtils.getMockedNumericFeature(i));
}
LimeConfig limeConfig = new LimeConfig().withSamples(10).withPerturbationContext(new PerturbationContext(seed, random, 1));
LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
assertStable(limeExplainer, sumSkipModel, featureList);
}
use of org.kie.kogito.explainability.model.PerturbationContext in project kogito-apps by kiegroup.
the class LimeExplainer method getPerturbedInputs.
private List<PredictionInput> getPerturbedInputs(List<Feature> features, LimeConfig executionConfig, PredictionProvider predictionProvider) {
List<PredictionInput> perturbedInputs = new ArrayList<>();
int size = executionConfig.getNoOfSamples();
DataDistribution dataDistribution = executionConfig.getDataDistribution();
Map<String, FeatureDistribution> featureDistributionsMap;
PerturbationContext perturbationContext = executionConfig.getPerturbationContext();
if (!dataDistribution.isEmpty()) {
Map<String, HighScoreNumericFeatureZones> numericFeatureZonesMap;
int max = executionConfig.getBoostrapInputs();
if (executionConfig.isHighScoreFeatureZones()) {
numericFeatureZonesMap = HighScoreNumericFeatureZonesProvider.getHighScoreFeatureZones(dataDistribution, predictionProvider, features, max);
} else {
numericFeatureZonesMap = new HashMap<>();
}
// generate feature distributions, if possible
featureDistributionsMap = DataUtils.boostrapFeatureDistributions(dataDistribution, perturbationContext, 2 * size, 1, Math.min(size, max), numericFeatureZonesMap);
} else {
featureDistributionsMap = new HashMap<>();
}
for (int i = 0; i < size; i++) {
List<Feature> newFeatures = DataUtils.perturbFeatures(features, perturbationContext, featureDistributionsMap);
perturbedInputs.add(new PredictionInput(newFeatures));
}
return perturbedInputs;
}
Aggregations