Search in sources :

Example 1 with LIMEExplainabilityRequest

use of org.kie.kogito.explainability.api.LIMEExplainabilityRequest in project kogito-apps by kiegroup.

the class LimeExplainerServiceHandlerTest method testCreateSucceededResult.

@Test
public void testCreateSucceededResult() {
    LIMEExplainabilityRequest request = new LIMEExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, Collections.emptyList(), Collections.emptyList());
    Map<String, Saliency> saliencies = Map.of("s1", new Saliency(new Output("salary", Type.NUMBER), List.of(new FeatureImportance(new Feature("age", Type.NUMBER, new Value(25.0)), 5.0), new FeatureImportance(new Feature("dependents", Type.NUMBER, new Value(2)), -11.0))));
    BaseExplainabilityResult base = handler.createSucceededResult(request, saliencies);
    assertTrue(base instanceof LIMEExplainabilityResult);
    LIMEExplainabilityResult result = (LIMEExplainabilityResult) base;
    assertEquals(ExplainabilityStatus.SUCCEEDED, result.getStatus());
    assertEquals(EXECUTION_ID, result.getExecutionId());
    assertEquals(1, result.getSaliencies().size());
    SaliencyModel saliencyModel = result.getSaliencies().iterator().next();
    assertEquals(2, saliencyModel.getFeatureImportance().size());
    assertEquals("age", saliencyModel.getFeatureImportance().get(0).getFeatureName());
    assertEquals(5.0, saliencyModel.getFeatureImportance().get(0).getFeatureScore());
    assertEquals("dependents", saliencyModel.getFeatureImportance().get(1).getFeatureName());
    assertEquals(-11.0, saliencyModel.getFeatureImportance().get(1).getFeatureScore());
}
Also used : LIMEExplainabilityRequest(org.kie.kogito.explainability.api.LIMEExplainabilityRequest) LIMEExplainabilityResult(org.kie.kogito.explainability.api.LIMEExplainabilityResult) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) BaseExplainabilityResult(org.kie.kogito.explainability.api.BaseExplainabilityResult) SaliencyModel(org.kie.kogito.explainability.api.SaliencyModel) Output(org.kie.kogito.explainability.model.Output) Value(org.kie.kogito.explainability.model.Value) StructureValue(org.kie.kogito.tracing.typedvalue.StructureValue) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) UnitValue(org.kie.kogito.tracing.typedvalue.UnitValue) CollectionValue(org.kie.kogito.tracing.typedvalue.CollectionValue) Saliency(org.kie.kogito.explainability.model.Saliency) Feature(org.kie.kogito.explainability.model.Feature) Test(org.junit.jupiter.api.Test)

Example 2 with LIMEExplainabilityRequest

use of org.kie.kogito.explainability.api.LIMEExplainabilityRequest in project kogito-apps by kiegroup.

the class LimeExplainerServiceHandlerTest method testCreateIntermediateResult.

@Test
public void testCreateIntermediateResult() {
    LIMEExplainabilityRequest request = new LIMEExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, Collections.emptyList(), Collections.emptyList());
    assertThrows(UnsupportedOperationException.class, () -> handler.createIntermediateResult(request, null));
}
Also used : LIMEExplainabilityRequest(org.kie.kogito.explainability.api.LIMEExplainabilityRequest) Test(org.junit.jupiter.api.Test)

Example 3 with LIMEExplainabilityRequest

use of org.kie.kogito.explainability.api.LIMEExplainabilityRequest in project kogito-apps by kiegroup.

the class LimeExplainerServiceHandlerTest method testCreateFailedResult.

