use of org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue in project kogito-apps by kiegroup.
the class ConversionUtilsTest method getDoubleSearchDomain.
private static CounterfactualSearchDomain getDoubleSearchDomain(String name, double lowerBound, double upperBound) {
final CounterfactualDomainRange range = new CounterfactualDomainRange(DoubleNode.valueOf(lowerBound), DoubleNode.valueOf(upperBound));
CounterfactualSearchDomainUnitValue searchDomain = new CounterfactualSearchDomainUnitValue("double", "double", Boolean.FALSE, range);
return new CounterfactualSearchDomain(name, searchDomain);
}
use of org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue in project kogito-apps by kiegroup.
the class ConversionUtilsTest method testToFeatureDomain_UnitFixedString.
@Test
void testToFeatureDomain_UnitFixedString() {
FeatureDomain featureDomain = ConversionUtils.toFeatureDomain(new CounterfactualSearchDomainUnitValue("string", "string", true, null));
assertTrue(featureDomain instanceof EmptyFeatureDomain);
}
use of org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue in project kogito-apps by kiegroup.
the class ConversionUtilsTest method testToFeatureDomain_UnitRangeInteger.
@Test
void testToFeatureDomain_UnitRangeInteger() {
FeatureDomain featureDomain = ConversionUtils.toFeatureDomain(new CounterfactualSearchDomainUnitValue("int", "int", true, new CounterfactualDomainRange(IntNode.valueOf(18), IntNode.valueOf(65))));
assertTrue(featureDomain instanceof NumericalFeatureDomain);
NumericalFeatureDomain numericalFeatureDomain = (NumericalFeatureDomain) featureDomain;
assertEquals(18.0, numericalFeatureDomain.getLowerBound());
assertEquals(65.0, numericalFeatureDomain.getUpperBound());
assertNull(numericalFeatureDomain.getCategories());
}
use of org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue in project kogito-apps by kiegroup.
the class TrustyServiceTest method givenStoredExecutionWhenCounterfactualRequestIsMadeThenExplainabilityEventHasCorrectPayload.
@Test
@SuppressWarnings("unchecked")
void givenStoredExecutionWhenCounterfactualRequestIsMadeThenExplainabilityEventHasCorrectPayload() {
Storage<String, Decision> decisionStorage = mock(Storage.class);
Storage<String, CounterfactualExplainabilityRequest> counterfactualStorage = mock(Storage.class);
ArgumentCaptor<BaseExplainabilityRequest> explainabilityEventArgumentCaptor = ArgumentCaptor.forClass(BaseExplainabilityRequest.class);
Decision decision = new Decision(TEST_EXECUTION_ID, TEST_SOURCE_URL, TEST_SERVICE_URL, 0L, true, null, "model", "modelNamespace", List.of(new DecisionInput("IN1", "yearsOfService", new UnitValue("integer", "integer", new IntNode(10)))), List.of(new DecisionOutcome("OUT1", "salary", "SUCCEEDED", new UnitValue("integer", "integer", new IntNode(1000)), Collections.emptyList(), Collections.emptyList())));
when(decisionStorage.containsKey(eq(TEST_EXECUTION_ID))).thenReturn(true);
when(trustyStorageServiceMock.getDecisionsStorage()).thenReturn(decisionStorage);
when(trustyStorageServiceMock.getCounterfactualRequestStorage()).thenReturn(counterfactualStorage);
when(decisionStorage.get(eq(TEST_EXECUTION_ID))).thenReturn(decision);
trustyService.requestCounterfactuals(TEST_EXECUTION_ID, List.of(new NamedTypedValue("salary", new UnitValue("integer", "integer", new IntNode(2000)))), List.of(new CounterfactualSearchDomain("yearsOfService", new CounterfactualSearchDomainUnitValue("integer", "integer", false, new CounterfactualDomainRange(new IntNode(10), new IntNode(30))))));
verify(explainabilityRequestProducerMock).sendEvent(explainabilityEventArgumentCaptor.capture());
BaseExplainabilityRequest event = explainabilityEventArgumentCaptor.getValue();
CounterfactualExplainabilityRequest request = (CounterfactualExplainabilityRequest) event;
assertEquals(TEST_EXECUTION_ID, request.getExecutionId());
assertEquals(TEST_SERVICE_URL, request.getServiceUrl());
// Check original input value has been copied into CF request
assertEquals(1, request.getOriginalInputs().size());
assertTrue(request.getOriginalInputs().stream().anyMatch(i -> i.getName().equals("yearsOfService")));
// It is safe to use the iterator unchecked as the collection only contains one item
assertEquals(decision.getInputs().iterator().next().getValue().toUnit().getValue().asInt(), request.getOriginalInputs().iterator().next().getValue().toUnit().getValue().asInt());
// Check CF goals have been copied into CF request
assertEquals(1, request.getGoals().size());
assertTrue(request.getGoals().stream().anyMatch(g -> g.getName().equals("salary")));
// It is safe to use the iterator unchecked as the collection only contains one item
assertEquals(2000, request.getGoals().iterator().next().getValue().toUnit().getValue().asInt());
// Check CF search domains have been copied into CF request
assertEquals(1, request.getSearchDomains().size());
assertTrue(request.getSearchDomains().stream().anyMatch(sd -> sd.getName().equals("yearsOfService")));
// It is safe to use the iterator unchecked as the collection only contains one item
CounterfactualSearchDomainValue searchDomain = request.getSearchDomains().iterator().next().getValue();
assertTrue(searchDomain instanceof CounterfactualSearchDomainUnitValue);
CounterfactualSearchDomainUnitValue unit = (CounterfactualSearchDomainUnitValue) searchDomain;
assertFalse(unit.isFixed());
assertNotNull(unit.getDomain());
assertTrue(unit.getDomain() instanceof CounterfactualDomainRange);
CounterfactualDomainRange range = (CounterfactualDomainRange) unit.getDomain();
assertEquals(10, range.getLowerBound().asInt());
assertEquals(30, range.getUpperBound().asInt());
// Check Max Running Time Seconds
assertEquals(MAX_RUNNING_TIME_SECONDS, request.getMaxRunningTimeSeconds());
}
use of org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue in project kogito-apps by kiegroup.
the class CounterfactualDomainSerialisationTest method testCounterfactualSearchDomain_Categorical_RoundTrip.
@Test
public void testCounterfactualSearchDomain_Categorical_RoundTrip() throws Exception {
CounterfactualDomainCategorical domainCategorical = new CounterfactualDomainCategorical(List.of(new TextNode("A"), new TextNode("B")));
CounterfactualSearchDomain searchDomain = new CounterfactualSearchDomain("age", new CounterfactualSearchDomainUnitValue("integer", "integer", Boolean.TRUE, domainCategorical));
mapper.writeValue(writer, searchDomain);
String searchDomainJson = writer.toString();
assertNotNull(searchDomainJson);
CounterfactualSearchDomain roundTrippedSearchDomain = mapper.readValue(searchDomainJson, CounterfactualSearchDomain.class);
assertTrue(roundTrippedSearchDomain.getValue() instanceof CounterfactualSearchDomainUnitValue);
assertEquals(searchDomain.getValue().getKind(), roundTrippedSearchDomain.getValue().getKind());
assertEquals(searchDomain.getName(), roundTrippedSearchDomain.getName());
assertEquals(searchDomain.getValue().getType(), roundTrippedSearchDomain.getValue().getType());
assertEquals(searchDomain.getValue().toUnit().getBaseType(), roundTrippedSearchDomain.getValue().toUnit().getBaseType());
assertEquals(searchDomain.getValue().toUnit().isFixed(), roundTrippedSearchDomain.getValue().toUnit().isFixed());
assertTrue(roundTrippedSearchDomain.getValue().toUnit().getDomain() instanceof CounterfactualDomainCategorical);
CounterfactualDomainCategorical roundTrippedDomainCategorical = (CounterfactualDomainCategorical) roundTrippedSearchDomain.getValue().toUnit().getDomain();
assertEquals(domainCategorical.getCategories().size(), roundTrippedDomainCategorical.getCategories().size());
assertTrue(roundTrippedDomainCategorical.getCategories().containsAll(domainCategorical.getCategories()));
}
Aggregations