Search in sources :

Example 6 with BaseExplainabilityResult

use of org.kie.kogito.explainability.api.BaseExplainabilityResult in project kogito-apps by kiegroup.

the class CounterfactualExplainerServiceHandlerTest method testCreateSucceededResult.

@Test
public void testCreateSucceededResult() {
    CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), MAX_RUNNING_TIME_SECONDS);
    List<CounterfactualEntity> entities = List.of(DoubleEntity.from(new Feature("input1", Type.NUMBER, new Value(123.0d)), 0, 1000));
    CounterfactualResult counterfactuals = new CounterfactualResult(entities, entities.stream().map(CounterfactualEntity::asFeature).collect(Collectors.toList()), List.of(new PredictionOutput(List.of(new Output("output1", Type.NUMBER, new Value(555.0d), 1.0)))), true, UUID.fromString(SOLUTION_ID), UUID.fromString(EXECUTION_ID), 0);
    BaseExplainabilityResult base = handler.createSucceededResult(request, counterfactuals);
    assertTrue(base instanceof CounterfactualExplainabilityResult);
    CounterfactualExplainabilityResult result = (CounterfactualExplainabilityResult) base;
    assertEquals(ExplainabilityStatus.SUCCEEDED, result.getStatus());
    assertEquals(CounterfactualExplainabilityResult.Stage.FINAL, result.getStage());
    assertEquals(EXECUTION_ID, result.getExecutionId());
    assertEquals(COUNTERFACTUAL_ID, result.getCounterfactualId());
    assertEquals(1, result.getInputs().size());
    assertTrue(result.getInputs().stream().anyMatch(i -> i.getName().equals("input1")));
    NamedTypedValue input1 = result.getInputs().iterator().next();
    assertEquals(Double.class.getSimpleName(), input1.getValue().getType());
    assertEquals(TypedValue.Kind.UNIT, input1.getValue().getKind());
    assertEquals(123.0, input1.getValue().toUnit().getValue().asDouble());
    assertEquals(1, result.getOutputs().size());
    assertTrue(result.getOutputs().stream().anyMatch(o -> o.getName().equals("output1")));
    NamedTypedValue output1 = result.getOutputs().iterator().next();
    assertEquals(Double.class.getSimpleName(), output1.getValue().getType());
    assertEquals(TypedValue.Kind.UNIT, output1.getValue().getKind());
    assertEquals(555.0, output1.getValue().toUnit().getValue().asDouble());
}
Also used : CounterfactualExplainabilityRequest(org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest) BeforeEach(org.junit.jupiter.api.BeforeEach) BaseExplainabilityRequest(org.kie.kogito.explainability.api.BaseExplainabilityRequest) Feature(org.kie.kogito.explainability.model.Feature) ArgumentMatchers.eq(org.mockito.ArgumentMatchers.eq) CounterfactualDomainRange(org.kie.kogito.explainability.api.CounterfactualDomainRange) CounterfactualEntity(org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity) Value(org.kie.kogito.explainability.model.Value) CounterfactualResult(org.kie.kogito.explainability.local.counterfactual.CounterfactualResult) Assertions.assertFalse(org.junit.jupiter.api.Assertions.assertFalse) Map(java.util.Map) EmptyFeatureDomain(org.kie.kogito.explainability.model.domain.EmptyFeatureDomain) BaseExplainabilityResult(org.kie.kogito.explainability.api.BaseExplainabilityResult) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) CounterfactualExplainabilityRequest(org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) ExplainabilityStatus(org.kie.kogito.explainability.api.ExplainabilityStatus) UUID(java.util.UUID) Collectors(java.util.stream.Collectors) TypedValue(org.kie.kogito.tracing.typedvalue.TypedValue) UnitValue(org.kie.kogito.tracing.typedvalue.UnitValue) Test(org.junit.jupiter.api.Test) List(java.util.List) CollectionValue(org.kie.kogito.tracing.typedvalue.CollectionValue) Output(org.kie.kogito.explainability.model.Output) Assertions.assertTrue(org.junit.jupiter.api.Assertions.assertTrue) Optional(java.util.Optional) CounterfactualExplainabilityResult(org.kie.kogito.explainability.api.CounterfactualExplainabilityResult) CounterfactualSearchDomain(org.kie.kogito.explainability.api.CounterfactualSearchDomain) Mockito.mock(org.mockito.Mockito.mock) Assertions.assertThrows(org.junit.jupiter.api.Assertions.assertThrows) ArgumentMatchers.any(org.mockito.ArgumentMatchers.any) IntNode(com.fasterxml.jackson.databind.node.IntNode) Prediction(org.kie.kogito.explainability.model.Prediction) StructureValue(org.kie.kogito.tracing.typedvalue.StructureValue) CounterfactualSearchDomainStructureValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainStructureValue) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) PredictionProviderFactory(org.kie.kogito.explainability.PredictionProviderFactory) Assertions.assertEquals(org.junit.jupiter.api.Assertions.assertEquals) CounterfactualSearchDomainCollectionValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainCollectionValue) CounterfactualExplainer(org.kie.kogito.explainability.local.counterfactual.CounterfactualExplainer) Type(org.kie.kogito.explainability.model.Type) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Mockito.verify(org.mockito.Mockito.verify) Consumer(java.util.function.Consumer) DoubleEntity(org.kie.kogito.explainability.local.counterfactual.entities.DoubleEntity) CounterfactualSearchDomainUnitValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue) DoubleNode(com.fasterxml.jackson.databind.node.DoubleNode) NumericalFeatureDomain(org.kie.kogito.explainability.model.domain.NumericalFeatureDomain) BooleanNode(com.fasterxml.jackson.databind.node.BooleanNode) Collections(java.util.Collections) ModelIdentifier(org.kie.kogito.explainability.api.ModelIdentifier) CounterfactualExplainabilityResult(org.kie.kogito.explainability.api.CounterfactualExplainabilityResult) Feature(org.kie.kogito.explainability.model.Feature) CounterfactualResult(org.kie.kogito.explainability.local.counterfactual.CounterfactualResult) CounterfactualEntity(org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) BaseExplainabilityResult(org.kie.kogito.explainability.api.BaseExplainabilityResult) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) Value(org.kie.kogito.explainability.model.Value) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) TypedValue(org.kie.kogito.tracing.typedvalue.TypedValue) UnitValue(org.kie.kogito.tracing.typedvalue.UnitValue) CollectionValue(org.kie.kogito.tracing.typedvalue.CollectionValue) StructureValue(org.kie.kogito.tracing.typedvalue.StructureValue) CounterfactualSearchDomainStructureValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainStructureValue) CounterfactualSearchDomainCollectionValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainCollectionValue) CounterfactualSearchDomainUnitValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue) Test(org.junit.jupiter.api.Test)