@Test
public void testCreateFailedResult() {
    LIMEExplainabilityRequest request = new LIMEExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, Collections.emptyList(), Collections.emptyList());
    BaseExplainabilityResult base = handler.createFailedResult(request, new NullPointerException("Something went wrong"));
    assertTrue(base instanceof LIMEExplainabilityResult);
    LIMEExplainabilityResult result = (LIMEExplainabilityResult) base;
    assertEquals(ExplainabilityStatus.FAILED, result.getStatus());
    assertEquals("Something went wrong", result.getStatusDetails());
    assertEquals(EXECUTION_ID, result.getExecutionId());
}
Also used : LIMEExplainabilityRequest(org.kie.kogito.explainability.api.LIMEExplainabilityRequest) LIMEExplainabilityResult(org.kie.kogito.explainability.api.LIMEExplainabilityResult) BaseExplainabilityResult(org.kie.kogito.explainability.api.BaseExplainabilityResult) Test(org.junit.jupiter.api.Test)

Example 4 with LIMEExplainabilityRequest

use of org.kie.kogito.explainability.api.LIMEExplainabilityRequest in project kogito-apps by kiegroup.

the class ExplainabilityApiV1IT method testEndpointWithBadRequests.

@Test
void testEndpointWithBadRequests() throws JsonProcessingException {
    LIMEExplainabilityRequest[] badRequests = new LIMEExplainabilityRequest[] { new LIMEExplainabilityRequest(executionId, serviceUrl, new ModelIdentifier("", "test"), Collections.emptyList(), Collections.emptyList()), new LIMEExplainabilityRequest(executionId, serviceUrl, new ModelIdentifier("test", ""), Collections.emptyList(), Collections.emptyList()), new LIMEExplainabilityRequest(executionId, "", new ModelIdentifier("test", "test"), Collections.emptyList(), Collections.emptyList()) };
    for (int i = 0; i < badRequests.length; i++) {
        String body = MAPPER.writeValueAsString(badRequests[i]);
        given().contentType(ContentType.JSON).body(body).when().post("/v1/explain").then().statusCode(400);
    }
}
Also used : LIMEExplainabilityRequest(org.kie.kogito.explainability.api.LIMEExplainabilityRequest) ModelIdentifier(org.kie.kogito.explainability.api.ModelIdentifier) Test(org.junit.jupiter.api.Test) QuarkusTest(io.quarkus.test.junit.QuarkusTest)

Example 5 with LIMEExplainabilityRequest

use of org.kie.kogito.explainability.api.LIMEExplainabilityRequest in project kogito-apps by kiegroup.

the class LimeExplainerServiceHandlerTest method testGetPredictionWithNonEmptyDefinition.

