Search in sources :

Example 21 with NamedTypedValue

use of org.kie.kogito.explainability.api.NamedTypedValue in project kogito-apps by kiegroup.

the class CounterfactualExplainerServiceHandlerTest method testGetPredictionWithFlatSearchDomainsNotFixed.

@Test
public void testGetPredictionWithFlatSearchDomainsNotFixed() {
    CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, List.of(new NamedTypedValue("output1", new UnitValue("number", new IntNode(25)))), Collections.emptyList(), List.of(new CounterfactualSearchDomain("output1", new CounterfactualSearchDomainUnitValue("number", "number", false, new CounterfactualDomainRange(new IntNode(10), new IntNode(20))))), MAX_RUNNING_TIME_SECONDS);
    Prediction prediction = handler.getPrediction(request);
    assertTrue(prediction instanceof CounterfactualPrediction);
    CounterfactualPrediction counterfactualPrediction = (CounterfactualPrediction) prediction;
    assertEquals(1, counterfactualPrediction.getInput().getFeatures().size());
    Feature feature1 = counterfactualPrediction.getInput().getFeatures().get(0);
    assertTrue(feature1.getDomain() instanceof NumericalFeatureDomain);
    final NumericalFeatureDomain domain = (NumericalFeatureDomain) feature1.getDomain();
    assertEquals(10, domain.getLowerBound());
    assertEquals(20, domain.getUpperBound());
    assertEquals(counterfactualPrediction.getMaxRunningTimeSeconds(), request.getMaxRunningTimeSeconds());
}
Also used : CounterfactualExplainabilityRequest(org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) IntNode(com.fasterxml.jackson.databind.node.IntNode) CounterfactualDomainRange(org.kie.kogito.explainability.api.CounterfactualDomainRange) Prediction(org.kie.kogito.explainability.model.Prediction) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) UnitValue(org.kie.kogito.tracing.typedvalue.UnitValue) CounterfactualSearchDomainUnitValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue) NumericalFeatureDomain(org.kie.kogito.explainability.model.domain.NumericalFeatureDomain) CounterfactualSearchDomainUnitValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue) Feature(org.kie.kogito.explainability.model.Feature) CounterfactualSearchDomain(org.kie.kogito.explainability.api.CounterfactualSearchDomain) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) Test(org.junit.jupiter.api.Test)

Example 22 with NamedTypedValue

use of org.kie.kogito.explainability.api.NamedTypedValue in project kogito-apps by kiegroup.

the class CounterfactualExplainerServiceHandlerTest method testGetPredictionWithFlatSearchDomainsFixed.

@Test
public void testGetPredictionWithFlatSearchDomainsFixed() {
    CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, List.of(new NamedTypedValue("output1", new UnitValue("number", new IntNode(25)))), Collections.emptyList(), List.of(new CounterfactualSearchDomain("output1", new CounterfactualSearchDomainUnitValue("number", "number", true, new CounterfactualDomainRange(new IntNode(10), new IntNode(20))))), MAX_RUNNING_TIME_SECONDS);
    Prediction prediction = handler.getPrediction(request);
    assertTrue(prediction instanceof CounterfactualPrediction);
    CounterfactualPrediction counterfactualPrediction = (CounterfactualPrediction) prediction;
    assertEquals(1, counterfactualPrediction.getInput().getFeatures().size());
    Feature feature1 = counterfactualPrediction.getInput().getFeatures().get(0);
    assertTrue(feature1.getDomain() instanceof EmptyFeatureDomain);
    assertEquals(counterfactualPrediction.getMaxRunningTimeSeconds(), request.getMaxRunningTimeSeconds());
}
Also used : CounterfactualExplainabilityRequest(org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) IntNode(com.fasterxml.jackson.databind.node.IntNode) CounterfactualDomainRange(org.kie.kogito.explainability.api.CounterfactualDomainRange) Prediction(org.kie.kogito.explainability.model.Prediction) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) EmptyFeatureDomain(org.kie.kogito.explainability.model.domain.EmptyFeatureDomain) UnitValue(org.kie.kogito.tracing.typedvalue.UnitValue) CounterfactualSearchDomainUnitValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue) CounterfactualSearchDomainUnitValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue) Feature(org.kie.kogito.explainability.model.Feature) CounterfactualSearchDomain(org.kie.kogito.explainability.api.CounterfactualSearchDomain) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) Test(org.junit.jupiter.api.Test)

