Search in sources :

Example 21 with UnitValue

use of org.kie.kogito.tracing.typedvalue.UnitValue in project kogito-apps by kiegroup.

the class CounterfactualExplainerServiceHandlerTest method testGetPredictionWithFlatSearchDomainsNotFixed.

@Test
public void testGetPredictionWithFlatSearchDomainsNotFixed() {
    CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, List.of(new NamedTypedValue("output1", new UnitValue("number", new IntNode(25)))), Collections.emptyList(), List.of(new CounterfactualSearchDomain("output1", new CounterfactualSearchDomainUnitValue("number", "number", false, new CounterfactualDomainRange(new IntNode(10), new IntNode(20))))), MAX_RUNNING_TIME_SECONDS);
    Prediction prediction = handler.getPrediction(request);
    assertTrue(prediction instanceof CounterfactualPrediction);
    CounterfactualPrediction counterfactualPrediction = (CounterfactualPrediction) prediction;
    assertEquals(1, counterfactualPrediction.getInput().getFeatures().size());
    Feature feature1 = counterfactualPrediction.getInput().getFeatures().get(0);
    assertTrue(feature1.getDomain() instanceof NumericalFeatureDomain);
    final NumericalFeatureDomain domain = (NumericalFeatureDomain) feature1.getDomain();
    assertEquals(10, domain.getLowerBound());
    assertEquals(20, domain.getUpperBound());
    assertEquals(counterfactualPrediction.getMaxRunningTimeSeconds(), request.getMaxRunningTimeSeconds());
}
Also used : CounterfactualExplainabilityRequest(org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) IntNode(com.fasterxml.jackson.databind.node.IntNode) CounterfactualDomainRange(org.kie.kogito.explainability.api.CounterfactualDomainRange) Prediction(org.kie.kogito.explainability.model.Prediction) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) UnitValue(org.kie.kogito.tracing.typedvalue.UnitValue) CounterfactualSearchDomainUnitValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue) NumericalFeatureDomain(org.kie.kogito.explainability.model.domain.NumericalFeatureDomain) CounterfactualSearchDomainUnitValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue) Feature(org.kie.kogito.explainability.model.Feature) CounterfactualSearchDomain(org.kie.kogito.explainability.api.CounterfactualSearchDomain) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) Test(org.junit.jupiter.api.Test)

Example 22 with UnitValue

use of org.kie.kogito.tracing.typedvalue.UnitValue in project kogito-apps by kiegroup.

the class CounterfactualExplainerServiceHandlerTest method testGetPredictionWithFlatSearchDomainsFixed.

@Test
public void testGetPredictionWithFlatSearchDomainsFixed() {
    CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, List.of(new NamedTypedValue("output1", new UnitValue("number", new IntNode(25)))), Collections.emptyList(), List.of(new CounterfactualSearchDomain("output1", new CounterfactualSearchDomainUnitValue("number", "number", true, new CounterfactualDomainRange(new IntNode(10), new IntNode(20))))), MAX_RUNNING_TIME_SECONDS);
    Prediction prediction = handler.getPrediction(request);
    assertTrue(prediction instanceof CounterfactualPrediction);
    CounterfactualPrediction counterfactualPrediction = (CounterfactualPrediction) prediction;
    assertEquals(1, counterfactualPrediction.getInput().getFeatures().size());
    Feature feature1 = counterfactualPrediction.getInput().getFeatures().get(0);
    assertTrue(feature1.getDomain() instanceof EmptyFeatureDomain);
    assertEquals(counterfactualPrediction.getMaxRunningTimeSeconds(), request.getMaxRunningTimeSeconds());
}
Also used : CounterfactualExplainabilityRequest(org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) IntNode(com.fasterxml.jackson.databind.node.IntNode) CounterfactualDomainRange(org.kie.kogito.explainability.api.CounterfactualDomainRange) Prediction(org.kie.kogito.explainability.model.Prediction) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) EmptyFeatureDomain(org.kie.kogito.explainability.model.domain.EmptyFeatureDomain) UnitValue(org.kie.kogito.tracing.typedvalue.UnitValue) CounterfactualSearchDomainUnitValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue) CounterfactualSearchDomainUnitValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue) Feature(org.kie.kogito.explainability.model.Feature) CounterfactualSearchDomain(org.kie.kogito.explainability.api.CounterfactualSearchDomain) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) Test(org.junit.jupiter.api.Test)

