Search in sources :

Example 16 with NamedTypedValue

use of org.kie.kogito.explainability.api.NamedTypedValue in project kogito-apps by kiegroup.

the class ExplanationServiceImplTest method testCounterfactualsExplainAsyncSuccess.

@SuppressWarnings("unchecked")
void testCounterfactualsExplainAsyncSuccess(ThrowingSupplier<BaseExplainabilityResult> invocation) {
    when(instance.stream()).thenReturn(Stream.of(cfExplainerServiceHandlerMock));
    when(cfExplainerMock.explainAsync(any(Prediction.class), eq(predictionProviderMock), any(Consumer.class))).thenReturn(CompletableFuture.completedFuture(COUNTERFACTUAL_RESULT));
    BaseExplainabilityResult result = assertDoesNotThrow(invocation);
    assertNotNull(result);
    assertTrue(result instanceof CounterfactualExplainabilityResult);
    CounterfactualExplainabilityResult counterfactualResult = (CounterfactualExplainabilityResult) result;
    assertEquals(EXECUTION_ID, counterfactualResult.getExecutionId());
    assertEquals(COUNTERFACTUAL_ID, counterfactualResult.getCounterfactualId());
    assertSame(ExplainabilityStatus.SUCCEEDED, counterfactualResult.getStatus());
    assertNull(counterfactualResult.getStatusDetails());
    assertEquals(COUNTERFACTUAL_RESULT.getEntities().size(), counterfactualResult.getInputs().size());
    assertEquals(COUNTERFACTUAL_RESULT.getOutput().size(), counterfactualResult.getOutputs().size());
    assertTrue(counterfactualResult.getOutputs().stream().anyMatch(o -> o.getName().equals("output1")));
    NamedTypedValue value = counterfactualResult.getOutputs().iterator().next();
    assertTrue(value.getValue().isUnit());
    assertEquals(Double.class.getSimpleName(), value.getValue().toUnit().getType());
    assertEquals(555.0, value.getValue().toUnit().getValue().asDouble());
}
Also used : Assertions.assertThrows(org.junit.jupiter.api.Assertions.assertThrows) SALIENCY_MAP(org.kie.kogito.explainability.TestUtils.SALIENCY_MAP) ArgumentMatchers.any(org.mockito.ArgumentMatchers.any) BeforeEach(org.junit.jupiter.api.BeforeEach) LIMEExplainabilityResult(org.kie.kogito.explainability.api.LIMEExplainabilityResult) Assertions.assertNotNull(org.junit.jupiter.api.Assertions.assertNotNull) LimeExplainerServiceHandler(org.kie.kogito.explainability.handlers.LimeExplainerServiceHandler) FeatureImportanceModel(org.kie.kogito.explainability.api.FeatureImportanceModel) COUNTERFACTUAL_ID(org.kie.kogito.explainability.TestUtils.COUNTERFACTUAL_ID) Prediction(org.kie.kogito.explainability.model.Prediction) ArgumentMatchers.eq(org.mockito.ArgumentMatchers.eq) Assertions.assertNull(org.junit.jupiter.api.Assertions.assertNull) CompletableFuture(java.util.concurrent.CompletableFuture) Mockito.spy(org.mockito.Mockito.spy) FEATURE_IMPORTANCE_1(org.kie.kogito.explainability.TestUtils.FEATURE_IMPORTANCE_1) Mockito.doThrow(org.mockito.Mockito.doThrow) CounterfactualExplainerServiceHandler(org.kie.kogito.explainability.handlers.CounterfactualExplainerServiceHandler) Assertions.assertEquals(org.junit.jupiter.api.Assertions.assertEquals) BaseExplainabilityResult(org.kie.kogito.explainability.api.BaseExplainabilityResult) SALIENCY(org.kie.kogito.explainability.TestUtils.SALIENCY) Instance(javax.enterprise.inject.Instance) LocalExplainerServiceHandlerRegistry(org.kie.kogito.explainability.handlers.LocalExplainerServiceHandlerRegistry) EXECUTION_ID(org.kie.kogito.explainability.TestUtils.EXECUTION_ID) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) LimeExplainer(org.kie.kogito.explainability.local.lime.LimeExplainer) CounterfactualExplainer(org.kie.kogito.explainability.local.counterfactual.CounterfactualExplainer) ExplainabilityStatus(org.kie.kogito.explainability.api.ExplainabilityStatus) Mockito.when(org.mockito.Mockito.when) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Assertions.assertSame(org.junit.jupiter.api.Assertions.assertSame) COUNTERFACTUAL_RESULT(org.kie.kogito.explainability.TestUtils.COUNTERFACTUAL_RESULT) Consumer(java.util.function.Consumer) Test(org.junit.jupiter.api.Test) COUNTERFACTUAL_REQUEST(org.kie.kogito.explainability.TestUtils.COUNTERFACTUAL_REQUEST) LIME_REQUEST(org.kie.kogito.explainability.TestUtils.LIME_REQUEST) Stream(java.util.stream.Stream) ThrowingSupplier(org.junit.jupiter.api.function.ThrowingSupplier) Assertions.assertTrue(org.junit.jupiter.api.Assertions.assertTrue) CounterfactualExplainabilityResult(org.kie.kogito.explainability.api.CounterfactualExplainabilityResult) SaliencyModel(org.kie.kogito.explainability.api.SaliencyModel) Assertions.assertDoesNotThrow(org.junit.jupiter.api.Assertions.assertDoesNotThrow) Mockito.mock(org.mockito.Mockito.mock) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) Consumer(java.util.function.Consumer) BaseExplainabilityResult(org.kie.kogito.explainability.api.BaseExplainabilityResult) Prediction(org.kie.kogito.explainability.model.Prediction) CounterfactualExplainabilityResult(org.kie.kogito.explainability.api.CounterfactualExplainabilityResult)

