use of org.kie.kogito.explainability.model.PredictionProvider in project kogito-apps by kiegroup.
the class LimeExplainerTest method testSparseBalance.
@ParameterizedTest
@ValueSource(longs = { 0, 1, 2, 3, 4 })
void testSparseBalance(long seed) throws InterruptedException, ExecutionException, TimeoutException {
for (int nf = 1; nf < 4; nf++) {
Random random = new Random();
int noOfSamples = 100;
LimeConfig limeConfigNoPenalty = new LimeConfig().withPerturbationContext(new PerturbationContext(seed, random, DEFAULT_NO_OF_PERTURBATIONS)).withSamples(noOfSamples).withPenalizeBalanceSparse(false);
LimeExplainer limeExplainerNoPenalty = new LimeExplainer(limeConfigNoPenalty);
List<Feature> features = new ArrayList<>();
for (int i = 0; i < nf; i++) {
features.add(TestUtils.getMockedNumericFeature(i));
}
PredictionInput input = new PredictionInput(features);
PredictionProvider model = TestUtils.getSumSkipModel(0);
PredictionOutput output = model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit()).get(0);
Prediction prediction = new SimplePrediction(input, output);
Map<String, Saliency> saliencyMapNoPenalty = limeExplainerNoPenalty.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
assertThat(saliencyMapNoPenalty).isNotNull();
String decisionName = "sum-but0";
Saliency saliencyNoPenalty = saliencyMapNoPenalty.get(decisionName);
LimeConfig limeConfig = new LimeConfig().withPerturbationContext(new PerturbationContext(seed, random, DEFAULT_NO_OF_PERTURBATIONS)).withSamples(noOfSamples).withPenalizeBalanceSparse(true);
LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
Map<String, Saliency> saliencyMap = limeExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
assertThat(saliencyMap).isNotNull();
Saliency saliency = saliencyMap.get(decisionName);
for (int i = 0; i < features.size(); i++) {
double score = saliency.getPerFeatureImportance().get(i).getScore();
double scoreNoPenalty = saliencyNoPenalty.getPerFeatureImportance().get(i).getScore();
assertThat(Math.abs(score)).isLessThanOrEqualTo(Math.abs(scoreNoPenalty));
}
}
}
use of org.kie.kogito.explainability.model.PredictionProvider in project kogito-apps by kiegroup.
the class LimeExplainerTest method testEmptyPrediction.
@ParameterizedTest
@ValueSource(longs = { 0, 1, 2, 3, 4 })
void testEmptyPrediction(long seed) throws ExecutionException, InterruptedException, TimeoutException {
Random random = new Random();
LimeConfig limeConfig = new LimeConfig().withPerturbationContext(new PerturbationContext(seed, random, DEFAULT_NO_OF_PERTURBATIONS)).withSamples(10);
LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
PredictionInput input = new PredictionInput(Collections.emptyList());
PredictionProvider model = TestUtils.getSumSkipModel(0);
PredictionOutput output = model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit()).get(0);
Prediction prediction = new SimplePrediction(input, output);
assertThrows(LocalExplanationException.class, () -> limeExplainer.explainAsync(prediction, model));
}
use of org.kie.kogito.explainability.model.PredictionProvider in project kogito-apps by kiegroup.
the class LimeExplainerTest method testNonEmptyInput.
@ParameterizedTest
@ValueSource(longs = { 0, 1, 2, 3, 4 })
void testNonEmptyInput(long seed) throws ExecutionException, InterruptedException, TimeoutException {
Random random = new Random();
LimeConfig limeConfig = new LimeConfig().withPerturbationContext(new PerturbationContext(seed, random, DEFAULT_NO_OF_PERTURBATIONS)).withSamples(10);
LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
List<Feature> features = new ArrayList<>();
for (int i = 0; i < 4; i++) {
features.add(TestUtils.getMockedNumericFeature(i));
}
PredictionInput input = new PredictionInput(features);
PredictionProvider model = TestUtils.getSumSkipModel(0);
PredictionOutput output = model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit()).get(0);
Prediction prediction = new SimplePrediction(input, output);
Map<String, Saliency> saliencyMap = limeExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
assertNotNull(saliencyMap);
}
use of org.kie.kogito.explainability.model.PredictionProvider in project kogito-apps by kiegroup.
the class LimeExplainerTest method testEmptyInput.
@Test
void testEmptyInput() {
LimeExplainer recordingLimeExplainer = new LimeExplainer();
PredictionProvider model = mock(PredictionProvider.class);
Prediction prediction = mock(Prediction.class);
assertThatCode(() -> recordingLimeExplainer.explainAsync(prediction, model)).hasMessage("cannot explain a prediction whose input is empty");
}
use of org.kie.kogito.explainability.model.PredictionProvider in project kogito-apps by kiegroup.
the class LimeStabilityTest method testStabilityDeterministic.
@ParameterizedTest
@ValueSource(longs = { 0, 1, 2, 3, 4 })
void testStabilityDeterministic(long seed) throws Exception {
List<LocalSaliencyStability> stabilities = new ArrayList<>();
for (int j = 0; j < 2; j++) {
Random random = new Random();
PredictionProvider model = TestUtils.getSumSkipModel(0);
List<Feature> featureList = new LinkedList<>();
for (int i = 0; i < 5; i++) {
featureList.add(TestUtils.getMockedNumericFeature(i));
}
PredictionInput input = new PredictionInput(featureList);
List<PredictionOutput> predictionOutputs = model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
Prediction prediction = new SimplePrediction(input, predictionOutputs.get(0));
LimeConfig limeConfig = new LimeConfig().withSamples(10).withPerturbationContext(new PerturbationContext(seed, random, 1));
LimeExplainer explainer = new LimeExplainer(limeConfig);
LocalSaliencyStability stability = ExplainabilityMetrics.getLocalSaliencyStability(model, prediction, explainer, 2, 10);
stabilities.add(stability);
}
LocalSaliencyStability first = stabilities.get(0);
LocalSaliencyStability second = stabilities.get(1);
String decisionName = "sum-but0";
assertThat(first.getNegativeStabilityScore(decisionName, 1)).isEqualTo(second.getNegativeStabilityScore(decisionName, 1));
assertThat(first.getPositiveStabilityScore(decisionName, 1)).isEqualTo(second.getPositiveStabilityScore(decisionName, 1));
assertThat(first.getNegativeStabilityScore(decisionName, 2)).isEqualTo(second.getNegativeStabilityScore(decisionName, 2));
assertThat(first.getPositiveStabilityScore(decisionName, 2)).isEqualTo(second.getPositiveStabilityScore(decisionName, 2));
}
Aggregations