Search in sources :

Example 26 with CounterfactualExplainabilityResult

use of org.kie.kogito.explainability.api.CounterfactualExplainabilityResult in project kogito-apps by kiegroup.

the class ExplainerServiceHandlerRegistryTest method testCounterfactual_getExplainabilityResultByIdWithOnlyIntermediateResults.

@Test
@SuppressWarnings({ "rawtypes", "unchecked" })
public void testCounterfactual_getExplainabilityResultByIdWithOnlyIntermediateResults() {
    Query query = mock(Query.class);
    CounterfactualExplainabilityResult result = mock(CounterfactualExplainabilityResult.class);
    when(result.getStage()).thenReturn(CounterfactualExplainabilityResult.Stage.INTERMEDIATE);
    when(storageCounterfactual.containsKey(eq(EXECUTION_ID))).thenReturn(true);
    when(storageCounterfactual.query()).thenReturn(query);
    when(query.filter(any())).thenReturn(query);
    when(query.execute()).thenReturn(List.of(result));
    assertThrows(IllegalArgumentException.class, () -> registry.getExplainabilityResultById(EXECUTION_ID, CounterfactualExplainabilityResult.class));
}
Also used : Query(org.kie.kogito.persistence.api.query.Query) CounterfactualExplainabilityResult(org.kie.kogito.explainability.api.CounterfactualExplainabilityResult) Test(org.junit.jupiter.api.Test)

Example 27 with CounterfactualExplainabilityResult

use of org.kie.kogito.explainability.api.CounterfactualExplainabilityResult in project kogito-apps by kiegroup.

the class ExplainerServiceHandlerRegistryTest method testCounterfactual_getExplainabilityResultByIdWithAllResults.

@Test
@SuppressWarnings({ "rawtypes", "unchecked" })
public void testCounterfactual_getExplainabilityResultByIdWithAllResults() {
    Query query = mock(Query.class);
    CounterfactualExplainabilityResult result1 = mock(CounterfactualExplainabilityResult.class);
    CounterfactualExplainabilityResult result2 = mock(CounterfactualExplainabilityResult.class);
    when(result1.getStage()).thenReturn(CounterfactualExplainabilityResult.Stage.INTERMEDIATE);
    when(result2.getStage()).thenReturn(CounterfactualExplainabilityResult.Stage.FINAL);
    when(storageCounterfactual.containsKey(eq(EXECUTION_ID))).thenReturn(true);
    when(storageCounterfactual.query()).thenReturn(query);
    when(query.filter(any())).thenReturn(query);
    when(query.execute()).thenReturn(List.of(result1, result2));
    CounterfactualExplainabilityResult actual = registry.getExplainabilityResultById(EXECUTION_ID, CounterfactualExplainabilityResult.class);
    verify(counterfactualExplainerServiceHandler).getExplainabilityResultById(eq(EXECUTION_ID));
    assertEquals(result2, actual);
}
Also used : Query(org.kie.kogito.persistence.api.query.Query) CounterfactualExplainabilityResult(org.kie.kogito.explainability.api.CounterfactualExplainabilityResult) Test(org.junit.jupiter.api.Test)

Example 28 with CounterfactualExplainabilityResult

use of org.kie.kogito.explainability.api.CounterfactualExplainabilityResult in project kogito-apps by kiegroup.

the class TrustyServiceTest method givenStoredCounterfactualResultsWhenGetCounterfactualResultsThenResultsAreReturned.

@Test
@SuppressWarnings({ "unchecked", "rawtypes" })
void givenStoredCounterfactualResultsWhenGetCounterfactualResultsThenResultsAreReturned() {
    Storage<String, CounterfactualExplainabilityResult> counterfactualStorage = mock(Storage.class);
    CounterfactualExplainabilityResult result1 = mock(CounterfactualExplainabilityResult.class);
    CounterfactualExplainabilityResult result2 = mock(CounterfactualExplainabilityResult.class);
    Query queryMock = mock(Query.class);
    when(queryMock.filter(any(List.class))).thenReturn(queryMock);
    when(queryMock.execute()).thenReturn(List.of(result1, result2));
    when(counterfactualStorage.query()).thenReturn(queryMock);
    when(trustyStorageServiceMock.getCounterfactualResultStorage()).thenReturn(counterfactualStorage);
    assertTrue(trustyService.getCounterfactualResults(TEST_EXECUTION_ID, TEST_COUNTERFACTUAL_ID).containsAll(List.of(result1, result2)));
}
Also used : Query(org.kie.kogito.persistence.api.query.Query) List(java.util.List) ArrayList(java.util.ArrayList) ArgumentMatchers.anyString(org.mockito.ArgumentMatchers.anyString) CounterfactualExplainabilityResult(org.kie.kogito.explainability.api.CounterfactualExplainabilityResult) Test(org.junit.jupiter.api.Test)

