use of org.kie.kogito.explainability.api.CounterfactualDomainCategorical in project kogito-apps by kiegroup.
the class ConversionUtils method toCounterfactualSearchDomain.
static Optional<FeatureDomain> toCounterfactualSearchDomain(CounterfactualDomain domain) {
if (Objects.isNull(domain)) {
return Optional.of(EmptyFeatureDomain.create());
} else if (domain instanceof CounterfactualDomainRange) {
CounterfactualDomainRange range = (CounterfactualDomainRange) domain;
JsonNode lb = range.getLowerBound();
JsonNode ub = range.getUpperBound();
if (lb.isNumber() && ub.isNumber()) {
return Optional.of(NumericalFeatureDomain.create(range.getLowerBound().asDouble(), range.getUpperBound().asDouble()));
} else {
throw new IllegalArgumentException(String.format("Unsupported CounterfactualDomainRange [%s, %s]", lb.asText(), ub.asText()));
}
} else if (domain instanceof CounterfactualDomainCategorical) {
CounterfactualDomainCategorical categorical = (CounterfactualDomainCategorical) domain;
Collection<JsonNode> jsonCategories = categorical.getCategories();
String[] categories = new String[jsonCategories.size()];
if (jsonCategories.stream().allMatch(JsonNode::isTextual)) {
jsonCategories.stream().map(JsonNode::asText).collect(Collectors.toList()).toArray(categories);
return Optional.of(CategoricalFeatureDomain.create(categories));
} else {
throw new IllegalArgumentException(String.format("Unsupported CounterfactualDomainCategorical [%s]", String.join(", ", categories)));
}
}
return Optional.empty();
}
use of org.kie.kogito.explainability.api.CounterfactualDomainCategorical in project kogito-apps by kiegroup.
the class ConversionUtilsTest method testToFeatureDomain_UnitCategoricalString.
@Test
void testToFeatureDomain_UnitCategoricalString() {
FeatureDomain featureDomain = ConversionUtils.toFeatureDomain(new CounterfactualSearchDomainUnitValue("string", "string", true, new CounterfactualDomainCategorical(List.of(TextNode.valueOf("Black"), TextNode.valueOf("White")))));
assertTrue(featureDomain instanceof CategoricalFeatureDomain);
CategoricalFeatureDomain categoricalFeatureDomain = (CategoricalFeatureDomain) featureDomain;
assertEquals(2, categoricalFeatureDomain.getCategories().size());
assertTrue(categoricalFeatureDomain.getCategories().containsAll(List.of("White", "Black")));
assertNull(categoricalFeatureDomain.getLowerBound());
assertNull(categoricalFeatureDomain.getUpperBound());
}
use of org.kie.kogito.explainability.api.CounterfactualDomainCategorical in project kogito-apps by kiegroup.
the class ExplainabilityApiV1IT method testCounterfactualRequestWithStructuredModel.
@Test
@SuppressWarnings("unchecked")
void testCounterfactualRequestWithStructuredModel() {
ArgumentCaptor<List<NamedTypedValue>> goalsCaptor = ArgumentCaptor.forClass(List.class);
ArgumentCaptor<List<CounterfactualSearchDomain>> searchDomainsCaptor = ArgumentCaptor.forClass(List.class);
mockServiceWithCounterfactualRequest();
CounterfactualRequestResponse response = given().filter(new RequestLoggingFilter()).filter(new ResponseLoggingFilter()).contentType(MediaType.APPLICATION_JSON).body(getCounterfactualWithStructuredModelJsonRequest()).when().post("/executions/decisions/" + TEST_EXECUTION_ID + "/explanations/counterfactuals").as(CounterfactualRequestResponse.class);
assertNotNull(response);
assertNotNull(response.getExecutionId());
assertNotNull(response.getCounterfactualId());
assertEquals(response.getExecutionId(), TEST_EXECUTION_ID);
assertEquals(response.getCounterfactualId(), TEST_COUNTERFACTUAL_ID);
verify(executionService).requestCounterfactuals(eq(TEST_EXECUTION_ID), goalsCaptor.capture(), searchDomainsCaptor.capture());
List<NamedTypedValue> goalsParameter = goalsCaptor.getValue();
assertNotNull(goalsParameter);
assertEquals(1, goalsParameter.size());
NamedTypedValue goal1 = goalsParameter.get(0);
assertEquals(TypedValue.Kind.STRUCTURE, goal1.getValue().getKind());
assertEquals("Fine", goal1.getName());
assertEquals("tFine", goal1.getValue().getType());
assertEquals(2, goal1.getValue().toStructure().getValue().size());
Iterator<Map.Entry<String, TypedValue>> goal1ChildIterator = goal1.getValue().toStructure().getValue().entrySet().iterator();
Map.Entry<String, TypedValue> goal1Child1 = goal1ChildIterator.next();
Map.Entry<String, TypedValue> goal1Child2 = goal1ChildIterator.next();
assertEquals(TypedValue.Kind.UNIT, goal1Child1.getValue().getKind());
assertEquals("Amount", goal1Child1.getKey());
assertEquals("number", goal1Child1.getValue().getType());
assertEquals(100, goal1Child1.getValue().toUnit().getValue().asInt());
assertEquals(TypedValue.Kind.UNIT, goal1Child2.getValue().getKind());
assertEquals("Points", goal1Child2.getKey());
assertEquals("number", goal1Child2.getValue().getType());
assertEquals(0, goal1Child2.getValue().toUnit().getValue().asInt());
List<CounterfactualSearchDomain> searchDomainsParameter = searchDomainsCaptor.getValue();
assertNotNull(searchDomainsParameter);
assertEquals(1, searchDomainsParameter.size());
CounterfactualSearchDomain domain1 = searchDomainsParameter.get(0);
assertEquals(TypedValue.Kind.STRUCTURE, domain1.getValue().getKind());
assertEquals("Violation", domain1.getName());
assertEquals("tViolation", domain1.getValue().getType());
assertEquals(3, domain1.getValue().toStructure().getValue().size());
Iterator<Map.Entry<String, CounterfactualSearchDomainValue>> domain1ChildIterator = domain1.getValue().toStructure().getValue().entrySet().iterator();
Map.Entry<String, CounterfactualSearchDomainValue> domain1Child1 = domain1ChildIterator.next();
Map.Entry<String, CounterfactualSearchDomainValue> domain1Child2 = domain1ChildIterator.next();
Map.Entry<String, CounterfactualSearchDomainValue> domain1Child3 = domain1ChildIterator.next();
assertEquals(TypedValue.Kind.UNIT, domain1Child1.getValue().getKind());
assertFalse(domain1Child1.getValue().toUnit().isFixed());
assertEquals("Type", domain1Child1.getKey());
assertEquals("string", domain1Child1.getValue().getType());
assertNotNull(domain1Child1.getValue().toUnit().getDomain());
assertTrue(domain1Child1.getValue().toUnit().getDomain() instanceof CounterfactualDomainCategorical);
CounterfactualDomainCategorical domain1Child1Def = (CounterfactualDomainCategorical) domain1Child1.getValue().toUnit().getDomain();
assertEquals(2, domain1Child1Def.getCategories().size());
assertTrue(domain1Child1Def.getCategories().stream().map(JsonNode::asText).collect(Collectors.toList()).containsAll(Arrays.asList("speed", "driving under the influence")));
assertEquals(TypedValue.Kind.UNIT, domain1Child2.getValue().getKind());
assertFalse(domain1Child2.getValue().toUnit().isFixed());
assertEquals("Actual Speed", domain1Child2.getKey());
assertEquals("number", domain1Child2.getValue().getType());
assertNotNull(domain1Child2.getValue().toUnit().getDomain());
assertTrue(domain1Child2.getValue().toUnit().getDomain() instanceof CounterfactualDomainRange);
CounterfactualDomainRange domain1Child2Def = (CounterfactualDomainRange) domain1Child2.getValue().toUnit().getDomain();
assertEquals(0, domain1Child2Def.getLowerBound().asInt());
assertEquals(100, domain1Child2Def.getUpperBound().asInt());
assertEquals(TypedValue.Kind.UNIT, domain1Child3.getValue().getKind());
assertFalse(domain1Child3.getValue().toUnit().isFixed());
assertEquals("Speed Limit", domain1Child3.getKey());
assertEquals("number", domain1Child3.getValue().getType());
assertNotNull(domain1Child3.getValue().toUnit().getDomain());
assertTrue(domain1Child3.getValue().toUnit().getDomain() instanceof CounterfactualDomainRange);
CounterfactualDomainRange domain1Child3Def = (CounterfactualDomainRange) domain1Child3.getValue().toUnit().getDomain();
assertEquals(0, domain1Child3Def.getLowerBound().asInt());
assertEquals(100, domain1Child3Def.getUpperBound().asInt());
}
use of org.kie.kogito.explainability.api.CounterfactualDomainCategorical in project kogito-apps by kiegroup.
the class ExplainabilityApiV1IT method testCounterfactualRequest.
@Test
@SuppressWarnings("unchecked")
void testCounterfactualRequest() {
ArgumentCaptor<List<NamedTypedValue>> goalsCaptor = ArgumentCaptor.forClass(List.class);
ArgumentCaptor<List<CounterfactualSearchDomain>> searchDomainsCaptor = ArgumentCaptor.forClass(List.class);
mockServiceWithCounterfactualRequest();
CounterfactualRequestResponse response = given().filter(new RequestLoggingFilter()).filter(new ResponseLoggingFilter()).contentType(MediaType.APPLICATION_JSON).body(getCounterfactualJsonRequest()).when().post("/executions/decisions/" + TEST_EXECUTION_ID + "/explanations/counterfactuals").as(CounterfactualRequestResponse.class);
assertNotNull(response);
assertNotNull(response.getExecutionId());
assertNotNull(response.getCounterfactualId());
assertEquals(response.getExecutionId(), TEST_EXECUTION_ID);
assertEquals(response.getCounterfactualId(), TEST_COUNTERFACTUAL_ID);
verify(executionService).requestCounterfactuals(eq(TEST_EXECUTION_ID), goalsCaptor.capture(), searchDomainsCaptor.capture());
List<NamedTypedValue> goalsParameter = goalsCaptor.getValue();
assertNotNull(goalsParameter);
assertEquals(2, goalsParameter.size());
NamedTypedValue goal1 = goalsParameter.get(0);
assertEquals(TypedValue.Kind.UNIT, goal1.getValue().getKind());
assertEquals("deposit", goal1.getName());
assertEquals("number", goal1.getValue().getType());
assertEquals(5000, goal1.getValue().toUnit().getValue().asInt());
NamedTypedValue goal2 = goalsParameter.get(1);
assertEquals(TypedValue.Kind.UNIT, goal2.getValue().getKind());
assertEquals("approved", goal2.getName());
assertEquals("boolean", goal2.getValue().getType());
assertEquals(Boolean.TRUE, goal2.getValue().toUnit().getValue().asBoolean());
List<CounterfactualSearchDomain> searchDomainsParameter = searchDomainsCaptor.getValue();
assertNotNull(searchDomainsParameter);
assertEquals(3, searchDomainsParameter.size());
CounterfactualSearchDomain domain1 = searchDomainsParameter.get(0);
assertEquals(TypedValue.Kind.UNIT, domain1.getValue().getKind());
assertTrue(domain1.getValue().toUnit().isFixed());
assertEquals("age", domain1.getName());
assertEquals("number", domain1.getValue().getType());
assertNull(domain1.getValue().toUnit().getDomain());
CounterfactualSearchDomain domain2 = searchDomainsParameter.get(1);
assertEquals(TypedValue.Kind.UNIT, domain2.getValue().getKind());
assertFalse(domain2.getValue().toUnit().isFixed());
assertEquals("income", domain2.getName());
assertEquals("number", domain2.getValue().getType());
assertNotNull(domain2.getValue().toUnit().getDomain());
assertTrue(domain2.getValue().toUnit().getDomain() instanceof CounterfactualDomainRange);
CounterfactualDomainRange domain2Def = (CounterfactualDomainRange) domain2.getValue().toUnit().getDomain();
assertEquals(0, domain2Def.getLowerBound().asInt());
assertEquals(1000, domain2Def.getUpperBound().asInt());
CounterfactualSearchDomain domain3 = searchDomainsParameter.get(2);
assertEquals(TypedValue.Kind.UNIT, domain3.getValue().getKind());
assertFalse(domain3.getValue().toUnit().isFixed());
assertEquals("taxCode", domain3.getName());
assertEquals("string", domain3.getValue().getType());
assertNotNull(domain3.getValue().toUnit().getDomain());
assertTrue(domain3.getValue().toUnit().getDomain() instanceof CounterfactualDomainCategorical);
CounterfactualDomainCategorical domain3Def = (CounterfactualDomainCategorical) domain3.getValue().toUnit().getDomain();
assertEquals(3, domain3Def.getCategories().size());
assertTrue(domain3Def.getCategories().stream().map(JsonNode::asText).collect(Collectors.toList()).containsAll(Arrays.asList("A", "B", "C")));
}
use of org.kie.kogito.explainability.api.CounterfactualDomainCategorical in project kogito-apps by kiegroup.
the class ExplainabilityApiV1Test method testGetCounterfactualResultsWhenExecutionDoesExistAndResultsHaveBeenCreated.
@Test
public void testGetCounterfactualResultsWhenExecutionDoesExistAndResultsHaveBeenCreated() {
NamedTypedValue goal = buildGoalUnit("unit", "string", new TextNode("hello"));
CounterfactualSearchDomain searchDomain = buildSearchDomainUnit("unit", "string", new CounterfactualDomainCategorical(List.of(new TextNode("hello"), new TextNode("goodbye"))));
CounterfactualExplainabilityResult solution1 = new CounterfactualExplainabilityResult(EXECUTION_ID, COUNTERFACTUAL_ID, "solution1", 0L, ExplainabilityStatus.SUCCEEDED, "", true, CounterfactualExplainabilityResult.Stage.INTERMEDIATE, Collections.emptyList(), Collections.emptyList());
CounterfactualExplainabilityResult solution2 = new CounterfactualExplainabilityResult(EXECUTION_ID, COUNTERFACTUAL_ID, "solution2", 1L, ExplainabilityStatus.SUCCEEDED, "", true, CounterfactualExplainabilityResult.Stage.FINAL, Collections.emptyList(), Collections.emptyList());
when(trustyService.getCounterfactualRequest(anyString(), anyString())).thenReturn(new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, new ModelIdentifier("resourceType", "resourceIdentifier"), COUNTERFACTUAL_ID, Collections.emptyList(), List.of(goal), List.of(searchDomain), MAX_RUNNING_TIME_SECONDS));
when(trustyService.getCounterfactualResults(anyString(), anyString())).thenReturn(List.of(solution1, solution2));
Response response = explainabilityEndpoint.getCounterfactualDetails(EXECUTION_ID, COUNTERFACTUAL_ID);
assertNotNull(response);
assertEquals(Response.Status.OK.getStatusCode(), response.getStatus());
Object entity = response.getEntity();
assertNotNull(entity);
assertTrue(entity instanceof CounterfactualResultsResponse);
CounterfactualResultsResponse resultsResponse = (CounterfactualResultsResponse) entity;
assertEquals(EXECUTION_ID, resultsResponse.getExecutionId());
assertEquals(COUNTERFACTUAL_ID, resultsResponse.getCounterfactualId());
assertEquals(MAX_RUNNING_TIME_SECONDS, resultsResponse.getMaxRunningTimeSeconds());
assertEquals(1, resultsResponse.getGoals().size());
assertEquals(goal, resultsResponse.getGoals().iterator().next());
assertEquals(1, resultsResponse.getSearchDomains().size());
assertEquals(searchDomain, resultsResponse.getSearchDomains().iterator().next());
assertEquals(2, resultsResponse.getSolutions().size());
assertEquals(solution1, resultsResponse.getSolutions().get(0));
assertEquals(solution2, resultsResponse.getSolutions().get(1));
}
Aggregations