use of org.kie.kogito.explainability.model.Output in project kogito-apps by kiegroup.
the class CounterfactualExplainerServiceHandlerTest method testGetPredictionWithFlatOutputModelReordered.
@Test
public void testGetPredictionWithFlatOutputModelReordered() {
CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, Collections.emptyList(), List.of(new NamedTypedValue("inputsAreValid", new UnitValue("boolean", BooleanNode.FALSE)), new NamedTypedValue("canRequestLoan", new UnitValue("booelan", BooleanNode.TRUE)), new NamedTypedValue("my-scoring-function", new UnitValue("number", new DoubleNode(0.85)))), Collections.emptyList(), MAX_RUNNING_TIME_SECONDS);
Prediction prediction = handler.getPrediction(request);
assertTrue(prediction instanceof CounterfactualPrediction);
CounterfactualPrediction counterfactualPrediction = (CounterfactualPrediction) prediction;
List<Output> outputs = counterfactualPrediction.getOutput().getOutputs();
assertEquals(3, outputs.size());
Output output1 = outputs.get(0);
assertEquals("my-scoring-function", output1.getName());
assertEquals(Type.NUMBER, output1.getType());
assertEquals(0.85, output1.getValue().asNumber());
Output output2 = outputs.get(1);
assertEquals("inputsAreValid", output2.getName());
assertEquals(Type.BOOLEAN, output2.getType());
assertEquals(Boolean.FALSE, output2.getValue().getUnderlyingObject());
Output output3 = outputs.get(2);
assertEquals("canRequestLoan", output3.getName());
assertEquals(Type.BOOLEAN, output3.getType());
assertEquals(Boolean.TRUE, output3.getValue().getUnderlyingObject());
assertTrue(counterfactualPrediction.getInput().getFeatures().isEmpty());
assertEquals(counterfactualPrediction.getMaxRunningTimeSeconds(), request.getMaxRunningTimeSeconds());
}
use of org.kie.kogito.explainability.model.Output in project kogito-apps by kiegroup.
the class CounterfactualExplainerServiceHandlerTest method testCreateSucceededResult.
@Test
public void testCreateSucceededResult() {
CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), MAX_RUNNING_TIME_SECONDS);
List<CounterfactualEntity> entities = List.of(DoubleEntity.from(new Feature("input1", Type.NUMBER, new Value(123.0d)), 0, 1000));
CounterfactualResult counterfactuals = new CounterfactualResult(entities, entities.stream().map(CounterfactualEntity::asFeature).collect(Collectors.toList()), List.of(new PredictionOutput(List.of(new Output("output1", Type.NUMBER, new Value(555.0d), 1.0)))), true, UUID.fromString(SOLUTION_ID), UUID.fromString(EXECUTION_ID), 0);
BaseExplainabilityResult base = handler.createSucceededResult(request, counterfactuals);
assertTrue(base instanceof CounterfactualExplainabilityResult);
CounterfactualExplainabilityResult result = (CounterfactualExplainabilityResult) base;
assertEquals(ExplainabilityStatus.SUCCEEDED, result.getStatus());
assertEquals(CounterfactualExplainabilityResult.Stage.FINAL, result.getStage());
assertEquals(EXECUTION_ID, result.getExecutionId());
assertEquals(COUNTERFACTUAL_ID, result.getCounterfactualId());
assertEquals(1, result.getInputs().size());
assertTrue(result.getInputs().stream().anyMatch(i -> i.getName().equals("input1")));
NamedTypedValue input1 = result.getInputs().iterator().next();
assertEquals(Double.class.getSimpleName(), input1.getValue().getType());
assertEquals(TypedValue.Kind.UNIT, input1.getValue().getKind());
assertEquals(123.0, input1.getValue().toUnit().getValue().asDouble());
assertEquals(1, result.getOutputs().size());
assertTrue(result.getOutputs().stream().anyMatch(o -> o.getName().equals("output1")));
NamedTypedValue output1 = result.getOutputs().iterator().next();
assertEquals(Double.class.getSimpleName(), output1.getValue().getType());
assertEquals(TypedValue.Kind.UNIT, output1.getValue().getKind());
assertEquals(555.0, output1.getValue().toUnit().getValue().asDouble());
}
use of org.kie.kogito.explainability.model.Output in project kogito-apps by kiegroup.
the class CounterfactualExplainerServiceHandlerTest method testCreateSucceededResultWithMoreThanOnePrediction.
@Test
public void testCreateSucceededResultWithMoreThanOnePrediction() {
CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), MAX_RUNNING_TIME_SECONDS);
CounterfactualResult counterfactuals = new CounterfactualResult(Collections.emptyList(), Collections.emptyList(), List.of(new PredictionOutput(List.of(new Output("output1", Type.NUMBER, new Value(555.0d), 1.0))), new PredictionOutput(List.of(new Output("output2", Type.NUMBER, new Value(777.0d), 2.0)))), true, UUID.fromString(SOLUTION_ID), UUID.fromString(EXECUTION_ID), 0);
assertThrows(IllegalStateException.class, () -> handler.createSucceededResult(request, counterfactuals));
}
use of org.kie.kogito.explainability.model.Output in project kogito-apps by kiegroup.
the class LimeExplainerServiceHandlerTest method testCreateSucceededResult.
@Test
public void testCreateSucceededResult() {
LIMEExplainabilityRequest request = new LIMEExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, Collections.emptyList(), Collections.emptyList());
Map<String, Saliency> saliencies = Map.of("s1", new Saliency(new Output("salary", Type.NUMBER), List.of(new FeatureImportance(new Feature("age", Type.NUMBER, new Value(25.0)), 5.0), new FeatureImportance(new Feature("dependents", Type.NUMBER, new Value(2)), -11.0))));
BaseExplainabilityResult base = handler.createSucceededResult(request, saliencies);
assertTrue(base instanceof LIMEExplainabilityResult);
LIMEExplainabilityResult result = (LIMEExplainabilityResult) base;
assertEquals(ExplainabilityStatus.SUCCEEDED, result.getStatus());
assertEquals(EXECUTION_ID, result.getExecutionId());
assertEquals(1, result.getSaliencies().size());
SaliencyModel saliencyModel = result.getSaliencies().iterator().next();
assertEquals(2, saliencyModel.getFeatureImportance().size());
assertEquals("age", saliencyModel.getFeatureImportance().get(0).getFeatureName());
assertEquals(5.0, saliencyModel.getFeatureImportance().get(0).getFeatureScore());
assertEquals("dependents", saliencyModel.getFeatureImportance().get(1).getFeatureName());
assertEquals(-11.0, saliencyModel.getFeatureImportance().get(1).getFeatureScore());
}
use of org.kie.kogito.explainability.model.Output in project kogito-apps by kiegroup.
the class LocalDMNPredictionProvider method toPredictionOutput.
public static PredictionOutput toPredictionOutput(DMNResult dmnResult) {
List<Output> outputs = new ArrayList<>();
for (DMNDecisionResult decisionResult : dmnResult.getDecisionResults()) {
Output output = buildOutput(decisionResult);
outputs.add(output);
}
return new PredictionOutput(outputs);
}
Aggregations