use of org.kie.kogito.explainability.api.NamedTypedValue in project kogito-apps by kiegroup.
the class CounterfactualExplainerServiceHandlerTest method testCreateSucceededResult.
@Test
public void testCreateSucceededResult() {
CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), MAX_RUNNING_TIME_SECONDS);
List<CounterfactualEntity> entities = List.of(DoubleEntity.from(new Feature("input1", Type.NUMBER, new Value(123.0d)), 0, 1000));
CounterfactualResult counterfactuals = new CounterfactualResult(entities, entities.stream().map(CounterfactualEntity::asFeature).collect(Collectors.toList()), List.of(new PredictionOutput(List.of(new Output("output1", Type.NUMBER, new Value(555.0d), 1.0)))), true, UUID.fromString(SOLUTION_ID), UUID.fromString(EXECUTION_ID), 0);
BaseExplainabilityResult base = handler.createSucceededResult(request, counterfactuals);
assertTrue(base instanceof CounterfactualExplainabilityResult);
CounterfactualExplainabilityResult result = (CounterfactualExplainabilityResult) base;
assertEquals(ExplainabilityStatus.SUCCEEDED, result.getStatus());
assertEquals(CounterfactualExplainabilityResult.Stage.FINAL, result.getStage());
assertEquals(EXECUTION_ID, result.getExecutionId());
assertEquals(COUNTERFACTUAL_ID, result.getCounterfactualId());
assertEquals(1, result.getInputs().size());
assertTrue(result.getInputs().stream().anyMatch(i -> i.getName().equals("input1")));
NamedTypedValue input1 = result.getInputs().iterator().next();
assertEquals(Double.class.getSimpleName(), input1.getValue().getType());
assertEquals(TypedValue.Kind.UNIT, input1.getValue().getKind());
assertEquals(123.0, input1.getValue().toUnit().getValue().asDouble());
assertEquals(1, result.getOutputs().size());
assertTrue(result.getOutputs().stream().anyMatch(o -> o.getName().equals("output1")));
NamedTypedValue output1 = result.getOutputs().iterator().next();
assertEquals(Double.class.getSimpleName(), output1.getValue().getType());
assertEquals(TypedValue.Kind.UNIT, output1.getValue().getKind());
assertEquals(555.0, output1.getValue().toUnit().getValue().asDouble());
}
use of org.kie.kogito.explainability.api.NamedTypedValue in project kogito-apps by kiegroup.
the class CounterfactualExplainerServiceHandlerTest method testGetPredictionWithCollectionOutputModel.
@Test
public void testGetPredictionWithCollectionOutputModel() {
CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, Collections.emptyList(), List.of(new NamedTypedValue("input1", new CollectionValue("number", List.of(new UnitValue("number", new IntNode(100)))))), Collections.emptyList(), MAX_RUNNING_TIME_SECONDS);
assertThrows(IllegalArgumentException.class, () -> handler.getPrediction(request));
}
use of org.kie.kogito.explainability.api.NamedTypedValue in project kogito-apps by kiegroup.
the class CounterfactualExplainerServiceHandler method getPrediction.
@Override
public Prediction getPrediction(CounterfactualExplainabilityRequest request) {
Collection<NamedTypedValue> goals = toMapBasedSorting(request.getGoals());
Collection<CounterfactualSearchDomain> searchDomains = request.getSearchDomains();
Collection<NamedTypedValue> originalInputs = request.getOriginalInputs();
Long maxRunningTimeSeconds = request.getMaxRunningTimeSeconds();
if (Objects.nonNull(maxRunningTimeSeconds)) {
if (maxRunningTimeSeconds > kafkaMaxRecordAgeSeconds) {
LOGGER.info(String.format("Maximum Running Timeout set to '%d's since the provided value '%d's exceeded the Messaging sub-system configuration '%d's.", kafkaMaxRecordAgeSeconds, maxRunningTimeSeconds, kafkaMaxRecordAgeSeconds));
maxRunningTimeSeconds = kafkaMaxRecordAgeSeconds;
}
}
// See https://issues.redhat.com/browse/FAI-473 and https://issues.redhat.com/browse/FAI-474
if (isUnsupportedModel(originalInputs, goals, searchDomains)) {
throw new IllegalArgumentException("Counterfactual explanations only support flat models.");
}
PredictionInput input = new PredictionInput(toFeatureList(originalInputs, searchDomains));
PredictionOutput output = new PredictionOutput(toOutputList(goals));
return new CounterfactualPrediction(input, output, null, UUID.fromString(request.getExecutionId()), maxRunningTimeSeconds);
}
use of org.kie.kogito.explainability.api.NamedTypedValue in project kogito-apps by kiegroup.
the class LimeExplainerServiceHandler method getPrediction.
@Override
public Prediction getPrediction(LIMEExplainabilityRequest request) {
Collection<NamedTypedValue> inputs = request.getInputs();
Collection<NamedTypedValue> outputs = request.getOutputs();
PredictionInput input = new PredictionInput(toFeatureList(inputs));
PredictionOutput output = new PredictionOutput(toOutputList(outputs));
return new SimplePrediction(input, output);
}
use of org.kie.kogito.explainability.api.NamedTypedValue in project kogito-apps by kiegroup.
the class TrustyServiceTest method doGivenStoredExecutionWhenCounterfactualRequestIsMadeThenExplainabilityEventIsEmittedTest.
@SuppressWarnings("unchecked")
void doGivenStoredExecutionWhenCounterfactualRequestIsMadeThenExplainabilityEventIsEmittedTest(CounterfactualDomain domain) {
Storage<String, Decision> decisionStorage = mock(Storage.class);
Storage<String, CounterfactualExplainabilityRequest> counterfactualStorage = mock(Storage.class);
ArgumentCaptor<BaseExplainabilityRequest> explainabilityEventArgumentCaptor = ArgumentCaptor.forClass(BaseExplainabilityRequest.class);
when(decisionStorage.containsKey(eq(TEST_EXECUTION_ID))).thenReturn(true);
when(trustyStorageServiceMock.getDecisionsStorage()).thenReturn(decisionStorage);
when(trustyStorageServiceMock.getCounterfactualRequestStorage()).thenReturn(counterfactualStorage);
when(decisionStorage.get(eq(TEST_EXECUTION_ID))).thenReturn(TrustyServiceTestUtils.buildCorrectDecision(TEST_EXECUTION_ID));
// The Goals structures must be comparable to the original decisions outcomes.
// The Search Domain structures must be identical those of the original decision inputs.
trustyService.requestCounterfactuals(TEST_EXECUTION_ID, List.of(new NamedTypedValue("Fine", new StructureValue("tFine", Map.of("Amount", new UnitValue("number", "number", new IntNode(0)), "Points", new UnitValue("number", "number", new IntNode(0))))), new NamedTypedValue("Should the driver be suspended?", new UnitValue("string", "string", new TextNode("No")))), List.of(new CounterfactualSearchDomain("Violation", new CounterfactualSearchDomainStructureValue("tViolation", Map.of("Type", new CounterfactualSearchDomainUnitValue("string", "string", true, domain), "Actual Speed", new CounterfactualSearchDomainUnitValue("number", "number", true, domain), "Speed Limit", new CounterfactualSearchDomainUnitValue("number", "number", true, domain)))), new CounterfactualSearchDomain("Driver", new CounterfactualSearchDomainStructureValue("tDriver", Map.of("Age", new CounterfactualSearchDomainUnitValue("number", "number", true, domain), "Points", new CounterfactualSearchDomainUnitValue("number", "number", true, domain))))));
verify(explainabilityRequestProducerMock).sendEvent(explainabilityEventArgumentCaptor.capture());
BaseExplainabilityRequest event = explainabilityEventArgumentCaptor.getValue();
assertNotNull(event);
assertTrue(event instanceof CounterfactualExplainabilityRequest);
CounterfactualExplainabilityRequest request = (CounterfactualExplainabilityRequest) event;
assertEquals(TEST_EXECUTION_ID, request.getExecutionId());
}
Aggregations