Example 17 with NamedTypedValue

use of org.kie.kogito.explainability.api.NamedTypedValue in project kogito-apps by kiegroup.

the class CounterfactualExplainerServiceHandlerTest method testGetPredictionWithStructuredOutputModel.

@Test
public void testGetPredictionWithStructuredOutputModel() {
    CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, Collections.emptyList(), List.of(new NamedTypedValue("input1", new StructureValue("number", Map.of("input2b", new UnitValue("number", new IntNode(55)))))), Collections.emptyList(), MAX_RUNNING_TIME_SECONDS);
    assertThrows(IllegalArgumentException.class, () -> handler.getPrediction(request));
}
Also used : CounterfactualExplainabilityRequest(org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) IntNode(com.fasterxml.jackson.databind.node.IntNode) UnitValue(org.kie.kogito.tracing.typedvalue.UnitValue) CounterfactualSearchDomainUnitValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue) StructureValue(org.kie.kogito.tracing.typedvalue.StructureValue) CounterfactualSearchDomainStructureValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainStructureValue) Test(org.junit.jupiter.api.Test)

Example 18 with NamedTypedValue

use of org.kie.kogito.explainability.api.NamedTypedValue in project kogito-apps by kiegroup.

the class CounterfactualExplainerServiceHandlerTest method testGetPredictionWithFlatInputModel.

