Search in sources :

Example 1 with UnitValue

use of org.kie.kogito.tracing.typedvalue.UnitValue in project kogito-apps by kiegroup.

the class ExplainabilityApiV1IT method testSalienciesWithExplainabilityResult.

@Test
void testSalienciesWithExplainabilityResult() {
    mockServiceWithExplainabilityResult();
    Decision decision = new Decision(TEST_EXECUTION_ID, "sourceUrl", "serviceUrl", 0L, true, "executorName", "executorModelName", "executorModelNamespace", new ArrayList<>(), new ArrayList<>());
    decision.getOutcomes().add(new DecisionOutcome("outcomeId1", "Output1", ExplainabilityStatus.SUCCEEDED.name(), new UnitValue("type", new IntNode(1)), Collections.emptyList(), Collections.emptyList()));
    decision.getOutcomes().add(new DecisionOutcome("outcomeId2", "Output2", ExplainabilityStatus.SUCCEEDED.name(), new UnitValue("type2", new IntNode(2)), Collections.emptyList(), Collections.emptyList()));
    when(executionService.getDecisionById(eq(TEST_EXECUTION_ID))).thenReturn(decision);
    SalienciesResponse response = given().filter(new ResponseLoggingFilter()).when().get("/executions/decisions/" + TEST_EXECUTION_ID + "/explanations/saliencies").as(SalienciesResponse.class);
    assertNotNull(response);
    assertNotNull(response.getSaliencies());
    assertSame(2, response.getSaliencies().size());
    List<SaliencyModel> sortedSaliencies = response.getSaliencies().stream().sorted((s1, s2) -> new CompareToBuilder().append(s1.getOutcomeName(), s2.getOutcomeName()).toComparison()).collect(Collectors.toList());
    assertNotNull(sortedSaliencies.get(0));
    assertEquals("Output1", sortedSaliencies.get(0).getOutcomeName());
    assertNotNull(sortedSaliencies.get(0).getFeatureImportance());
    assertSame(2, sortedSaliencies.get(0).getFeatureImportance().size());
    assertEquals("Feature1", sortedSaliencies.get(0).getFeatureImportance().get(0).getFeatureName());
    assertEquals(0.49384, sortedSaliencies.get(0).getFeatureImportance().get(0).getFeatureScore());
    assertEquals("Feature2", sortedSaliencies.get(0).getFeatureImportance().get(1).getFeatureName());
    assertEquals(-0.1084, sortedSaliencies.get(0).getFeatureImportance().get(1).getFeatureScore());
    assertNotNull(sortedSaliencies.get(1));
    assertEquals("Output2", sortedSaliencies.get(1).getOutcomeName());
    assertNotNull(sortedSaliencies.get(1).getFeatureImportance());
    assertSame(2, sortedSaliencies.get(1).getFeatureImportance().size());
    assertEquals("Feature1", sortedSaliencies.get(1).getFeatureImportance().get(0).getFeatureName());
    assertEquals(0.0, sortedSaliencies.get(1).getFeatureImportance().get(0).getFeatureScore());
    assertEquals("Feature2", sortedSaliencies.get(1).getFeatureImportance().get(1).getFeatureName());
    assertEquals(0.70293, sortedSaliencies.get(1).getFeatureImportance().get(1).getFeatureScore());
}
Also used : LIMEExplainabilityResult(org.kie.kogito.explainability.api.LIMEExplainabilityResult) Arrays(java.util.Arrays) FeatureImportanceModel(org.kie.kogito.explainability.api.FeatureImportanceModel) ArgumentMatchers.eq(org.mockito.ArgumentMatchers.eq) CounterfactualSearchDomainValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainValue) CounterfactualDomainRange(org.kie.kogito.explainability.api.CounterfactualDomainRange) TrustyServiceTestUtils.getCounterfactualJsonRequest(org.kie.kogito.trusty.service.common.TrustyServiceTestUtils.getCounterfactualJsonRequest) SalienciesResponse(org.kie.kogito.trusty.service.common.responses.SalienciesResponse) MediaType(javax.ws.rs.core.MediaType) Assertions.assertFalse(org.junit.jupiter.api.Assertions.assertFalse) Map(java.util.Map) JsonNode(com.fasterxml.jackson.databind.JsonNode) CounterfactualDomainCategorical(org.kie.kogito.explainability.api.CounterfactualDomainCategorical) BaseExplainabilityResult(org.kie.kogito.explainability.api.BaseExplainabilityResult) RequestLoggingFilter(io.restassured.filter.log.RequestLoggingFilter) CounterfactualExplainabilityRequest(org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) ExplainabilityStatus(org.kie.kogito.explainability.api.ExplainabilityStatus) Collectors(java.util.stream.Collectors) TypedValue(org.kie.kogito.tracing.typedvalue.TypedValue) UnitValue(org.kie.kogito.tracing.typedvalue.UnitValue) Test(org.junit.jupiter.api.Test) List(java.util.List) Assertions.assertTrue(org.junit.jupiter.api.Assertions.assertTrue) CounterfactualExplainabilityResult(org.kie.kogito.explainability.api.CounterfactualExplainabilityResult) RestAssured.given(io.restassured.RestAssured.given) CounterfactualSearchDomain(org.kie.kogito.explainability.api.CounterfactualSearchDomain) SaliencyModel(org.kie.kogito.explainability.api.SaliencyModel) TrustyService(org.kie.kogito.trusty.service.common.TrustyService) DecisionOutcome(org.kie.kogito.trusty.storage.api.model.decision.DecisionOutcome) ArgumentMatchers.any(org.mockito.ArgumentMatchers.any) Assertions.assertNotNull(org.junit.jupiter.api.Assertions.assertNotNull) Decision(org.kie.kogito.trusty.storage.api.model.decision.Decision) IntNode(com.fasterxml.jackson.databind.node.IntNode) Assertions.assertNull(org.junit.jupiter.api.Assertions.assertNull) CounterfactualRequestResponse(org.kie.kogito.trusty.service.common.responses.CounterfactualRequestResponse) ArrayList(java.util.ArrayList) QuarkusTest(io.quarkus.test.junit.QuarkusTest) ArgumentCaptor(org.mockito.ArgumentCaptor) CompareToBuilder(org.testcontainers.shaded.org.apache.commons.lang.builder.CompareToBuilder) Assertions.assertEquals(org.junit.jupiter.api.Assertions.assertEquals) InjectMock(io.quarkus.test.junit.mockito.InjectMock) Iterator(java.util.Iterator) ResponseLoggingFilter(io.restassured.filter.log.ResponseLoggingFilter) TrustyServiceTestUtils.getCounterfactualWithStructuredModelJsonRequest(org.kie.kogito.trusty.service.common.TrustyServiceTestUtils.getCounterfactualWithStructuredModelJsonRequest) Mockito.when(org.mockito.Mockito.when) Assertions.assertSame(org.junit.jupiter.api.Assertions.assertSame) Mockito.verify(org.mockito.Mockito.verify) Collections(java.util.Collections) ModelIdentifier(org.kie.kogito.explainability.api.ModelIdentifier) CounterfactualResultsResponse(org.kie.kogito.trusty.service.common.responses.CounterfactualResultsResponse) ArgumentMatchers.anyString(org.mockito.ArgumentMatchers.anyString) SalienciesResponse(org.kie.kogito.trusty.service.common.responses.SalienciesResponse) IntNode(com.fasterxml.jackson.databind.node.IntNode) SaliencyModel(org.kie.kogito.explainability.api.SaliencyModel) DecisionOutcome(org.kie.kogito.trusty.storage.api.model.decision.DecisionOutcome) UnitValue(org.kie.kogito.tracing.typedvalue.UnitValue) ResponseLoggingFilter(io.restassured.filter.log.ResponseLoggingFilter) CompareToBuilder(org.testcontainers.shaded.org.apache.commons.lang.builder.CompareToBuilder) Decision(org.kie.kogito.trusty.storage.api.model.decision.Decision) Test(org.junit.jupiter.api.Test) QuarkusTest(io.quarkus.test.junit.QuarkusTest)

