use of org.kie.kogito.explainability.model.Prediction in project kogito-apps by kiegroup.
the class LimeExplainerTest method testEmptyInput.
@Test
void testEmptyInput() {
LimeExplainer recordingLimeExplainer = new LimeExplainer();
PredictionProvider model = mock(PredictionProvider.class);
Prediction prediction = mock(Prediction.class);
assertThatCode(() -> recordingLimeExplainer.explainAsync(prediction, model)).hasMessage("cannot explain a prediction whose input is empty");
}
use of org.kie.kogito.explainability.model.Prediction in project kogito-apps by kiegroup.
the class LimeStabilityTest method testStabilityDeterministic.
@ParameterizedTest
@ValueSource(longs = { 0, 1, 2, 3, 4 })
void testStabilityDeterministic(long seed) throws Exception {
List<LocalSaliencyStability> stabilities = new ArrayList<>();
for (int j = 0; j < 2; j++) {
Random random = new Random();
PredictionProvider model = TestUtils.getSumSkipModel(0);
List<Feature> featureList = new LinkedList<>();
for (int i = 0; i < 5; i++) {
featureList.add(TestUtils.getMockedNumericFeature(i));
}
PredictionInput input = new PredictionInput(featureList);
List<PredictionOutput> predictionOutputs = model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
Prediction prediction = new SimplePrediction(input, predictionOutputs.get(0));
LimeConfig limeConfig = new LimeConfig().withSamples(10).withPerturbationContext(new PerturbationContext(seed, random, 1));
LimeExplainer explainer = new LimeExplainer(limeConfig);
LocalSaliencyStability stability = ExplainabilityMetrics.getLocalSaliencyStability(model, prediction, explainer, 2, 10);
stabilities.add(stability);
}
LocalSaliencyStability first = stabilities.get(0);
LocalSaliencyStability second = stabilities.get(1);
String decisionName = "sum-but0";
assertThat(first.getNegativeStabilityScore(decisionName, 1)).isEqualTo(second.getNegativeStabilityScore(decisionName, 1));
assertThat(first.getPositiveStabilityScore(decisionName, 1)).isEqualTo(second.getPositiveStabilityScore(decisionName, 1));
assertThat(first.getNegativeStabilityScore(decisionName, 2)).isEqualTo(second.getNegativeStabilityScore(decisionName, 2));
assertThat(first.getPositiveStabilityScore(decisionName, 2)).isEqualTo(second.getPositiveStabilityScore(decisionName, 2));
}
use of org.kie.kogito.explainability.model.Prediction in project kogito-apps by kiegroup.
the class CountingOptimizationStrategyTest method testMaybeOptimize.
@Test
void testMaybeOptimize() {
LimeOptimizationService optimizationService = mock(LimeOptimizationService.class);
CountingOptimizationStrategy strategy = new CountingOptimizationStrategy(10, optimizationService);
List<Prediction> recordedPredictions = Collections.emptyList();
PredictionProvider model = mock(PredictionProvider.class);
LimeExplainer explaier = new LimeExplainer();
LimeConfig config = new LimeConfig();
assertThatCode(() -> strategy.maybeOptimize(recordedPredictions, model, explaier, config)).doesNotThrowAnyException();
}
use of org.kie.kogito.explainability.model.Prediction in project kogito-apps by kiegroup.
the class AggregatedLimeExplainerTest method testExplainWithPredictions.
@ParameterizedTest
@ValueSource(ints = { 0, 1, 2, 3, 4 })
void testExplainWithPredictions(int seed) throws ExecutionException, InterruptedException {
Random random = new Random();
random.setSeed(seed);
PredictionProvider sumSkipModel = TestUtils.getSumSkipModel(1);
DataDistribution dataDistribution = DataUtils.generateRandomDataDistribution(3, 100, random);
List<PredictionInput> samples = dataDistribution.sample(10);
List<PredictionOutput> predictionOutputs = sumSkipModel.predictAsync(samples).get();
List<Prediction> predictions = DataUtils.getPredictions(samples, predictionOutputs);
AggregatedLimeExplainer aggregatedLimeExplainer = new AggregatedLimeExplainer();
Map<String, Saliency> explain = aggregatedLimeExplainer.explainFromPredictions(sumSkipModel, predictions).get();
assertNotNull(explain);
assertEquals(1, explain.size());
assertTrue(explain.containsKey("sum-but1"));
Saliency saliency = explain.get("sum-but1");
assertNotNull(saliency);
List<String> collect = saliency.getPositiveFeatures(2).stream().map(FeatureImportance::getFeature).map(Feature::getName).collect(Collectors.toList());
// skipped feature should not appear in top two positive features
assertFalse(collect.contains("f1"));
}
use of org.kie.kogito.explainability.model.Prediction in project kogito-apps by kiegroup.
the class PartialDependencePlotExplainerTest method testTextClassifier.
@ParameterizedTest
@ValueSource(ints = { 0, 1, 2, 3, 4 })
void testTextClassifier(int seed) throws Exception {
Random random = new Random();
random.setSeed(seed);
PartialDependencePlotExplainer partialDependencePlotExplainer = new PartialDependencePlotExplainer();
PredictionProvider model = TestUtils.getDummyTextClassifier();
Collection<Prediction> predictions = new ArrayList<>(3);
List<String> texts = List.of("we want your money", "please reply quickly", "you are the lucky winner", "huge donation for you!", "bitcoin for you");
for (String text : texts) {
List<Feature> features = new ArrayList<>();
features.add(FeatureFactory.newFulltextFeature("text", text));
PredictionInput predictionInput = new PredictionInput(features);
PredictionOutput predictionOutput = model.predictAsync(List.of(predictionInput)).get().get(0);
predictions.add(new SimplePrediction(predictionInput, predictionOutput));
}
List<PartialDependenceGraph> pdps = partialDependencePlotExplainer.explainFromPredictions(model, predictions);
assertThat(pdps).isNotEmpty();
}
Aggregations