@Test
public void testGetPredictionWithFlatInputModel() {
    CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, List.of(new NamedTypedValue("input1", new UnitValue("number", new IntNode(20)))), Collections.emptyList(), Collections.emptyList(), MAX_RUNNING_TIME_SECONDS);
    Prediction prediction = handler.getPrediction(request);
    assertTrue(prediction instanceof CounterfactualPrediction);
    CounterfactualPrediction counterfactualPrediction = (CounterfactualPrediction) prediction;
    assertEquals(1, counterfactualPrediction.getInput().getFeatures().size());
    Optional<Feature> oInput1 = counterfactualPrediction.getInput().getFeatures().stream().filter(f -> f.getName().equals("input1")).findFirst();
    assertTrue(oInput1.isPresent());
    Feature input1 = oInput1.get();
    assertEquals(Type.NUMBER, input1.getType());
    assertEquals(20, input1.getValue().asNumber());
    assertTrue(counterfactualPrediction.getInput().getFeatures().stream().allMatch(f -> f.getDomain().isEmpty()));
    assertTrue(counterfactualPrediction.getInput().getFeatures().stream().allMatch(Feature::isConstrained));
    assertTrue(counterfactualPrediction.getOutput().getOutputs().isEmpty());
    assertEquals(counterfactualPrediction.getMaxRunningTimeSeconds(), request.getMaxRunningTimeSeconds());
}
Also used : CounterfactualExplainabilityRequest(org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest) BeforeEach(org.junit.jupiter.api.BeforeEach) BaseExplainabilityRequest(org.kie.kogito.explainability.api.BaseExplainabilityRequest) Feature(org.kie.kogito.explainability.model.Feature) ArgumentMatchers.eq(org.mockito.ArgumentMatchers.eq) CounterfactualDomainRange(org.kie.kogito.explainability.api.CounterfactualDomainRange) CounterfactualEntity(org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity) Value(org.kie.kogito.explainability.model.Value) CounterfactualResult(org.kie.kogito.explainability.local.counterfactual.CounterfactualResult) Assertions.assertFalse(org.junit.jupiter.api.Assertions.assertFalse) Map(java.util.Map) EmptyFeatureDomain(org.kie.kogito.explainability.model.domain.EmptyFeatureDomain) BaseExplainabilityResult(org.kie.kogito.explainability.api.BaseExplainabilityResult) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) CounterfactualExplainabilityRequest(org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) ExplainabilityStatus(org.kie.kogito.explainability.api.ExplainabilityStatus) UUID(java.util.UUID) Collectors(java.util.stream.Collectors) TypedValue(org.kie.kogito.tracing.typedvalue.TypedValue) UnitValue(org.kie.kogito.tracing.typedvalue.UnitValue) Test(org.junit.jupiter.api.Test) List(java.util.List) CollectionValue(org.kie.kogito.tracing.typedvalue.CollectionValue) Output(org.kie.kogito.explainability.model.Output) Assertions.assertTrue(org.junit.jupiter.api.Assertions.assertTrue) Optional(java.util.Optional) CounterfactualExplainabilityResult(org.kie.kogito.explainability.api.CounterfactualExplainabilityResult) CounterfactualSearchDomain(org.kie.kogito.explainability.api.CounterfactualSearchDomain) Mockito.mock(org.mockito.Mockito.mock) Assertions.assertThrows(org.junit.jupiter.api.Assertions.assertThrows) ArgumentMatchers.any(org.mockito.ArgumentMatchers.any) IntNode(com.fasterxml.jackson.databind.node.IntNode) Prediction(org.kie.kogito.explainability.model.Prediction) StructureValue(org.kie.kogito.tracing.typedvalue.StructureValue) CounterfactualSearchDomainStructureValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainStructureValue) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) PredictionProviderFactory(org.kie.kogito.explainability.PredictionProviderFactory) Assertions.assertEquals(org.junit.jupiter.api.Assertions.assertEquals) CounterfactualSearchDomainCollectionValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainCollectionValue) CounterfactualExplainer(org.kie.kogito.explainability.local.counterfactual.CounterfactualExplainer) Type(org.kie.kogito.explainability.model.Type) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Mockito.verify(org.mockito.Mockito.verify) Consumer(java.util.function.Consumer) DoubleEntity(org.kie.kogito.explainability.local.counterfactual.entities.DoubleEntity) CounterfactualSearchDomainUnitValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue) DoubleNode(com.fasterxml.jackson.databind.node.DoubleNode) NumericalFeatureDomain(org.kie.kogito.explainability.model.domain.NumericalFeatureDomain) BooleanNode(com.fasterxml.jackson.databind.node.BooleanNode) Collections(java.util.Collections) ModelIdentifier(org.kie.kogito.explainability.api.ModelIdentifier) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) IntNode(com.fasterxml.jackson.databind.node.IntNode) Prediction(org.kie.kogito.explainability.model.Prediction) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) UnitValue(org.kie.kogito.tracing.typedvalue.UnitValue) CounterfactualSearchDomainUnitValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue) Feature(org.kie.kogito.explainability.model.Feature) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) Test(org.junit.jupiter.api.Test)

