use of org.kie.kogito.explainability.model.PredictionProviderMetadata in project kogito-apps by kiegroup.
the class AggregatedLimeExplainerTest method testExplainWithMetadata.
@ParameterizedTest
@ValueSource(ints = { 0, 1, 2, 3, 4 })
void testExplainWithMetadata(int seed) throws ExecutionException, InterruptedException {
Random random = new Random();
random.setSeed(seed);
PredictionProvider sumSkipModel = TestUtils.getSumSkipModel(1);
PredictionProviderMetadata metadata = new PredictionProviderMetadata() {
@Override
public DataDistribution getDataDistribution() {
return DataUtils.generateRandomDataDistribution(3, 100, random);
}
@Override
public PredictionInput getInputShape() {
List<Feature> features = new LinkedList<>();
features.add(FeatureFactory.newNumericalFeature("f0", 0));
features.add(FeatureFactory.newNumericalFeature("f1", 0));
features.add(FeatureFactory.newNumericalFeature("f2", 0));
return new PredictionInput(features);
}
@Override
public PredictionOutput getOutputShape() {
List<Output> outputs = new LinkedList<>();
outputs.add(new Output("sum-but1", Type.BOOLEAN, new Value(false), 0d));
return new PredictionOutput(outputs);
}
};
AggregatedLimeExplainer aggregatedLimeExplainer = new AggregatedLimeExplainer();
Map<String, Saliency> explain = aggregatedLimeExplainer.explainFromMetadata(sumSkipModel, metadata).get();
assertNotNull(explain);
assertEquals(1, explain.size());
assertTrue(explain.containsKey("sum-but1"));
Saliency saliency = explain.get("sum-but1");
assertNotNull(saliency);
List<String> collect = saliency.getPositiveFeatures(2).stream().map(FeatureImportance::getFeature).map(Feature::getName).collect(Collectors.toList());
// skipped feature should not appear in top two positive features
assertFalse(collect.contains("f1"));
}
use of org.kie.kogito.explainability.model.PredictionProviderMetadata in project kogito-apps by kiegroup.
the class PartialDependencePlotExplainerTest method getMetadata.
private PredictionProviderMetadata getMetadata(Random random) {
return new PredictionProviderMetadata() {
@Override
public DataDistribution getDataDistribution() {
return DataUtils.generateRandomDataDistribution(3, 100, random);
}
@Override
public PredictionInput getInputShape() {
List<Feature> features = new LinkedList<>();
features.add(FeatureFactory.newNumericalFeature("f0", 0));
features.add(FeatureFactory.newNumericalFeature("f1", 0));
features.add(FeatureFactory.newNumericalFeature("f2", 0));
return new PredictionInput(features);
}
@Override
public PredictionOutput getOutputShape() {
List<Output> outputs = new LinkedList<>();
outputs.add(new Output("sum-but0", Type.BOOLEAN, new Value(false), 0d));
return new PredictionOutput(outputs);
}
};
}
Aggregations