use of org.kie.kogito.explainability.model.PredictionOutput in project kogito-apps by kiegroup.
the class TestUtils method getSymbolicArithmeticModel.
public static PredictionProvider getSymbolicArithmeticModel() {
return inputs -> supplyAsync(() -> {
List<PredictionOutput> predictionOutputs = new LinkedList<>();
final String OPERAND_FEATURE_NAME = "operand";
for (PredictionInput predictionInput : inputs) {
List<Feature> features = predictionInput.getFeatures();
// Find a valid operand feature, if any
Optional<String> operand = features.stream().filter(f -> OPERAND_FEATURE_NAME.equals(f.getName())).map(f -> f.getValue().asString()).findFirst();
if (!operand.isPresent()) {
throw new IllegalArgumentException("No valid operand found in features");
}
final String operandValue = operand.get();
double result = 0;
// Apply the found operand to the rest of the features
for (Feature feature : features) {
if (!OPERAND_FEATURE_NAME.equals(feature.getName())) {
switch(operandValue) {
case "+":
result += feature.getValue().asNumber();
break;
case "-":
result -= feature.getValue().asNumber();
break;
case "*":
result *= feature.getValue().asNumber();
break;
case "/":
result /= feature.getValue().asNumber();
break;
}
}
}
PredictionOutput predictionOutput = new PredictionOutput(List.of(new Output("result", Type.NUMBER, new Value(result), 1d)));
predictionOutputs.add(predictionOutput);
}
return predictionOutputs;
});
}
use of org.kie.kogito.explainability.model.PredictionOutput in project kogito-apps by kiegroup.
the class TestUtils method getNoisySumModel.
public static PredictionProvider getNoisySumModel(Random rn, double noiseMagnitude) {
return inputs -> supplyAsync(() -> {
List<PredictionOutput> predictionOutputs = new LinkedList<>();
for (PredictionInput predictionInput : inputs) {
List<Feature> features = predictionInput.getFeatures();
double result = 0;
for (int i = 0; i < features.size(); i++) {
result += features.get(i).getValue().asNumber() + ((rn.nextDouble() - .5) * noiseMagnitude);
}
PredictionOutput predictionOutput = new PredictionOutput(List.of(new Output("noisy-sum", Type.NUMBER, new Value(result), 1d)));
predictionOutputs.add(predictionOutput);
}
return predictionOutputs;
});
}
use of org.kie.kogito.explainability.model.PredictionOutput in project kogito-apps by kiegroup.
the class TestUtils method getSumThresholdModel.
public static PredictionProvider getSumThresholdModel(double center, double epsilon) {
return inputs -> supplyAsync(() -> {
List<PredictionOutput> predictionOutputs = new LinkedList<>();
for (PredictionInput predictionInput : inputs) {
List<Feature> features = predictionInput.getFeatures();
double result = 0;
for (int i = 0; i < features.size(); i++) {
result += features.get(i).getValue().asNumber();
}
final boolean inside = (result >= center - epsilon && result <= center + epsilon);
PredictionOutput predictionOutput = new PredictionOutput(List.of(new Output("inside", Type.BOOLEAN, new Value(inside), 1.0 - Math.abs(result - center))));
predictionOutputs.add(predictionOutput);
}
return predictionOutputs;
});
}
use of org.kie.kogito.explainability.model.PredictionOutput in project kogito-apps by kiegroup.
the class AggregatedLimeExplainerTest method testExplainWithMetadata.
@ParameterizedTest
@ValueSource(ints = { 0, 1, 2, 3, 4 })
void testExplainWithMetadata(int seed) throws ExecutionException, InterruptedException {
Random random = new Random();
random.setSeed(seed);
PredictionProvider sumSkipModel = TestUtils.getSumSkipModel(1);
PredictionProviderMetadata metadata = new PredictionProviderMetadata() {
@Override
public DataDistribution getDataDistribution() {
return DataUtils.generateRandomDataDistribution(3, 100, random);
}
@Override
public PredictionInput getInputShape() {
List<Feature> features = new LinkedList<>();
features.add(FeatureFactory.newNumericalFeature("f0", 0));
features.add(FeatureFactory.newNumericalFeature("f1", 0));
features.add(FeatureFactory.newNumericalFeature("f2", 0));
return new PredictionInput(features);
}
@Override
public PredictionOutput getOutputShape() {
List<Output> outputs = new LinkedList<>();
outputs.add(new Output("sum-but1", Type.BOOLEAN, new Value(false), 0d));
return new PredictionOutput(outputs);
}
};
AggregatedLimeExplainer aggregatedLimeExplainer = new AggregatedLimeExplainer();
Map<String, Saliency> explain = aggregatedLimeExplainer.explainFromMetadata(sumSkipModel, metadata).get();
assertNotNull(explain);
assertEquals(1, explain.size());
assertTrue(explain.containsKey("sum-but1"));
Saliency saliency = explain.get("sum-but1");
assertNotNull(saliency);
List<String> collect = saliency.getPositiveFeatures(2).stream().map(FeatureImportance::getFeature).map(Feature::getName).collect(Collectors.toList());
// skipped feature should not appear in top two positive features
assertFalse(collect.contains("f1"));
}
use of org.kie.kogito.explainability.model.PredictionOutput in project kogito-apps by kiegroup.
the class CounterfactualExplainerTest method testFinalUniqueIds.
@ParameterizedTest
@ValueSource(ints = { 0, 1, 2 })
void testFinalUniqueIds(int seed) throws ExecutionException, InterruptedException, TimeoutException {
Random random = new Random();
random.setSeed(seed);
final List<Output> goal = new ArrayList<>();
List<Feature> features = List.of(FeatureFactory.newNumericalFeature("f-num1", 10.0, NumericalFeatureDomain.create(0, 20)));
PredictionProvider model = TestUtils.getFeaturePassModel(0);
final TerminationConfig terminationConfig = new TerminationConfig().withScoreCalculationCountLimit(100_000L);
final SolverConfig solverConfig = SolverConfigBuilder.builder().withTerminationConfig(terminationConfig).build();
solverConfig.setRandomSeed((long) seed);
solverConfig.setEnvironmentMode(EnvironmentMode.REPRODUCIBLE);
final List<UUID> intermediateIds = new ArrayList<>();
final List<UUID> executionIds = new ArrayList<>();
final Consumer<CounterfactualResult> captureIntermediateIds = counterfactual -> {
intermediateIds.add(counterfactual.getSolutionId());
};
final Consumer<CounterfactualResult> captureExecutionIds = counterfactual -> {
executionIds.add(counterfactual.getExecutionId());
};
final CounterfactualConfig counterfactualConfig = new CounterfactualConfig().withSolverConfig(solverConfig);
solverConfig.withEasyScoreCalculatorClass(MockCounterFactualScoreCalculator.class);
final CounterfactualExplainer counterfactualExplainer = new CounterfactualExplainer(counterfactualConfig);
PredictionInput input = new PredictionInput(features);
PredictionOutput output = new PredictionOutput(goal);
final UUID executionId = UUID.randomUUID();
Prediction prediction = new CounterfactualPrediction(input, output, null, executionId, null);
final CounterfactualResult counterfactualResult = counterfactualExplainer.explainAsync(prediction, model, captureIntermediateIds.andThen(captureExecutionIds)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
for (CounterfactualEntity entity : counterfactualResult.getEntities()) {
logger.debug("Entity: {}", entity);
}
// All intermediate ids should be unique
assertEquals((int) intermediateIds.stream().distinct().count(), intermediateIds.size());
// There should be at least one intermediate id
assertTrue(intermediateIds.size() > 0);
// There should be at least one execution id
assertTrue(executionIds.size() > 0);
// We should have the same number of execution ids as intermediate ids (captured from intermediate results)
assertEquals(executionIds.size(), intermediateIds.size());
// All execution ids should be the same
assertEquals(1, (int) executionIds.stream().distinct().count());
// The last intermediate id must be different from the final result id
assertNotEquals(intermediateIds.get(intermediateIds.size() - 1), counterfactualResult.getSolutionId());
// Captured execution ids should be the same as the one provided
assertEquals(executionIds.get(0), executionId);
}
Aggregations