use of org.kie.kogito.explainability.api.BaseExplainabilityResult in project kogito-apps by kiegroup.
the class ExplanationServiceImplTest method testCounterfactualsxplainAsyncFailed.
@Test
@SuppressWarnings("unchecked")
void testCounterfactualsxplainAsyncFailed() {
String errorMessage = "Something bad happened";
RuntimeException exception = new RuntimeException(errorMessage);
when(instance.stream()).thenReturn(Stream.of(cfExplainerServiceHandlerMock));
when(cfExplainerMock.explainAsync(any(Prediction.class), eq(predictionProviderMock), any(Consumer.class))).thenThrow(exception);
BaseExplainabilityResult result = assertDoesNotThrow(() -> explanationService.explainAsync(COUNTERFACTUAL_REQUEST, callbackMock).toCompletableFuture().get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit()));
assertNotNull(result);
assertTrue(result instanceof CounterfactualExplainabilityResult);
CounterfactualExplainabilityResult exceptionResult = (CounterfactualExplainabilityResult) result;
assertEquals(EXECUTION_ID, exceptionResult.getExecutionId());
assertSame(ExplainabilityStatus.FAILED, exceptionResult.getStatus());
assertEquals(errorMessage, exceptionResult.getStatusDetails());
}
use of org.kie.kogito.explainability.api.BaseExplainabilityResult in project kogito-apps by kiegroup.
the class ExplanationServiceImplTest method testLIMEExplainAsyncSuccess.
@SuppressWarnings("unchecked")
void testLIMEExplainAsyncSuccess(ThrowingSupplier<BaseExplainabilityResult> invocation) {
when(instance.stream()).thenReturn(Stream.of(limeExplainerServiceHandlerMock));
when(limeExplainerMock.explainAsync(any(Prediction.class), eq(predictionProviderMock), any(Consumer.class))).thenReturn(CompletableFuture.completedFuture(SALIENCY_MAP));
BaseExplainabilityResult result = assertDoesNotThrow(invocation);
assertNotNull(result);
assertTrue(result instanceof LIMEExplainabilityResult);
LIMEExplainabilityResult limeResult = (LIMEExplainabilityResult) result;
assertEquals(EXECUTION_ID, limeResult.getExecutionId());
assertSame(ExplainabilityStatus.SUCCEEDED, limeResult.getStatus());
assertNull(limeResult.getStatusDetails());
assertEquals(SALIENCY_MAP.size(), limeResult.getSaliencies().size());
SaliencyModel saliency = limeResult.getSaliencies().iterator().next();
assertEquals(SALIENCY.getPerFeatureImportance().size(), saliency.getFeatureImportance().size());
FeatureImportanceModel featureImportance1 = saliency.getFeatureImportance().get(0);
assertEquals(FEATURE_IMPORTANCE_1.getFeature().getName(), featureImportance1.getFeatureName());
assertEquals(FEATURE_IMPORTANCE_1.getScore(), featureImportance1.getFeatureScore(), 0.01);
}
use of org.kie.kogito.explainability.api.BaseExplainabilityResult in project kogito-apps by kiegroup.
the class ExplanationServiceImplTest method testServiceHandlerLookupCounterfactuals.
@Test
void testServiceHandlerLookupCounterfactuals() {
when(instance.stream()).thenReturn(Stream.of(limeExplainerServiceHandlerMock, cfExplainerServiceHandlerMock));
when(cfExplainerMock.explainAsync(any(), any(), any())).thenReturn(CompletableFuture.completedFuture(COUNTERFACTUAL_RESULT));
BaseExplainabilityResult result = assertDoesNotThrow(() -> explanationService.explainAsync(COUNTERFACTUAL_REQUEST, callbackMock).toCompletableFuture().get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit()));
assertNotNull(result);
assertTrue(result instanceof CounterfactualExplainabilityResult);
}
use of org.kie.kogito.explainability.api.BaseExplainabilityResult in project kogito-apps by kiegroup.
the class ExplanationServiceImplTest method testServiceHandlerLookupLIME.
@Test
void testServiceHandlerLookupLIME() {
when(instance.stream()).thenReturn(Stream.of(limeExplainerServiceHandlerMock, cfExplainerServiceHandlerMock));
when(limeExplainerMock.explainAsync(any(), any(), any())).thenReturn(CompletableFuture.completedFuture(SALIENCY_MAP));
BaseExplainabilityResult result = assertDoesNotThrow(() -> explanationService.explainAsync(LIME_REQUEST, callbackMock).toCompletableFuture().get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit()));
assertNotNull(result);
assertTrue(result instanceof LIMEExplainabilityResult);
}
use of org.kie.kogito.explainability.api.BaseExplainabilityResult in project kogito-apps by kiegroup.
the class CounterfactualExplainerServiceHandlerTest method testCreateIntermediateResult.
@Test
public void testCreateIntermediateResult() {
CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), MAX_RUNNING_TIME_SECONDS);
List<CounterfactualEntity> entities = List.of(DoubleEntity.from(new Feature("input1", Type.NUMBER, new Value(123.0d)), 0, 1000));
CounterfactualResult counterfactuals = new CounterfactualResult(entities, entities.stream().map(CounterfactualEntity::asFeature).collect(Collectors.toList()), List.of(new PredictionOutput(List.of(new Output("output1", Type.NUMBER, new Value(555.0d), 1.0)))), true, UUID.fromString(SOLUTION_ID), UUID.fromString(EXECUTION_ID), 0);
BaseExplainabilityResult base = handler.createIntermediateResult(request, counterfactuals);
assertTrue(base instanceof CounterfactualExplainabilityResult);
CounterfactualExplainabilityResult result = (CounterfactualExplainabilityResult) base;
assertEquals(ExplainabilityStatus.SUCCEEDED, result.getStatus());
assertEquals(CounterfactualExplainabilityResult.Stage.INTERMEDIATE, result.getStage());
assertEquals(EXECUTION_ID, result.getExecutionId());
assertEquals(COUNTERFACTUAL_ID, result.getCounterfactualId());
assertEquals(1, result.getInputs().size());
assertTrue(result.getInputs().stream().anyMatch(i -> i.getName().equals("input1")));
NamedTypedValue input1 = result.getInputs().iterator().next();
assertEquals(Double.class.getSimpleName(), input1.getValue().getType());
assertEquals(TypedValue.Kind.UNIT, input1.getValue().getKind());
assertEquals(123.0, input1.getValue().toUnit().getValue().asDouble());
assertEquals(1, result.getOutputs().size());
assertTrue(result.getOutputs().stream().anyMatch(o -> o.getName().equals("output1")));
NamedTypedValue output1 = result.getOutputs().iterator().next();
assertEquals(Double.class.getSimpleName(), output1.getValue().getType());
assertEquals(TypedValue.Kind.UNIT, output1.getValue().getKind());
assertEquals(555.0, output1.getValue().toUnit().getValue().asDouble());
}
Aggregations