Example 2 with UnitValue

use of org.kie.kogito.tracing.typedvalue.UnitValue in project kogito-apps by kiegroup.

the class LIMESaliencyConverterTest method testFromResult_DecisionExists.

@Test
public void testFromResult_DecisionExists() {
    LIMEExplainabilityResult result = LIMEExplainabilityResult.buildSucceeded(EXECUTION_ID, List.of(new SaliencyModel("outcomeName1", List.of(new FeatureImportanceModel("feature1a", 1.0), new FeatureImportanceModel("feature1b", 2.0))), new SaliencyModel("outcomeName2", List.of(new FeatureImportanceModel("feature2", 3.0)))));
    Decision decision = new Decision(EXECUTION_ID, "sourceUrl", "serviceUrl", 0L, true, "executorName", "executorModelName", "executorModelNamespace", new ArrayList<>(), new ArrayList<>());
    decision.getOutcomes().add(new DecisionOutcome("outcomeId1", "outcomeName1", ExplainabilityStatus.SUCCEEDED.name(), new UnitValue("type", new IntNode(1)), Collections.emptyList(), Collections.emptyList()));
    decision.getOutcomes().add(new DecisionOutcome("outcomeId2", "outcomeName2", ExplainabilityStatus.SUCCEEDED.name(), new UnitValue("type2", new IntNode(2)), Collections.emptyList(), Collections.emptyList()));
    when(trustyService.getDecisionById(eq(EXECUTION_ID))).thenReturn(decision);
    SalienciesResponse response = converter.fromResult(EXECUTION_ID, result);
    assertNotNull(response);
    assertEquals(ExplainabilityStatus.SUCCEEDED.name(), response.getStatus());
    assertEquals(2, response.getSaliencies().size());
    List<SaliencyResponse> saliencyResponses = new ArrayList<>(response.getSaliencies());
    SaliencyResponse saliencyResponse1 = saliencyResponses.get(0);
    assertEquals("outcomeId1", saliencyResponse1.getOutcomeId());
    assertEquals("outcomeName1", saliencyResponse1.getOutcomeName());
    assertEquals(2, saliencyResponse1.getFeatureImportance().size());
    Optional<FeatureImportanceModel> oFeatureImportance1Model1 = saliencyResponse1.getFeatureImportance().stream().filter(fim -> fim.getFeatureName().equals("feature1a")).findFirst();
    assertTrue(oFeatureImportance1Model1.isPresent());
    assertEquals(1.0, oFeatureImportance1Model1.get().getFeatureScore());
    Optional<FeatureImportanceModel> oFeatureImportance2Model1 = saliencyResponse1.getFeatureImportance().stream().filter(fim -> fim.getFeatureName().equals("feature1b")).findFirst();
    assertTrue(oFeatureImportance2Model1.isPresent());
    assertEquals(2.0, oFeatureImportance2Model1.get().getFeatureScore());
    SaliencyResponse saliencyResponse2 = saliencyResponses.get(1);
    assertEquals("outcomeId2", saliencyResponse2.getOutcomeId());
    assertEquals("outcomeName2", saliencyResponse2.getOutcomeName());
    assertEquals(1, saliencyResponse2.getFeatureImportance().size());
    Optional<FeatureImportanceModel> oFeatureImportance1Model2 = saliencyResponse2.getFeatureImportance().stream().filter(fim -> fim.getFeatureName().equals("feature2")).findFirst();
    assertTrue(oFeatureImportance1Model2.isPresent());
    assertEquals(3.0, oFeatureImportance1Model2.get().getFeatureScore());
}
Also used : Assertions.assertThrows(org.junit.jupiter.api.Assertions.assertThrows) BeforeEach(org.junit.jupiter.api.BeforeEach) LIMEExplainabilityResult(org.kie.kogito.explainability.api.LIMEExplainabilityResult) Assertions.assertNotNull(org.junit.jupiter.api.Assertions.assertNotNull) FeatureImportanceModel(org.kie.kogito.explainability.api.FeatureImportanceModel) Decision(org.kie.kogito.trusty.storage.api.model.decision.Decision) IntNode(com.fasterxml.jackson.databind.node.IntNode) Mock(org.mockito.Mock) ArgumentMatchers.eq(org.mockito.ArgumentMatchers.eq) SaliencyResponse(org.kie.kogito.trusty.service.common.responses.SaliencyResponse) ArrayList(java.util.ArrayList) SalienciesResponse(org.kie.kogito.trusty.service.common.responses.SalienciesResponse) ExtendWith(org.junit.jupiter.api.extension.ExtendWith) Assertions.assertEquals(org.junit.jupiter.api.Assertions.assertEquals) MockitoExtension(org.mockito.junit.jupiter.MockitoExtension) ExplainabilityStatus(org.kie.kogito.explainability.api.ExplainabilityStatus) Mockito.when(org.mockito.Mockito.when) UnitValue(org.kie.kogito.tracing.typedvalue.UnitValue) Test(org.junit.jupiter.api.Test) List(java.util.List) Assertions.assertTrue(org.junit.jupiter.api.Assertions.assertTrue) Optional(java.util.Optional) SaliencyModel(org.kie.kogito.explainability.api.SaliencyModel) Collections(java.util.Collections) TrustyService(org.kie.kogito.trusty.service.common.TrustyService) DecisionOutcome(org.kie.kogito.trusty.storage.api.model.decision.DecisionOutcome) SalienciesResponse(org.kie.kogito.trusty.service.common.responses.SalienciesResponse) LIMEExplainabilityResult(org.kie.kogito.explainability.api.LIMEExplainabilityResult) SaliencyResponse(org.kie.kogito.trusty.service.common.responses.SaliencyResponse) SaliencyModel(org.kie.kogito.explainability.api.SaliencyModel) DecisionOutcome(org.kie.kogito.trusty.storage.api.model.decision.DecisionOutcome) ArrayList(java.util.ArrayList) UnitValue(org.kie.kogito.tracing.typedvalue.UnitValue) Decision(org.kie.kogito.trusty.storage.api.model.decision.Decision) IntNode(com.fasterxml.jackson.databind.node.IntNode) FeatureImportanceModel(org.kie.kogito.explainability.api.FeatureImportanceModel) Test(org.junit.jupiter.api.Test)

