use of org.kie.kogito.tracing.typedvalue.UnitValue in project kogito-apps by kiegroup.
the class CounterfactualExplainerServiceHandlerTest method testGetPredictionWithFlatOutputModelReordered.
@Test
public void testGetPredictionWithFlatOutputModelReordered() {
CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, Collections.emptyList(), List.of(new NamedTypedValue("inputsAreValid", new UnitValue("boolean", BooleanNode.FALSE)), new NamedTypedValue("canRequestLoan", new UnitValue("booelan", BooleanNode.TRUE)), new NamedTypedValue("my-scoring-function", new UnitValue("number", new DoubleNode(0.85)))), Collections.emptyList(), MAX_RUNNING_TIME_SECONDS);
Prediction prediction = handler.getPrediction(request);
assertTrue(prediction instanceof CounterfactualPrediction);
CounterfactualPrediction counterfactualPrediction = (CounterfactualPrediction) prediction;
List<Output> outputs = counterfactualPrediction.getOutput().getOutputs();
assertEquals(3, outputs.size());
Output output1 = outputs.get(0);
assertEquals("my-scoring-function", output1.getName());
assertEquals(Type.NUMBER, output1.getType());
assertEquals(0.85, output1.getValue().asNumber());
Output output2 = outputs.get(1);
assertEquals("inputsAreValid", output2.getName());
assertEquals(Type.BOOLEAN, output2.getType());
assertEquals(Boolean.FALSE, output2.getValue().getUnderlyingObject());
Output output3 = outputs.get(2);
assertEquals("canRequestLoan", output3.getName());
assertEquals(Type.BOOLEAN, output3.getType());
assertEquals(Boolean.TRUE, output3.getValue().getUnderlyingObject());
assertTrue(counterfactualPrediction.getInput().getFeatures().isEmpty());
assertEquals(counterfactualPrediction.getMaxRunningTimeSeconds(), request.getMaxRunningTimeSeconds());
}
use of org.kie.kogito.tracing.typedvalue.UnitValue in project kogito-apps by kiegroup.
the class CounterfactualExplainerServiceHandlerTest method testGetPredictionWithCollectionOutputModel.
@Test
public void testGetPredictionWithCollectionOutputModel() {
CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, Collections.emptyList(), List.of(new NamedTypedValue("input1", new CollectionValue("number", List.of(new UnitValue("number", new IntNode(100)))))), Collections.emptyList(), MAX_RUNNING_TIME_SECONDS);
assertThrows(IllegalArgumentException.class, () -> handler.getPrediction(request));
}
use of org.kie.kogito.tracing.typedvalue.UnitValue in project kogito-apps by kiegroup.
the class ConversionUtilsTest method toOutputTypedValue.
@Test
void toOutputTypedValue() {
Output name = ConversionUtils.toOutput("name", new UnitValue("number", new DoubleNode(10d)));
assertNotNull(name);
assertEquals("name", name.getName());
assertEquals(Type.NUMBER, name.getType());
assertEquals(10d, name.getValue().getUnderlyingObject());
Output name1 = ConversionUtils.toOutput("name1", new StructureValue("complex", singletonMap("key", new UnitValue("string1", new TextNode("stringValue")))));
assertNotNull(name1);
assertEquals("name1", name1.getName());
assertEquals(Type.COMPOSITE, name1.getType());
assertTrue(name1.getValue().getUnderlyingObject() instanceof List);
@SuppressWarnings("unchecked") List<Output> outputs = (List<Output>) name1.getValue().getUnderlyingObject();
assertEquals(1, outputs.size());
assertEquals(Type.TEXT, outputs.get(0).getType());
assertEquals("stringValue", outputs.get(0).getValue().getUnderlyingObject());
List<TypedValue> values = List.of(new UnitValue("number", new DoubleNode(0d)), new UnitValue("number", new DoubleNode(1d)));
assertNotNull(ConversionUtils.toOutput("name", new CollectionValue("list", values)));
}
use of org.kie.kogito.tracing.typedvalue.UnitValue in project kogito-apps by kiegroup.
the class TrustyServiceTest method doGivenStoredExecutionWhenCounterfactualRequestIsMadeThenExplainabilityEventIsEmittedTest.
@SuppressWarnings("unchecked")
void doGivenStoredExecutionWhenCounterfactualRequestIsMadeThenExplainabilityEventIsEmittedTest(CounterfactualDomain domain) {
Storage<String, Decision> decisionStorage = mock(Storage.class);
Storage<String, CounterfactualExplainabilityRequest> counterfactualStorage = mock(Storage.class);
ArgumentCaptor<BaseExplainabilityRequest> explainabilityEventArgumentCaptor = ArgumentCaptor.forClass(BaseExplainabilityRequest.class);
when(decisionStorage.containsKey(eq(TEST_EXECUTION_ID))).thenReturn(true);
when(trustyStorageServiceMock.getDecisionsStorage()).thenReturn(decisionStorage);
when(trustyStorageServiceMock.getCounterfactualRequestStorage()).thenReturn(counterfactualStorage);
when(decisionStorage.get(eq(TEST_EXECUTION_ID))).thenReturn(TrustyServiceTestUtils.buildCorrectDecision(TEST_EXECUTION_ID));
// The Goals structures must be comparable to the original decisions outcomes.
// The Search Domain structures must be identical those of the original decision inputs.
trustyService.requestCounterfactuals(TEST_EXECUTION_ID, List.of(new NamedTypedValue("Fine", new StructureValue("tFine", Map.of("Amount", new UnitValue("number", "number", new IntNode(0)), "Points", new UnitValue("number", "number", new IntNode(0))))), new NamedTypedValue("Should the driver be suspended?", new UnitValue("string", "string", new TextNode("No")))), List.of(new CounterfactualSearchDomain("Violation", new CounterfactualSearchDomainStructureValue("tViolation", Map.of("Type", new CounterfactualSearchDomainUnitValue("string", "string", true, domain), "Actual Speed", new CounterfactualSearchDomainUnitValue("number", "number", true, domain), "Speed Limit", new CounterfactualSearchDomainUnitValue("number", "number", true, domain)))), new CounterfactualSearchDomain("Driver", new CounterfactualSearchDomainStructureValue("tDriver", Map.of("Age", new CounterfactualSearchDomainUnitValue("number", "number", true, domain), "Points", new CounterfactualSearchDomainUnitValue("number", "number", true, domain))))));
verify(explainabilityRequestProducerMock).sendEvent(explainabilityEventArgumentCaptor.capture());
BaseExplainabilityRequest event = explainabilityEventArgumentCaptor.getValue();
assertNotNull(event);
assertTrue(event instanceof CounterfactualExplainabilityRequest);
CounterfactualExplainabilityRequest request = (CounterfactualExplainabilityRequest) event;
assertEquals(TEST_EXECUTION_ID, request.getExecutionId());
}
use of org.kie.kogito.tracing.typedvalue.UnitValue in project kogito-apps by kiegroup.
the class TrustyServiceTest method givenADecisionToProcessWhenExplainabilityIsEnabledThenRequestIsSent.
@Test
@SuppressWarnings("unchecked")
void givenADecisionToProcessWhenExplainabilityIsEnabledThenRequestIsSent() throws JsonProcessingException {
trustyService.enableExplainability();
Decision decision = new Decision(TEST_EXECUTION_ID, TEST_SOURCE_URL, TEST_SERVICE_URL, 1591692950000L, true, null, "model", "modelNamespace", List.of(new DecisionInput("1", "Input1", new CollectionValue("string", List.of(new UnitValue("string", "string", toJsonNode("\"ONE\"")), new UnitValue("string", "string", toJsonNode("\"TWO\""))))), new DecisionInput("2", "Input2", new StructureValue("Person", Map.of("Name", new UnitValue("string", "string", toJsonNode("\"George Orwell\"")), "Age", new UnitValue("number", "number", toJsonNode("45")))))), List.of(new DecisionOutcome("OUT1", "Result", "SUCCEEDED", new UnitValue("string", "string", toJsonNode("\"YES\"")), Collections.emptyList(), Collections.emptyList())));
Storage<String, Decision> decisionStorageMock = mock(Storage.class);
when(decisionStorageMock.containsKey(eq(TEST_EXECUTION_ID))).thenReturn(false);
when(trustyStorageServiceMock.getDecisionsStorage()).thenReturn(decisionStorageMock);
trustyService.processDecision(TEST_EXECUTION_ID, decision);
verify(explainabilityRequestProducerMock).sendEvent(any());
}
Aggregations