use of org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue in project kogito-apps by kiegroup.
the class TrustyServiceTest method doGivenStoredExecutionWhenCounterfactualRequestIsMadeThenRequestIsStoredTest.
@SuppressWarnings("unchecked")
void doGivenStoredExecutionWhenCounterfactualRequestIsMadeThenRequestIsStoredTest(CounterfactualDomain domain) {
Storage<String, Decision> decisionStorage = mock(Storage.class);
Storage<String, CounterfactualExplainabilityRequest> counterfactualStorage = mock(Storage.class);
ArgumentCaptor<CounterfactualExplainabilityRequest> counterfactualArgumentCaptor = ArgumentCaptor.forClass(CounterfactualExplainabilityRequest.class);
when(decisionStorage.containsKey(eq(TEST_EXECUTION_ID))).thenReturn(true);
when(trustyStorageServiceMock.getDecisionsStorage()).thenReturn(decisionStorage);
when(trustyStorageServiceMock.getCounterfactualRequestStorage()).thenReturn(counterfactualStorage);
when(decisionStorage.get(eq(TEST_EXECUTION_ID))).thenReturn(TrustyServiceTestUtils.buildCorrectDecision(TEST_EXECUTION_ID));
// The Goals structures must be comparable to the original decisions outcomes.
// The Search Domain structures must be identical those of the original decision inputs.
trustyService.requestCounterfactuals(TEST_EXECUTION_ID, List.of(new NamedTypedValue("Fine", new StructureValue("tFine", Map.of("Amount", new UnitValue("number", "number", new IntNode(0)), "Points", new UnitValue("number", "number", new IntNode(0))))), new NamedTypedValue("Should the driver be suspended?", new UnitValue("string", "string", new TextNode("No")))), List.of(new CounterfactualSearchDomain("Violation", new CounterfactualSearchDomainStructureValue("tViolation", Map.of("Type", new CounterfactualSearchDomainUnitValue("string", "string", true, domain), "Actual Speed", new CounterfactualSearchDomainUnitValue("number", "number", true, domain), "Speed Limit", new CounterfactualSearchDomainUnitValue("number", "number", true, domain)))), new CounterfactualSearchDomain("Driver", new CounterfactualSearchDomainStructureValue("tDriver", Map.of("Age", new CounterfactualSearchDomainUnitValue("number", "number", true, domain), "Points", new CounterfactualSearchDomainUnitValue("number", "number", true, domain))))));
verify(counterfactualStorage).put(anyString(), counterfactualArgumentCaptor.capture());
CounterfactualExplainabilityRequest counterfactual = counterfactualArgumentCaptor.getValue();
assertNotNull(counterfactual);
assertEquals(TEST_EXECUTION_ID, counterfactual.getExecutionId());
}
use of org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue in project kogito-apps by kiegroup.
the class CounterfactualExplainerServiceHandlerTest method testGetPredictionWithCollectionSearchDomains.
@Test
public void testGetPredictionWithCollectionSearchDomains() {
CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, Collections.emptyList(), Collections.emptyList(), List.of(new CounterfactualSearchDomain("input1", new CounterfactualSearchDomainCollectionValue("number", List.of(new CounterfactualSearchDomainUnitValue("number", "number", true, new CounterfactualDomainRange(new IntNode(10), new IntNode(20))))))), MAX_RUNNING_TIME_SECONDS);
assertThrows(IllegalArgumentException.class, () -> handler.getPrediction(request));
}
use of org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue in project kogito-apps by kiegroup.
the class CounterfactualExplainerServiceHandlerTest method testGetPredictionWithFlatSearchDomainsNotFixed.
@Test
public void testGetPredictionWithFlatSearchDomainsNotFixed() {
CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, List.of(new NamedTypedValue("output1", new UnitValue("number", new IntNode(25)))), Collections.emptyList(), List.of(new CounterfactualSearchDomain("output1", new CounterfactualSearchDomainUnitValue("number", "number", false, new CounterfactualDomainRange(new IntNode(10), new IntNode(20))))), MAX_RUNNING_TIME_SECONDS);
Prediction prediction = handler.getPrediction(request);
assertTrue(prediction instanceof CounterfactualPrediction);
CounterfactualPrediction counterfactualPrediction = (CounterfactualPrediction) prediction;
assertEquals(1, counterfactualPrediction.getInput().getFeatures().size());
Feature feature1 = counterfactualPrediction.getInput().getFeatures().get(0);
assertTrue(feature1.getDomain() instanceof NumericalFeatureDomain);
final NumericalFeatureDomain domain = (NumericalFeatureDomain) feature1.getDomain();
assertEquals(10, domain.getLowerBound());
assertEquals(20, domain.getUpperBound());
assertEquals(counterfactualPrediction.getMaxRunningTimeSeconds(), request.getMaxRunningTimeSeconds());
}
use of org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue in project kogito-apps by kiegroup.
the class CounterfactualExplainerServiceHandlerTest method testGetPredictionWithFlatSearchDomainsFixed.
@Test
public void testGetPredictionWithFlatSearchDomainsFixed() {
CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, List.of(new NamedTypedValue("output1", new UnitValue("number", new IntNode(25)))), Collections.emptyList(), List.of(new CounterfactualSearchDomain("output1", new CounterfactualSearchDomainUnitValue("number", "number", true, new CounterfactualDomainRange(new IntNode(10), new IntNode(20))))), MAX_RUNNING_TIME_SECONDS);
Prediction prediction = handler.getPrediction(request);
assertTrue(prediction instanceof CounterfactualPrediction);
CounterfactualPrediction counterfactualPrediction = (CounterfactualPrediction) prediction;
assertEquals(1, counterfactualPrediction.getInput().getFeatures().size());
Feature feature1 = counterfactualPrediction.getInput().getFeatures().get(0);
assertTrue(feature1.getDomain() instanceof EmptyFeatureDomain);
assertEquals(counterfactualPrediction.getMaxRunningTimeSeconds(), request.getMaxRunningTimeSeconds());
}
use of org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue in project kogito-apps by kiegroup.
the class ConversionUtilsTest method testToFeatureDomain_UnitCategoricalString.
@Test
void testToFeatureDomain_UnitCategoricalString() {
FeatureDomain featureDomain = ConversionUtils.toFeatureDomain(new CounterfactualSearchDomainUnitValue("string", "string", true, new CounterfactualDomainCategorical(List.of(TextNode.valueOf("Black"), TextNode.valueOf("White")))));
assertTrue(featureDomain instanceof CategoricalFeatureDomain);
CategoricalFeatureDomain categoricalFeatureDomain = (CategoricalFeatureDomain) featureDomain;
assertEquals(2, categoricalFeatureDomain.getCategories().size());
assertTrue(categoricalFeatureDomain.getCategories().containsAll(List.of("White", "Black")));
assertNull(categoricalFeatureDomain.getLowerBound());
assertNull(categoricalFeatureDomain.getUpperBound());
}
Aggregations