Search in sources :

Example 1 with BaseExplainabilityResult

use of org.kie.kogito.explainability.api.BaseExplainabilityResult in project kogito-apps by kiegroup.

the class ExplanationServiceImplTest method testCounterfactualsxplainAsyncFailed.

@Test
@SuppressWarnings("unchecked")
void testCounterfactualsxplainAsyncFailed() {
    String errorMessage = "Something bad happened";
    RuntimeException exception = new RuntimeException(errorMessage);
    when(instance.stream()).thenReturn(Stream.of(cfExplainerServiceHandlerMock));
    when(cfExplainerMock.explainAsync(any(Prediction.class), eq(predictionProviderMock), any(Consumer.class))).thenThrow(exception);
    BaseExplainabilityResult result = assertDoesNotThrow(() -> explanationService.explainAsync(COUNTERFACTUAL_REQUEST, callbackMock).toCompletableFuture().get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit()));
    assertNotNull(result);
    assertTrue(result instanceof CounterfactualExplainabilityResult);
    CounterfactualExplainabilityResult exceptionResult = (CounterfactualExplainabilityResult) result;
    assertEquals(EXECUTION_ID, exceptionResult.getExecutionId());
    assertSame(ExplainabilityStatus.FAILED, exceptionResult.getStatus());
    assertEquals(errorMessage, exceptionResult.getStatusDetails());
}
Also used : Consumer(java.util.function.Consumer) BaseExplainabilityResult(org.kie.kogito.explainability.api.BaseExplainabilityResult) Prediction(org.kie.kogito.explainability.model.Prediction) CounterfactualExplainabilityResult(org.kie.kogito.explainability.api.CounterfactualExplainabilityResult) Test(org.junit.jupiter.api.Test)

Example 2 with BaseExplainabilityResult

use of org.kie.kogito.explainability.api.BaseExplainabilityResult in project kogito-apps by kiegroup.

the class ExplanationServiceImplTest method testLIMEExplainAsyncSuccess.

@SuppressWarnings("unchecked")
void testLIMEExplainAsyncSuccess(ThrowingSupplier<BaseExplainabilityResult> invocation) {
    when(instance.stream()).thenReturn(Stream.of(limeExplainerServiceHandlerMock));
    when(limeExplainerMock.explainAsync(any(Prediction.class), eq(predictionProviderMock), any(Consumer.class))).thenReturn(CompletableFuture.completedFuture(SALIENCY_MAP));
    BaseExplainabilityResult result = assertDoesNotThrow(invocation);
    assertNotNull(result);
    assertTrue(result instanceof LIMEExplainabilityResult);
    LIMEExplainabilityResult limeResult = (LIMEExplainabilityResult) result;
    assertEquals(EXECUTION_ID, limeResult.getExecutionId());
    assertSame(ExplainabilityStatus.SUCCEEDED, limeResult.getStatus());
    assertNull(limeResult.getStatusDetails());
    assertEquals(SALIENCY_MAP.size(), limeResult.getSaliencies().size());
    SaliencyModel saliency = limeResult.getSaliencies().iterator().next();
    assertEquals(SALIENCY.getPerFeatureImportance().size(), saliency.getFeatureImportance().size());
    FeatureImportanceModel featureImportance1 = saliency.getFeatureImportance().get(0);
    assertEquals(FEATURE_IMPORTANCE_1.getFeature().getName(), featureImportance1.getFeatureName());
    assertEquals(FEATURE_IMPORTANCE_1.getScore(), featureImportance1.getFeatureScore(), 0.01);
}
Also used : LIMEExplainabilityResult(org.kie.kogito.explainability.api.LIMEExplainabilityResult) Consumer(java.util.function.Consumer) BaseExplainabilityResult(org.kie.kogito.explainability.api.BaseExplainabilityResult) SaliencyModel(org.kie.kogito.explainability.api.SaliencyModel) FeatureImportanceModel(org.kie.kogito.explainability.api.FeatureImportanceModel) Prediction(org.kie.kogito.explainability.model.Prediction)

Example 3 with BaseExplainabilityResult

use of org.kie.kogito.explainability.api.BaseExplainabilityResult in project kogito-apps by kiegroup.

the class ExplanationServiceImplTest method testServiceHandlerLookupCounterfactuals.

@Test
void testServiceHandlerLookupCounterfactuals() {
    when(instance.stream()).thenReturn(Stream.of(limeExplainerServiceHandlerMock, cfExplainerServiceHandlerMock));
    when(cfExplainerMock.explainAsync(any(), any(), any())).thenReturn(CompletableFuture.completedFuture(COUNTERFACTUAL_RESULT));
    BaseExplainabilityResult result = assertDoesNotThrow(() -> explanationService.explainAsync(COUNTERFACTUAL_REQUEST, callbackMock).toCompletableFuture().get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit()));
    assertNotNull(result);
    assertTrue(result instanceof CounterfactualExplainabilityResult);
}
Also used : BaseExplainabilityResult(org.kie.kogito.explainability.api.BaseExplainabilityResult) CounterfactualExplainabilityResult(org.kie.kogito.explainability.api.CounterfactualExplainabilityResult) Test(org.junit.jupiter.api.Test)