Example 23 with NamedTypedValue

use of org.kie.kogito.explainability.api.NamedTypedValue in project kogito-apps by kiegroup.

the class LimeExplainerServiceHandlerTest method testGetPredictionWithNonEmptyDefinition.

@Test
@SuppressWarnings("unchecked")
public void testGetPredictionWithNonEmptyDefinition() {
    LIMEExplainabilityRequest request = new LIMEExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, List.of(new NamedTypedValue("input1", new UnitValue("number", "number", new IntNode(20))), new NamedTypedValue("input2", new StructureValue("number", Map.of("input2b", new UnitValue("number", new IntNode(55))))), new NamedTypedValue("input3", new CollectionValue("number", List.of(new UnitValue("number", new IntNode(100)))))), List.of(new NamedTypedValue("output1", new UnitValue("number", "number", new IntNode(20))), new NamedTypedValue("output2", new StructureValue("number", Map.of("output2b", new UnitValue("number", new IntNode(55))))), new NamedTypedValue("output3", new CollectionValue("number", List.of(new UnitValue("number", new IntNode(100)))))));
    Prediction prediction = handler.getPrediction(request);
    // Inputs
    assertEquals(3, prediction.getInput().getFeatures().size());
    Optional<Feature> oInput1 = prediction.getInput().getFeatures().stream().filter(f -> f.getName().equals("input1")).findFirst();
    assertTrue(oInput1.isPresent());
    Feature input1 = oInput1.get();
    assertEquals(Type.NUMBER, input1.getType());
    assertEquals(20, input1.getValue().asNumber());
    Optional<Feature> oInput2 = prediction.getInput().getFeatures().stream().filter(f -> f.getName().equals("input2")).findFirst();
    assertTrue(oInput2.isPresent());
    Feature input2 = oInput2.get();
    assertEquals(Type.COMPOSITE, input2.getType());
    assertTrue(input2.getValue().getUnderlyingObject() instanceof List);
    List<Feature> input2Object = (List<Feature>) input2.getValue().getUnderlyingObject();
    assertEquals(1, input2Object.size());
    Optional<Feature> oInput2Child = input2Object.stream().filter(f -> f.getName().equals("input2b")).findFirst();
    assertTrue(oInput2Child.isPresent());
    Feature input2Child = oInput2Child.get();
    assertEquals(Type.NUMBER, input2Child.getType());
    assertEquals(55, input2Child.getValue().asNumber());
    Optional<Feature> oInput3 = prediction.getInput().getFeatures().stream().filter(f -> f.getName().equals("input3")).findFirst();
    assertTrue(oInput3.isPresent());
    Feature input3 = oInput3.get();
    assertEquals(Type.COMPOSITE, input3.getType());
    assertTrue(input3.getValue().getUnderlyingObject() instanceof List);
    List<Feature> input3Object = (List<Feature>) input3.getValue().getUnderlyingObject();
    assertEquals(1, input3Object.size());
    Feature input3Child = input3Object.get(0);
    assertEquals(Type.NUMBER, input3Child.getType());
    assertEquals(100, input3Child.getValue().asNumber());
    // Outputs
    assertEquals(3, prediction.getOutput().getOutputs().size());
    Optional<Output> oOutput1 = prediction.getOutput().getOutputs().stream().filter(o -> o.getName().equals("output1")).findFirst();
    assertTrue(oOutput1.isPresent());
    Output output1 = oOutput1.get();
    assertEquals(Type.NUMBER, output1.getType());
    assertEquals(20, output1.getValue().asNumber());
    Optional<Output> oOutput2 = prediction.getOutput().getOutputs().stream().filter(o -> o.getName().equals("output2")).findFirst();
    assertTrue(oOutput2.isPresent());
    Output output2 = oOutput2.get();
    assertEquals(Type.COMPOSITE, input2.getType());
    assertTrue(output2.getValue().getUnderlyingObject() instanceof List);
    List<Output> output2Object = (List<Output>) output2.getValue().getUnderlyingObject();
    assertEquals(1, output2Object.size());
    Optional<Output> oOutput2Child = output2Object.stream().filter(f -> f.getName().equals("output2b")).findFirst();
    assertTrue(oOutput2Child.isPresent());
    Output output2Child = oOutput2Child.get();
    assertEquals(Type.NUMBER, output2Child.getType());
    assertEquals(55, output2Child.getValue().asNumber());
    Optional<Output> oOutput3 = prediction.getOutput().getOutputs().stream().filter(o -> o.getName().equals("output3")).findFirst();
    assertTrue(oOutput3.isPresent());
    Output output3 = oOutput3.get();
    assertEquals(Type.COMPOSITE, output3.getType());
    assertTrue(output3.getValue().getUnderlyingObject() instanceof List);
    List<Output> output3Object = (List<Output>) output3.getValue().getUnderlyingObject();
    assertEquals(1, output3Object.size());
    Output output3Child = output3Object.get(0);
    assertEquals(Type.NUMBER, output3Child.getType());
    assertEquals(100, output3Child.getValue().asNumber());
}
Also used : Assertions.assertThrows(org.junit.jupiter.api.Assertions.assertThrows) ArgumentMatchers.any(org.mockito.ArgumentMatchers.any) BeforeEach(org.junit.jupiter.api.BeforeEach) LIMEExplainabilityResult(org.kie.kogito.explainability.api.LIMEExplainabilityResult) IntNode(com.fasterxml.jackson.databind.node.IntNode) BaseExplainabilityRequest(org.kie.kogito.explainability.api.BaseExplainabilityRequest) Feature(org.kie.kogito.explainability.model.Feature) Prediction(org.kie.kogito.explainability.model.Prediction) ArgumentMatchers.eq(org.mockito.ArgumentMatchers.eq) Value(org.kie.kogito.explainability.model.Value) Saliency(org.kie.kogito.explainability.model.Saliency) StructureValue(org.kie.kogito.tracing.typedvalue.StructureValue) Assertions.assertFalse(org.junit.jupiter.api.Assertions.assertFalse) Map(java.util.Map) PredictionProviderFactory(org.kie.kogito.explainability.PredictionProviderFactory) Assertions.assertEquals(org.junit.jupiter.api.Assertions.assertEquals) BaseExplainabilityResult(org.kie.kogito.explainability.api.BaseExplainabilityResult) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) LimeExplainer(org.kie.kogito.explainability.local.lime.LimeExplainer) ExplainabilityStatus(org.kie.kogito.explainability.api.ExplainabilityStatus) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) Type(org.kie.kogito.explainability.model.Type) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) UnitValue(org.kie.kogito.tracing.typedvalue.UnitValue) Mockito.verify(org.mockito.Mockito.verify) Consumer(java.util.function.Consumer) Test(org.junit.jupiter.api.Test) List(java.util.List) LIMEExplainabilityRequest(org.kie.kogito.explainability.api.LIMEExplainabilityRequest) CollectionValue(org.kie.kogito.tracing.typedvalue.CollectionValue) Output(org.kie.kogito.explainability.model.Output) Assertions.assertTrue(org.junit.jupiter.api.Assertions.assertTrue) Optional(java.util.Optional) SaliencyModel(org.kie.kogito.explainability.api.SaliencyModel) Collections(java.util.Collections) ModelIdentifier(org.kie.kogito.explainability.api.ModelIdentifier) Mockito.mock(org.mockito.Mockito.mock) LIMEExplainabilityRequest(org.kie.kogito.explainability.api.LIMEExplainabilityRequest) Prediction(org.kie.kogito.explainability.model.Prediction) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) UnitValue(org.kie.kogito.tracing.typedvalue.UnitValue) Feature(org.kie.kogito.explainability.model.Feature) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) IntNode(com.fasterxml.jackson.databind.node.IntNode) CollectionValue(org.kie.kogito.tracing.typedvalue.CollectionValue) Output(org.kie.kogito.explainability.model.Output) StructureValue(org.kie.kogito.tracing.typedvalue.StructureValue) List(java.util.List) Test(org.junit.jupiter.api.Test)

