use of org.kie.kogito.explainability.model.PredictionOutput in project kogito-apps by kiegroup.
the class CounterfactualExplainerServiceHandlerTest method testCreateIntermediateResult.
@Test
public void testCreateIntermediateResult() {
CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), MAX_RUNNING_TIME_SECONDS);
List<CounterfactualEntity> entities = List.of(DoubleEntity.from(new Feature("input1", Type.NUMBER, new Value(123.0d)), 0, 1000));
CounterfactualResult counterfactuals = new CounterfactualResult(entities, entities.stream().map(CounterfactualEntity::asFeature).collect(Collectors.toList()), List.of(new PredictionOutput(List.of(new Output("output1", Type.NUMBER, new Value(555.0d), 1.0)))), true, UUID.fromString(SOLUTION_ID), UUID.fromString(EXECUTION_ID), 0);
BaseExplainabilityResult base = handler.createIntermediateResult(request, counterfactuals);
assertTrue(base instanceof CounterfactualExplainabilityResult);
CounterfactualExplainabilityResult result = (CounterfactualExplainabilityResult) base;
assertEquals(ExplainabilityStatus.SUCCEEDED, result.getStatus());
assertEquals(CounterfactualExplainabilityResult.Stage.INTERMEDIATE, result.getStage());
assertEquals(EXECUTION_ID, result.getExecutionId());
assertEquals(COUNTERFACTUAL_ID, result.getCounterfactualId());
assertEquals(1, result.getInputs().size());
assertTrue(result.getInputs().stream().anyMatch(i -> i.getName().equals("input1")));
NamedTypedValue input1 = result.getInputs().iterator().next();
assertEquals(Double.class.getSimpleName(), input1.getValue().getType());
assertEquals(TypedValue.Kind.UNIT, input1.getValue().getKind());
assertEquals(123.0, input1.getValue().toUnit().getValue().asDouble());
assertEquals(1, result.getOutputs().size());
assertTrue(result.getOutputs().stream().anyMatch(o -> o.getName().equals("output1")));
NamedTypedValue output1 = result.getOutputs().iterator().next();
assertEquals(Double.class.getSimpleName(), output1.getValue().getType());
assertEquals(TypedValue.Kind.UNIT, output1.getValue().getKind());
assertEquals(555.0, output1.getValue().toUnit().getValue().asDouble());
}
use of org.kie.kogito.explainability.model.PredictionOutput in project kogito-apps by kiegroup.
the class CounterfactualExplainerServiceHandlerTest method testCreateSucceededResult.
@Test
public void testCreateSucceededResult() {
CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), MAX_RUNNING_TIME_SECONDS);
List<CounterfactualEntity> entities = List.of(DoubleEntity.from(new Feature("input1", Type.NUMBER, new Value(123.0d)), 0, 1000));
CounterfactualResult counterfactuals = new CounterfactualResult(entities, entities.stream().map(CounterfactualEntity::asFeature).collect(Collectors.toList()), List.of(new PredictionOutput(List.of(new Output("output1", Type.NUMBER, new Value(555.0d), 1.0)))), true, UUID.fromString(SOLUTION_ID), UUID.fromString(EXECUTION_ID), 0);
BaseExplainabilityResult base = handler.createSucceededResult(request, counterfactuals);
assertTrue(base instanceof CounterfactualExplainabilityResult);
CounterfactualExplainabilityResult result = (CounterfactualExplainabilityResult) base;
assertEquals(ExplainabilityStatus.SUCCEEDED, result.getStatus());
assertEquals(CounterfactualExplainabilityResult.Stage.FINAL, result.getStage());
assertEquals(EXECUTION_ID, result.getExecutionId());
assertEquals(COUNTERFACTUAL_ID, result.getCounterfactualId());
assertEquals(1, result.getInputs().size());
assertTrue(result.getInputs().stream().anyMatch(i -> i.getName().equals("input1")));
NamedTypedValue input1 = result.getInputs().iterator().next();
assertEquals(Double.class.getSimpleName(), input1.getValue().getType());
assertEquals(TypedValue.Kind.UNIT, input1.getValue().getKind());
assertEquals(123.0, input1.getValue().toUnit().getValue().asDouble());
assertEquals(1, result.getOutputs().size());
assertTrue(result.getOutputs().stream().anyMatch(o -> o.getName().equals("output1")));
NamedTypedValue output1 = result.getOutputs().iterator().next();
assertEquals(Double.class.getSimpleName(), output1.getValue().getType());
assertEquals(TypedValue.Kind.UNIT, output1.getValue().getKind());
assertEquals(555.0, output1.getValue().toUnit().getValue().asDouble());
}
use of org.kie.kogito.explainability.model.PredictionOutput in project kogito-apps by kiegroup.
the class CounterfactualExplainerServiceHandlerTest method testCreateSucceededResultWithMoreThanOnePrediction.
@Test
public void testCreateSucceededResultWithMoreThanOnePrediction() {
CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), MAX_RUNNING_TIME_SECONDS);
CounterfactualResult counterfactuals = new CounterfactualResult(Collections.emptyList(), Collections.emptyList(), List.of(new PredictionOutput(List.of(new Output("output1", Type.NUMBER, new Value(555.0d), 1.0))), new PredictionOutput(List.of(new Output("output2", Type.NUMBER, new Value(777.0d), 2.0)))), true, UUID.fromString(SOLUTION_ID), UUID.fromString(EXECUTION_ID), 0);
assertThrows(IllegalStateException.class, () -> handler.createSucceededResult(request, counterfactuals));
}
use of org.kie.kogito.explainability.model.PredictionOutput in project kogito-apps by kiegroup.
the class LocalDMNPredictionProvider method toPredictionOutput.
public static PredictionOutput toPredictionOutput(DMNResult dmnResult) {
List<Output> outputs = new ArrayList<>();
for (DMNDecisionResult decisionResult : dmnResult.getDecisionResults()) {
Output output = buildOutput(decisionResult);
outputs.add(output);
}
return new PredictionOutput(outputs);
}
use of org.kie.kogito.explainability.model.PredictionOutput in project kogito-apps by kiegroup.
the class LocalDMNPredictionProvider method predictAsync.
@Override
@SuppressWarnings("unchecked")
public CompletableFuture<List<PredictionOutput>> predictAsync(List<PredictionInput> inputs) {
List<PredictionOutput> predictionOutputs = new ArrayList<>();
for (PredictionInput input : inputs) {
Map<String, Object> contextVariables = (Map<String, Object>) toMap(input.getFeatures()).get(DUMMY_DMN_CONTEXT_KEY);
predictionOutputs.add(toPredictionOutput(dmnEvaluator.evaluate(contextVariables)));
}
return completedFuture(predictionOutputs);
}
Aggregations