Search in sources :

Example 1 with CounterfactualRequestResponse

use of org.kie.kogito.trusty.service.common.responses.CounterfactualRequestResponse in project kogito-apps by kiegroup.

the class ExplainabilityApiV1Test method testRequestCounterfactualsWhenExecutionDoesExist.

@Test
public void testRequestCounterfactualsWhenExecutionDoesExist() {
    when(trustyService.requestCounterfactuals(anyString(), any(), any())).thenReturn(new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, new ModelIdentifier("resourceType", "resourceIdentifier"), COUNTERFACTUAL_ID, Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), MAX_RUNNING_TIME_SECONDS));
    org.kie.kogito.trusty.service.common.requests.CounterfactualRequest request = new org.kie.kogito.trusty.service.common.requests.CounterfactualRequest(Collections.emptyList(), Collections.emptyList());
    Response response = explainabilityEndpoint.requestCounterfactuals(EXECUTION_ID, request);
    assertNotNull(response);
    assertEquals(Response.Status.OK.getStatusCode(), response.getStatus());
    Object entity = response.getEntity();
    assertNotNull(entity);
    assertTrue(entity instanceof CounterfactualRequestResponse);
    CounterfactualRequestResponse counterfactualRequestResponse = (CounterfactualRequestResponse) entity;
    assertEquals(EXECUTION_ID, counterfactualRequestResponse.getExecutionId());
    assertEquals(COUNTERFACTUAL_ID, counterfactualRequestResponse.getCounterfactualId());
    assertEquals(MAX_RUNNING_TIME_SECONDS, counterfactualRequestResponse.getMaxRunningTimeSeconds());
}
Also used : CounterfactualExplainabilityRequest(org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest) CounterfactualRequestResponse(org.kie.kogito.trusty.service.common.responses.CounterfactualRequestResponse) Response(javax.ws.rs.core.Response) CounterfactualResultsResponse(org.kie.kogito.trusty.service.common.responses.CounterfactualResultsResponse) ModelIdentifier(org.kie.kogito.explainability.api.ModelIdentifier) CounterfactualRequestResponse(org.kie.kogito.trusty.service.common.responses.CounterfactualRequestResponse) Test(org.junit.jupiter.api.Test)

Example 2 with CounterfactualRequestResponse

use of org.kie.kogito.trusty.service.common.responses.CounterfactualRequestResponse in project kogito-apps by kiegroup.

the class AbstractTrustyExplainabilityEnd2EndIT method doCounterfactualRequests.

private Callable<Boolean> doCounterfactualRequests(final InfinispanTrustyServiceContainer trustyService, final String accessToken, final String executionId) {
    LOGGER.info(String.format("Reading Decision [%s]'s Inputs...", executionId));
    DecisionStructuredInputsResponse inputs = given().port(trustyService.getFirstMappedPort()).auth().oauth2(accessToken).when().get("/executions/decisions/" + executionId + "/structuredInputs").then().statusCode(200).extract().as(DecisionStructuredInputsResponse.class);
    LOGGER.info(String.format("Reading Decision [%s]'s Outputs...", executionId));
    DecisionOutcomesResponse outcomes = given().port(trustyService.getFirstMappedPort()).auth().oauth2(accessToken).when().get("/executions/decisions/" + executionId + "/outcomes").then().statusCode(200).extract().as(DecisionOutcomesResponse.class);
    // Debugging output (handy to keep).
    ObjectMapper mapper = new ObjectMapper();
    ObjectWriter writer = mapper.writerWithDefaultPrettyPrinter();
    StringBuilder sb = new StringBuilder();
    sb.append("== INPUTS ==>\n");
    inputs.getInputs().forEach(i -> {
        try {
            sb.append(writer.writeValueAsString(i)).append("\n");
        } catch (JsonProcessingException jpe) {
        // Swallow
        }
    });
    sb.append("== OUTPUTS ==>\n");
    outcomes.getOutcomes().forEach(o -> {
        try {
            sb.append(writer.writeValueAsString(o.getOutcomeResult())).append("\n");
        } catch (JsonProcessingException jpe) {
        // Swallow
        }
    });
    LOGGER.debug(sb.toString());
    return () -> {
        LOGGER.info(String.format("Checking Decision [%s]'s Counterfactual request was successful...", executionId));
        // The Goals and Search Domain structures must match those of the original decision
        // See https://issues.redhat.com/browse/FAI-486
        CounterfactualRequestResponse counterfactualRequestResponse = given().port(trustyService.getFirstMappedPort()).auth().oauth2(accessToken).when().contentType(ContentType.JSON).body(new CounterfactualRequest(outcomes.getOutcomes().stream().map(AbstractTrustyExplainabilityEnd2EndIT::toCounterfactualGoal).collect(Collectors.toList()), inputs.getInputs().stream().map(AbstractTrustyExplainabilityEnd2EndIT::toCounterfactualSearchDomain).collect(Collectors.toList()))).post("/executions/decisions/" + executionId + "/explanations/counterfactuals").then().statusCode(200).extract().as(CounterfactualRequestResponse.class);
        return Objects.nonNull(counterfactualRequestResponse) && Objects.equals(executionId, counterfactualRequestResponse.getExecutionId()) && Objects.nonNull(counterfactualRequestResponse.getCounterfactualId());
    };
}
Also used : DecisionStructuredInputsResponse(org.kie.kogito.trusty.service.common.responses.decision.DecisionStructuredInputsResponse) DecisionOutcomesResponse(org.kie.kogito.trusty.service.common.responses.decision.DecisionOutcomesResponse) ObjectWriter(com.fasterxml.jackson.databind.ObjectWriter) JsonProcessingException(com.fasterxml.jackson.core.JsonProcessingException) CounterfactualRequestResponse(org.kie.kogito.trusty.service.common.responses.CounterfactualRequestResponse) ObjectMapper(com.fasterxml.jackson.databind.ObjectMapper) CounterfactualRequest(org.kie.kogito.trusty.service.common.requests.CounterfactualRequest)