Example 24 with NamedTypedValue

use of org.kie.kogito.explainability.api.NamedTypedValue in project kogito-apps by kiegroup.

the class ConversionUtilsTest method toFeatureDomainsConstraintsMultiElement.

@Test
void toFeatureDomainsConstraintsMultiElement() {
    final Random random = new Random();
    List<NamedTypedValue> values = IntStream.range(0, 10).mapToObj(i -> getDoubleUnit("f-" + i, random.nextDouble())).collect(Collectors.toList());
    List<CounterfactualSearchDomain> domains = IntStream.range(0, 10).mapToObj(i -> getDoubleSearchDomain("f-" + i, -1, 1)).collect(Collectors.toList());
    final List<Feature> features = ConversionUtils.toFeatureList(values, domains);
    assertEquals(10, features.size());
    assertTrue(features.stream().allMatch(f -> f.getType() == Type.NUMBER));
    assertTrue(features.stream().noneMatch(Feature::isConstrained));
    assertTrue(features.stream().map(Feature::getDomain).noneMatch(FeatureDomain::isEmpty));
    assertTrue(features.stream().map(Feature::getDomain).map(FeatureDomain::getLowerBound).allMatch(lb -> lb == -1.0));
    assertTrue(features.stream().map(Feature::getDomain).map(FeatureDomain::getUpperBound).allMatch(ub -> ub == 1.0));
}
Also used : IntStream(java.util.stream.IntStream) Assertions.assertThrows(org.junit.jupiter.api.Assertions.assertThrows) Assertions.assertNotNull(org.junit.jupiter.api.Assertions.assertNotNull) IntNode(com.fasterxml.jackson.databind.node.IntNode) Feature(org.kie.kogito.explainability.model.Feature) Assertions.assertNull(org.junit.jupiter.api.Assertions.assertNull) Random(java.util.Random) CounterfactualDomainRange(org.kie.kogito.explainability.api.CounterfactualDomainRange) Value(org.kie.kogito.explainability.model.Value) StructureValue(org.kie.kogito.tracing.typedvalue.StructureValue) ObjectNode(com.fasterxml.jackson.databind.node.ObjectNode) ArrayList(java.util.ArrayList) Pair(org.apache.commons.lang3.tuple.Pair) Assertions.assertFalse(org.junit.jupiter.api.Assertions.assertFalse) Map(java.util.Map) EmptyFeatureDomain(org.kie.kogito.explainability.model.domain.EmptyFeatureDomain) JsonNode(com.fasterxml.jackson.databind.JsonNode) JsonObject(io.vertx.core.json.JsonObject) Collections.singletonMap(java.util.Collections.singletonMap) Assertions.assertEquals(org.junit.jupiter.api.Assertions.assertEquals) CounterfactualDomainCategorical(org.kie.kogito.explainability.api.CounterfactualDomainCategorical) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) CategoricalFeatureDomain(org.kie.kogito.explainability.model.domain.CategoricalFeatureDomain) Collections.emptyList(java.util.Collections.emptyList) Collection(java.util.Collection) Collectors(java.util.stream.Collectors) Type(org.kie.kogito.explainability.model.Type) TextNode(com.fasterxml.jackson.databind.node.TextNode) TypedValue(org.kie.kogito.tracing.typedvalue.TypedValue) UnitValue(org.kie.kogito.tracing.typedvalue.UnitValue) Test(org.junit.jupiter.api.Test) List(java.util.List) CounterfactualDomain(org.kie.kogito.explainability.api.CounterfactualDomain) CollectionValue(org.kie.kogito.tracing.typedvalue.CollectionValue) CounterfactualSearchDomainUnitValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue) DoubleNode(com.fasterxml.jackson.databind.node.DoubleNode) NumericalFeatureDomain(org.kie.kogito.explainability.model.domain.NumericalFeatureDomain) JsonNodeFactory(com.fasterxml.jackson.databind.node.JsonNodeFactory) Output(org.kie.kogito.explainability.model.Output) Assertions.assertTrue(org.junit.jupiter.api.Assertions.assertTrue) Optional(java.util.Optional) BooleanNode(com.fasterxml.jackson.databind.node.BooleanNode) CounterfactualSearchDomain(org.kie.kogito.explainability.api.CounterfactualSearchDomain) FeatureDomain(org.kie.kogito.explainability.model.domain.FeatureDomain) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) Random(java.util.Random) EmptyFeatureDomain(org.kie.kogito.explainability.model.domain.EmptyFeatureDomain) CategoricalFeatureDomain(org.kie.kogito.explainability.model.domain.CategoricalFeatureDomain) NumericalFeatureDomain(org.kie.kogito.explainability.model.domain.NumericalFeatureDomain) FeatureDomain(org.kie.kogito.explainability.model.domain.FeatureDomain) Feature(org.kie.kogito.explainability.model.Feature) CounterfactualSearchDomain(org.kie.kogito.explainability.api.CounterfactualSearchDomain) Test(org.junit.jupiter.api.Test)