Example 7 with BaseExplainabilityResult

use of org.kie.kogito.explainability.api.BaseExplainabilityResult in project kogito-apps by kiegroup.

the class CounterfactualExplainerServiceHandlerTest method testCreateFailedResult.

@Test
public void testCreateFailedResult() {
    CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), MAX_RUNNING_TIME_SECONDS);
    BaseExplainabilityResult base = handler.createFailedResult(request, new NullPointerException("Something went wrong"));
    assertTrue(base instanceof CounterfactualExplainabilityResult);
    CounterfactualExplainabilityResult result = (CounterfactualExplainabilityResult) base;
    assertEquals(ExplainabilityStatus.FAILED, result.getStatus());
    assertEquals("Something went wrong", result.getStatusDetails());
    assertEquals(EXECUTION_ID, result.getExecutionId());
    assertEquals(COUNTERFACTUAL_ID, result.getCounterfactualId());
}
Also used : CounterfactualExplainabilityRequest(org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest) BaseExplainabilityResult(org.kie.kogito.explainability.api.BaseExplainabilityResult) CounterfactualExplainabilityResult(org.kie.kogito.explainability.api.CounterfactualExplainabilityResult) Test(org.junit.jupiter.api.Test)

Example 8 with BaseExplainabilityResult

use of org.kie.kogito.explainability.api.BaseExplainabilityResult in project kogito-apps by kiegroup.