Example 23 with UnitValue

use of org.kie.kogito.tracing.typedvalue.UnitValue in project kogito-apps by kiegroup.

the class LimeExplainerServiceHandlerTest method testGetPredictionWithNonEmptyDefinition.

@Test
@SuppressWarnings("unchecked")
public void testGetPredictionWithNonEmptyDefinition() {
    LIMEExplainabilityRequest request = new LIMEExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, List.of(new NamedTypedValue("input1", new UnitValue("number", "number", new IntNode(20))), new NamedTypedValue("input2", new StructureValue("number", Map.of("input2b", new UnitValue("number", new IntNode(55))))), new NamedTypedValue("input3", new CollectionValue("number", List.of(new UnitValue("number", new IntNode(100)))))), List.of(new NamedTypedValue("output1", new UnitValue("number", "number", new IntNode(20))), new NamedTypedValue("output2", new StructureValue("number", Map.of("output2b", new UnitValue("number", new IntNode(55))))), new NamedTypedValue("output3", new CollectionValue("number", List.of(new UnitValue("number", new IntNode(100)))))));
    Prediction prediction = handler.getPrediction(request);
    // Inputs
    assertEquals(3, prediction.getInput().getFeatures().size());
    Optional<Feature> oInput1 = prediction.getInput().getFeatures().stream().filter(f -> f.getName().equals("input1")).findFirst();
    assertTrue(oInput1.isPresent());
    Feature input1 = oInput1.get();
    assertEquals(Type.NUMBER, input1.getType());
    assertEquals(20, input1.getValue().asNumber());
    Optional<Feature> oInput2 = prediction.getInput().getFeatures().stream().filter(f -> f.getName().equals("input2")).findFirst();
    assertTrue(oInput2.isPresent());
    Feature input2 = oInput2.get();
    assertEquals(Type.COMPOSITE, input2.getType());
    assertTrue(input2.getValue().getUnderlyingObject() instanceof List);
    List<Feature> input2Object = (List<Feature>) input2.getValue().getUnderlyingObject();
    assertEquals(1, input2Object.size());
    Optional<Feature> oInput2Child = input2Object.stream().filter(f -> f.getName().equals("input2b")).findFirst();
    assertTrue(oInput2Child.isPresent());
    Feature input2Child = oInput2Child.get();
    assertEquals(Type.NUMBER, input2Child.getType());
    assertEquals(55, input2Child.getValue().asNumber());
    Optional<Feature> oInput3 = prediction.getInput().getFeatures().stream().filter(f -> f.getName().equals("input3")).findFirst();
    assertTrue(oInput3.isPresent());
    Feature input3 = oInput3.get();
    assertEquals(Type.COMPOSITE, input3.getType());
    assertTrue(input3.getValue().getUnderlyingObject() instanceof List);
    List<Feature> input3Object = (List<Feature>) input3.getValue().getUnderlyingObject();
    assertEquals(1, input3Object.size());
    Feature input3Child = input3Object.get(0);
    assertEquals(Type.NUMBER, input3Child.getType());
    assertEquals(100, input3Child.getValue().asNumber());
    // Outputs
    assertEquals(3, prediction.getOutput().getOutputs().size());
    Optional<Output> oOutput1 = prediction.getOutput().getOutputs().stream().filter(o -> o.getName().equals("output1")).findFirst();
    assertTrue(oOutput1.isPresent());
    Output output1 = oOutput1.get();
    assertEquals(Type.NUMBER, output1.getType());
    assertEquals(20, output1.getValue().asNumber());
    Optional<Output> oOutput2 = prediction.getOutput().getOutputs().stream().filter(o -> o.getName().equals("output2")).findFirst();
    assertTrue(oOutput2.isPresent());
    Output output2 = oOutput2.get();
    assertEquals(Type.COMPOSITE, input2.getType());
    assertTrue(output2.getValue().getUnderlyingObject() instanceof List);
    List<Output> output2Object = (List<Output>) output2.getValue().getUnderlyingObject();
    assertEquals(1, output2Object.size());
    Optional<Output> oOutput2Child = output2Object.stream().filter(f -> f.getName().equals("output2b")).findFirst();
    assertTrue(oOutput2Child.isPresent());
    Output output2Child = oOutput2Child.get();
    assertEquals(Type.NUMBER, output2Child.getType());
    assertEquals(55, output2Child.getValue().asNumber());
    Optional<Output> oOutput3 = prediction.getOutput().getOutputs().stream().filter(o -> o.getName().equals("output3")).findFirst();
    assertTrue(oOutput3.isPresent());
    Output output3 = oOutput3.get();
    assertEquals(Type.COMPOSITE, output3.getType());
    assertTrue(output3.getValue().getUnderlyingObject() instanceof List);
    List<Output> output3Object = (List<Output>) output3.getValue().getUnderlyingObject();
    assertEquals(1, output3Object.size());
    Output output3Child = output3Object.get(0);
    assertEquals(Type.NUMBER, output3Child.getType());
    assertEquals(100, output3Child.getValue().asNumber());
}
Also used : Assertions.assertThrows(org.junit.jupiter.api.Assertions.assertThrows) ArgumentMatchers.any(org.mockito.ArgumentMatchers.any) BeforeEach(org.junit.jupiter.api.BeforeEach) LIMEExplainabilityResult(org.kie.kogito.explainability.api.LIMEExplainabilityResult) IntNode(com.fasterxml.jackson.databind.node.IntNode) BaseExplainabilityRequest(org.kie.kogito.explainability.api.BaseExplainabilityRequest) Feature(org.kie.kogito.explainability.model.Feature) Prediction(org.kie.kogito.explainability.model.Prediction) ArgumentMatchers.eq(org.mockito.ArgumentMatchers.eq) Value(org.kie.kogito.explainability.model.Value) Saliency(org.kie.kogito.explainability.model.Saliency) StructureValue(org.kie.kogito.tracing.typedvalue.StructureValue) Assertions.assertFalse(org.junit.jupiter.api.Assertions.assertFalse) Map(java.util.Map) PredictionProviderFactory(org.kie.kogito.explainability.PredictionProviderFactory) Assertions.assertEquals(org.junit.jupiter.api.Assertions.assertEquals) BaseExplainabilityResult(org.kie.kogito.explainability.api.BaseExplainabilityResult) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) LimeExplainer(org.kie.kogito.explainability.local.lime.LimeExplainer) ExplainabilityStatus(org.kie.kogito.explainability.api.ExplainabilityStatus) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) Type(org.kie.kogito.explainability.model.Type) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) UnitValue(org.kie.kogito.tracing.typedvalue.UnitValue) Mockito.verify(org.mockito.Mockito.verify) Consumer(java.util.function.Consumer) Test(org.junit.jupiter.api.Test) List(java.util.List) LIMEExplainabilityRequest(org.kie.kogito.explainability.api.LIMEExplainabilityRequest) CollectionValue(org.kie.kogito.tracing.typedvalue.CollectionValue) Output(org.kie.kogito.explainability.model.Output) Assertions.assertTrue(org.junit.jupiter.api.Assertions.assertTrue) Optional(java.util.Optional) SaliencyModel(org.kie.kogito.explainability.api.SaliencyModel) Collections(java.util.Collections) ModelIdentifier(org.kie.kogito.explainability.api.ModelIdentifier) Mockito.mock(org.mockito.Mockito.mock) LIMEExplainabilityRequest(org.kie.kogito.explainability.api.LIMEExplainabilityRequest) Prediction(org.kie.kogito.explainability.model.Prediction) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) UnitValue(org.kie.kogito.tracing.typedvalue.UnitValue) Feature(org.kie.kogito.explainability.model.Feature) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) IntNode(com.fasterxml.jackson.databind.node.IntNode) CollectionValue(org.kie.kogito.tracing.typedvalue.CollectionValue) Output(org.kie.kogito.explainability.model.Output) StructureValue(org.kie.kogito.tracing.typedvalue.StructureValue) List(java.util.List) Test(org.junit.jupiter.api.Test)

