use of org.kie.kogito.explainability.api.ModelIdentifier in project kogito-apps by kiegroup.
the class ExplainabilityApiV1Test method testRequestCounterfactualsWhenExecutionDoesExist.
@Test
public void testRequestCounterfactualsWhenExecutionDoesExist() {
when(trustyService.requestCounterfactuals(anyString(), any(), any())).thenReturn(new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, new ModelIdentifier("resourceType", "resourceIdentifier"), COUNTERFACTUAL_ID, Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), MAX_RUNNING_TIME_SECONDS));
org.kie.kogito.trusty.service.common.requests.CounterfactualRequest request = new org.kie.kogito.trusty.service.common.requests.CounterfactualRequest(Collections.emptyList(), Collections.emptyList());
Response response = explainabilityEndpoint.requestCounterfactuals(EXECUTION_ID, request);
assertNotNull(response);
assertEquals(Response.Status.OK.getStatusCode(), response.getStatus());
Object entity = response.getEntity();
assertNotNull(entity);
assertTrue(entity instanceof CounterfactualRequestResponse);
CounterfactualRequestResponse counterfactualRequestResponse = (CounterfactualRequestResponse) entity;
assertEquals(EXECUTION_ID, counterfactualRequestResponse.getExecutionId());
assertEquals(COUNTERFACTUAL_ID, counterfactualRequestResponse.getCounterfactualId());
assertEquals(MAX_RUNNING_TIME_SECONDS, counterfactualRequestResponse.getMaxRunningTimeSeconds());
}
use of org.kie.kogito.explainability.api.ModelIdentifier in project kogito-apps by kiegroup.
the class CounterfactualExplainabilityRequestMarshallerTest method testWriteAndRead.
@Test
public void testWriteAndRead() throws IOException {
ModelIdentifier modelIdentifier = new ModelIdentifier("resourceType", "resourceId");
List<NamedTypedValue> originalInputs = Collections.singletonList(new NamedTypedValue("unitIn", new UnitValue("number", "number", JsonNodeFactory.instance.numberNode(10))));
List<NamedTypedValue> goals = Collections.singletonList(new NamedTypedValue("unitIn", new UnitValue("number", "number", JsonNodeFactory.instance.numberNode(10))));
List<CounterfactualSearchDomain> searchDomains = Collections.singletonList(new CounterfactualSearchDomain("age", new CounterfactualSearchDomainUnitValue("integer", "integer", Boolean.TRUE, new CounterfactualDomainRange(JsonNodeFactory.instance.numberNode(0), JsonNodeFactory.instance.numberNode(10)))));
CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest("executionId", "serviceUrl", modelIdentifier, "counterfactualId", originalInputs, goals, searchDomains, 60L);
CounterfactualExplainabilityRequestMarshaller marshaller = new CounterfactualExplainabilityRequestMarshaller(new ObjectMapper());
marshaller.writeTo(writer, request);
CounterfactualExplainabilityRequest retrieved = marshaller.readFrom(reader);
Assertions.assertEquals(request.getExecutionId(), retrieved.getExecutionId());
Assertions.assertEquals(request.getCounterfactualId(), retrieved.getCounterfactualId());
Assertions.assertEquals(goals.get(0).getName(), retrieved.getGoals().stream().findFirst().get().getName());
Assertions.assertEquals(searchDomains.get(0).getName(), retrieved.getSearchDomains().stream().findFirst().get().getName());
Assertions.assertEquals(0, ((CounterfactualDomainRange) retrieved.getSearchDomains().stream().findFirst().get().getValue().toUnit().getDomain()).getLowerBound().asInt());
Assertions.assertEquals(60L, request.getMaxRunningTimeSeconds());
}
use of org.kie.kogito.explainability.api.ModelIdentifier in project kogito-apps by kiegroup.
the class ExplainabilityApiV1IT method testEndpointWithBadRequests.
@Test
void testEndpointWithBadRequests() throws JsonProcessingException {
LIMEExplainabilityRequest[] badRequests = new LIMEExplainabilityRequest[] { new LIMEExplainabilityRequest(executionId, serviceUrl, new ModelIdentifier("", "test"), Collections.emptyList(), Collections.emptyList()), new LIMEExplainabilityRequest(executionId, serviceUrl, new ModelIdentifier("test", ""), Collections.emptyList(), Collections.emptyList()), new LIMEExplainabilityRequest(executionId, "", new ModelIdentifier("test", "test"), Collections.emptyList(), Collections.emptyList()) };
for (int i = 0; i < badRequests.length; i++) {
String body = MAPPER.writeValueAsString(badRequests[i]);
given().contentType(ContentType.JSON).body(body).when().post("/v1/explain").then().statusCode(400);
}
}
use of org.kie.kogito.explainability.api.ModelIdentifier in project kogito-apps by kiegroup.
the class ExplainabilityApiV1IT method testEndpointWithRequest.
@Test
void testEndpointWithRequest() throws JsonProcessingException {
ModelIdentifier modelIdentifier = new ModelIdentifier("dmn", "namespace:name");
String body = MAPPER.writeValueAsString(new LIMEExplainabilityRequest(executionId, serviceUrl, modelIdentifier, Collections.emptyList(), Collections.emptyList()));
BaseExplainabilityResult result = given().contentType(ContentType.JSON).body(body).when().post("/v1/explain").as(BaseExplainabilityResult.class);
assertEquals(executionId, result.getExecutionId());
}
use of org.kie.kogito.explainability.api.ModelIdentifier in project kogito-apps by kiegroup.
the class ExplainabilityApiV1Test method testGetCounterfactualResultsWhenExecutionDoesExistAndResultsHaveBeenCreated.
@Test
public void testGetCounterfactualResultsWhenExecutionDoesExistAndResultsHaveBeenCreated() {
NamedTypedValue goal = buildGoalUnit("unit", "string", new TextNode("hello"));
CounterfactualSearchDomain searchDomain = buildSearchDomainUnit("unit", "string", new CounterfactualDomainCategorical(List.of(new TextNode("hello"), new TextNode("goodbye"))));
CounterfactualExplainabilityResult solution1 = new CounterfactualExplainabilityResult(EXECUTION_ID, COUNTERFACTUAL_ID, "solution1", 0L, ExplainabilityStatus.SUCCEEDED, "", true, CounterfactualExplainabilityResult.Stage.INTERMEDIATE, Collections.emptyList(), Collections.emptyList());
CounterfactualExplainabilityResult solution2 = new CounterfactualExplainabilityResult(EXECUTION_ID, COUNTERFACTUAL_ID, "solution2", 1L, ExplainabilityStatus.SUCCEEDED, "", true, CounterfactualExplainabilityResult.Stage.FINAL, Collections.emptyList(), Collections.emptyList());
when(trustyService.getCounterfactualRequest(anyString(), anyString())).thenReturn(new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, new ModelIdentifier("resourceType", "resourceIdentifier"), COUNTERFACTUAL_ID, Collections.emptyList(), List.of(goal), List.of(searchDomain), MAX_RUNNING_TIME_SECONDS));
when(trustyService.getCounterfactualResults(anyString(), anyString())).thenReturn(List.of(solution1, solution2));
Response response = explainabilityEndpoint.getCounterfactualDetails(EXECUTION_ID, COUNTERFACTUAL_ID);
assertNotNull(response);
assertEquals(Response.Status.OK.getStatusCode(), response.getStatus());
Object entity = response.getEntity();
assertNotNull(entity);
assertTrue(entity instanceof CounterfactualResultsResponse);
CounterfactualResultsResponse resultsResponse = (CounterfactualResultsResponse) entity;
assertEquals(EXECUTION_ID, resultsResponse.getExecutionId());
assertEquals(COUNTERFACTUAL_ID, resultsResponse.getCounterfactualId());
assertEquals(MAX_RUNNING_TIME_SECONDS, resultsResponse.getMaxRunningTimeSeconds());
assertEquals(1, resultsResponse.getGoals().size());
assertEquals(goal, resultsResponse.getGoals().iterator().next());
assertEquals(1, resultsResponse.getSearchDomains().size());
assertEquals(searchDomain, resultsResponse.getSearchDomains().iterator().next());
assertEquals(2, resultsResponse.getSolutions().size());
assertEquals(solution1, resultsResponse.getSolutions().get(0));
assertEquals(solution2, resultsResponse.getSolutions().get(1));
}
Aggregations