@Test
@SuppressWarnings("unchecked")
public void testGetPredictionWithNonEmptyDefinition() {
    LIMEExplainabilityRequest request = new LIMEExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, List.of(new NamedTypedValue("input1", new UnitValue("number", "number", new IntNode(20))), new NamedTypedValue("input2", new StructureValue("number", Map.of("input2b", new UnitValue("number", new IntNode(55))))), new NamedTypedValue("input3", new CollectionValue("number", List.of(new UnitValue("number", new IntNode(100)))))), List.of(new NamedTypedValue("output1", new UnitValue("number", "number", new IntNode(20))), new NamedTypedValue("output2", new StructureValue("number", Map.of("output2b", new UnitValue("number", new IntNode(55))))), new NamedTypedValue("output3", new CollectionValue("number", List.of(new UnitValue("number", new IntNode(100)))))));
    Prediction prediction = handler.getPrediction(request);
    // Inputs
    assertEquals(3, prediction.getInput().getFeatures().size());
    Optional<Feature> oInput1 = prediction.getInput().getFeatures().stream().filter(f -> f.getName().equals("input1")).findFirst();
    assertTrue(oInput1.isPresent());
    Feature input1 = oInput1.get();
    assertEquals(Type.NUMBER, input1.getType());
    assertEquals(20, input1.getValue().asNumber());
    Optional<Feature> oInput2 = prediction.getInput().getFeatures().stream().filter(f -> f.getName().equals("input2")).findFirst();
    assertTrue(oInput2.isPresent());
    Feature input2 = oInput2.get();
    assertEquals(Type.COMPOSITE, input2.getType());
    assertTrue(input2.getValue().getUnderlyingObject() instanceof List);
    List<Feature> input2Object = (List<Feature>) input2.getValue().getUnderlyingObject();
    assertEquals(1, input2Object.size());
    Optional<Feature> oInput2Child = input2Object.stream().filter(f -> f.getName().equals("input2b")).findFirst();
    assertTrue(oInput2Child.isPresent());
    Feature input2Child = oInput2Child.get();
    assertEquals(Type.NUMBER, input2Child.getType());
    assertEquals(55, input2Child.getValue().asNumber());
    Optional<Feature> oInput3 = prediction.getInput().getFeatures().stream().filter(f -> f.getName().equals("input3")).findFirst();
    assertTrue(oInput3.isPresent());
    Feature input3 = oInput3.get();
    assertEquals(Type.COMPOSITE, input3.getType());
    assertTrue(input3.getValue().getUnderlyingObject() instanceof List);
    List<Feature> input3Object = (List<Feature>) input3.getValue().getUnderlyingObject();
    assertEquals(1, input3Object.size());
    Feature input3Child = input3Object.get(0);
    assertEquals(Type.NUMBER, input3Child.getType());
    assertEquals(100, input3Child.getValue().asNumber());
    // Outputs
    assertEquals(3, prediction.getOutput().getOutputs().size());
    Optional<Output> oOutput1 = prediction.getOutput().getOutputs().stream().filter(o -> o.getName().equals("output1")).findFirst();
    assertTrue(oOutput1.isPresent());
    Output output1 = oOutput1.get();
    assertEquals(Type.NUMBER, output1.getType());
    assertEquals(20, output1.getValue().asNumber());
    Optional<Output> oOutput2 = prediction.getOutput().getOutputs().stream().filter(o -> o.getName().equals("output2")).findFirst();
    assertTrue(oOutput2.isPresent());
    Output output2 = oOutput2.get();
    assertEquals(Type.COMPOSITE, input2.getType());
    assertTrue(output2.getValue().getUnderlyingObject() instanceof List);
    List<Output> output2Object = (List<Output>) output2.getValue().getUnderlyingObject();
    assertEquals(1, output2Object.size());
    Optional<Output> oOutput2Child = output2Object.stream().filter(f -> f.getName().equals("output2b")).findFirst();
    assertTrue(oOutput2Child.isPresent());
    Output output2Child = oOutput2Child.get();
    assertEquals(Type.NUMBER, output2Child.getType());
    assertEquals(55, output2Child.getValue().asNumber());
    Optional<Output> oOutput3 = prediction.getOutput().getOutputs().stream().filter(o -> o.getName().equals("output3")).findFirst();
    assertTrue(oOutput3.isPresent());
    Output output3 = oOutput3.get();
    assertEquals(Type.COMPOSITE, output3.getType());
    assertTrue(output3.getValue().getUnderlyingObject() instanceof List);
    List<Output> output3Object = (List<Output>) output3.getValue().getUnderlyingObject();
    assertEquals(1, output3Object.size());
    Output output3Child = output3Object.get(0);
    assertEquals(Type.NUMBER, output3Child.getType());
    assertEquals(100, output3Child.getValue().asNumber());
}
Also used : Assertions.assertThrows(org.junit.jupiter.api.Assertions.assertThrows) ArgumentMatchers.any(org.mockito.ArgumentMatchers.any) BeforeEach(org.junit.jupiter.api.BeforeEach) LIMEExplainabilityResult(org.kie.kogito.explainability.api.LIMEExplainabilityResult) IntNode(com.fasterxml.jackson.databind.node.IntNode) BaseExplainabilityRequest(org.kie.kogito.explainability.api.BaseExplainabilityRequest) Feature(org.kie.kogito.explainability.model.Feature) Prediction(org.kie.kogito.explainability.model.Prediction) ArgumentMatchers.eq(org.mockito.ArgumentMatchers.eq) Value(org.kie.kogito.explainability.model.Value) Saliency(org.kie.kogito.explainability.model.Saliency) StructureValue(org.kie.kogito.tracing.typedvalue.StructureValue) Assertions.assertFalse(org.junit.jupiter.api.Assertions.assertFalse) Map(java.util.Map) PredictionProviderFactory(org.kie.kogito.explainability.PredictionProviderFactory) Assertions.assertEquals(org.junit.jupiter.api.Assertions.assertEquals) BaseExplainabilityResult(org.kie.kogito.explainability.api.BaseExplainabilityResult) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) LimeExplainer(org.kie.kogito.explainability.local.lime.LimeExplainer) ExplainabilityStatus(org.kie.kogito.explainability.api.ExplainabilityStatus) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) Type(org.kie.kogito.explainability.model.Type) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) UnitValue(org.kie.kogito.tracing.typedvalue.UnitValue) Mockito.verify(org.mockito.Mockito.verify) Consumer(java.util.function.Consumer) Test(org.junit.jupiter.api.Test) List(java.util.List) LIMEExplainabilityRequest(org.kie.kogito.explainability.api.LIMEExplainabilityRequest) CollectionValue(org.kie.kogito.tracing.typedvalue.CollectionValue) Output(org.kie.kogito.explainability.model.Output) Assertions.assertTrue(org.junit.jupiter.api.Assertions.assertTrue) Optional(java.util.Optional) SaliencyModel(org.kie.kogito.explainability.api.SaliencyModel) Collections(java.util.Collections) ModelIdentifier(org.kie.kogito.explainability.api.ModelIdentifier) Mockito.mock(org.mockito.Mockito.mock) LIMEExplainabilityRequest(org.kie.kogito.explainability.api.LIMEExplainabilityRequest) Prediction(org.kie.kogito.explainability.model.Prediction) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) UnitValue(org.kie.kogito.tracing.typedvalue.UnitValue) Feature(org.kie.kogito.explainability.model.Feature) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) IntNode(com.fasterxml.jackson.databind.node.IntNode) CollectionValue(org.kie.kogito.tracing.typedvalue.CollectionValue) Output(org.kie.kogito.explainability.model.Output) StructureValue(org.kie.kogito.tracing.typedvalue.StructureValue) List(java.util.List) Test(org.junit.jupiter.api.Test)