the class LimeExplainerServiceHandlerTest method testCreateSucceededResult.

@Test
public void testCreateSucceededResult() {
    LIMEExplainabilityRequest request = new LIMEExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, Collections.emptyList(), Collections.emptyList());
    Map<String, Saliency> saliencies = Map.of("s1", new Saliency(new Output("salary", Type.NUMBER), List.of(new FeatureImportance(new Feature("age", Type.NUMBER, new Value(25.0)), 5.0), new FeatureImportance(new Feature("dependents", Type.NUMBER, new Value(2)), -11.0))));
    BaseExplainabilityResult base = handler.createSucceededResult(request, saliencies);
    assertTrue(base instanceof LIMEExplainabilityResult);
    LIMEExplainabilityResult result = (LIMEExplainabilityResult) base;
    assertEquals(ExplainabilityStatus.SUCCEEDED, result.getStatus());
    assertEquals(EXECUTION_ID, result.getExecutionId());
    assertEquals(1, result.getSaliencies().size());
    SaliencyModel saliencyModel = result.getSaliencies().iterator().next();
    assertEquals(2, saliencyModel.getFeatureImportance().size());
    assertEquals("age", saliencyModel.getFeatureImportance().get(0).getFeatureName());
    assertEquals(5.0, saliencyModel.getFeatureImportance().get(0).getFeatureScore());
    assertEquals("dependents", saliencyModel.getFeatureImportance().get(1).getFeatureName());
    assertEquals(-11.0, saliencyModel.getFeatureImportance().get(1).getFeatureScore());
}
Also used : LIMEExplainabilityRequest(org.kie.kogito.explainability.api.LIMEExplainabilityRequest) LIMEExplainabilityResult(org.kie.kogito.explainability.api.LIMEExplainabilityResult) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) BaseExplainabilityResult(org.kie.kogito.explainability.api.BaseExplainabilityResult) SaliencyModel(org.kie.kogito.explainability.api.SaliencyModel) Output(org.kie.kogito.explainability.model.Output) Value(org.kie.kogito.explainability.model.Value) StructureValue(org.kie.kogito.tracing.typedvalue.StructureValue) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) UnitValue(org.kie.kogito.tracing.typedvalue.UnitValue) CollectionValue(org.kie.kogito.tracing.typedvalue.CollectionValue) Saliency(org.kie.kogito.explainability.model.Saliency) Feature(org.kie.kogito.explainability.model.Feature) Test(org.junit.jupiter.api.Test)

Example 9 with BaseExplainabilityResult

use of org.kie.kogito.explainability.api.BaseExplainabilityResult in project kogito-apps by kiegroup.

the class LimeExplainerServiceHandlerTest method testCreateFailedResult.

@Test
public void testCreateFailedResult() {
    LIMEExplainabilityRequest request = new LIMEExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, Collections.emptyList(), Collections.emptyList());
    BaseExplainabilityResult base = handler.createFailedResult(request, new NullPointerException("Something went wrong"));
    assertTrue(base instanceof LIMEExplainabilityResult);
    LIMEExplainabilityResult result = (LIMEExplainabilityResult) base;
    assertEquals(ExplainabilityStatus.FAILED, result.getStatus());
    assertEquals("Something went wrong", result.getStatusDetails());
    assertEquals(EXECUTION_ID, result.getExecutionId());
}
Also used : LIMEExplainabilityRequest(org.kie.kogito.explainability.api.LIMEExplainabilityRequest) LIMEExplainabilityResult(org.kie.kogito.explainability.api.LIMEExplainabilityResult) BaseExplainabilityResult(org.kie.kogito.explainability.api.BaseExplainabilityResult) Test(org.junit.jupiter.api.Test)

Example 10 with BaseExplainabilityResult

use of org.kie.kogito.explainability.api.BaseExplainabilityResult in project kogito-apps by kiegroup.