Example 25 with NamedTypedValue

use of org.kie.kogito.explainability.api.NamedTypedValue in project kogito-apps by kiegroup.

the class ConversionUtilsTest method toFeatureDomainsConstraintsSingleElement.

@Test
void toFeatureDomainsConstraintsSingleElement() {
    NamedTypedValue typedValue = getDoubleUnit("f-1", 20.0);
    CounterfactualSearchDomain domain = getDoubleSearchDomain("f-1", 18.0, 65.0);
    final List<Feature> features = ConversionUtils.toFeatureList(List.of(typedValue), List.of(domain));
    assertEquals(1, features.size());
    final Feature feature = features.get(0);
    assertEquals(Type.NUMBER, feature.getType());
    assertEquals("f-1", feature.getName());
    assertEquals(20.0, feature.getValue().asNumber());
    assertFalse(feature.isConstrained());
    assertFalse(feature.getDomain().isEmpty());
    assertEquals(18, feature.getDomain().getLowerBound());
    assertEquals(65, feature.getDomain().getUpperBound());
}
Also used : NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) Feature(org.kie.kogito.explainability.model.Feature) CounterfactualSearchDomain(org.kie.kogito.explainability.api.CounterfactualSearchDomain) Test(org.junit.jupiter.api.Test)

Aggregations

