use of org.kie.kogito.explainability.api.CounterfactualDomainRange in project kogito-apps by kiegroup.
the class AbstractTrustyServiceIT method testCounterfactuals_StoreSingleAndRetrieveSingleWithSearchDomainRange.
@Test
public void testCounterfactuals_StoreSingleAndRetrieveSingleWithSearchDomainRange() {
String executionId = "myCFExecution1";
storeExecution(executionId, 0L);
// The Goals structures must be comparable to the original decisions outcomes.
// The Search Domain structures must be identical those of the original decision inputs.
CounterfactualSearchDomain searchDomain = buildSearchDomainUnit("test", "number", new CounterfactualDomainRange(new IntNode(1), new IntNode(2)));
CounterfactualExplainabilityRequest request = trustyService.requestCounterfactuals(executionId, Collections.emptyList(), Collections.singletonList(searchDomain));
assertNotNull(request);
assertEquals(request.getExecutionId(), executionId);
assertNotNull(request.getCounterfactualId());
assertEquals(1, request.getSearchDomains().size());
List<CounterfactualSearchDomain> requestSearchDomains = new ArrayList<>(request.getSearchDomains());
assertCounterfactualSearchDomainRange(searchDomain, requestSearchDomains.get(0));
CounterfactualExplainabilityRequest result = trustyService.getCounterfactualRequest(executionId, request.getCounterfactualId());
assertNotNull(result);
assertEquals(request.getExecutionId(), result.getExecutionId());
assertEquals(request.getCounterfactualId(), result.getCounterfactualId());
assertEquals(1, result.getSearchDomains().size());
List<CounterfactualSearchDomain> resultSearchDomains = new ArrayList<>(result.getSearchDomains());
assertCounterfactualSearchDomainRange(searchDomain, resultSearchDomains.get(0));
}
use of org.kie.kogito.explainability.api.CounterfactualDomainRange in project kogito-apps by kiegroup.
the class AbstractTrustyServiceIT method assertCounterfactualDomainRange.
private void assertCounterfactualDomainRange(CounterfactualDomain expectedDomain, CounterfactualDomain actualDomain) {
assertTrue(expectedDomain instanceof CounterfactualDomainRange);
assertTrue(actualDomain instanceof CounterfactualDomainRange);
CounterfactualDomainRange expectedDomainRange = (CounterfactualDomainRange) expectedDomain;
CounterfactualDomainRange actualDomainRange = (CounterfactualDomainRange) actualDomain;
assertEquals(expectedDomainRange.getLowerBound(), actualDomainRange.getLowerBound());
assertEquals(expectedDomainRange.getUpperBound(), actualDomainRange.getUpperBound());
}
use of org.kie.kogito.explainability.api.CounterfactualDomainRange in project kogito-apps by kiegroup.
the class AbstractTrustyServiceIT method testCounterfactuals_StoreMultipleAndRetrieveAllWithEmptyDefinition.
@Test
public void testCounterfactuals_StoreMultipleAndRetrieveAllWithEmptyDefinition() {
String executionId = "myCFExecution2";
storeExecution(executionId, 0L);
// The Goals structures must be comparable to the original decisions outcomes.
// The Search Domain structures must be identical those of the original decision inputs.
CounterfactualSearchDomain searchDomain = buildSearchDomainUnit("test", "number", new CounterfactualDomainRange(new IntNode(1), new IntNode(2)));
CounterfactualExplainabilityRequest request1 = trustyService.requestCounterfactuals(executionId, Collections.emptyList(), Collections.singletonList(searchDomain));
CounterfactualExplainabilityRequest request2 = trustyService.requestCounterfactuals(executionId, Collections.emptyList(), Collections.singletonList(searchDomain));
List<CounterfactualExplainabilityRequest> result = trustyService.getCounterfactualRequests(executionId);
assertNotNull(result);
assertEquals(2, result.size());
assertTrue(result.stream().anyMatch(c -> c.getCounterfactualId().equals(request1.getCounterfactualId())));
assertTrue(result.stream().anyMatch(c -> c.getCounterfactualId().equals(request2.getCounterfactualId())));
}
use of org.kie.kogito.explainability.api.CounterfactualDomainRange in project kogito-apps by kiegroup.
the class CounterfactualExplainerServiceHandlerTest method testGetPredictionWithCollectionSearchDomains.
@Test
public void testGetPredictionWithCollectionSearchDomains() {
CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, Collections.emptyList(), Collections.emptyList(), List.of(new CounterfactualSearchDomain("input1", new CounterfactualSearchDomainCollectionValue("number", List.of(new CounterfactualSearchDomainUnitValue("number", "number", true, new CounterfactualDomainRange(new IntNode(10), new IntNode(20))))))), MAX_RUNNING_TIME_SECONDS);
assertThrows(IllegalArgumentException.class, () -> handler.getPrediction(request));
}
use of org.kie.kogito.explainability.api.CounterfactualDomainRange in project kogito-apps by kiegroup.
the class CounterfactualExplainerServiceHandlerTest method testGetPredictionWithFlatSearchDomainsNotFixed.
@Test
public void testGetPredictionWithFlatSearchDomainsNotFixed() {
CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, List.of(new NamedTypedValue("output1", new UnitValue("number", new IntNode(25)))), Collections.emptyList(), List.of(new CounterfactualSearchDomain("output1", new CounterfactualSearchDomainUnitValue("number", "number", false, new CounterfactualDomainRange(new IntNode(10), new IntNode(20))))), MAX_RUNNING_TIME_SECONDS);
Prediction prediction = handler.getPrediction(request);
assertTrue(prediction instanceof CounterfactualPrediction);
CounterfactualPrediction counterfactualPrediction = (CounterfactualPrediction) prediction;
assertEquals(1, counterfactualPrediction.getInput().getFeatures().size());
Feature feature1 = counterfactualPrediction.getInput().getFeatures().get(0);
assertTrue(feature1.getDomain() instanceof NumericalFeatureDomain);
final NumericalFeatureDomain domain = (NumericalFeatureDomain) feature1.getDomain();
assertEquals(10, domain.getLowerBound());
assertEquals(20, domain.getUpperBound());
assertEquals(counterfactualPrediction.getMaxRunningTimeSeconds(), request.getMaxRunningTimeSeconds());
}
Aggregations