use of org.kie.kogito.explainability.api.CounterfactualSearchDomain in project kogito-apps by kiegroup.
the class AbstractTrustyServiceIT method testCounterfactuals_StoreSingleAndRetrieveSingleWithSearchDomainCategorical.
@Test
public void testCounterfactuals_StoreSingleAndRetrieveSingleWithSearchDomainCategorical() {
String executionId = "myCFExecution1";
storeExecution(executionId, 0L);
// The Goals structures must be comparable to the original decisions outcomes.
// The Search Domain structures must be identical those of the original decision inputs.
Collection<JsonNode> categories = List.of(new TextNode("1"), new TextNode("2"));
CounterfactualSearchDomain searchDomain = buildSearchDomainUnit("test", "number", new CounterfactualDomainCategorical(categories));
CounterfactualExplainabilityRequest request = trustyService.requestCounterfactuals(executionId, Collections.emptyList(), Collections.singletonList(searchDomain));
assertNotNull(request);
assertEquals(request.getExecutionId(), executionId);
assertNotNull(request.getCounterfactualId());
assertEquals(1, request.getSearchDomains().size());
List<CounterfactualSearchDomain> requestSearchDomains = new ArrayList<>(request.getSearchDomains());
assertCounterfactualSearchDomainCategorical(searchDomain, requestSearchDomains.get(0));
CounterfactualExplainabilityRequest result = trustyService.getCounterfactualRequest(executionId, request.getCounterfactualId());
assertNotNull(result);
assertEquals(request.getExecutionId(), result.getExecutionId());
assertEquals(request.getCounterfactualId(), result.getCounterfactualId());
assertEquals(1, result.getSearchDomains().size());
List<CounterfactualSearchDomain> resultSearchDomains = new ArrayList<>(result.getSearchDomains());
assertCounterfactualSearchDomainCategorical(searchDomain, resultSearchDomains.get(0));
}
use of org.kie.kogito.explainability.api.CounterfactualSearchDomain in project kogito-apps by kiegroup.
the class CounterfactualDomainSerialisationTest method testCounterfactualSearchDomain_Categorical_RoundTrip.
@Test
public void testCounterfactualSearchDomain_Categorical_RoundTrip() throws Exception {
CounterfactualDomainCategorical domainCategorical = new CounterfactualDomainCategorical(List.of(new TextNode("A"), new TextNode("B")));
CounterfactualSearchDomain searchDomain = new CounterfactualSearchDomain("age", new CounterfactualSearchDomainUnitValue("integer", "integer", Boolean.TRUE, domainCategorical));
mapper.writeValue(writer, searchDomain);
String searchDomainJson = writer.toString();
assertNotNull(searchDomainJson);
CounterfactualSearchDomain roundTrippedSearchDomain = mapper.readValue(searchDomainJson, CounterfactualSearchDomain.class);
assertTrue(roundTrippedSearchDomain.getValue() instanceof CounterfactualSearchDomainUnitValue);
assertEquals(searchDomain.getValue().getKind(), roundTrippedSearchDomain.getValue().getKind());
assertEquals(searchDomain.getName(), roundTrippedSearchDomain.getName());
assertEquals(searchDomain.getValue().getType(), roundTrippedSearchDomain.getValue().getType());
assertEquals(searchDomain.getValue().toUnit().getBaseType(), roundTrippedSearchDomain.getValue().toUnit().getBaseType());
assertEquals(searchDomain.getValue().toUnit().isFixed(), roundTrippedSearchDomain.getValue().toUnit().isFixed());
assertTrue(roundTrippedSearchDomain.getValue().toUnit().getDomain() instanceof CounterfactualDomainCategorical);
CounterfactualDomainCategorical roundTrippedDomainCategorical = (CounterfactualDomainCategorical) roundTrippedSearchDomain.getValue().toUnit().getDomain();
assertEquals(domainCategorical.getCategories().size(), roundTrippedDomainCategorical.getCategories().size());
assertTrue(roundTrippedDomainCategorical.getCategories().containsAll(domainCategorical.getCategories()));
}
use of org.kie.kogito.explainability.api.CounterfactualSearchDomain in project kogito-apps by kiegroup.
the class CounterfactualDomainSerialisationTest method testCounterfactualSearchDomain_Range_RoundTrip.
@Test
public void testCounterfactualSearchDomain_Range_RoundTrip() throws Exception {
CounterfactualDomainRange domainRange = new CounterfactualDomainRange(new IntNode(18), new IntNode(65));
CounterfactualSearchDomain searchDomain = new CounterfactualSearchDomain("age", new CounterfactualSearchDomainUnitValue("integer", "integer", Boolean.TRUE, domainRange));
mapper.writeValue(writer, searchDomain);
String searchDomainJson = writer.toString();
assertNotNull(searchDomainJson);
CounterfactualSearchDomain roundTrippedSearchDomain = mapper.readValue(searchDomainJson, CounterfactualSearchDomain.class);
assertTrue(roundTrippedSearchDomain.getValue() instanceof CounterfactualSearchDomainUnitValue);
assertEquals(searchDomain.getValue().getKind(), roundTrippedSearchDomain.getValue().getKind());
assertEquals(searchDomain.getName(), roundTrippedSearchDomain.getName());
assertEquals(searchDomain.getValue().getType(), roundTrippedSearchDomain.getValue().getType());
assertEquals(searchDomain.getValue().toUnit().getBaseType(), roundTrippedSearchDomain.getValue().toUnit().getBaseType());
assertEquals(searchDomain.getValue().toUnit().isFixed(), roundTrippedSearchDomain.getValue().toUnit().isFixed());
assertTrue(roundTrippedSearchDomain.getValue().toUnit().getDomain() instanceof CounterfactualDomainRange);
CounterfactualDomainRange roundTrippedDomainRange = (CounterfactualDomainRange) roundTrippedSearchDomain.getValue().toUnit().getDomain();
assertEquals(domainRange.getLowerBound(), roundTrippedDomainRange.getLowerBound());
assertEquals(domainRange.getUpperBound(), roundTrippedDomainRange.getUpperBound());
}
Aggregations