Example 3 with CounterfactualRequestResponse

use of org.kie.kogito.trusty.service.common.responses.CounterfactualRequestResponse in project kogito-apps by kiegroup.

the class ExplainabilityApiV1IT method testCounterfactualRequestWithStructuredModel.

@Test
@SuppressWarnings("unchecked")
void testCounterfactualRequestWithStructuredModel() {
    ArgumentCaptor<List<NamedTypedValue>> goalsCaptor = ArgumentCaptor.forClass(List.class);
    ArgumentCaptor<List<CounterfactualSearchDomain>> searchDomainsCaptor = ArgumentCaptor.forClass(List.class);
    mockServiceWithCounterfactualRequest();
    CounterfactualRequestResponse response = given().filter(new RequestLoggingFilter()).filter(new ResponseLoggingFilter()).contentType(MediaType.APPLICATION_JSON).body(getCounterfactualWithStructuredModelJsonRequest()).when().post("/executions/decisions/" + TEST_EXECUTION_ID + "/explanations/counterfactuals").as(CounterfactualRequestResponse.class);
    assertNotNull(response);
    assertNotNull(response.getExecutionId());
    assertNotNull(response.getCounterfactualId());
    assertEquals(response.getExecutionId(), TEST_EXECUTION_ID);
    assertEquals(response.getCounterfactualId(), TEST_COUNTERFACTUAL_ID);
    verify(executionService).requestCounterfactuals(eq(TEST_EXECUTION_ID), goalsCaptor.capture(), searchDomainsCaptor.capture());
    List<NamedTypedValue> goalsParameter = goalsCaptor.getValue();
    assertNotNull(goalsParameter);
    assertEquals(1, goalsParameter.size());
    NamedTypedValue goal1 = goalsParameter.get(0);
    assertEquals(TypedValue.Kind.STRUCTURE, goal1.getValue().getKind());
    assertEquals("Fine", goal1.getName());
    assertEquals("tFine", goal1.getValue().getType());
    assertEquals(2, goal1.getValue().toStructure().getValue().size());
    Iterator<Map.Entry<String, TypedValue>> goal1ChildIterator = goal1.getValue().toStructure().getValue().entrySet().iterator();
    Map.Entry<String, TypedValue> goal1Child1 = goal1ChildIterator.next();
    Map.Entry<String, TypedValue> goal1Child2 = goal1ChildIterator.next();
    assertEquals(TypedValue.Kind.UNIT, goal1Child1.getValue().getKind());
    assertEquals("Amount", goal1Child1.getKey());
    assertEquals("number", goal1Child1.getValue().getType());
    assertEquals(100, goal1Child1.getValue().toUnit().getValue().asInt());
    assertEquals(TypedValue.Kind.UNIT, goal1Child2.getValue().getKind());
    assertEquals("Points", goal1Child2.getKey());
    assertEquals("number", goal1Child2.getValue().getType());
    assertEquals(0, goal1Child2.getValue().toUnit().getValue().asInt());
    List<CounterfactualSearchDomain> searchDomainsParameter = searchDomainsCaptor.getValue();
    assertNotNull(searchDomainsParameter);
    assertEquals(1, searchDomainsParameter.size());
    CounterfactualSearchDomain domain1 = searchDomainsParameter.get(0);
    assertEquals(TypedValue.Kind.STRUCTURE, domain1.getValue().getKind());
    assertEquals("Violation", domain1.getName());
    assertEquals("tViolation", domain1.getValue().getType());
    assertEquals(3, domain1.getValue().toStructure().getValue().size());
    Iterator<Map.Entry<String, CounterfactualSearchDomainValue>> domain1ChildIterator = domain1.getValue().toStructure().getValue().entrySet().iterator();
    Map.Entry<String, CounterfactualSearchDomainValue> domain1Child1 = domain1ChildIterator.next();
    Map.Entry<String, CounterfactualSearchDomainValue> domain1Child2 = domain1ChildIterator.next();
    Map.Entry<String, CounterfactualSearchDomainValue> domain1Child3 = domain1ChildIterator.next();
    assertEquals(TypedValue.Kind.UNIT, domain1Child1.getValue().getKind());
    assertFalse(domain1Child1.getValue().toUnit().isFixed());
    assertEquals("Type", domain1Child1.getKey());
    assertEquals("string", domain1Child1.getValue().getType());
    assertNotNull(domain1Child1.getValue().toUnit().getDomain());
    assertTrue(domain1Child1.getValue().toUnit().getDomain() instanceof CounterfactualDomainCategorical);
    CounterfactualDomainCategorical domain1Child1Def = (CounterfactualDomainCategorical) domain1Child1.getValue().toUnit().getDomain();
    assertEquals(2, domain1Child1Def.getCategories().size());
    assertTrue(domain1Child1Def.getCategories().stream().map(JsonNode::asText).collect(Collectors.toList()).containsAll(Arrays.asList("speed", "driving under the influence")));
    assertEquals(TypedValue.Kind.UNIT, domain1Child2.getValue().getKind());
    assertFalse(domain1Child2.getValue().toUnit().isFixed());
    assertEquals("Actual Speed", domain1Child2.getKey());
    assertEquals("number", domain1Child2.getValue().getType());
    assertNotNull(domain1Child2.getValue().toUnit().getDomain());
    assertTrue(domain1Child2.getValue().toUnit().getDomain() instanceof CounterfactualDomainRange);
    CounterfactualDomainRange domain1Child2Def = (CounterfactualDomainRange) domain1Child2.getValue().toUnit().getDomain();
    assertEquals(0, domain1Child2Def.getLowerBound().asInt());
    assertEquals(100, domain1Child2Def.getUpperBound().asInt());
    assertEquals(TypedValue.Kind.UNIT, domain1Child3.getValue().getKind());
    assertFalse(domain1Child3.getValue().toUnit().isFixed());
    assertEquals("Speed Limit", domain1Child3.getKey());
    assertEquals("number", domain1Child3.getValue().getType());
    assertNotNull(domain1Child3.getValue().toUnit().getDomain());
    assertTrue(domain1Child3.getValue().toUnit().getDomain() instanceof CounterfactualDomainRange);
    CounterfactualDomainRange domain1Child3Def = (CounterfactualDomainRange) domain1Child3.getValue().toUnit().getDomain();
    assertEquals(0, domain1Child3Def.getLowerBound().asInt());
    assertEquals(100, domain1Child3Def.getUpperBound().asInt());
}
Also used : CounterfactualDomainRange(org.kie.kogito.explainability.api.CounterfactualDomainRange) CounterfactualDomainCategorical(org.kie.kogito.explainability.api.CounterfactualDomainCategorical) ArgumentMatchers.anyString(org.mockito.ArgumentMatchers.anyString) ResponseLoggingFilter(io.restassured.filter.log.ResponseLoggingFilter) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) List(java.util.List) ArrayList(java.util.ArrayList) RequestLoggingFilter(io.restassured.filter.log.RequestLoggingFilter) CounterfactualRequestResponse(org.kie.kogito.trusty.service.common.responses.CounterfactualRequestResponse) Map(java.util.Map) CounterfactualSearchDomainValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainValue) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) TypedValue(org.kie.kogito.tracing.typedvalue.TypedValue) CounterfactualSearchDomain(org.kie.kogito.explainability.api.CounterfactualSearchDomain) Test(org.junit.jupiter.api.Test) QuarkusTest(io.quarkus.test.junit.QuarkusTest)

