use of org.kie.kogito.explainability.api.CounterfactualSearchDomainValue in project kogito-apps by kiegroup.
the class ExplainabilityApiV1IT method testCounterfactualRequestWithStructuredModel.
@Test
@SuppressWarnings("unchecked")
void testCounterfactualRequestWithStructuredModel() {
ArgumentCaptor<List<NamedTypedValue>> goalsCaptor = ArgumentCaptor.forClass(List.class);
ArgumentCaptor<List<CounterfactualSearchDomain>> searchDomainsCaptor = ArgumentCaptor.forClass(List.class);
mockServiceWithCounterfactualRequest();
CounterfactualRequestResponse response = given().filter(new RequestLoggingFilter()).filter(new ResponseLoggingFilter()).contentType(MediaType.APPLICATION_JSON).body(getCounterfactualWithStructuredModelJsonRequest()).when().post("/executions/decisions/" + TEST_EXECUTION_ID + "/explanations/counterfactuals").as(CounterfactualRequestResponse.class);
assertNotNull(response);
assertNotNull(response.getExecutionId());
assertNotNull(response.getCounterfactualId());
assertEquals(response.getExecutionId(), TEST_EXECUTION_ID);
assertEquals(response.getCounterfactualId(), TEST_COUNTERFACTUAL_ID);
verify(executionService).requestCounterfactuals(eq(TEST_EXECUTION_ID), goalsCaptor.capture(), searchDomainsCaptor.capture());
List<NamedTypedValue> goalsParameter = goalsCaptor.getValue();
assertNotNull(goalsParameter);
assertEquals(1, goalsParameter.size());
NamedTypedValue goal1 = goalsParameter.get(0);
assertEquals(TypedValue.Kind.STRUCTURE, goal1.getValue().getKind());
assertEquals("Fine", goal1.getName());
assertEquals("tFine", goal1.getValue().getType());
assertEquals(2, goal1.getValue().toStructure().getValue().size());
Iterator<Map.Entry<String, TypedValue>> goal1ChildIterator = goal1.getValue().toStructure().getValue().entrySet().iterator();
Map.Entry<String, TypedValue> goal1Child1 = goal1ChildIterator.next();
Map.Entry<String, TypedValue> goal1Child2 = goal1ChildIterator.next();
assertEquals(TypedValue.Kind.UNIT, goal1Child1.getValue().getKind());
assertEquals("Amount", goal1Child1.getKey());
assertEquals("number", goal1Child1.getValue().getType());
assertEquals(100, goal1Child1.getValue().toUnit().getValue().asInt());
assertEquals(TypedValue.Kind.UNIT, goal1Child2.getValue().getKind());
assertEquals("Points", goal1Child2.getKey());
assertEquals("number", goal1Child2.getValue().getType());
assertEquals(0, goal1Child2.getValue().toUnit().getValue().asInt());
List<CounterfactualSearchDomain> searchDomainsParameter = searchDomainsCaptor.getValue();
assertNotNull(searchDomainsParameter);
assertEquals(1, searchDomainsParameter.size());
CounterfactualSearchDomain domain1 = searchDomainsParameter.get(0);
assertEquals(TypedValue.Kind.STRUCTURE, domain1.getValue().getKind());
assertEquals("Violation", domain1.getName());
assertEquals("tViolation", domain1.getValue().getType());
assertEquals(3, domain1.getValue().toStructure().getValue().size());
Iterator<Map.Entry<String, CounterfactualSearchDomainValue>> domain1ChildIterator = domain1.getValue().toStructure().getValue().entrySet().iterator();
Map.Entry<String, CounterfactualSearchDomainValue> domain1Child1 = domain1ChildIterator.next();
Map.Entry<String, CounterfactualSearchDomainValue> domain1Child2 = domain1ChildIterator.next();
Map.Entry<String, CounterfactualSearchDomainValue> domain1Child3 = domain1ChildIterator.next();
assertEquals(TypedValue.Kind.UNIT, domain1Child1.getValue().getKind());
assertFalse(domain1Child1.getValue().toUnit().isFixed());
assertEquals("Type", domain1Child1.getKey());
assertEquals("string", domain1Child1.getValue().getType());
assertNotNull(domain1Child1.getValue().toUnit().getDomain());
assertTrue(domain1Child1.getValue().toUnit().getDomain() instanceof CounterfactualDomainCategorical);
CounterfactualDomainCategorical domain1Child1Def = (CounterfactualDomainCategorical) domain1Child1.getValue().toUnit().getDomain();
assertEquals(2, domain1Child1Def.getCategories().size());
assertTrue(domain1Child1Def.getCategories().stream().map(JsonNode::asText).collect(Collectors.toList()).containsAll(Arrays.asList("speed", "driving under the influence")));
assertEquals(TypedValue.Kind.UNIT, domain1Child2.getValue().getKind());
assertFalse(domain1Child2.getValue().toUnit().isFixed());
assertEquals("Actual Speed", domain1Child2.getKey());
assertEquals("number", domain1Child2.getValue().getType());
assertNotNull(domain1Child2.getValue().toUnit().getDomain());
assertTrue(domain1Child2.getValue().toUnit().getDomain() instanceof CounterfactualDomainRange);
CounterfactualDomainRange domain1Child2Def = (CounterfactualDomainRange) domain1Child2.getValue().toUnit().getDomain();
assertEquals(0, domain1Child2Def.getLowerBound().asInt());
assertEquals(100, domain1Child2Def.getUpperBound().asInt());
assertEquals(TypedValue.Kind.UNIT, domain1Child3.getValue().getKind());
assertFalse(domain1Child3.getValue().toUnit().isFixed());
assertEquals("Speed Limit", domain1Child3.getKey());
assertEquals("number", domain1Child3.getValue().getType());
assertNotNull(domain1Child3.getValue().toUnit().getDomain());
assertTrue(domain1Child3.getValue().toUnit().getDomain() instanceof CounterfactualDomainRange);
CounterfactualDomainRange domain1Child3Def = (CounterfactualDomainRange) domain1Child3.getValue().toUnit().getDomain();
assertEquals(0, domain1Child3Def.getLowerBound().asInt());
assertEquals(100, domain1Child3Def.getUpperBound().asInt());
}
use of org.kie.kogito.explainability.api.CounterfactualSearchDomainValue in project kogito-apps by kiegroup.
the class TrustyServiceTest method givenStoredExecutionWhenCounterfactualRequestIsMadeThenExplainabilityEventHasCorrectPayload.
@Test
@SuppressWarnings("unchecked")
void givenStoredExecutionWhenCounterfactualRequestIsMadeThenExplainabilityEventHasCorrectPayload() {
Storage<String, Decision> decisionStorage = mock(Storage.class);
Storage<String, CounterfactualExplainabilityRequest> counterfactualStorage = mock(Storage.class);
ArgumentCaptor<BaseExplainabilityRequest> explainabilityEventArgumentCaptor = ArgumentCaptor.forClass(BaseExplainabilityRequest.class);
Decision decision = new Decision(TEST_EXECUTION_ID, TEST_SOURCE_URL, TEST_SERVICE_URL, 0L, true, null, "model", "modelNamespace", List.of(new DecisionInput("IN1", "yearsOfService", new UnitValue("integer", "integer", new IntNode(10)))), List.of(new DecisionOutcome("OUT1", "salary", "SUCCEEDED", new UnitValue("integer", "integer", new IntNode(1000)), Collections.emptyList(), Collections.emptyList())));
when(decisionStorage.containsKey(eq(TEST_EXECUTION_ID))).thenReturn(true);
when(trustyStorageServiceMock.getDecisionsStorage()).thenReturn(decisionStorage);
when(trustyStorageServiceMock.getCounterfactualRequestStorage()).thenReturn(counterfactualStorage);
when(decisionStorage.get(eq(TEST_EXECUTION_ID))).thenReturn(decision);
trustyService.requestCounterfactuals(TEST_EXECUTION_ID, List.of(new NamedTypedValue("salary", new UnitValue("integer", "integer", new IntNode(2000)))), List.of(new CounterfactualSearchDomain("yearsOfService", new CounterfactualSearchDomainUnitValue("integer", "integer", false, new CounterfactualDomainRange(new IntNode(10), new IntNode(30))))));
verify(explainabilityRequestProducerMock).sendEvent(explainabilityEventArgumentCaptor.capture());
BaseExplainabilityRequest event = explainabilityEventArgumentCaptor.getValue();
CounterfactualExplainabilityRequest request = (CounterfactualExplainabilityRequest) event;
assertEquals(TEST_EXECUTION_ID, request.getExecutionId());
assertEquals(TEST_SERVICE_URL, request.getServiceUrl());
// Check original input value has been copied into CF request
assertEquals(1, request.getOriginalInputs().size());
assertTrue(request.getOriginalInputs().stream().anyMatch(i -> i.getName().equals("yearsOfService")));
// It is safe to use the iterator unchecked as the collection only contains one item
assertEquals(decision.getInputs().iterator().next().getValue().toUnit().getValue().asInt(), request.getOriginalInputs().iterator().next().getValue().toUnit().getValue().asInt());
// Check CF goals have been copied into CF request
assertEquals(1, request.getGoals().size());
assertTrue(request.getGoals().stream().anyMatch(g -> g.getName().equals("salary")));
// It is safe to use the iterator unchecked as the collection only contains one item
assertEquals(2000, request.getGoals().iterator().next().getValue().toUnit().getValue().asInt());
// Check CF search domains have been copied into CF request
assertEquals(1, request.getSearchDomains().size());
assertTrue(request.getSearchDomains().stream().anyMatch(sd -> sd.getName().equals("yearsOfService")));
// It is safe to use the iterator unchecked as the collection only contains one item
CounterfactualSearchDomainValue searchDomain = request.getSearchDomains().iterator().next().getValue();
assertTrue(searchDomain instanceof CounterfactualSearchDomainUnitValue);
CounterfactualSearchDomainUnitValue unit = (CounterfactualSearchDomainUnitValue) searchDomain;
assertFalse(unit.isFixed());
assertNotNull(unit.getDomain());
assertTrue(unit.getDomain() instanceof CounterfactualDomainRange);
CounterfactualDomainRange range = (CounterfactualDomainRange) unit.getDomain();
assertEquals(10, range.getLowerBound().asInt());
assertEquals(30, range.getUpperBound().asInt());
// Check Max Running Time Seconds
assertEquals(MAX_RUNNING_TIME_SECONDS, request.getMaxRunningTimeSeconds());
}
Aggregations