Example 3 with UnitValue

use of org.kie.kogito.tracing.typedvalue.UnitValue in project kogito-apps by kiegroup.

the class CounterfactualExplainabilityRequestMarshallerTest method testWriteAndRead.

@Test
public void testWriteAndRead() throws IOException {
    ModelIdentifier modelIdentifier = new ModelIdentifier("resourceType", "resourceId");
    List<NamedTypedValue> originalInputs = Collections.singletonList(new NamedTypedValue("unitIn", new UnitValue("number", "number", JsonNodeFactory.instance.numberNode(10))));
    List<NamedTypedValue> goals = Collections.singletonList(new NamedTypedValue("unitIn", new UnitValue("number", "number", JsonNodeFactory.instance.numberNode(10))));
    List<CounterfactualSearchDomain> searchDomains = Collections.singletonList(new CounterfactualSearchDomain("age", new CounterfactualSearchDomainUnitValue("integer", "integer", Boolean.TRUE, new CounterfactualDomainRange(JsonNodeFactory.instance.numberNode(0), JsonNodeFactory.instance.numberNode(10)))));
    CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest("executionId", "serviceUrl", modelIdentifier, "counterfactualId", originalInputs, goals, searchDomains, 60L);
    CounterfactualExplainabilityRequestMarshaller marshaller = new CounterfactualExplainabilityRequestMarshaller(new ObjectMapper());
    marshaller.writeTo(writer, request);
    CounterfactualExplainabilityRequest retrieved = marshaller.readFrom(reader);
    Assertions.assertEquals(request.getExecutionId(), retrieved.getExecutionId());
    Assertions.assertEquals(request.getCounterfactualId(), retrieved.getCounterfactualId());
    Assertions.assertEquals(goals.get(0).getName(), retrieved.getGoals().stream().findFirst().get().getName());
    Assertions.assertEquals(searchDomains.get(0).getName(), retrieved.getSearchDomains().stream().findFirst().get().getName());
    Assertions.assertEquals(0, ((CounterfactualDomainRange) retrieved.getSearchDomains().stream().findFirst().get().getValue().toUnit().getDomain()).getLowerBound().asInt());
    Assertions.assertEquals(60L, request.getMaxRunningTimeSeconds());
}
Also used : CounterfactualExplainabilityRequest(org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) CounterfactualDomainRange(org.kie.kogito.explainability.api.CounterfactualDomainRange) ModelIdentifier(org.kie.kogito.explainability.api.ModelIdentifier) UnitValue(org.kie.kogito.tracing.typedvalue.UnitValue) CounterfactualSearchDomainUnitValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue) CounterfactualSearchDomainUnitValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper) CounterfactualSearchDomain(org.kie.kogito.explainability.api.CounterfactualSearchDomain) Test(org.junit.jupiter.api.Test)

