use of org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest in project kogito-apps by kiegroup.
the class TrustyServiceImpl method requestCounterfactuals.
@Override
public CounterfactualExplainabilityRequest requestCounterfactuals(String executionId, List<NamedTypedValue> goals, List<CounterfactualSearchDomain> searchDomains) {
Storage<String, Decision> storage = storageService.getDecisionsStorage();
if (!storage.containsKey(executionId)) {
throw new IllegalArgumentException(String.format("A decision with ID %s is not present in the storage. Counterfactuals cannot be requested.", executionId));
}
CounterfactualExplainabilityRequest request = makeCounterfactualRequest(executionId, goals, searchDomains, maxRunningTimeSeconds);
storeCounterfactualRequest(request);
sendCounterfactualRequestEvent(request);
return request;
}
use of org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest in project kogito-apps by kiegroup.
the class CounterfactualExplainerServiceHandlerTest method testGetPredictionWithStructuredOutputModel.
@Test
public void testGetPredictionWithStructuredOutputModel() {
CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, Collections.emptyList(), List.of(new NamedTypedValue("input1", new StructureValue("number", Map.of("input2b", new UnitValue("number", new IntNode(55)))))), Collections.emptyList(), MAX_RUNNING_TIME_SECONDS);
assertThrows(IllegalArgumentException.class, () -> handler.getPrediction(request));
}
use of org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest in project kogito-apps by kiegroup.
the class CounterfactualExplainerServiceHandlerTest method testGetPredictionWithFlatInputModel.
@Test
public void testGetPredictionWithFlatInputModel() {
CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, List.of(new NamedTypedValue("input1", new UnitValue("number", new IntNode(20)))), Collections.emptyList(), Collections.emptyList(), MAX_RUNNING_TIME_SECONDS);
Prediction prediction = handler.getPrediction(request);
assertTrue(prediction instanceof CounterfactualPrediction);
CounterfactualPrediction counterfactualPrediction = (CounterfactualPrediction) prediction;
assertEquals(1, counterfactualPrediction.getInput().getFeatures().size());
Optional<Feature> oInput1 = counterfactualPrediction.getInput().getFeatures().stream().filter(f -> f.getName().equals("input1")).findFirst();
assertTrue(oInput1.isPresent());
Feature input1 = oInput1.get();
assertEquals(Type.NUMBER, input1.getType());
assertEquals(20, input1.getValue().asNumber());
assertTrue(counterfactualPrediction.getInput().getFeatures().stream().allMatch(f -> f.getDomain().isEmpty()));
assertTrue(counterfactualPrediction.getInput().getFeatures().stream().allMatch(Feature::isConstrained));
assertTrue(counterfactualPrediction.getOutput().getOutputs().isEmpty());
assertEquals(counterfactualPrediction.getMaxRunningTimeSeconds(), request.getMaxRunningTimeSeconds());
}
use of org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest in project kogito-apps by kiegroup.
the class CounterfactualExplainerServiceHandlerTest method testGetPredictionWithFlatOutputModel.
@Test
public void testGetPredictionWithFlatOutputModel() {
CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, Collections.emptyList(), List.of(new NamedTypedValue("output1", new UnitValue("number", new IntNode(20)))), Collections.emptyList(), MAX_RUNNING_TIME_SECONDS);
Prediction prediction = handler.getPrediction(request);
assertTrue(prediction instanceof CounterfactualPrediction);
CounterfactualPrediction counterfactualPrediction = (CounterfactualPrediction) prediction;
assertEquals(1, counterfactualPrediction.getOutput().getOutputs().size());
Optional<Output> oOutput1 = counterfactualPrediction.getOutput().getOutputs().stream().filter(f -> f.getName().equals("output1")).findFirst();
assertTrue(oOutput1.isPresent());
Output output1 = oOutput1.get();
assertEquals(Type.NUMBER, output1.getType());
assertEquals(20, output1.getValue().asNumber());
assertTrue(counterfactualPrediction.getInput().getFeatures().isEmpty());
assertEquals(counterfactualPrediction.getMaxRunningTimeSeconds(), request.getMaxRunningTimeSeconds());
}
use of org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest in project kogito-apps by kiegroup.
the class CounterfactualExplainerServiceHandlerTest method testGetPredictionWithCollectionSearchDomains.
@Test
public void testGetPredictionWithCollectionSearchDomains() {
CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, Collections.emptyList(), Collections.emptyList(), List.of(new CounterfactualSearchDomain("input1", new CounterfactualSearchDomainCollectionValue("number", List.of(new CounterfactualSearchDomainUnitValue("number", "number", true, new CounterfactualDomainRange(new IntNode(10), new IntNode(20))))))), MAX_RUNNING_TIME_SECONDS);
assertThrows(IllegalArgumentException.class, () -> handler.getPrediction(request));
}
Aggregations