the class BaseExplainabilityMessagingHandlerIT method explainabilityRequestIsProcessedAndAnIntermediateMessageIsSent.

@Test
@SuppressWarnings({ "rawtypes", "unchecked" })
void explainabilityRequestIsProcessedAndAnIntermediateMessageIsSent() throws Exception {
    BaseExplainabilityRequest request = buildRequest();
    BaseExplainabilityResult result = buildResult();
    doAnswer(i -> {
        Object parameter = i.getArguments()[1];
        Consumer<BaseExplainabilityResult> consumer = (Consumer) parameter;
        mockExplainAsyncInvocationWithIntermediateResults(consumer);
        return CompletableFuture.completedFuture(result);
    }).when(explanationService).explainAsync(any(BaseExplainabilityRequest.class), any());
    kafkaClient.produce(ExplainabilityCloudEventBuilder.buildCloudEventJsonString(request), TOPIC_REQUEST);
    verify(explanationService, timeout(2000).times(1)).explainAsync(any(BaseExplainabilityRequest.class), any());
    final CountDownLatch countDownLatch = new CountDownLatch(getTotalExpectedEventCountWithIntermediateResults());
    kafkaClient.consume(TOPIC_RESULT, s -> {
        LOGGER.info("Received from kafka: {}", s);
        CloudEventUtils.decode(s).ifPresent((CloudEvent cloudEvent) -> {
            try {
                BaseExplainabilityResult event = objectMapper.readValue(cloudEvent.getData().toBytes(), BaseExplainabilityResult.class);
                assertNotNull(event);
                assertResult(event);
                countDownLatch.countDown();
            } catch (IOException e) {
                LOGGER.error("Error parsing {}", s, e);
                throw new RuntimeException(e);
            }
        });
    });
    assertTrue(countDownLatch.await(5, TimeUnit.SECONDS));
    kafkaClient.shutdown();
}
Also used : BaseExplainabilityResult(org.kie.kogito.explainability.api.BaseExplainabilityResult) Consumer(java.util.function.Consumer) BaseExplainabilityRequest(org.kie.kogito.explainability.api.BaseExplainabilityRequest) IOException(java.io.IOException) CountDownLatch(java.util.concurrent.CountDownLatch) CloudEvent(io.cloudevents.CloudEvent) Test(org.junit.jupiter.api.Test)

Aggregations

BaseExplainabilityResult (org.kie.kogito.explainability.api.BaseExplainabilityResult)16 Test (org.junit.jupiter.api.Test)14 Consumer (java.util.function.Consumer)8 Prediction (org.kie.kogito.explainability.model.Prediction)7 CounterfactualExplainabilityResult (org.kie.kogito.explainability.api.CounterfactualExplainabilityResult)6 LIMEExplainabilityResult (org.kie.kogito.explainability.api.LIMEExplainabilityResult)6 BaseExplainabilityRequest (org.kie.kogito.explainability.api.BaseExplainabilityRequest)5 NamedTypedValue (org.kie.kogito.explainability.api.NamedTypedValue)4 Assertions.assertEquals (org.junit.jupiter.api.Assertions.assertEquals)3 Assertions.assertThrows (org.junit.jupiter.api.Assertions.assertThrows)3 Assertions.assertTrue (org.junit.jupiter.api.Assertions.assertTrue)3 BeforeEach (org.junit.jupiter.api.BeforeEach)3 CounterfactualExplainabilityRequest (org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest)3 ExplainabilityStatus (org.kie.kogito.explainability.api.ExplainabilityStatus)3 LIMEExplainabilityRequest (org.kie.kogito.explainability.api.LIMEExplainabilityRequest)3 ModelIdentifier (org.kie.kogito.explainability.api.ModelIdentifier)3 SaliencyModel (org.kie.kogito.explainability.api.SaliencyModel)3 Feature (org.kie.kogito.explainability.model.Feature)3 Output (org.kie.kogito.explainability.model.Output)3 Value (org.kie.kogito.explainability.model.Value)3