use of org.kie.kogito.explainability.model.PredictionProvider in project kogito-apps by kiegroup.
the class TestUtils method getSymbolicArithmeticModel.
public static PredictionProvider getSymbolicArithmeticModel() {
return inputs -> supplyAsync(() -> {
List<PredictionOutput> predictionOutputs = new LinkedList<>();
final String OPERAND_FEATURE_NAME = "operand";
for (PredictionInput predictionInput : inputs) {
List<Feature> features = predictionInput.getFeatures();
// Find a valid operand feature, if any
Optional<String> operand = features.stream().filter(f -> OPERAND_FEATURE_NAME.equals(f.getName())).map(f -> f.getValue().asString()).findFirst();
if (!operand.isPresent()) {
throw new IllegalArgumentException("No valid operand found in features");
}
final String operandValue = operand.get();
double result = 0;
// Apply the found operand to the rest of the features
for (Feature feature : features) {
if (!OPERAND_FEATURE_NAME.equals(feature.getName())) {
switch(operandValue) {
case "+":
result += feature.getValue().asNumber();
break;
case "-":
result -= feature.getValue().asNumber();
break;
case "*":
result *= feature.getValue().asNumber();
break;
case "/":
result /= feature.getValue().asNumber();
break;
}
}
}
PredictionOutput predictionOutput = new PredictionOutput(List.of(new Output("result", Type.NUMBER, new Value(result), 1d)));
predictionOutputs.add(predictionOutput);
}
return predictionOutputs;
});
}
use of org.kie.kogito.explainability.model.PredictionProvider in project kogito-apps by kiegroup.
the class TestUtils method getNoisySumModel.
public static PredictionProvider getNoisySumModel(Random rn, double noiseMagnitude) {
return inputs -> supplyAsync(() -> {
List<PredictionOutput> predictionOutputs = new LinkedList<>();
for (PredictionInput predictionInput : inputs) {
List<Feature> features = predictionInput.getFeatures();
double result = 0;
for (int i = 0; i < features.size(); i++) {
result += features.get(i).getValue().asNumber() + ((rn.nextDouble() - .5) * noiseMagnitude);
}
PredictionOutput predictionOutput = new PredictionOutput(List.of(new Output("noisy-sum", Type.NUMBER, new Value(result), 1d)));
predictionOutputs.add(predictionOutput);
}
return predictionOutputs;
});
}
use of org.kie.kogito.explainability.model.PredictionProvider in project kogito-apps by kiegroup.
the class TestUtils method getSumThresholdModel.
public static PredictionProvider getSumThresholdModel(double center, double epsilon) {
return inputs -> supplyAsync(() -> {
List<PredictionOutput> predictionOutputs = new LinkedList<>();
for (PredictionInput predictionInput : inputs) {
List<Feature> features = predictionInput.getFeatures();
double result = 0;
for (int i = 0; i < features.size(); i++) {
result += features.get(i).getValue().asNumber();
}
final boolean inside = (result >= center - epsilon && result <= center + epsilon);
PredictionOutput predictionOutput = new PredictionOutput(List.of(new Output("inside", Type.BOOLEAN, new Value(inside), 1.0 - Math.abs(result - center))));
predictionOutputs.add(predictionOutput);
}
return predictionOutputs;
});
}
use of org.kie.kogito.explainability.model.PredictionProvider in project kogito-apps by kiegroup.
the class AggregatedLimeExplainerTest method testExplainWithMetadata.
@ParameterizedTest
@ValueSource(ints = { 0, 1, 2, 3, 4 })
void testExplainWithMetadata(int seed) throws ExecutionException, InterruptedException {
Random random = new Random();
random.setSeed(seed);
PredictionProvider sumSkipModel = TestUtils.getSumSkipModel(1);
PredictionProviderMetadata metadata = new PredictionProviderMetadata() {
@Override
public DataDistribution getDataDistribution() {
return DataUtils.generateRandomDataDistribution(3, 100, random);
}
@Override
public PredictionInput getInputShape() {
List<Feature> features = new LinkedList<>();
features.add(FeatureFactory.newNumericalFeature("f0", 0));
features.add(FeatureFactory.newNumericalFeature("f1", 0));
features.add(FeatureFactory.newNumericalFeature("f2", 0));
return new PredictionInput(features);
}
@Override
public PredictionOutput getOutputShape() {
List<Output> outputs = new LinkedList<>();
outputs.add(new Output("sum-but1", Type.BOOLEAN, new Value(false), 0d));
return new PredictionOutput(outputs);
}
};
AggregatedLimeExplainer aggregatedLimeExplainer = new AggregatedLimeExplainer();
Map<String, Saliency> explain = aggregatedLimeExplainer.explainFromMetadata(sumSkipModel, metadata).get();
assertNotNull(explain);
assertEquals(1, explain.size());
assertTrue(explain.containsKey("sum-but1"));
Saliency saliency = explain.get("sum-but1");
assertNotNull(saliency);
List<String> collect = saliency.getPositiveFeatures(2).stream().map(FeatureImportance::getFeature).map(Feature::getName).collect(Collectors.toList());
// skipped feature should not appear in top two positive features
assertFalse(collect.contains("f1"));
}
use of org.kie.kogito.explainability.model.PredictionProvider in project kogito-apps by kiegroup.
the class PartialDependencePlotExplainerTest method testBrokenPredict.
@ParameterizedTest
@ValueSource(ints = { 0, 1, 2, 3, 4 })
void testBrokenPredict(int seed) {
Random random = new Random();
random.setSeed(seed);
Config.INSTANCE.setAsyncTimeout(1);
Config.INSTANCE.setAsyncTimeUnit(TimeUnit.MILLISECONDS);
PartialDependencePlotExplainer partialDependencePlotProvider = new PartialDependencePlotExplainer();
PredictionProvider brokenProvider = inputs -> supplyAsync(() -> {
await().atLeast(1, TimeUnit.SECONDS).until(() -> false);
throw new RuntimeException("this should never happen");
});
try {
Assertions.assertThrows(TimeoutException.class, () -> partialDependencePlotProvider.explainFromMetadata(brokenProvider, getMetadata(random)));
} finally {
Config.INSTANCE.setAsyncTimeout(Config.DEFAULT_ASYNC_TIMEOUT);
Config.INSTANCE.setAsyncTimeUnit(Config.DEFAULT_ASYNC_TIMEUNIT);
}
}
Aggregations