Example 24 with UnitValue

use of org.kie.kogito.tracing.typedvalue.UnitValue in project kogito-apps by kiegroup.

the class ConversionUtilsTest method toFeatureTypedValue.

@Test
void toFeatureTypedValue() {
    Feature name = ConversionUtils.toFeature("name", new UnitValue("number", new DoubleNode(10d)));
    assertNotNull(name);
    assertEquals("name", name.getName());
    assertEquals(Type.NUMBER, name.getType());
    assertEquals(10d, name.getValue().getUnderlyingObject());
    assertTrue(name.isConstrained());
    assertTrue(name.getDomain().isEmpty());
    Feature name1 = ConversionUtils.toFeature("name1", new StructureValue("complex", singletonMap("key", new UnitValue("string1", new TextNode("stringValue")))));
    assertNotNull(name1);
    assertTrue(name1.isConstrained());
    assertTrue(name1.getDomain().isEmpty());
    assertEquals("name1", name1.getName());
    assertEquals(Type.COMPOSITE, name1.getType());
    assertTrue(name1.getValue().getUnderlyingObject() instanceof List);
    @SuppressWarnings("unchecked") List<Feature> features = (List<Feature>) name1.getValue().getUnderlyingObject();
    assertEquals(1, features.size());
    assertEquals(Type.TEXT, features.get(0).getType());
    assertEquals("stringValue", features.get(0).getValue().getUnderlyingObject());
    List<TypedValue> values = List.of(new UnitValue("number", new DoubleNode(0d)), new UnitValue("number", new DoubleNode(1d)));
    Feature collectionFeature = ConversionUtils.toFeature("name", new CollectionValue("list", values));
    assertNotNull(collectionFeature);
    assertEquals("name", collectionFeature.getName());
    assertEquals(Type.COMPOSITE, collectionFeature.getType());
    assertTrue(collectionFeature.getValue().getUnderlyingObject() instanceof List);
    @SuppressWarnings("unchecked") List<Feature> objects = (List<Feature>) collectionFeature.getValue().getUnderlyingObject();
    assertEquals(2, objects.size());
    for (Feature f : objects) {
        assertNotNull(f);
        assertNotNull(f.getName());
        assertNotNull(f.getType());
        assertEquals(Type.NUMBER, f.getType());
        assertNotNull(f.getValue());
    }
}
Also used : CollectionValue(org.kie.kogito.tracing.typedvalue.CollectionValue) UnitValue(org.kie.kogito.tracing.typedvalue.UnitValue) CounterfactualSearchDomainUnitValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue) DoubleNode(com.fasterxml.jackson.databind.node.DoubleNode) StructureValue(org.kie.kogito.tracing.typedvalue.StructureValue) TextNode(com.fasterxml.jackson.databind.node.TextNode) ArrayList(java.util.ArrayList) Collections.emptyList(java.util.Collections.emptyList) List(java.util.List) Feature(org.kie.kogito.explainability.model.Feature) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) TypedValue(org.kie.kogito.tracing.typedvalue.TypedValue) Test(org.junit.jupiter.api.Test)