Example 4 with BaseExplainabilityResult

use of org.kie.kogito.explainability.api.BaseExplainabilityResult in project kogito-apps by kiegroup.

the class ExplanationServiceImplTest method testServiceHandlerLookupLIME.

@Test
void testServiceHandlerLookupLIME() {
    when(instance.stream()).thenReturn(Stream.of(limeExplainerServiceHandlerMock, cfExplainerServiceHandlerMock));
    when(limeExplainerMock.explainAsync(any(), any(), any())).thenReturn(CompletableFuture.completedFuture(SALIENCY_MAP));
    BaseExplainabilityResult result = assertDoesNotThrow(() -> explanationService.explainAsync(LIME_REQUEST, callbackMock).toCompletableFuture().get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit()));
    assertNotNull(result);
    assertTrue(result instanceof LIMEExplainabilityResult);
}
Also used : LIMEExplainabilityResult(org.kie.kogito.explainability.api.LIMEExplainabilityResult) BaseExplainabilityResult(org.kie.kogito.explainability.api.BaseExplainabilityResult) Test(org.junit.jupiter.api.Test)

Example 5 with BaseExplainabilityResult

use of org.kie.kogito.explainability.api.BaseExplainabilityResult in project kogito-apps by kiegroup.

the class CounterfactualExplainerServiceHandlerTest method testCreateIntermediateResult.

@Test
public void testCreateIntermediateResult() {
    CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), MAX_RUNNING_TIME_SECONDS);
    List<CounterfactualEntity> entities = List.of(DoubleEntity.from(new Feature("input1", Type.NUMBER, new Value(123.0d)), 0, 1000));
    CounterfactualResult counterfactuals = new CounterfactualResult(entities, entities.stream().map(CounterfactualEntity::asFeature).collect(Collectors.toList()), List.of(new PredictionOutput(List.of(new Output("output1", Type.NUMBER, new Value(555.0d), 1.0)))), true, UUID.fromString(SOLUTION_ID), UUID.fromString(EXECUTION_ID), 0);
    BaseExplainabilityResult base = handler.createIntermediateResult(request, counterfactuals);
    assertTrue(base instanceof CounterfactualExplainabilityResult);
    CounterfactualExplainabilityResult result = (CounterfactualExplainabilityResult) base;
    assertEquals(ExplainabilityStatus.SUCCEEDED, result.getStatus());
    assertEquals(CounterfactualExplainabilityResult.Stage.INTERMEDIATE, result.getStage());
    assertEquals(EXECUTION_ID, result.getExecutionId());
    assertEquals(COUNTERFACTUAL_ID, result.getCounterfactualId());
    assertEquals(1, result.getInputs().size());
    assertTrue(result.getInputs().stream().anyMatch(i -> i.getName().equals("input1")));
    NamedTypedValue input1 = result.getInputs().iterator().next();
    assertEquals(Double.class.getSimpleName(), input1.getValue().getType());
    assertEquals(TypedValue.Kind.UNIT, input1.getValue().getKind());
    assertEquals(123.0, input1.getValue().toUnit().getValue().asDouble());
    assertEquals(1, result.getOutputs().size());
    assertTrue(result.getOutputs().stream().anyMatch(o -> o.getName().equals("output1")));
    NamedTypedValue output1 = result.getOutputs().iterator().next();
    assertEquals(Double.class.getSimpleName(), output1.getValue().getType());
    assertEquals(TypedValue.Kind.UNIT, output1.getValue().getKind());
    assertEquals(555.0, output1.getValue().toUnit().getValue().asDouble());
}
Also used : CounterfactualExplainabilityRequest(org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest) BeforeEach(org.junit.jupiter.api.BeforeEach) BaseExplainabilityRequest(org.kie.kogito.explainability.api.BaseExplainabilityRequest) Feature(org.kie.kogito.explainability.model.Feature) ArgumentMatchers.eq(org.mockito.ArgumentMatchers.eq) CounterfactualDomainRange(org.kie.kogito.explainability.api.CounterfactualDomainRange) CounterfactualEntity(org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity) Value(org.kie.kogito.explainability.model.Value) CounterfactualResult(org.kie.kogito.explainability.local.counterfactual.CounterfactualResult) Assertions.assertFalse(org.junit.jupiter.api.Assertions.assertFalse) Map(java.util.Map) EmptyFeatureDomain(org.kie.kogito.explainability.model.domain.EmptyFeatureDomain) BaseExplainabilityResult(org.kie.kogito.explainability.api.BaseExplainabilityResult) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) CounterfactualExplainabilityRequest(org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) ExplainabilityStatus(org.kie.kogito.explainability.api.ExplainabilityStatus) UUID(java.util.UUID) Collectors(java.util.stream.Collectors) TypedValue(org.kie.kogito.tracing.typedvalue.TypedValue) UnitValue(org.kie.kogito.tracing.typedvalue.UnitValue) Test(org.junit.jupiter.api.Test) List(java.util.List) CollectionValue(org.kie.kogito.tracing.typedvalue.CollectionValue) Output(org.kie.kogito.explainability.model.Output) Assertions.assertTrue(org.junit.jupiter.api.Assertions.assertTrue) Optional(java.util.Optional) CounterfactualExplainabilityResult(org.kie.kogito.explainability.api.CounterfactualExplainabilityResult) CounterfactualSearchDomain(org.kie.kogito.explainability.api.CounterfactualSearchDomain) Mockito.mock(org.mockito.Mockito.mock) Assertions.assertThrows(org.junit.jupiter.api.Assertions.assertThrows) ArgumentMatchers.any(org.mockito.ArgumentMatchers.any) IntNode(com.fasterxml.jackson.databind.node.IntNode) Prediction(org.kie.kogito.explainability.model.Prediction) StructureValue(org.kie.kogito.tracing.typedvalue.StructureValue) CounterfactualSearchDomainStructureValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainStructureValue) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) PredictionProviderFactory(org.kie.kogito.explainability.PredictionProviderFactory) Assertions.assertEquals(org.junit.jupiter.api.Assertions.assertEquals) CounterfactualSearchDomainCollectionValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainCollectionValue) CounterfactualExplainer(org.kie.kogito.explainability.local.counterfactual.CounterfactualExplainer) Type(org.kie.kogito.explainability.model.Type) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Mockito.verify(org.mockito.Mockito.verify) Consumer(java.util.function.Consumer) DoubleEntity(org.kie.kogito.explainability.local.counterfactual.entities.DoubleEntity) CounterfactualSearchDomainUnitValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue) DoubleNode(com.fasterxml.jackson.databind.node.DoubleNode) NumericalFeatureDomain(org.kie.kogito.explainability.model.domain.NumericalFeatureDomain) BooleanNode(com.fasterxml.jackson.databind.node.BooleanNode) Collections(java.util.Collections) ModelIdentifier(org.kie.kogito.explainability.api.ModelIdentifier) CounterfactualExplainabilityResult(org.kie.kogito.explainability.api.CounterfactualExplainabilityResult) Feature(org.kie.kogito.explainability.model.Feature) CounterfactualResult(org.kie.kogito.explainability.local.counterfactual.CounterfactualResult) CounterfactualEntity(org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) BaseExplainabilityResult(org.kie.kogito.explainability.api.BaseExplainabilityResult) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) Value(org.kie.kogito.explainability.model.Value) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) TypedValue(org.kie.kogito.tracing.typedvalue.TypedValue) UnitValue(org.kie.kogito.tracing.typedvalue.UnitValue) CollectionValue(org.kie.kogito.tracing.typedvalue.CollectionValue) StructureValue(org.kie.kogito.tracing.typedvalue.StructureValue) CounterfactualSearchDomainStructureValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainStructureValue) CounterfactualSearchDomainCollectionValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainCollectionValue) CounterfactualSearchDomainUnitValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue) Test(org.junit.jupiter.api.Test)

