use of org.kie.kogito.explainability.api.NamedTypedValue in project kogito-apps by kiegroup.
the class CounterfactualExplainerServiceHandlerTest method testGetPredictionWithFlatSearchDomainsNotFixed.
@Test
public void testGetPredictionWithFlatSearchDomainsNotFixed() {
CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, List.of(new NamedTypedValue("output1", new UnitValue("number", new IntNode(25)))), Collections.emptyList(), List.of(new CounterfactualSearchDomain("output1", new CounterfactualSearchDomainUnitValue("number", "number", false, new CounterfactualDomainRange(new IntNode(10), new IntNode(20))))), MAX_RUNNING_TIME_SECONDS);
Prediction prediction = handler.getPrediction(request);
assertTrue(prediction instanceof CounterfactualPrediction);
CounterfactualPrediction counterfactualPrediction = (CounterfactualPrediction) prediction;
assertEquals(1, counterfactualPrediction.getInput().getFeatures().size());
Feature feature1 = counterfactualPrediction.getInput().getFeatures().get(0);
assertTrue(feature1.getDomain() instanceof NumericalFeatureDomain);
final NumericalFeatureDomain domain = (NumericalFeatureDomain) feature1.getDomain();
assertEquals(10, domain.getLowerBound());
assertEquals(20, domain.getUpperBound());
assertEquals(counterfactualPrediction.getMaxRunningTimeSeconds(), request.getMaxRunningTimeSeconds());
}
use of org.kie.kogito.explainability.api.NamedTypedValue in project kogito-apps by kiegroup.
the class CounterfactualExplainerServiceHandlerTest method testGetPredictionWithFlatSearchDomainsFixed.
@Test
public void testGetPredictionWithFlatSearchDomainsFixed() {
CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, List.of(new NamedTypedValue("output1", new UnitValue("number", new IntNode(25)))), Collections.emptyList(), List.of(new CounterfactualSearchDomain("output1", new CounterfactualSearchDomainUnitValue("number", "number", true, new CounterfactualDomainRange(new IntNode(10), new IntNode(20))))), MAX_RUNNING_TIME_SECONDS);
Prediction prediction = handler.getPrediction(request);
assertTrue(prediction instanceof CounterfactualPrediction);
CounterfactualPrediction counterfactualPrediction = (CounterfactualPrediction) prediction;
assertEquals(1, counterfactualPrediction.getInput().getFeatures().size());
Feature feature1 = counterfactualPrediction.getInput().getFeatures().get(0);
assertTrue(feature1.getDomain() instanceof EmptyFeatureDomain);
assertEquals(counterfactualPrediction.getMaxRunningTimeSeconds(), request.getMaxRunningTimeSeconds());
}
use of org.kie.kogito.explainability.api.NamedTypedValue in project kogito-apps by kiegroup.
the class LimeExplainerServiceHandlerTest method testGetPredictionWithNonEmptyDefinition.
@Test
@SuppressWarnings("unchecked")
public void testGetPredictionWithNonEmptyDefinition() {
LIMEExplainabilityRequest request = new LIMEExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, List.of(new NamedTypedValue("input1", new UnitValue("number", "number", new IntNode(20))), new NamedTypedValue("input2", new StructureValue("number", Map.of("input2b", new UnitValue("number", new IntNode(55))))), new NamedTypedValue("input3", new CollectionValue("number", List.of(new UnitValue("number", new IntNode(100)))))), List.of(new NamedTypedValue("output1", new UnitValue("number", "number", new IntNode(20))), new NamedTypedValue("output2", new StructureValue("number", Map.of("output2b", new UnitValue("number", new IntNode(55))))), new NamedTypedValue("output3", new CollectionValue("number", List.of(new UnitValue("number", new IntNode(100)))))));
Prediction prediction = handler.getPrediction(request);
// Inputs
assertEquals(3, prediction.getInput().getFeatures().size());
Optional<Feature> oInput1 = prediction.getInput().getFeatures().stream().filter(f -> f.getName().equals("input1")).findFirst();
assertTrue(oInput1.isPresent());
Feature input1 = oInput1.get();
assertEquals(Type.NUMBER, input1.getType());
assertEquals(20, input1.getValue().asNumber());
Optional<Feature> oInput2 = prediction.getInput().getFeatures().stream().filter(f -> f.getName().equals("input2")).findFirst();
assertTrue(oInput2.isPresent());
Feature input2 = oInput2.get();
assertEquals(Type.COMPOSITE, input2.getType());
assertTrue(input2.getValue().getUnderlyingObject() instanceof List);
List<Feature> input2Object = (List<Feature>) input2.getValue().getUnderlyingObject();
assertEquals(1, input2Object.size());
Optional<Feature> oInput2Child = input2Object.stream().filter(f -> f.getName().equals("input2b")).findFirst();
assertTrue(oInput2Child.isPresent());
Feature input2Child = oInput2Child.get();
assertEquals(Type.NUMBER, input2Child.getType());
assertEquals(55, input2Child.getValue().asNumber());
Optional<Feature> oInput3 = prediction.getInput().getFeatures().stream().filter(f -> f.getName().equals("input3")).findFirst();
assertTrue(oInput3.isPresent());
Feature input3 = oInput3.get();
assertEquals(Type.COMPOSITE, input3.getType());
assertTrue(input3.getValue().getUnderlyingObject() instanceof List);
List<Feature> input3Object = (List<Feature>) input3.getValue().getUnderlyingObject();
assertEquals(1, input3Object.size());
Feature input3Child = input3Object.get(0);
assertEquals(Type.NUMBER, input3Child.getType());
assertEquals(100, input3Child.getValue().asNumber());
// Outputs
assertEquals(3, prediction.getOutput().getOutputs().size());
Optional<Output> oOutput1 = prediction.getOutput().getOutputs().stream().filter(o -> o.getName().equals("output1")).findFirst();
assertTrue(oOutput1.isPresent());
Output output1 = oOutput1.get();
assertEquals(Type.NUMBER, output1.getType());
assertEquals(20, output1.getValue().asNumber());
Optional<Output> oOutput2 = prediction.getOutput().getOutputs().stream().filter(o -> o.getName().equals("output2")).findFirst();
assertTrue(oOutput2.isPresent());
Output output2 = oOutput2.get();
assertEquals(Type.COMPOSITE, input2.getType());
assertTrue(output2.getValue().getUnderlyingObject() instanceof List);
List<Output> output2Object = (List<Output>) output2.getValue().getUnderlyingObject();
assertEquals(1, output2Object.size());
Optional<Output> oOutput2Child = output2Object.stream().filter(f -> f.getName().equals("output2b")).findFirst();
assertTrue(oOutput2Child.isPresent());
Output output2Child = oOutput2Child.get();
assertEquals(Type.NUMBER, output2Child.getType());
assertEquals(55, output2Child.getValue().asNumber());
Optional<Output> oOutput3 = prediction.getOutput().getOutputs().stream().filter(o -> o.getName().equals("output3")).findFirst();
assertTrue(oOutput3.isPresent());
Output output3 = oOutput3.get();
assertEquals(Type.COMPOSITE, output3.getType());
assertTrue(output3.getValue().getUnderlyingObject() instanceof List);
List<Output> output3Object = (List<Output>) output3.getValue().getUnderlyingObject();
assertEquals(1, output3Object.size());
Output output3Child = output3Object.get(0);
assertEquals(Type.NUMBER, output3Child.getType());
assertEquals(100, output3Child.getValue().asNumber());
}
use of org.kie.kogito.explainability.api.NamedTypedValue in project kogito-apps by kiegroup.
the class ConversionUtilsTest method toFeatureDomainsConstraintsMultiElement.
@Test
void toFeatureDomainsConstraintsMultiElement() {
final Random random = new Random();
List<NamedTypedValue> values = IntStream.range(0, 10).mapToObj(i -> getDoubleUnit("f-" + i, random.nextDouble())).collect(Collectors.toList());
List<CounterfactualSearchDomain> domains = IntStream.range(0, 10).mapToObj(i -> getDoubleSearchDomain("f-" + i, -1, 1)).collect(Collectors.toList());
final List<Feature> features = ConversionUtils.toFeatureList(values, domains);
assertEquals(10, features.size());
assertTrue(features.stream().allMatch(f -> f.getType() == Type.NUMBER));
assertTrue(features.stream().noneMatch(Feature::isConstrained));
assertTrue(features.stream().map(Feature::getDomain).noneMatch(FeatureDomain::isEmpty));
assertTrue(features.stream().map(Feature::getDomain).map(FeatureDomain::getLowerBound).allMatch(lb -> lb == -1.0));
assertTrue(features.stream().map(Feature::getDomain).map(FeatureDomain::getUpperBound).allMatch(ub -> ub == 1.0));
}
use of org.kie.kogito.explainability.api.NamedTypedValue in project kogito-apps by kiegroup.
the class ConversionUtilsTest method toFeatureDomainsConstraintsSingleElement.
@Test
void toFeatureDomainsConstraintsSingleElement() {
NamedTypedValue typedValue = getDoubleUnit("f-1", 20.0);
CounterfactualSearchDomain domain = getDoubleSearchDomain("f-1", 18.0, 65.0);
final List<Feature> features = ConversionUtils.toFeatureList(List.of(typedValue), List.of(domain));
assertEquals(1, features.size());
final Feature feature = features.get(0);
assertEquals(Type.NUMBER, feature.getType());
assertEquals("f-1", feature.getName());
assertEquals(20.0, feature.getValue().asNumber());
assertFalse(feature.isConstrained());
assertFalse(feature.getDomain().isEmpty());
assertEquals(18, feature.getDomain().getLowerBound());
assertEquals(65, feature.getDomain().getUpperBound());
}
Aggregations