Example 29 with CounterfactualExplainabilityResult

use of org.kie.kogito.explainability.api.CounterfactualExplainabilityResult in project kogito-apps by kiegroup.

the class AbstractTrustyServiceIT method testStoreExplainabilityResult_Counterfactual_DuplicateRemoval_FinalThenIntermediate.

@Test
public void testStoreExplainabilityResult_Counterfactual_DuplicateRemoval_FinalThenIntermediate() {
    String executionId = "myCFExecution1Store";
    String counterfactualId = "myCFCounterfactualId";
    NamedTypedValue input1 = new NamedTypedValue("field1", new UnitValue("typeRef1", "typeRef1", new IntNode(25)));
    NamedTypedValue input2 = new NamedTypedValue("field2", new UnitValue("typeRef2", "typeRef2", new IntNode(99)));
    NamedTypedValue output1 = new NamedTypedValue("field3", new UnitValue("typeRef3", "typeRef3", new IntNode(200)));
    NamedTypedValue output2 = new NamedTypedValue("field4", new UnitValue("typeRef4", "typeRef4", new IntNode(1000)));
    // First solution is the FINAL (for whatever reason, e.g. messaging delays, the INTERMEDIATE is received afterwards)
    trustyService.storeExplainabilityResult(executionId, new CounterfactualExplainabilityResult(executionId, counterfactualId, "solutionId1", 0L, ExplainabilityStatus.SUCCEEDED, "status", true, CounterfactualExplainabilityResult.Stage.FINAL, List.of(input1, input2), List.of(output1, output2)));
    List<CounterfactualExplainabilityResult> result1 = trustyService.getCounterfactualResults(executionId, counterfactualId);
    assertNotNull(result1);
    assertEquals(1, result1.size());
    assertEquals("solutionId1", result1.get(0).getSolutionId());
    assertEquals(CounterfactualExplainabilityResult.Stage.FINAL, result1.get(0).getStage());
    trustyService.storeExplainabilityResult(executionId, new CounterfactualExplainabilityResult(executionId, counterfactualId, "solutionId2", 0L, ExplainabilityStatus.SUCCEEDED, "status", true, CounterfactualExplainabilityResult.Stage.INTERMEDIATE, List.of(input1, input2), List.of(output1, output2)));
    List<CounterfactualExplainabilityResult> result2 = trustyService.getCounterfactualResults(executionId, counterfactualId);
    assertNotNull(result2);
    assertEquals(1, result1.size());
    assertEquals("solutionId1", result1.get(0).getSolutionId());
    assertEquals(CounterfactualExplainabilityResult.Stage.FINAL, result1.get(0).getStage());
}
Also used : NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) IntNode(com.fasterxml.jackson.databind.node.IntNode) UnitValue(org.kie.kogito.tracing.typedvalue.UnitValue) CounterfactualExplainabilityResult(org.kie.kogito.explainability.api.CounterfactualExplainabilityResult) Test(org.junit.jupiter.api.Test)

Aggregations

CounterfactualExplainabilityResult (org.kie.kogito.explainability.api.CounterfactualExplainabilityResult)29 Test (org.junit.jupiter.api.Test)26 NamedTypedValue (org.kie.kogito.explainability.api.NamedTypedValue)8 BaseExplainabilityResult (org.kie.kogito.explainability.api.BaseExplainabilityResult)6 UnitValue (org.kie.kogito.tracing.typedvalue.UnitValue)6 IntNode (com.fasterxml.jackson.databind.node.IntNode)5 ArrayList (java.util.ArrayList)4 List (java.util.List)4 Consumer (java.util.function.Consumer)4 CounterfactualExplainabilityRequest (org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest)4 Assertions.assertEquals (org.junit.jupiter.api.Assertions.assertEquals)3 Assertions.assertThrows (org.junit.jupiter.api.Assertions.assertThrows)3 Assertions.assertTrue (org.junit.jupiter.api.Assertions.assertTrue)3 BeforeEach (org.junit.jupiter.api.BeforeEach)3 CounterfactualSearchDomain (org.kie.kogito.explainability.api.CounterfactualSearchDomain)3 ExplainabilityStatus (org.kie.kogito.explainability.api.ExplainabilityStatus)3 ModelIdentifier (org.kie.kogito.explainability.api.ModelIdentifier)3 Prediction (org.kie.kogito.explainability.model.Prediction)3 Query (org.kie.kogito.persistence.api.query.Query)3 BooleanNode (com.fasterxml.jackson.databind.node.BooleanNode)2