use of org.kie.kogito.explainability.model.PredictionInput in project kogito-apps by kiegroup.
the class TestUtils method getSymbolicArithmeticModel.
public static PredictionProvider getSymbolicArithmeticModel() {
return inputs -> supplyAsync(() -> {
List<PredictionOutput> predictionOutputs = new LinkedList<>();
final String OPERAND_FEATURE_NAME = "operand";
for (PredictionInput predictionInput : inputs) {
List<Feature> features = predictionInput.getFeatures();
// Find a valid operand feature, if any
Optional<String> operand = features.stream().filter(f -> OPERAND_FEATURE_NAME.equals(f.getName())).map(f -> f.getValue().asString()).findFirst();
if (!operand.isPresent()) {
throw new IllegalArgumentException("No valid operand found in features");
}
final String operandValue = operand.get();
double result = 0;
// Apply the found operand to the rest of the features
for (Feature feature : features) {
if (!OPERAND_FEATURE_NAME.equals(feature.getName())) {
switch(operandValue) {
case "+":
result += feature.getValue().asNumber();
break;
case "-":
result -= feature.getValue().asNumber();
break;
case "*":
result *= feature.getValue().asNumber();
break;
case "/":
result /= feature.getValue().asNumber();
break;
}
}
}
PredictionOutput predictionOutput = new PredictionOutput(List.of(new Output("result", Type.NUMBER, new Value(result), 1d)));
predictionOutputs.add(predictionOutput);
}
return predictionOutputs;
});
}
use of org.kie.kogito.explainability.model.PredictionInput in project kogito-apps by kiegroup.
the class TestUtils method getNoisySumModel.
public static PredictionProvider getNoisySumModel(Random rn, double noiseMagnitude) {
return inputs -> supplyAsync(() -> {
List<PredictionOutput> predictionOutputs = new LinkedList<>();
for (PredictionInput predictionInput : inputs) {
List<Feature> features = predictionInput.getFeatures();
double result = 0;
for (int i = 0; i < features.size(); i++) {
result += features.get(i).getValue().asNumber() + ((rn.nextDouble() - .5) * noiseMagnitude);
}
PredictionOutput predictionOutput = new PredictionOutput(List.of(new Output("noisy-sum", Type.NUMBER, new Value(result), 1d)));
predictionOutputs.add(predictionOutput);
}
return predictionOutputs;
});
}
use of org.kie.kogito.explainability.model.PredictionInput in project kogito-apps by kiegroup.
the class TestUtils method getSumThresholdModel.
public static PredictionProvider getSumThresholdModel(double center, double epsilon) {
return inputs -> supplyAsync(() -> {
List<PredictionOutput> predictionOutputs = new LinkedList<>();
for (PredictionInput predictionInput : inputs) {
List<Feature> features = predictionInput.getFeatures();
double result = 0;
for (int i = 0; i < features.size(); i++) {
result += features.get(i).getValue().asNumber();
}
final boolean inside = (result >= center - epsilon && result <= center + epsilon);
PredictionOutput predictionOutput = new PredictionOutput(List.of(new Output("inside", Type.BOOLEAN, new Value(inside), 1.0 - Math.abs(result - center))));
predictionOutputs.add(predictionOutput);
}
return predictionOutputs;
});
}
use of org.kie.kogito.explainability.model.PredictionInput in project kogito-apps by kiegroup.
the class AggregatedLimeExplainerTest method testExplainWithMetadata.
@ParameterizedTest
@ValueSource(ints = { 0, 1, 2, 3, 4 })
void testExplainWithMetadata(int seed) throws ExecutionException, InterruptedException {
Random random = new Random();
random.setSeed(seed);
PredictionProvider sumSkipModel = TestUtils.getSumSkipModel(1);
PredictionProviderMetadata metadata = new PredictionProviderMetadata() {
@Override
public DataDistribution getDataDistribution() {
return DataUtils.generateRandomDataDistribution(3, 100, random);
}
@Override
public PredictionInput getInputShape() {
List<Feature> features = new LinkedList<>();
features.add(FeatureFactory.newNumericalFeature("f0", 0));
features.add(FeatureFactory.newNumericalFeature("f1", 0));
features.add(FeatureFactory.newNumericalFeature("f2", 0));
return new PredictionInput(features);
}
@Override
public PredictionOutput getOutputShape() {
List<Output> outputs = new LinkedList<>();
outputs.add(new Output("sum-but1", Type.BOOLEAN, new Value(false), 0d));
return new PredictionOutput(outputs);
}
};
AggregatedLimeExplainer aggregatedLimeExplainer = new AggregatedLimeExplainer();
Map<String, Saliency> explain = aggregatedLimeExplainer.explainFromMetadata(sumSkipModel, metadata).get();
assertNotNull(explain);
assertEquals(1, explain.size());
assertTrue(explain.containsKey("sum-but1"));
Saliency saliency = explain.get("sum-but1");
assertNotNull(saliency);
List<String> collect = saliency.getPositiveFeatures(2).stream().map(FeatureImportance::getFeature).map(Feature::getName).collect(Collectors.toList());
// skipped feature should not appear in top two positive features
assertFalse(collect.contains("f1"));
}
use of org.kie.kogito.explainability.model.PredictionInput in project kogito-apps by kiegroup.
the class CounterfactualEntityFactoryTest method testCreateFixedEntities.
@Test
void testCreateFixedEntities() {
List<Feature> features = new LinkedList<>();
features.add(FeatureFactory.newNumericalFeature("f-num1", 100.1));
features.add(FeatureFactory.newNumericalFeature("f-num2", 100.2, NumericalFeatureDomain.create(0.0, 1000.0)));
features.add(FeatureFactory.newNumericalFeature("f-num3", 100.3));
features.add(FeatureFactory.newNumericalFeature("f-num4", 100.4, NumericalFeatureDomain.create(0.0, 1000.0)));
PredictionInput input = new PredictionInput(features);
List<CounterfactualEntity> entities = CounterfactualEntityFactory.createEntities(input);
// Check types
assertTrue(entities.get(0) instanceof FixedDoubleEntity);
assertTrue(entities.get(1) instanceof DoubleEntity);
assertTrue(entities.get(2) instanceof FixedDoubleEntity);
assertTrue(entities.get(3) instanceof DoubleEntity);
// Check values
assertEquals(100.1, entities.get(0).asFeature().getValue().asNumber());
assertEquals(100.2, entities.get(1).asFeature().getValue().asNumber());
assertEquals(100.3, entities.get(2).asFeature().getValue().asNumber());
assertEquals(100.4, entities.get(3).asFeature().getValue().asNumber());
// Check constraints
assertTrue(entities.get(0).isConstrained());
assertFalse(entities.get(1).isConstrained());
assertTrue(entities.get(2).isConstrained());
assertFalse(entities.get(3).isConstrained());
}
Aggregations