use of org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue in project kogito-apps by kiegroup.
the class CounterfactualExplainabilityRequestMarshallerTest method testWriteAndRead.
@Test
public void testWriteAndRead() throws IOException {
ModelIdentifier modelIdentifier = new ModelIdentifier("resourceType", "resourceId");
List<NamedTypedValue> originalInputs = Collections.singletonList(new NamedTypedValue("unitIn", new UnitValue("number", "number", JsonNodeFactory.instance.numberNode(10))));
List<NamedTypedValue> goals = Collections.singletonList(new NamedTypedValue("unitIn", new UnitValue("number", "number", JsonNodeFactory.instance.numberNode(10))));
List<CounterfactualSearchDomain> searchDomains = Collections.singletonList(new CounterfactualSearchDomain("age", new CounterfactualSearchDomainUnitValue("integer", "integer", Boolean.TRUE, new CounterfactualDomainRange(JsonNodeFactory.instance.numberNode(0), JsonNodeFactory.instance.numberNode(10)))));
CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest("executionId", "serviceUrl", modelIdentifier, "counterfactualId", originalInputs, goals, searchDomains, 60L);
CounterfactualExplainabilityRequestMarshaller marshaller = new CounterfactualExplainabilityRequestMarshaller(new ObjectMapper());
marshaller.writeTo(writer, request);
CounterfactualExplainabilityRequest retrieved = marshaller.readFrom(reader);
Assertions.assertEquals(request.getExecutionId(), retrieved.getExecutionId());
Assertions.assertEquals(request.getCounterfactualId(), retrieved.getCounterfactualId());
Assertions.assertEquals(goals.get(0).getName(), retrieved.getGoals().stream().findFirst().get().getName());
Assertions.assertEquals(searchDomains.get(0).getName(), retrieved.getSearchDomains().stream().findFirst().get().getName());
Assertions.assertEquals(0, ((CounterfactualDomainRange) retrieved.getSearchDomains().stream().findFirst().get().getValue().toUnit().getDomain()).getLowerBound().asInt());
Assertions.assertEquals(60L, request.getMaxRunningTimeSeconds());
}
use of org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue in project kogito-apps by kiegroup.
the class CounterfactualExplainerServiceHandlerTest method testGetPredictionWithStructuredSearchDomains.
@Test
public void testGetPredictionWithStructuredSearchDomains() {
CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, Collections.emptyList(), Collections.emptyList(), List.of(new CounterfactualSearchDomain("input1", new CounterfactualSearchDomainStructureValue("number", Map.of("input2b", new CounterfactualSearchDomainUnitValue("number", "number", true, new CounterfactualDomainRange(new IntNode(10), new IntNode(20))))))), MAX_RUNNING_TIME_SECONDS);
assertThrows(IllegalArgumentException.class, () -> handler.getPrediction(request));
}
use of org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue in project kogito-apps by kiegroup.
the class ConversionUtilsTest method testToFeatureDomain_UnitFixedNumber.
@Test
void testToFeatureDomain_UnitFixedNumber() {
FeatureDomain featureDomain = ConversionUtils.toFeatureDomain(new CounterfactualSearchDomainUnitValue("integer", "integer", true, null));
assertTrue(featureDomain instanceof EmptyFeatureDomain);
}
use of org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue in project kogito-apps by kiegroup.
the class ConversionUtilsTest method testToFeatureDomain_UnitRangeDouble.
@Test
void testToFeatureDomain_UnitRangeDouble() {
FeatureDomain featureDomain = ConversionUtils.toFeatureDomain(new CounterfactualSearchDomainUnitValue("double", "double", true, new CounterfactualDomainRange(DoubleNode.valueOf(-273.15), DoubleNode.valueOf(Double.MAX_VALUE))));
assertTrue(featureDomain instanceof NumericalFeatureDomain);
NumericalFeatureDomain numericalFeatureDomain = (NumericalFeatureDomain) featureDomain;
assertEquals(-273.15, numericalFeatureDomain.getLowerBound());
assertEquals(Double.MAX_VALUE, numericalFeatureDomain.getUpperBound());
assertNull(numericalFeatureDomain.getCategories());
}
use of org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue in project kogito-apps by kiegroup.
the class TrustyServiceTest method doGivenStoredExecutionWhenCounterfactualRequestIsMadeThenExplainabilityEventIsEmittedTest.
@SuppressWarnings("unchecked")
void doGivenStoredExecutionWhenCounterfactualRequestIsMadeThenExplainabilityEventIsEmittedTest(CounterfactualDomain domain) {
Storage<String, Decision> decisionStorage = mock(Storage.class);
Storage<String, CounterfactualExplainabilityRequest> counterfactualStorage = mock(Storage.class);
ArgumentCaptor<BaseExplainabilityRequest> explainabilityEventArgumentCaptor = ArgumentCaptor.forClass(BaseExplainabilityRequest.class);
when(decisionStorage.containsKey(eq(TEST_EXECUTION_ID))).thenReturn(true);
when(trustyStorageServiceMock.getDecisionsStorage()).thenReturn(decisionStorage);
when(trustyStorageServiceMock.getCounterfactualRequestStorage()).thenReturn(counterfactualStorage);
when(decisionStorage.get(eq(TEST_EXECUTION_ID))).thenReturn(TrustyServiceTestUtils.buildCorrectDecision(TEST_EXECUTION_ID));
// The Goals structures must be comparable to the original decisions outcomes.
// The Search Domain structures must be identical those of the original decision inputs.
trustyService.requestCounterfactuals(TEST_EXECUTION_ID, List.of(new NamedTypedValue("Fine", new StructureValue("tFine", Map.of("Amount", new UnitValue("number", "number", new IntNode(0)), "Points", new UnitValue("number", "number", new IntNode(0))))), new NamedTypedValue("Should the driver be suspended?", new UnitValue("string", "string", new TextNode("No")))), List.of(new CounterfactualSearchDomain("Violation", new CounterfactualSearchDomainStructureValue("tViolation", Map.of("Type", new CounterfactualSearchDomainUnitValue("string", "string", true, domain), "Actual Speed", new CounterfactualSearchDomainUnitValue("number", "number", true, domain), "Speed Limit", new CounterfactualSearchDomainUnitValue("number", "number", true, domain)))), new CounterfactualSearchDomain("Driver", new CounterfactualSearchDomainStructureValue("tDriver", Map.of("Age", new CounterfactualSearchDomainUnitValue("number", "number", true, domain), "Points", new CounterfactualSearchDomainUnitValue("number", "number", true, domain))))));
verify(explainabilityRequestProducerMock).sendEvent(explainabilityEventArgumentCaptor.capture());
BaseExplainabilityRequest event = explainabilityEventArgumentCaptor.getValue();
assertNotNull(event);
assertTrue(event instanceof CounterfactualExplainabilityRequest);
CounterfactualExplainabilityRequest request = (CounterfactualExplainabilityRequest) event;
assertEquals(TEST_EXECUTION_ID, request.getExecutionId());
}
Aggregations