Aggregations

LIMEExplainabilityRequest (org.kie.kogito.explainability.api.LIMEExplainabilityRequest)10 Test (org.junit.jupiter.api.Test)9 BaseExplainabilityResult (org.kie.kogito.explainability.api.BaseExplainabilityResult)4 ModelIdentifier (org.kie.kogito.explainability.api.ModelIdentifier)4 LIMEExplainabilityResult (org.kie.kogito.explainability.api.LIMEExplainabilityResult)3 NamedTypedValue (org.kie.kogito.explainability.api.NamedTypedValue)3 QuarkusTest (io.quarkus.test.junit.QuarkusTest)2 SaliencyModel (org.kie.kogito.explainability.api.SaliencyModel)2 Feature (org.kie.kogito.explainability.model.Feature)2 FeatureImportance (org.kie.kogito.explainability.model.FeatureImportance)2 Output (org.kie.kogito.explainability.model.Output)2 Prediction (org.kie.kogito.explainability.model.Prediction)2 Saliency (org.kie.kogito.explainability.model.Saliency)2 SimplePrediction (org.kie.kogito.explainability.model.SimplePrediction)2 Value (org.kie.kogito.explainability.model.Value)2 CollectionValue (org.kie.kogito.tracing.typedvalue.CollectionValue)2 StructureValue (org.kie.kogito.tracing.typedvalue.StructureValue)2 UnitValue (org.kie.kogito.tracing.typedvalue.UnitValue)2 IntNode (com.fasterxml.jackson.databind.node.IntNode)1 Collections (java.util.Collections)1