use of org.kie.kogito.explainability.model.PredictionOutput in project kogito-apps by kiegroup.
the class TestUtils method getDummyTextClassifier.
public static PredictionProvider getDummyTextClassifier() {
List<String> blackList = Arrays.asList("money", "$", "£", "bitcoin");
return inputs -> supplyAsync(() -> {
List<PredictionOutput> outputs = new LinkedList<>();
for (PredictionInput input : inputs) {
boolean spam = false;
for (Feature f : input.getFeatures()) {
if (!spam) {
String s = f.getValue().asString();
String[] words = s.split(" ");
for (String w : words) {
if (blackList.contains(w)) {
spam = true;
break;
}
}
}
}
Output output = new Output("spam", Type.BOOLEAN, new Value(spam), 1d);
outputs.add(new PredictionOutput(List.of(output)));
}
return outputs;
});
}
use of org.kie.kogito.explainability.model.PredictionOutput in project kogito-apps by kiegroup.
the class TestUtils method getFeaturePassModel.
public static PredictionProvider getFeaturePassModel(int featureIndex) {
return inputs -> supplyAsync(() -> {
List<PredictionOutput> predictionOutputs = new LinkedList<>();
for (PredictionInput predictionInput : inputs) {
List<Feature> features = predictionInput.getFeatures();
Feature feature = features.get(featureIndex);
PredictionOutput predictionOutput = new PredictionOutput(List.of(new Output("feature-" + featureIndex, feature.getType(), feature.getValue(), 1d)));
predictionOutputs.add(predictionOutput);
}
return predictionOutputs;
});
}
use of org.kie.kogito.explainability.model.PredictionOutput in project kogito-apps by kiegroup.
the class TestUtils method getFeatureSkipModel.
/**
* Test model which returns the inputs as outputs, except for a single specified feature
*
* @param featureIndex Index of the input feature to omit from output
* @return A {@link PredictionProvider} model
*/
public static PredictionProvider getFeatureSkipModel(int featureIndex) {
return inputs -> supplyAsync(() -> {
List<PredictionOutput> predictionOutputs = new LinkedList<>();
for (PredictionInput predictionInput : inputs) {
List<Feature> features = predictionInput.getFeatures();
List<Output> outputs = new ArrayList<>();
for (int i = 0; i < features.size(); i++) {
if (i != featureIndex) {
Feature feature = features.get(i);
outputs.add(new Output(feature.getName(), feature.getType(), feature.getValue(), 1.0));
}
}
PredictionOutput predictionOutput = new PredictionOutput(outputs);
predictionOutputs.add(predictionOutput);
}
return predictionOutputs;
});
}
use of org.kie.kogito.explainability.model.PredictionOutput in project kogito-apps by kiegroup.
the class TestUtils method getEvenFeatureModel.
public static PredictionProvider getEvenFeatureModel(int featureIndex) {
return inputs -> supplyAsync(() -> {
List<PredictionOutput> predictionOutputs = new LinkedList<>();
for (PredictionInput predictionInput : inputs) {
List<Feature> features = predictionInput.getFeatures();
Feature feature = features.get(featureIndex);
double v = feature.getValue().asNumber();
PredictionOutput predictionOutput = new PredictionOutput(List.of(new Output("feature-" + featureIndex, Type.BOOLEAN, new Value(v % 2 == 0), 1d)));
predictionOutputs.add(predictionOutput);
}
return predictionOutputs;
});
}
use of org.kie.kogito.explainability.model.PredictionOutput in project kogito-apps by kiegroup.
the class AggregatedLimeExplainerTest method testExplainWithPredictions.
@ParameterizedTest
@ValueSource(ints = { 0, 1, 2, 3, 4 })
void testExplainWithPredictions(int seed) throws ExecutionException, InterruptedException {
Random random = new Random();
random.setSeed(seed);
PredictionProvider sumSkipModel = TestUtils.getSumSkipModel(1);
DataDistribution dataDistribution = DataUtils.generateRandomDataDistribution(3, 100, random);
List<PredictionInput> samples = dataDistribution.sample(10);
List<PredictionOutput> predictionOutputs = sumSkipModel.predictAsync(samples).get();
List<Prediction> predictions = DataUtils.getPredictions(samples, predictionOutputs);
AggregatedLimeExplainer aggregatedLimeExplainer = new AggregatedLimeExplainer();
Map<String, Saliency> explain = aggregatedLimeExplainer.explainFromPredictions(sumSkipModel, predictions).get();
assertNotNull(explain);
assertEquals(1, explain.size());
assertTrue(explain.containsKey("sum-but1"));
Saliency saliency = explain.get("sum-but1");
assertNotNull(saliency);
List<String> collect = saliency.getPositiveFeatures(2).stream().map(FeatureImportance::getFeature).map(Feature::getName).collect(Collectors.toList());
// skipped feature should not appear in top two positive features
assertFalse(collect.contains("f1"));
}
Aggregations