Example 19 with NamedTypedValue

use of org.kie.kogito.explainability.api.NamedTypedValue in project kogito-apps by kiegroup.

the class CounterfactualExplainerServiceHandlerTest method testGetPredictionWithFlatOutputModel.

@Test
public void testGetPredictionWithFlatOutputModel() {
    CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, Collections.emptyList(), List.of(new NamedTypedValue("output1", new UnitValue("number", new IntNode(20)))), Collections.emptyList(), MAX_RUNNING_TIME_SECONDS);
    Prediction prediction = handler.getPrediction(request);
    assertTrue(prediction instanceof CounterfactualPrediction);
    CounterfactualPrediction counterfactualPrediction = (CounterfactualPrediction) prediction;
    assertEquals(1, counterfactualPrediction.getOutput().getOutputs().size());
    Optional<Output> oOutput1 = counterfactualPrediction.getOutput().getOutputs().stream().filter(f -> f.getName().equals("output1")).findFirst();
    assertTrue(oOutput1.isPresent());
    Output output1 = oOutput1.get();
    assertEquals(Type.NUMBER, output1.getType());
    assertEquals(20, output1.getValue().asNumber());
    assertTrue(counterfactualPrediction.getInput().getFeatures().isEmpty());
    assertEquals(counterfactualPrediction.getMaxRunningTimeSeconds(), request.getMaxRunningTimeSeconds());
}
Also used : CounterfactualExplainabilityRequest(org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest) BeforeEach(org.junit.jupiter.api.BeforeEach) BaseExplainabilityRequest(org.kie.kogito.explainability.api.BaseExplainabilityRequest) Feature(org.kie.kogito.explainability.model.Feature) ArgumentMatchers.eq(org.mockito.ArgumentMatchers.eq) CounterfactualDomainRange(org.kie.kogito.explainability.api.CounterfactualDomainRange) CounterfactualEntity(org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity) Value(org.kie.kogito.explainability.model.Value) CounterfactualResult(org.kie.kogito.explainability.local.counterfactual.CounterfactualResult) Assertions.assertFalse(org.junit.jupiter.api.Assertions.assertFalse) Map(java.util.Map) EmptyFeatureDomain(org.kie.kogito.explainability.model.domain.EmptyFeatureDomain) BaseExplainabilityResult(org.kie.kogito.explainability.api.BaseExplainabilityResult) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) CounterfactualExplainabilityRequest(org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) ExplainabilityStatus(org.kie.kogito.explainability.api.ExplainabilityStatus) UUID(java.util.UUID) Collectors(java.util.stream.Collectors) TypedValue(org.kie.kogito.tracing.typedvalue.TypedValue) UnitValue(org.kie.kogito.tracing.typedvalue.UnitValue) Test(org.junit.jupiter.api.Test) List(java.util.List) CollectionValue(org.kie.kogito.tracing.typedvalue.CollectionValue) Output(org.kie.kogito.explainability.model.Output) Assertions.assertTrue(org.junit.jupiter.api.Assertions.assertTrue) Optional(java.util.Optional) CounterfactualExplainabilityResult(org.kie.kogito.explainability.api.CounterfactualExplainabilityResult) CounterfactualSearchDomain(org.kie.kogito.explainability.api.CounterfactualSearchDomain) Mockito.mock(org.mockito.Mockito.mock) Assertions.assertThrows(org.junit.jupiter.api.Assertions.assertThrows) ArgumentMatchers.any(org.mockito.ArgumentMatchers.any) IntNode(com.fasterxml.jackson.databind.node.IntNode) Prediction(org.kie.kogito.explainability.model.Prediction) StructureValue(org.kie.kogito.tracing.typedvalue.StructureValue) CounterfactualSearchDomainStructureValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainStructureValue) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) PredictionProviderFactory(org.kie.kogito.explainability.PredictionProviderFactory) Assertions.assertEquals(org.junit.jupiter.api.Assertions.assertEquals) CounterfactualSearchDomainCollectionValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainCollectionValue) CounterfactualExplainer(org.kie.kogito.explainability.local.counterfactual.CounterfactualExplainer) Type(org.kie.kogito.explainability.model.Type) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Mockito.verify(org.mockito.Mockito.verify) Consumer(java.util.function.Consumer) DoubleEntity(org.kie.kogito.explainability.local.counterfactual.entities.DoubleEntity) CounterfactualSearchDomainUnitValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue) DoubleNode(com.fasterxml.jackson.databind.node.DoubleNode) NumericalFeatureDomain(org.kie.kogito.explainability.model.domain.NumericalFeatureDomain) BooleanNode(com.fasterxml.jackson.databind.node.BooleanNode) Collections(java.util.Collections) ModelIdentifier(org.kie.kogito.explainability.api.ModelIdentifier) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) IntNode(com.fasterxml.jackson.databind.node.IntNode) Prediction(org.kie.kogito.explainability.model.Prediction) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) UnitValue(org.kie.kogito.tracing.typedvalue.UnitValue) CounterfactualSearchDomainUnitValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) Test(org.junit.jupiter.api.Test)

