use of org.kie.kogito.explainability.model.PredictionInput in project kogito-apps by kiegroup.
the class DataUtilsTest method testTexify.
@Test
void testTexify() {
List<Feature> features = new ArrayList<>();
features.add(TestUtils.getMockedTextFeature("we go here and there"));
features.add(TestUtils.getMockedTextFeature("as you go there and here"));
PredictionInput input = new PredictionInput(features);
String textifiedInput = DataUtils.textify(input);
assertThat(textifiedInput).isNotNull().isEqualTo("we go here and there as you go there and here");
}
use of org.kie.kogito.explainability.model.PredictionInput in project kogito-apps by kiegroup.
the class DataUtilsTest method testPerturbDropString.
@ParameterizedTest
@ValueSource(ints = { 0, 1, 2, 3 })
void testPerturbDropString(int param) {
List<Feature> features = new LinkedList<>();
features.add(FeatureFactory.newTextFeature("f0", "foo"));
features.add(FeatureFactory.newTextFeature("f1", "foo bar"));
features.add(FeatureFactory.newTextFeature("f2", " "));
features.add(FeatureFactory.newTextFeature("f3", "foo bar "));
PredictionInput input = new PredictionInput(features);
assertPerturbDropString(input, param);
}
use of org.kie.kogito.explainability.model.PredictionInput in project kogito-apps by kiegroup.
the class DataUtilsTest method testPerturbDropNumeric.
@ParameterizedTest
@ValueSource(ints = { 0, 1, 2, 3 })
void testPerturbDropNumeric(int param) {
List<Feature> features = new LinkedList<>();
features.add(FeatureFactory.newNumericalFeature("f0", 1));
features.add(FeatureFactory.newNumericalFeature("f1", 3.14));
features.add(FeatureFactory.newNumericalFeature("f2", 0.55));
PredictionInput input = new PredictionInput(features);
assertPerturbDropNumeric(input, param);
}
use of org.kie.kogito.explainability.model.PredictionInput in project kogito-apps by kiegroup.
the class ExplainabilityMetricsTest method testFidelityWithTextClassifier.
@Test
void testFidelityWithTextClassifier() throws ExecutionException, InterruptedException, TimeoutException {
List<Pair<Saliency, Prediction>> pairs = new LinkedList<>();
LimeConfig limeConfig = new LimeConfig().withSamples(10);
LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
PredictionProvider model = TestUtils.getDummyTextClassifier();
List<Feature> features = new LinkedList<>();
features.add(FeatureFactory.newFulltextFeature("f-0", "brown fox", s -> Arrays.asList(s.split(" "))));
features.add(FeatureFactory.newTextFeature("f-1", "money"));
PredictionInput input = new PredictionInput(features);
Prediction prediction = new SimplePrediction(input, model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit()).get(0));
Map<String, Saliency> saliencyMap = limeExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
for (Saliency saliency : saliencyMap.values()) {
pairs.add(Pair.of(saliency, prediction));
}
Assertions.assertDoesNotThrow(() -> {
ExplainabilityMetrics.classificationFidelity(pairs);
});
}
use of org.kie.kogito.explainability.model.PredictionInput in project kogito-apps by kiegroup.
the class PmmlScorecardCategoricalLimeExplainerTest method testExplanationImpactScoreWithOptimization.
@Test
void testExplanationImpactScoreWithOptimization() throws ExecutionException, InterruptedException {
PredictionProvider model = getModel();
List<PredictionInput> samples = getSamples();
List<PredictionOutput> predictionOutputs = model.predictAsync(samples.subList(0, 5)).get();
List<Prediction> predictions = DataUtils.getPredictions(samples, predictionOutputs);
long seed = 0;
LimeConfigOptimizer limeConfigOptimizer = new LimeConfigOptimizer().withDeterministicExecution(true).forImpactScore();
Random random = new Random();
PerturbationContext perturbationContext = new PerturbationContext(seed, random, 1);
LimeConfig initialConfig = new LimeConfig().withSamples(10).withPerturbationContext(perturbationContext);
LimeConfig optimizedConfig = limeConfigOptimizer.optimize(initialConfig, predictions, model);
assertThat(optimizedConfig).isNotSameAs(initialConfig);
}
Aggregations