Example 4 with CounterfactualRequestResponse

use of org.kie.kogito.trusty.service.common.responses.CounterfactualRequestResponse in project kogito-apps by kiegroup.

the class ExplainabilityApiV1IT method testCounterfactualRequest.

@Test
@SuppressWarnings("unchecked")
void testCounterfactualRequest() {
    ArgumentCaptor<List<NamedTypedValue>> goalsCaptor = ArgumentCaptor.forClass(List.class);
    ArgumentCaptor<List<CounterfactualSearchDomain>> searchDomainsCaptor = ArgumentCaptor.forClass(List.class);
    mockServiceWithCounterfactualRequest();
    CounterfactualRequestResponse response = given().filter(new RequestLoggingFilter()).filter(new ResponseLoggingFilter()).contentType(MediaType.APPLICATION_JSON).body(getCounterfactualJsonRequest()).when().post("/executions/decisions/" + TEST_EXECUTION_ID + "/explanations/counterfactuals").as(CounterfactualRequestResponse.class);
    assertNotNull(response);
    assertNotNull(response.getExecutionId());
    assertNotNull(response.getCounterfactualId());
    assertEquals(response.getExecutionId(), TEST_EXECUTION_ID);
    assertEquals(response.getCounterfactualId(), TEST_COUNTERFACTUAL_ID);
    verify(executionService).requestCounterfactuals(eq(TEST_EXECUTION_ID), goalsCaptor.capture(), searchDomainsCaptor.capture());
    List<NamedTypedValue> goalsParameter = goalsCaptor.getValue();
    assertNotNull(goalsParameter);
    assertEquals(2, goalsParameter.size());
    NamedTypedValue goal1 = goalsParameter.get(0);
    assertEquals(TypedValue.Kind.UNIT, goal1.getValue().getKind());
    assertEquals("deposit", goal1.getName());
    assertEquals("number", goal1.getValue().getType());
    assertEquals(5000, goal1.getValue().toUnit().getValue().asInt());
    NamedTypedValue goal2 = goalsParameter.get(1);
    assertEquals(TypedValue.Kind.UNIT, goal2.getValue().getKind());
    assertEquals("approved", goal2.getName());
    assertEquals("boolean", goal2.getValue().getType());
    assertEquals(Boolean.TRUE, goal2.getValue().toUnit().getValue().asBoolean());
    List<CounterfactualSearchDomain> searchDomainsParameter = searchDomainsCaptor.getValue();
    assertNotNull(searchDomainsParameter);
    assertEquals(3, searchDomainsParameter.size());
    CounterfactualSearchDomain domain1 = searchDomainsParameter.get(0);
    assertEquals(TypedValue.Kind.UNIT, domain1.getValue().getKind());
    assertTrue(domain1.getValue().toUnit().isFixed());
    assertEquals("age", domain1.getName());
    assertEquals("number", domain1.getValue().getType());
    assertNull(domain1.getValue().toUnit().getDomain());
    CounterfactualSearchDomain domain2 = searchDomainsParameter.get(1);
    assertEquals(TypedValue.Kind.UNIT, domain2.getValue().getKind());
    assertFalse(domain2.getValue().toUnit().isFixed());
    assertEquals("income", domain2.getName());
    assertEquals("number", domain2.getValue().getType());
    assertNotNull(domain2.getValue().toUnit().getDomain());
    assertTrue(domain2.getValue().toUnit().getDomain() instanceof CounterfactualDomainRange);
    CounterfactualDomainRange domain2Def = (CounterfactualDomainRange) domain2.getValue().toUnit().getDomain();
    assertEquals(0, domain2Def.getLowerBound().asInt());
    assertEquals(1000, domain2Def.getUpperBound().asInt());
    CounterfactualSearchDomain domain3 = searchDomainsParameter.get(2);
    assertEquals(TypedValue.Kind.UNIT, domain3.getValue().getKind());
    assertFalse(domain3.getValue().toUnit().isFixed());
    assertEquals("taxCode", domain3.getName());
    assertEquals("string", domain3.getValue().getType());
    assertNotNull(domain3.getValue().toUnit().getDomain());
    assertTrue(domain3.getValue().toUnit().getDomain() instanceof CounterfactualDomainCategorical);
    CounterfactualDomainCategorical domain3Def = (CounterfactualDomainCategorical) domain3.getValue().toUnit().getDomain();
    assertEquals(3, domain3Def.getCategories().size());
    assertTrue(domain3Def.getCategories().stream().map(JsonNode::asText).collect(Collectors.toList()).containsAll(Arrays.asList("A", "B", "C")));
}
Also used : NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) CounterfactualDomainRange(org.kie.kogito.explainability.api.CounterfactualDomainRange) List(java.util.List) ArrayList(java.util.ArrayList) RequestLoggingFilter(io.restassured.filter.log.RequestLoggingFilter) CounterfactualDomainCategorical(org.kie.kogito.explainability.api.CounterfactualDomainCategorical) ResponseLoggingFilter(io.restassured.filter.log.ResponseLoggingFilter) CounterfactualRequestResponse(org.kie.kogito.trusty.service.common.responses.CounterfactualRequestResponse) CounterfactualSearchDomain(org.kie.kogito.explainability.api.CounterfactualSearchDomain) Test(org.junit.jupiter.api.Test) QuarkusTest(io.quarkus.test.junit.QuarkusTest)