Example 4 with UnitValue

use of org.kie.kogito.tracing.typedvalue.UnitValue in project kogito-apps by kiegroup.

the class DecisionMarshallerTest method testWriteAndRead.

@Test
public void testWriteAndRead() throws IOException {
    List<DecisionInput> inputs = Collections.singletonList(new DecisionInput("id", "in", new UnitValue("nameIn", "number", JsonNodeFactory.instance.numberNode(10))));
    List<DecisionOutcome> outcomes = Collections.singletonList(new DecisionOutcome("id", "out", DMNDecisionResult.DecisionEvaluationStatus.SUCCEEDED.toString(), new UnitValue("nameOut", "number", JsonNodeFactory.instance.numberNode(10)), List.of(new NamedTypedValue("nameOut", new UnitValue("number", "number", JsonNodeFactory.instance.numberNode(10)))), new ArrayList<>()));
    Decision decision = new Decision("executionId", "source", "serviceUrl", 0L, true, "executor", "model", "namespace", inputs, outcomes);
    DecisionMarshaller marshaller = new DecisionMarshaller(new ObjectMapper());
    marshaller.writeTo(writer, decision);
    Decision retrieved = marshaller.readFrom(reader);
    Assertions.assertEquals(decision.getExecutionId(), retrieved.getExecutionId());
    Assertions.assertEquals(decision.getSourceUrl(), retrieved.getSourceUrl());
    Assertions.assertEquals(decision.getServiceUrl(), retrieved.getServiceUrl());
    Assertions.assertEquals(decision.getExecutedModelName(), retrieved.getExecutedModelName());
    Assertions.assertEquals(inputs.get(0).getName(), retrieved.getInputs().stream().findFirst().get().getName());
    Assertions.assertEquals(inputs.get(0).getValue().getType(), retrieved.getInputs().stream().findFirst().get().getValue().getType());
    Assertions.assertEquals(inputs.get(0).getValue().toUnit().getBaseType(), retrieved.getInputs().stream().findFirst().get().getValue().toUnit().getBaseType());
    Assertions.assertEquals(outcomes.get(0).getOutcomeId(), retrieved.getOutcomes().stream().findFirst().get().getOutcomeId());
    Assertions.assertTrue(retrieved.getOutcomes().stream().findFirst().get().getMessages().isEmpty());
}
Also used : DecisionInput(org.kie.kogito.trusty.storage.api.model.decision.DecisionInput) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) DecisionOutcome(org.kie.kogito.trusty.storage.api.model.decision.DecisionOutcome) ArrayList(java.util.ArrayList) UnitValue(org.kie.kogito.tracing.typedvalue.UnitValue) Decision(org.kie.kogito.trusty.storage.api.model.decision.Decision) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper) Test(org.junit.jupiter.api.Test)

