use of org.kie.kogito.explainability.api.BaseExplainabilityResult in project kogito-apps by kiegroup.
the class CounterfactualExplainerServiceHandlerTest method testCreateSucceededResult.
@Test
public void testCreateSucceededResult() {
CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), MAX_RUNNING_TIME_SECONDS);
List<CounterfactualEntity> entities = List.of(DoubleEntity.from(new Feature("input1", Type.NUMBER, new Value(123.0d)), 0, 1000));
CounterfactualResult counterfactuals = new CounterfactualResult(entities, entities.stream().map(CounterfactualEntity::asFeature).collect(Collectors.toList()), List.of(new PredictionOutput(List.of(new Output("output1", Type.NUMBER, new Value(555.0d), 1.0)))), true, UUID.fromString(SOLUTION_ID), UUID.fromString(EXECUTION_ID), 0);
BaseExplainabilityResult base = handler.createSucceededResult(request, counterfactuals);
assertTrue(base instanceof CounterfactualExplainabilityResult);
CounterfactualExplainabilityResult result = (CounterfactualExplainabilityResult) base;
assertEquals(ExplainabilityStatus.SUCCEEDED, result.getStatus());
assertEquals(CounterfactualExplainabilityResult.Stage.FINAL, result.getStage());
assertEquals(EXECUTION_ID, result.getExecutionId());
assertEquals(COUNTERFACTUAL_ID, result.getCounterfactualId());
assertEquals(1, result.getInputs().size());
assertTrue(result.getInputs().stream().anyMatch(i -> i.getName().equals("input1")));
NamedTypedValue input1 = result.getInputs().iterator().next();
assertEquals(Double.class.getSimpleName(), input1.getValue().getType());
assertEquals(TypedValue.Kind.UNIT, input1.getValue().getKind());
assertEquals(123.0, input1.getValue().toUnit().getValue().asDouble());
assertEquals(1, result.getOutputs().size());
assertTrue(result.getOutputs().stream().anyMatch(o -> o.getName().equals("output1")));
NamedTypedValue output1 = result.getOutputs().iterator().next();
assertEquals(Double.class.getSimpleName(), output1.getValue().getType());
assertEquals(TypedValue.Kind.UNIT, output1.getValue().getKind());
assertEquals(555.0, output1.getValue().toUnit().getValue().asDouble());
}
use of org.kie.kogito.explainability.api.BaseExplainabilityResult in project kogito-apps by kiegroup.
the class CounterfactualExplainerServiceHandlerTest method testCreateFailedResult.
@Test
public void testCreateFailedResult() {
CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), MAX_RUNNING_TIME_SECONDS);
BaseExplainabilityResult base = handler.createFailedResult(request, new NullPointerException("Something went wrong"));
assertTrue(base instanceof CounterfactualExplainabilityResult);
CounterfactualExplainabilityResult result = (CounterfactualExplainabilityResult) base;
assertEquals(ExplainabilityStatus.FAILED, result.getStatus());
assertEquals("Something went wrong", result.getStatusDetails());
assertEquals(EXECUTION_ID, result.getExecutionId());
assertEquals(COUNTERFACTUAL_ID, result.getCounterfactualId());
}
use of org.kie.kogito.explainability.api.BaseExplainabilityResult in project kogito-apps by kiegroup.
the class LimeExplainerServiceHandlerTest method testCreateSucceededResult.
@Test
public void testCreateSucceededResult() {
LIMEExplainabilityRequest request = new LIMEExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, Collections.emptyList(), Collections.emptyList());
Map<String, Saliency> saliencies = Map.of("s1", new Saliency(new Output("salary", Type.NUMBER), List.of(new FeatureImportance(new Feature("age", Type.NUMBER, new Value(25.0)), 5.0), new FeatureImportance(new Feature("dependents", Type.NUMBER, new Value(2)), -11.0))));
BaseExplainabilityResult base = handler.createSucceededResult(request, saliencies);
assertTrue(base instanceof LIMEExplainabilityResult);
LIMEExplainabilityResult result = (LIMEExplainabilityResult) base;
assertEquals(ExplainabilityStatus.SUCCEEDED, result.getStatus());
assertEquals(EXECUTION_ID, result.getExecutionId());
assertEquals(1, result.getSaliencies().size());
SaliencyModel saliencyModel = result.getSaliencies().iterator().next();
assertEquals(2, saliencyModel.getFeatureImportance().size());
assertEquals("age", saliencyModel.getFeatureImportance().get(0).getFeatureName());
assertEquals(5.0, saliencyModel.getFeatureImportance().get(0).getFeatureScore());
assertEquals("dependents", saliencyModel.getFeatureImportance().get(1).getFeatureName());
assertEquals(-11.0, saliencyModel.getFeatureImportance().get(1).getFeatureScore());
}
use of org.kie.kogito.explainability.api.BaseExplainabilityResult in project kogito-apps by kiegroup.
the class LimeExplainerServiceHandlerTest method testCreateFailedResult.
@Test
public void testCreateFailedResult() {
LIMEExplainabilityRequest request = new LIMEExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, Collections.emptyList(), Collections.emptyList());
BaseExplainabilityResult base = handler.createFailedResult(request, new NullPointerException("Something went wrong"));
assertTrue(base instanceof LIMEExplainabilityResult);
LIMEExplainabilityResult result = (LIMEExplainabilityResult) base;
assertEquals(ExplainabilityStatus.FAILED, result.getStatus());
assertEquals("Something went wrong", result.getStatusDetails());
assertEquals(EXECUTION_ID, result.getExecutionId());
}
use of org.kie.kogito.explainability.api.BaseExplainabilityResult in project kogito-apps by kiegroup.
the class BaseExplainabilityMessagingHandlerIT method explainabilityRequestIsProcessedAndAnIntermediateMessageIsSent.
@Test
@SuppressWarnings({ "rawtypes", "unchecked" })
void explainabilityRequestIsProcessedAndAnIntermediateMessageIsSent() throws Exception {
BaseExplainabilityRequest request = buildRequest();
BaseExplainabilityResult result = buildResult();
doAnswer(i -> {
Object parameter = i.getArguments()[1];
Consumer<BaseExplainabilityResult> consumer = (Consumer) parameter;
mockExplainAsyncInvocationWithIntermediateResults(consumer);
return CompletableFuture.completedFuture(result);
}).when(explanationService).explainAsync(any(BaseExplainabilityRequest.class), any());
kafkaClient.produce(ExplainabilityCloudEventBuilder.buildCloudEventJsonString(request), TOPIC_REQUEST);
verify(explanationService, timeout(2000).times(1)).explainAsync(any(BaseExplainabilityRequest.class), any());
final CountDownLatch countDownLatch = new CountDownLatch(getTotalExpectedEventCountWithIntermediateResults());
kafkaClient.consume(TOPIC_RESULT, s -> {
LOGGER.info("Received from kafka: {}", s);
CloudEventUtils.decode(s).ifPresent((CloudEvent cloudEvent) -> {
try {
BaseExplainabilityResult event = objectMapper.readValue(cloudEvent.getData().toBytes(), BaseExplainabilityResult.class);
assertNotNull(event);
assertResult(event);
countDownLatch.countDown();
} catch (IOException e) {
LOGGER.error("Error parsing {}", s, e);
throw new RuntimeException(e);
}
});
});
assertTrue(countDownLatch.await(5, TimeUnit.SECONDS));
kafkaClient.shutdown();
}
Aggregations