Example 25 with UnitValue

use of org.kie.kogito.tracing.typedvalue.UnitValue in project kogito-apps by kiegroup.

the class ConversionUtilsTest method testNestedCollection.

@Test
void testNestedCollection() {
    Collection<TypedValue> depthTwoOne = new ArrayList<>(2);
    depthTwoOne.add(new StructureValue("complex", singletonMap("key", new UnitValue("string1", new TextNode("value one")))));
    depthTwoOne.add(new StructureValue("complex", singletonMap("key", new UnitValue("string1", new TextNode("value two")))));
    Collection<TypedValue> depthTwoTwo = new ArrayList<>(2);
    depthTwoTwo.add(new StructureValue("complex", singletonMap("key", new UnitValue("string1", new TextNode("value three")))));
    depthTwoTwo.add(new StructureValue("complex", singletonMap("key", new UnitValue("string1", new TextNode("value four")))));
    CollectionValue depthOneLeft = new CollectionValue("list", depthTwoOne);
    CollectionValue depthOneRight = new CollectionValue("list", depthTwoTwo);
    Collection<TypedValue> depthOne = new ArrayList<>(2);
    depthOne.add(depthOneLeft);
    depthOne.add(depthOneRight);
    CollectionValue value = new CollectionValue("list", depthOne);
    Feature collectionFeature = ConversionUtils.toFeature("name", value);
    assertNotNull(collectionFeature);
    assertEquals("name", collectionFeature.getName());
    assertEquals(Type.COMPOSITE, collectionFeature.getType());
    assertTrue(collectionFeature.getValue().getUnderlyingObject() instanceof List);
    @SuppressWarnings("unchecked") List<Feature> deepFeatures = (List<Feature>) collectionFeature.getValue().getUnderlyingObject();
    assertEquals(2, deepFeatures.size());
    for (Feature f : deepFeatures) {
        assertNotNull(f);
        assertNotNull(f.getName());
        assertNotNull(f.getType());
        assertEquals(Type.COMPOSITE, f.getType());
        assertNotNull(f.getValue());
        List<Feature> nestedOneValues = (List<Feature>) f.getValue().getUnderlyingObject();
        for (Feature nestedOneValue : nestedOneValues) {
            assertNotNull(nestedOneValue);
            assertNotNull(nestedOneValue.getName());
            assertNotNull(nestedOneValue.getType());
            assertEquals(Type.COMPOSITE, nestedOneValue.getType());
            assertNotNull(nestedOneValue.getValue());
            List<Feature> nestedTwoValues = (List<Feature>) nestedOneValue.getValue().getUnderlyingObject();
            for (Feature nestedTwoValue : nestedTwoValues) {
                assertNotNull(nestedTwoValue);
                assertNotNull(nestedTwoValue.getName());
                assertNotNull(nestedTwoValue.getType());
                assertEquals(Type.TEXT, nestedTwoValue.getType());
                assertNotNull(nestedTwoValue.getValue());
                assertTrue(nestedTwoValue.getValue().asString().contains("value"));
            }
        }
    }
}
Also used : CollectionValue(org.kie.kogito.tracing.typedvalue.CollectionValue) ArrayList(java.util.ArrayList) UnitValue(org.kie.kogito.tracing.typedvalue.UnitValue) CounterfactualSearchDomainUnitValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue) StructureValue(org.kie.kogito.tracing.typedvalue.StructureValue) TextNode(com.fasterxml.jackson.databind.node.TextNode) ArrayList(java.util.ArrayList) Collections.emptyList(java.util.Collections.emptyList) List(java.util.List) Feature(org.kie.kogito.explainability.model.Feature) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) TypedValue(org.kie.kogito.tracing.typedvalue.TypedValue) Test(org.junit.jupiter.api.Test)

