use of org.kie.kogito.explainability.api.NamedTypedValue in project kogito-apps by kiegroup.
the class CounterfactualExplainabilityRequestMarshallerTest method testWriteAndRead.
@Test
public void testWriteAndRead() throws IOException {
ModelIdentifier modelIdentifier = new ModelIdentifier("resourceType", "resourceId");
List<NamedTypedValue> originalInputs = Collections.singletonList(new NamedTypedValue("unitIn", new UnitValue("number", "number", JsonNodeFactory.instance.numberNode(10))));
List<NamedTypedValue> goals = Collections.singletonList(new NamedTypedValue("unitIn", new UnitValue("number", "number", JsonNodeFactory.instance.numberNode(10))));
List<CounterfactualSearchDomain> searchDomains = Collections.singletonList(new CounterfactualSearchDomain("age", new CounterfactualSearchDomainUnitValue("integer", "integer", Boolean.TRUE, new CounterfactualDomainRange(JsonNodeFactory.instance.numberNode(0), JsonNodeFactory.instance.numberNode(10)))));
CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest("executionId", "serviceUrl", modelIdentifier, "counterfactualId", originalInputs, goals, searchDomains, 60L);
CounterfactualExplainabilityRequestMarshaller marshaller = new CounterfactualExplainabilityRequestMarshaller(new ObjectMapper());
marshaller.writeTo(writer, request);
CounterfactualExplainabilityRequest retrieved = marshaller.readFrom(reader);
Assertions.assertEquals(request.getExecutionId(), retrieved.getExecutionId());
Assertions.assertEquals(request.getCounterfactualId(), retrieved.getCounterfactualId());
Assertions.assertEquals(goals.get(0).getName(), retrieved.getGoals().stream().findFirst().get().getName());
Assertions.assertEquals(searchDomains.get(0).getName(), retrieved.getSearchDomains().stream().findFirst().get().getName());
Assertions.assertEquals(0, ((CounterfactualDomainRange) retrieved.getSearchDomains().stream().findFirst().get().getValue().toUnit().getDomain()).getLowerBound().asInt());
Assertions.assertEquals(60L, request.getMaxRunningTimeSeconds());
}
use of org.kie.kogito.explainability.api.NamedTypedValue in project kogito-apps by kiegroup.
the class DecisionMarshallerTest method testWriteAndRead.
@Test
public void testWriteAndRead() throws IOException {
List<DecisionInput> inputs = Collections.singletonList(new DecisionInput("id", "in", new UnitValue("nameIn", "number", JsonNodeFactory.instance.numberNode(10))));
List<DecisionOutcome> outcomes = Collections.singletonList(new DecisionOutcome("id", "out", DMNDecisionResult.DecisionEvaluationStatus.SUCCEEDED.toString(), new UnitValue("nameOut", "number", JsonNodeFactory.instance.numberNode(10)), List.of(new NamedTypedValue("nameOut", new UnitValue("number", "number", JsonNodeFactory.instance.numberNode(10)))), new ArrayList<>()));
Decision decision = new Decision("executionId", "source", "serviceUrl", 0L, true, "executor", "model", "namespace", inputs, outcomes);
DecisionMarshaller marshaller = new DecisionMarshaller(new ObjectMapper());
marshaller.writeTo(writer, decision);
Decision retrieved = marshaller.readFrom(reader);
Assertions.assertEquals(decision.getExecutionId(), retrieved.getExecutionId());
Assertions.assertEquals(decision.getSourceUrl(), retrieved.getSourceUrl());
Assertions.assertEquals(decision.getServiceUrl(), retrieved.getServiceUrl());
Assertions.assertEquals(decision.getExecutedModelName(), retrieved.getExecutedModelName());
Assertions.assertEquals(inputs.get(0).getName(), retrieved.getInputs().stream().findFirst().get().getName());
Assertions.assertEquals(inputs.get(0).getValue().getType(), retrieved.getInputs().stream().findFirst().get().getValue().getType());
Assertions.assertEquals(inputs.get(0).getValue().toUnit().getBaseType(), retrieved.getInputs().stream().findFirst().get().getValue().toUnit().getBaseType());
Assertions.assertEquals(outcomes.get(0).getOutcomeId(), retrieved.getOutcomes().stream().findFirst().get().getOutcomeId());
Assertions.assertTrue(retrieved.getOutcomes().stream().findFirst().get().getMessages().isEmpty());
}
use of org.kie.kogito.explainability.api.NamedTypedValue in project kogito-apps by kiegroup.
the class CounterfactualExplainerServiceHandlerTest method testCreateIntermediateResult.
@Test
public void testCreateIntermediateResult() {
CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), MAX_RUNNING_TIME_SECONDS);
List<CounterfactualEntity> entities = List.of(DoubleEntity.from(new Feature("input1", Type.NUMBER, new Value(123.0d)), 0, 1000));
CounterfactualResult counterfactuals = new CounterfactualResult(entities, entities.stream().map(CounterfactualEntity::asFeature).collect(Collectors.toList()), List.of(new PredictionOutput(List.of(new Output("output1", Type.NUMBER, new Value(555.0d), 1.0)))), true, UUID.fromString(SOLUTION_ID), UUID.fromString(EXECUTION_ID), 0);
BaseExplainabilityResult base = handler.createIntermediateResult(request, counterfactuals);
assertTrue(base instanceof CounterfactualExplainabilityResult);
CounterfactualExplainabilityResult result = (CounterfactualExplainabilityResult) base;
assertEquals(ExplainabilityStatus.SUCCEEDED, result.getStatus());
assertEquals(CounterfactualExplainabilityResult.Stage.INTERMEDIATE, result.getStage());
assertEquals(EXECUTION_ID, result.getExecutionId());
assertEquals(COUNTERFACTUAL_ID, result.getCounterfactualId());
assertEquals(1, result.getInputs().size());
assertTrue(result.getInputs().stream().anyMatch(i -> i.getName().equals("input1")));
NamedTypedValue input1 = result.getInputs().iterator().next();
assertEquals(Double.class.getSimpleName(), input1.getValue().getType());
assertEquals(TypedValue.Kind.UNIT, input1.getValue().getKind());
assertEquals(123.0, input1.getValue().toUnit().getValue().asDouble());
assertEquals(1, result.getOutputs().size());
assertTrue(result.getOutputs().stream().anyMatch(o -> o.getName().equals("output1")));
NamedTypedValue output1 = result.getOutputs().iterator().next();
assertEquals(Double.class.getSimpleName(), output1.getValue().getType());
assertEquals(TypedValue.Kind.UNIT, output1.getValue().getKind());
assertEquals(555.0, output1.getValue().toUnit().getValue().asDouble());
}
use of org.kie.kogito.explainability.api.NamedTypedValue in project kogito-apps by kiegroup.
the class CounterfactualExplainerServiceHandlerTest method testGetPredictionWithStructuredInputModel.
@Test
public void testGetPredictionWithStructuredInputModel() {
CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, List.of(new NamedTypedValue("input1", new StructureValue("number", Map.of("input2b", new UnitValue("number", new IntNode(55)))))), Collections.emptyList(), Collections.emptyList(), MAX_RUNNING_TIME_SECONDS);
assertThrows(IllegalArgumentException.class, () -> handler.getPrediction(request));
}
use of org.kie.kogito.explainability.api.NamedTypedValue in project kogito-apps by kiegroup.
the class CounterfactualExplainerServiceHandlerTest method testGetPredictionWithFlatOutputModelReordered.
@Test
public void testGetPredictionWithFlatOutputModelReordered() {
CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, Collections.emptyList(), List.of(new NamedTypedValue("inputsAreValid", new UnitValue("boolean", BooleanNode.FALSE)), new NamedTypedValue("canRequestLoan", new UnitValue("booelan", BooleanNode.TRUE)), new NamedTypedValue("my-scoring-function", new UnitValue("number", new DoubleNode(0.85)))), Collections.emptyList(), MAX_RUNNING_TIME_SECONDS);
Prediction prediction = handler.getPrediction(request);
assertTrue(prediction instanceof CounterfactualPrediction);
CounterfactualPrediction counterfactualPrediction = (CounterfactualPrediction) prediction;
List<Output> outputs = counterfactualPrediction.getOutput().getOutputs();
assertEquals(3, outputs.size());
Output output1 = outputs.get(0);
assertEquals("my-scoring-function", output1.getName());
assertEquals(Type.NUMBER, output1.getType());
assertEquals(0.85, output1.getValue().asNumber());
Output output2 = outputs.get(1);
assertEquals("inputsAreValid", output2.getName());
assertEquals(Type.BOOLEAN, output2.getType());
assertEquals(Boolean.FALSE, output2.getValue().getUnderlyingObject());
Output output3 = outputs.get(2);
assertEquals("canRequestLoan", output3.getName());
assertEquals(Type.BOOLEAN, output3.getType());
assertEquals(Boolean.TRUE, output3.getValue().getUnderlyingObject());
assertTrue(counterfactualPrediction.getInput().getFeatures().isEmpty());
assertEquals(counterfactualPrediction.getMaxRunningTimeSeconds(), request.getMaxRunningTimeSeconds());
}
Aggregations