Aggregations

BaseExplainabilityResult (org.kie.kogito.explainability.api.BaseExplainabilityResult)16 Test (org.junit.jupiter.api.Test)14 Consumer (java.util.function.Consumer)8 Prediction (org.kie.kogito.explainability.model.Prediction)7 CounterfactualExplainabilityResult (org.kie.kogito.explainability.api.CounterfactualExplainabilityResult)6 LIMEExplainabilityResult (org.kie.kogito.explainability.api.LIMEExplainabilityResult)6 BaseExplainabilityRequest (org.kie.kogito.explainability.api.BaseExplainabilityRequest)5 NamedTypedValue (org.kie.kogito.explainability.api.NamedTypedValue)4 Assertions.assertEquals (org.junit.jupiter.api.Assertions.assertEquals)3 Assertions.assertThrows (org.junit.jupiter.api.Assertions.assertThrows)3 Assertions.assertTrue (org.junit.jupiter.api.Assertions.assertTrue)3 BeforeEach (org.junit.jupiter.api.BeforeEach)3 CounterfactualExplainabilityRequest (org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest)3 ExplainabilityStatus (org.kie.kogito.explainability.api.ExplainabilityStatus)3 LIMEExplainabilityRequest (org.kie.kogito.explainability.api.LIMEExplainabilityRequest)3 ModelIdentifier (org.kie.kogito.explainability.api.ModelIdentifier)3 SaliencyModel (org.kie.kogito.explainability.api.SaliencyModel)3 Feature (org.kie.kogito.explainability.model.Feature)3 Output (org.kie.kogito.explainability.model.Output)3 Value (org.kie.kogito.explainability.model.Value)3