Aggregations

UnitValue (org.kie.kogito.tracing.typedvalue.UnitValue)31 Test (org.junit.jupiter.api.Test)24 NamedTypedValue (org.kie.kogito.explainability.api.NamedTypedValue)23 IntNode (com.fasterxml.jackson.databind.node.IntNode)18 CounterfactualSearchDomainUnitValue (org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue)17 CounterfactualExplainabilityRequest (org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest)14 StructureValue (org.kie.kogito.tracing.typedvalue.StructureValue)12 List (java.util.List)10 CollectionValue (org.kie.kogito.tracing.typedvalue.CollectionValue)10 CounterfactualSearchDomain (org.kie.kogito.explainability.api.CounterfactualSearchDomain)9 ArrayList (java.util.ArrayList)8 CounterfactualExplainabilityResult (org.kie.kogito.explainability.api.CounterfactualExplainabilityResult)8 CounterfactualSearchDomainStructureValue (org.kie.kogito.explainability.api.CounterfactualSearchDomainStructureValue)8 TypedValue (org.kie.kogito.tracing.typedvalue.TypedValue)8 Decision (org.kie.kogito.trusty.storage.api.model.decision.Decision)8 Collections (java.util.Collections)7 Assertions.assertEquals (org.junit.jupiter.api.Assertions.assertEquals)7 Assertions.assertTrue (org.junit.jupiter.api.Assertions.assertTrue)7 CounterfactualDomainRange (org.kie.kogito.explainability.api.CounterfactualDomainRange)7 ExplainabilityStatus (org.kie.kogito.explainability.api.ExplainabilityStatus)7