use of org.kie.kogito.explainability.model.Prediction in project kogito-apps by kiegroup.
the class RecordingLimeExplainerTest method testExplainNonOptimized.
@Test
void testExplainNonOptimized() throws ExecutionException, InterruptedException, TimeoutException {
RecordingLimeExplainer limeExplainer = new RecordingLimeExplainer(10);
List<Feature> features = new ArrayList<>();
for (int i = 0; i < 4; i++) {
features.add(TestUtils.getMockedNumericFeature(i));
}
PredictionInput input = new PredictionInput(features);
PredictionProvider model = TestUtils.getSumSkipModel(0);
PredictionOutput output = model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit()).get(0);
Prediction prediction = new SimplePrediction(input, output);
Map<String, Saliency> saliencyMap = limeExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
assertNotNull(saliencyMap);
}
use of org.kie.kogito.explainability.model.Prediction in project kogito-apps by kiegroup.
the class RecordingLimeExplainerTest method testAutomaticConfigOptimization.
@ParameterizedTest
@ValueSource(longs = { 0 })
void testAutomaticConfigOptimization(long seed) throws Exception {
PredictionProvider model = TestUtils.getSumThresholdModel(10, 10);
PerturbationContext pc = new PerturbationContext(seed, new Random(), 1);
LimeConfig config = new LimeConfig().withPerturbationContext(pc);
RecordingLimeExplainer limeExplainer = new RecordingLimeExplainer(2);
for (int i = 0; i < 50; i++) {
List<Feature> features = new LinkedList<>();
features.add(TestUtils.getMockedNumericFeature(Type.NUMBER.randomValue(pc).asNumber()));
features.add(TestUtils.getMockedNumericFeature(Type.NUMBER.randomValue(pc).asNumber()));
features.add(TestUtils.getMockedNumericFeature(Type.NUMBER.randomValue(pc).asNumber()));
PredictionInput input = new PredictionInput(features);
List<PredictionOutput> outputs = model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
Prediction prediction = new SimplePrediction(input, outputs.get(0));
Map<String, Saliency> saliencyMap = limeExplainer.explainAsync(prediction, model).toCompletableFuture().get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
for (Saliency saliency : saliencyMap.values()) {
assertNotNull(saliency);
}
}
LimeConfig optimizedConfig = limeExplainer.getExecutionConfig();
assertThat(optimizedConfig).isNotEqualTo(config);
}
use of org.kie.kogito.explainability.model.Prediction in project kogito-apps by kiegroup.
the class RecordingLimeExplainerTest method testEmptyInput.
@Test
void testEmptyInput() {
RecordingLimeExplainer recordingLimeExplainer = new RecordingLimeExplainer(10);
PredictionProvider model = mock(PredictionProvider.class);
Prediction prediction = mock(Prediction.class);
assertThatCode(() -> recordingLimeExplainer.explainAsync(prediction, model)).hasMessage("cannot explain a prediction whose input is empty");
}
use of org.kie.kogito.explainability.model.Prediction in project kogito-apps by kiegroup.
the class ShapKernelExplainerTest method shapTestCase.
/*
* given a specific model, config, background, explanations, ske, and expected shap values,
* test that the computed shape values match expected shap values
*/
private void shapTestCase(PredictionProvider model, ShapKernelExplainer ske, double[][] toExplainRaw, double[][][] expected) throws InterruptedException, TimeoutException, ExecutionException {
// establish background data and desired data to explain
List<PredictionInput> toExplain = createPIFromMatrix(toExplainRaw);
// initialize explainer
List<PredictionOutput> predictionOutputs = model.predictAsync(toExplain).get(5, TimeUnit.SECONDS);
List<Prediction> predictions = new ArrayList<>();
for (int i = 0; i < predictionOutputs.size(); i++) {
predictions.add(new SimplePrediction(toExplain.get(i), predictionOutputs.get(i)));
}
// evaluate if the explanations match the expected value
for (int i = 0; i < toExplain.size(); i++) {
// explanations shape: outputSize x nfeatures
Saliency[] explanationSaliencies = ske.explainAsync(predictions.get(i), model).get(5, TimeUnit.SECONDS).getSaliencies();
RealMatrix explanations = saliencyToMatrix(explanationSaliencies)[0];
for (int j = 0; j < explanations.getRowDimension(); j++) {
assertArrayEquals(expected[i][j], explanations.getRow(j), 1e-6);
}
}
}
use of org.kie.kogito.explainability.model.Prediction in project kogito-apps by kiegroup.
the class ShapKernelExplainerTest method testParallel.
@Test
void testParallel() throws InterruptedException, ExecutionException {
// establish background data and desired data to explain
double[][] largeBackground = new double[100][10];
for (int i = 0; i < 100; i++) {
for (int j = 0; j < 10; j++) {
largeBackground[i][j] = i / 100. + j;
}
}
double[][] toExplainLargeBackground = { { 0, 1., -2., 3.5, -4.1, 5.5, -12., .8, .11, 15. } };
double[][][] expected = { { { -0.495, 0., -4.495, 0.005, -8.595, 0.005, -18.495, -6.695, -8.385, 5.505 } } };
List<PredictionInput> background = createPIFromMatrix(largeBackground);
List<PredictionInput> toExplain = createPIFromMatrix(toExplainLargeBackground);
PredictionProvider model = TestUtils.getSumSkipModel(1);
ShapConfig skConfig = testConfig.withBackground(background).build();
// initialize explainer
List<PredictionOutput> predictionOutputs = model.predictAsync(toExplain).get();
List<Prediction> predictions = new ArrayList<>();
for (int i = 0; i < predictionOutputs.size(); i++) {
predictions.add(new SimplePrediction(toExplain.get(i), predictionOutputs.get(i)));
}
// evaluate if the explanations match the expected value
ShapKernelExplainer ske = new ShapKernelExplainer(skConfig);
CompletableFuture<ShapResults> explanationsCF = ske.explainAsync(predictions.get(0), model);
ExecutorService executor = ForkJoinPool.commonPool();
executor.submit(() -> {
Saliency[] explanationSaliencies = explanationsCF.join().getSaliencies();
RealMatrix explanations = saliencyToMatrix(explanationSaliencies)[0];
assertArrayEquals(expected[0][0], explanations.getRow(0), 1e-2);
});
}
Aggregations