Example 5 with UnitValue

use of org.kie.kogito.tracing.typedvalue.UnitValue in project kogito-apps by kiegroup.

the class CounterfactualExplainerServiceHandlerTest method testGetPredictionWithStructuredInputModel.

@Test
public void testGetPredictionWithStructuredInputModel() {
    CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, List.of(new NamedTypedValue("input1", new StructureValue("number", Map.of("input2b", new UnitValue("number", new IntNode(55)))))), Collections.emptyList(), Collections.emptyList(), MAX_RUNNING_TIME_SECONDS);
    assertThrows(IllegalArgumentException.class, () -> handler.getPrediction(request));
}
Also used : CounterfactualExplainabilityRequest(org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) IntNode(com.fasterxml.jackson.databind.node.IntNode) UnitValue(org.kie.kogito.tracing.typedvalue.UnitValue) CounterfactualSearchDomainUnitValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue) StructureValue(org.kie.kogito.tracing.typedvalue.StructureValue) CounterfactualSearchDomainStructureValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainStructureValue) Test(org.junit.jupiter.api.Test)

Aggregations

UnitValue (org.kie.kogito.tracing.typedvalue.UnitValue)31 Test (org.junit.jupiter.api.Test)24 NamedTypedValue (org.kie.kogito.explainability.api.NamedTypedValue)23 IntNode (com.fasterxml.jackson.databind.node.IntNode)18 CounterfactualSearchDomainUnitValue (org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue)17 CounterfactualExplainabilityRequest (org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest)14 StructureValue (org.kie.kogito.tracing.typedvalue.StructureValue)12 List (java.util.List)10 CollectionValue (org.kie.kogito.tracing.typedvalue.CollectionValue)10 CounterfactualSearchDomain (org.kie.kogito.explainability.api.CounterfactualSearchDomain)9 ArrayList (java.util.ArrayList)8 CounterfactualExplainabilityResult (org.kie.kogito.explainability.api.CounterfactualExplainabilityResult)8 CounterfactualSearchDomainStructureValue (org.kie.kogito.explainability.api.CounterfactualSearchDomainStructureValue)8 TypedValue (org.kie.kogito.tracing.typedvalue.TypedValue)8 Decision (org.kie.kogito.trusty.storage.api.model.decision.Decision)8 Collections (java.util.Collections)7 Assertions.assertEquals (org.junit.jupiter.api.Assertions.assertEquals)7 Assertions.assertTrue (org.junit.jupiter.api.Assertions.assertTrue)7 CounterfactualDomainRange (org.kie.kogito.explainability.api.CounterfactualDomainRange)7 ExplainabilityStatus (org.kie.kogito.explainability.api.ExplainabilityStatus)7