Example 5 with CounterfactualRequestResponse

use of org.kie.kogito.trusty.service.common.responses.CounterfactualRequestResponse in project kogito-apps by kiegroup.

the class ExplainabilityApiV1Test method testGetAllCounterfactualsWhenExecutionDoesExist.

@Test
@SuppressWarnings({ "rawtypes", "unchecked" })
public void testGetAllCounterfactualsWhenExecutionDoesExist() {
    when(trustyService.getCounterfactualRequests(anyString())).thenReturn(List.of(new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, new ModelIdentifier("resourceType", "resourceIdentifier"), COUNTERFACTUAL_ID, Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), MAX_RUNNING_TIME_SECONDS)));
    Response response = explainabilityEndpoint.getAllCounterfactualsSummary(EXECUTION_ID);
    assertNotNull(response);
    assertEquals(Response.Status.OK.getStatusCode(), response.getStatus());
    Object entity = response.getEntity();
    assertNotNull(entity);
    assertTrue(entity instanceof List);
    List<CounterfactualRequestResponse> counterfactualRequestResponse = (List) entity;
    assertEquals(1, counterfactualRequestResponse.size());
    CounterfactualRequestResponse counterfactual = counterfactualRequestResponse.get(0);
    assertEquals(EXECUTION_ID, counterfactual.getExecutionId());
    assertEquals(COUNTERFACTUAL_ID, counterfactual.getCounterfactualId());
    assertEquals(MAX_RUNNING_TIME_SECONDS, counterfactual.getMaxRunningTimeSeconds());
}
Also used : CounterfactualExplainabilityRequest(org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest) CounterfactualRequestResponse(org.kie.kogito.trusty.service.common.responses.CounterfactualRequestResponse) Response(javax.ws.rs.core.Response) CounterfactualResultsResponse(org.kie.kogito.trusty.service.common.responses.CounterfactualResultsResponse) ModelIdentifier(org.kie.kogito.explainability.api.ModelIdentifier) List(java.util.List) CounterfactualRequestResponse(org.kie.kogito.trusty.service.common.responses.CounterfactualRequestResponse) Test(org.junit.jupiter.api.Test)

