use of org.kie.kogito.explainability.api.CounterfactualSearchDomain in project kogito-apps by kiegroup.
the class CounterfactualExplainerServiceHandlerTest method testGetPredictionWithCollectionSearchDomains.
@Test
public void testGetPredictionWithCollectionSearchDomains() {
CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, Collections.emptyList(), Collections.emptyList(), List.of(new CounterfactualSearchDomain("input1", new CounterfactualSearchDomainCollectionValue("number", List.of(new CounterfactualSearchDomainUnitValue("number", "number", true, new CounterfactualDomainRange(new IntNode(10), new IntNode(20))))))), MAX_RUNNING_TIME_SECONDS);
assertThrows(IllegalArgumentException.class, () -> handler.getPrediction(request));
}
use of org.kie.kogito.explainability.api.CounterfactualSearchDomain in project kogito-apps by kiegroup.
the class CounterfactualExplainerServiceHandlerTest method testGetPredictionWithFlatSearchDomainsNotFixed.
@Test
public void testGetPredictionWithFlatSearchDomainsNotFixed() {
CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, List.of(new NamedTypedValue("output1", new UnitValue("number", new IntNode(25)))), Collections.emptyList(), List.of(new CounterfactualSearchDomain("output1", new CounterfactualSearchDomainUnitValue("number", "number", false, new CounterfactualDomainRange(new IntNode(10), new IntNode(20))))), MAX_RUNNING_TIME_SECONDS);
Prediction prediction = handler.getPrediction(request);
assertTrue(prediction instanceof CounterfactualPrediction);
CounterfactualPrediction counterfactualPrediction = (CounterfactualPrediction) prediction;
assertEquals(1, counterfactualPrediction.getInput().getFeatures().size());
Feature feature1 = counterfactualPrediction.getInput().getFeatures().get(0);
assertTrue(feature1.getDomain() instanceof NumericalFeatureDomain);
final NumericalFeatureDomain domain = (NumericalFeatureDomain) feature1.getDomain();
assertEquals(10, domain.getLowerBound());
assertEquals(20, domain.getUpperBound());
assertEquals(counterfactualPrediction.getMaxRunningTimeSeconds(), request.getMaxRunningTimeSeconds());
}
use of org.kie.kogito.explainability.api.CounterfactualSearchDomain in project kogito-apps by kiegroup.
the class CounterfactualExplainerServiceHandlerTest method testGetPredictionWithFlatSearchDomainsFixed.
@Test
public void testGetPredictionWithFlatSearchDomainsFixed() {
CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, List.of(new NamedTypedValue("output1", new UnitValue("number", new IntNode(25)))), Collections.emptyList(), List.of(new CounterfactualSearchDomain("output1", new CounterfactualSearchDomainUnitValue("number", "number", true, new CounterfactualDomainRange(new IntNode(10), new IntNode(20))))), MAX_RUNNING_TIME_SECONDS);
Prediction prediction = handler.getPrediction(request);
assertTrue(prediction instanceof CounterfactualPrediction);
CounterfactualPrediction counterfactualPrediction = (CounterfactualPrediction) prediction;
assertEquals(1, counterfactualPrediction.getInput().getFeatures().size());
Feature feature1 = counterfactualPrediction.getInput().getFeatures().get(0);
assertTrue(feature1.getDomain() instanceof EmptyFeatureDomain);
assertEquals(counterfactualPrediction.getMaxRunningTimeSeconds(), request.getMaxRunningTimeSeconds());
}
use of org.kie.kogito.explainability.api.CounterfactualSearchDomain in project kogito-apps by kiegroup.
the class ConversionUtilsTest method getDoubleSearchDomain.
private static CounterfactualSearchDomain getDoubleSearchDomain(String name, double lowerBound, double upperBound) {
final CounterfactualDomainRange range = new CounterfactualDomainRange(DoubleNode.valueOf(lowerBound), DoubleNode.valueOf(upperBound));
CounterfactualSearchDomainUnitValue searchDomain = new CounterfactualSearchDomainUnitValue("double", "double", Boolean.FALSE, range);
return new CounterfactualSearchDomain(name, searchDomain);
}
use of org.kie.kogito.explainability.api.CounterfactualSearchDomain in project kogito-apps by kiegroup.
the class ConversionUtilsTest method toFeatureDomainsConstraintsMultiElement.
@Test
void toFeatureDomainsConstraintsMultiElement() {
final Random random = new Random();
List<NamedTypedValue> values = IntStream.range(0, 10).mapToObj(i -> getDoubleUnit("f-" + i, random.nextDouble())).collect(Collectors.toList());
List<CounterfactualSearchDomain> domains = IntStream.range(0, 10).mapToObj(i -> getDoubleSearchDomain("f-" + i, -1, 1)).collect(Collectors.toList());
final List<Feature> features = ConversionUtils.toFeatureList(values, domains);
assertEquals(10, features.size());
assertTrue(features.stream().allMatch(f -> f.getType() == Type.NUMBER));
assertTrue(features.stream().noneMatch(Feature::isConstrained));
assertTrue(features.stream().map(Feature::getDomain).noneMatch(FeatureDomain::isEmpty));
assertTrue(features.stream().map(Feature::getDomain).map(FeatureDomain::getLowerBound).allMatch(lb -> lb == -1.0));
assertTrue(features.stream().map(Feature::getDomain).map(FeatureDomain::getUpperBound).allMatch(ub -> ub == 1.0));
}
Aggregations