Example 20 with NamedTypedValue

use of org.kie.kogito.explainability.api.NamedTypedValue in project kogito-apps by kiegroup.

the class CounterfactualExplainerServiceHandlerTest method testGetPredictionWithCollectionInputModel.

@Test
public void testGetPredictionWithCollectionInputModel() {
    CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, List.of(new NamedTypedValue("input1", new CollectionValue("number", List.of(new UnitValue("number", new IntNode(100)))))), Collections.emptyList(), Collections.emptyList(), MAX_RUNNING_TIME_SECONDS);
    assertThrows(IllegalArgumentException.class, () -> handler.getPrediction(request));
}
Also used : CounterfactualExplainabilityRequest(org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) CollectionValue(org.kie.kogito.tracing.typedvalue.CollectionValue) CounterfactualSearchDomainCollectionValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainCollectionValue) IntNode(com.fasterxml.jackson.databind.node.IntNode) UnitValue(org.kie.kogito.tracing.typedvalue.UnitValue) CounterfactualSearchDomainUnitValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue) Test(org.junit.jupiter.api.Test)

Aggregations

NamedTypedValue (org.kie.kogito.explainability.api.NamedTypedValue)33 Test (org.junit.jupiter.api.Test)27 UnitValue (org.kie.kogito.tracing.typedvalue.UnitValue)22 IntNode (com.fasterxml.jackson.databind.node.IntNode)19 CounterfactualExplainabilityRequest (org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest)19 CounterfactualSearchDomain (org.kie.kogito.explainability.api.CounterfactualSearchDomain)19 CounterfactualSearchDomainUnitValue (org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue)16 CounterfactualDomainRange (org.kie.kogito.explainability.api.CounterfactualDomainRange)12 CounterfactualExplainabilityResult (org.kie.kogito.explainability.api.CounterfactualExplainabilityResult)11 StructureValue (org.kie.kogito.tracing.typedvalue.StructureValue)11 List (java.util.List)9 CounterfactualSearchDomainStructureValue (org.kie.kogito.explainability.api.CounterfactualSearchDomainStructureValue)9 Prediction (org.kie.kogito.explainability.model.Prediction)9 CollectionValue (org.kie.kogito.tracing.typedvalue.CollectionValue)9 Map (java.util.Map)8 Assertions.assertEquals (org.junit.jupiter.api.Assertions.assertEquals)8 Assertions.assertThrows (org.junit.jupiter.api.Assertions.assertThrows)8 Assertions.assertTrue (org.junit.jupiter.api.Assertions.assertTrue)8 ModelIdentifier (org.kie.kogito.explainability.api.ModelIdentifier)8 CounterfactualPrediction (org.kie.kogito.explainability.model.CounterfactualPrediction)8