use of org.kie.kogito.explainability.model.SimplePrediction in project kogito-apps by kiegroup.
the class LimeExplainerTest method testZeroSampleSize.
@Test
void testZeroSampleSize() throws ExecutionException, InterruptedException, TimeoutException {
LimeConfig limeConfig = new LimeConfig().withSamples(0);
LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
List<Feature> features = new ArrayList<>();
for (int i = 0; i < 4; i++) {
features.add(TestUtils.getMockedNumericFeature(i));
}
PredictionInput input = new PredictionInput(features);
PredictionProvider model = TestUtils.getSumSkipModel(0);
PredictionOutput output = model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit()).get(0);
Prediction prediction = new SimplePrediction(input, output);
Map<String, Saliency> saliencyMap = limeExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
assertNotNull(saliencyMap);
}
use of org.kie.kogito.explainability.model.SimplePrediction in project kogito-apps by kiegroup.
the class LimeExplainerTest method testDeterministic.
@ParameterizedTest
@ValueSource(longs = { 0, 1, 2, 3, 4 })
void testDeterministic(long seed) throws ExecutionException, InterruptedException, TimeoutException {
List<Saliency> saliencies = new ArrayList<>();
for (int j = 0; j < 2; j++) {
Random random = new Random();
LimeConfig limeConfig = new LimeConfig().withPerturbationContext(new PerturbationContext(seed, random, DEFAULT_NO_OF_PERTURBATIONS)).withSamples(10);
LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
List<Feature> features = new ArrayList<>();
for (int i = 0; i < 4; i++) {
features.add(TestUtils.getMockedNumericFeature(i));
}
PredictionInput input = new PredictionInput(features);
PredictionProvider model = TestUtils.getSumSkipModel(0);
PredictionOutput output = model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit()).get(0);
Prediction prediction = new SimplePrediction(input, output);
Map<String, Saliency> saliencyMap = limeExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
saliencies.add(saliencyMap.get("sum-but0"));
}
assertThat(saliencies.get(0).getPerFeatureImportance().stream().map(FeatureImportance::getScore).collect(Collectors.toList())).isEqualTo(saliencies.get(1).getPerFeatureImportance().stream().map(FeatureImportance::getScore).collect(Collectors.toList()));
}
use of org.kie.kogito.explainability.model.SimplePrediction in project kogito-apps by kiegroup.
the class LimeExplainerTest method testNormalizedWeights.
@Test
void testNormalizedWeights() throws InterruptedException, ExecutionException, TimeoutException {
Random random = new Random();
LimeConfig limeConfig = new LimeConfig().withNormalizeWeights(true).withPerturbationContext(new PerturbationContext(4L, random, 2)).withSamples(10);
LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
int nf = 4;
List<Feature> features = new ArrayList<>();
for (int i = 0; i < nf; i++) {
features.add(TestUtils.getMockedNumericFeature(i));
}
PredictionInput input = new PredictionInput(features);
PredictionProvider model = TestUtils.getSumSkipModel(0);
PredictionOutput output = model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit()).get(0);
Prediction prediction = new SimplePrediction(input, output);
Map<String, Saliency> saliencyMap = limeExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
assertThat(saliencyMap).isNotNull();
String decisionName = "sum-but0";
Saliency saliency = saliencyMap.get(decisionName);
List<FeatureImportance> perFeatureImportance = saliency.getPerFeatureImportance();
for (FeatureImportance featureImportance : perFeatureImportance) {
assertThat(featureImportance.getScore()).isBetween(0d, 1d);
}
}
use of org.kie.kogito.explainability.model.SimplePrediction in project kogito-apps by kiegroup.
the class LimeStabilityTest method assertStable.
private void assertStable(LimeExplainer limeExplainer, PredictionProvider model, List<Feature> featureList) throws Exception {
PredictionInput input = new PredictionInput(featureList);
List<PredictionOutput> predictionOutputs = model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
for (PredictionOutput predictionOutput : predictionOutputs) {
Prediction prediction = new SimplePrediction(input, predictionOutput);
List<Saliency> saliencies = new LinkedList<>();
for (int i = 0; i < 100; i++) {
Map<String, Saliency> saliencyMap = limeExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
saliencies.addAll(saliencyMap.values());
}
// check that the topmost important feature is stable
List<String> names = new LinkedList<>();
saliencies.stream().map(s -> s.getPositiveFeatures(1)).filter(f -> !f.isEmpty()).forEach(f -> names.add(f.get(0).getFeature().getName()));
Map<String, Long> frequencyMap = names.stream().collect(Collectors.groupingBy(Function.identity(), Collectors.counting()));
boolean topFeature = false;
for (Map.Entry<String, Long> entry : frequencyMap.entrySet()) {
if (entry.getValue() >= TOP_FEATURE_THRESHOLD) {
topFeature = true;
break;
}
}
assertTrue(topFeature);
// check that the impact is stable
List<Double> impacts = new ArrayList<>(saliencies.size());
for (Saliency saliency : saliencies) {
double v = ExplainabilityMetrics.impactScore(model, prediction, saliency.getTopFeatures(2));
impacts.add(v);
}
Map<Double, Long> impactMap = impacts.stream().collect(Collectors.groupingBy(Function.identity(), Collectors.counting()));
boolean topImpact = false;
for (Map.Entry<Double, Long> entry : impactMap.entrySet()) {
if (entry.getValue() >= TOP_FEATURE_THRESHOLD) {
topImpact = true;
break;
}
}
assertTrue(topImpact);
}
}
use of org.kie.kogito.explainability.model.SimplePrediction in project kogito-apps by kiegroup.
the class LimeCombinedScoreCalculatorTest method testNonZeroScore.
@Test
void testNonZeroScore() throws ExecutionException, InterruptedException, TimeoutException {
PredictionProvider model = TestUtils.getDummyTextClassifier();
LimeCombinedScoreCalculator scoreCalculator = new LimeCombinedScoreCalculator();
LimeConfig config = new LimeConfig();
List<Feature> features = List.of(FeatureFactory.newFulltextFeature("text", "money so they say is the root of all evil today"));
PredictionInput input = new PredictionInput(features);
List<PredictionOutput> predictionOutputs = model.predictAsync(List.of(input)).get(Config.DEFAULT_ASYNC_TIMEOUT, Config.DEFAULT_ASYNC_TIMEUNIT);
assertThat(predictionOutputs).isNotNull();
assertThat(predictionOutputs.size()).isEqualTo(1);
PredictionOutput output = predictionOutputs.get(0);
Prediction prediction = new SimplePrediction(input, output);
List<Prediction> predictions = List.of(prediction);
List<LimeConfigEntity> entities = LimeConfigEntityFactory.createEncodingEntities(config);
LimeConfigSolution solution = new LimeConfigSolution(config, predictions, entities, model);
SimpleBigDecimalScore score = scoreCalculator.calculateScore(solution);
assertThat(score).isNotNull();
assertThat(score.getScore()).isNotNull().isNotEqualTo(BigDecimal.valueOf(0));
}
Aggregations