NamedTypedValue (org.kie.kogito.explainability.api.NamedTypedValue)33 Test (org.junit.jupiter.api.Test)27 UnitValue (org.kie.kogito.tracing.typedvalue.UnitValue)22 IntNode (com.fasterxml.jackson.databind.node.IntNode)19 CounterfactualExplainabilityRequest (org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest)19 CounterfactualSearchDomain (org.kie.kogito.explainability.api.CounterfactualSearchDomain)19 CounterfactualSearchDomainUnitValue (org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue)16 CounterfactualDomainRange (org.kie.kogito.explainability.api.CounterfactualDomainRange)12 CounterfactualExplainabilityResult (org.kie.kogito.explainability.api.CounterfactualExplainabilityResult)11 StructureValue (org.kie.kogito.tracing.typedvalue.StructureValue)11 List (java.util.List)9 CounterfactualSearchDomainStructureValue (org.kie.kogito.explainability.api.CounterfactualSearchDomainStructureValue)9 Prediction (org.kie.kogito.explainability.model.Prediction)9 CollectionValue (org.kie.kogito.tracing.typedvalue.CollectionValue)9 Map (java.util.Map)8 Assertions.assertEquals (org.junit.jupiter.api.Assertions.assertEquals)8 Assertions.assertThrows (org.junit.jupiter.api.Assertions.assertThrows)8 Assertions.assertTrue (org.junit.jupiter.api.Assertions.assertTrue)8 ModelIdentifier (org.kie.kogito.explainability.api.ModelIdentifier)8 CounterfactualPrediction (org.kie.kogito.explainability.model.CounterfactualPrediction)8