Aggregations

CounterfactualRequestResponse (org.kie.kogito.trusty.service.common.responses.CounterfactualRequestResponse)5 Test (org.junit.jupiter.api.Test)4 List (java.util.List)3 QuarkusTest (io.quarkus.test.junit.QuarkusTest)2 RequestLoggingFilter (io.restassured.filter.log.RequestLoggingFilter)2 ResponseLoggingFilter (io.restassured.filter.log.ResponseLoggingFilter)2 ArrayList (java.util.ArrayList)2 Response (javax.ws.rs.core.Response)2 CounterfactualDomainCategorical (org.kie.kogito.explainability.api.CounterfactualDomainCategorical)2 CounterfactualDomainRange (org.kie.kogito.explainability.api.CounterfactualDomainRange)2 CounterfactualExplainabilityRequest (org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest)2 CounterfactualSearchDomain (org.kie.kogito.explainability.api.CounterfactualSearchDomain)2 ModelIdentifier (org.kie.kogito.explainability.api.ModelIdentifier)2 NamedTypedValue (org.kie.kogito.explainability.api.NamedTypedValue)2 CounterfactualResultsResponse (org.kie.kogito.trusty.service.common.responses.CounterfactualResultsResponse)2 JsonProcessingException (com.fasterxml.jackson.core.JsonProcessingException)1 ObjectMapper (com.fasterxml.jackson.databind.ObjectMapper)1 ObjectWriter (com.fasterxml.jackson.databind.ObjectWriter)1 Map (java.util.Map)1 CounterfactualSearchDomainValue (org.kie.kogito.explainability.api.CounterfactualSearchDomainValue)1