use of org.kie.kogito.explainability.model.domain.FeatureDomain in project kogito-apps by kiegroup.
the class CounterfactualEntityFactoryTest method testDoubleFactory.
@Test
void testDoubleFactory() {
final double value = 5.5;
final FeatureDomain domain = NumericalFeatureDomain.create(0.0, 10.0);
final Feature feature = FeatureFactory.newNumericalFeature("double-feature", value, domain);
final CounterfactualEntity counterfactualEntity = CounterfactualEntityFactory.from(feature);
assertTrue(counterfactualEntity instanceof DoubleEntity);
assertEquals(value, counterfactualEntity.asFeature().getValue().asNumber());
}
use of org.kie.kogito.explainability.model.domain.FeatureDomain in project kogito-apps by kiegroup.
the class CounterfactualScoreCalculatorTest method testNullIntegerInput.
/**
* Null values for input Integer features should not be accepted as valid
*/
@Test
void testNullIntegerInput() throws ExecutionException, InterruptedException {
List<Feature> features = new ArrayList<>();
List<FeatureDomain> featureDomains = new ArrayList<>();
List<Boolean> constraints = new ArrayList<>();
// f-1
features.add(FeatureFactory.newNumericalFeature("f-1", 1.0));
featureDomains.add(NumericalFeatureDomain.create(0.0, 10.0));
constraints.add(false);
// f-2
features.add(FeatureFactory.newNumericalFeature("f-2", null));
featureDomains.add(NumericalFeatureDomain.create(0, 10));
constraints.add(false);
// f-3
features.add(FeatureFactory.newBooleanFeature("f-3", true));
featureDomains.add(EmptyFeatureDomain.create());
constraints.add(false);
PredictionInput input = new PredictionInput(features);
PredictionFeatureDomain domains = new PredictionFeatureDomain(featureDomains);
IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> {
CounterfactualEntityFactory.createEntities(input);
});
assertEquals("Null numeric features are not supported in counterfactuals", exception.getMessage());
}
use of org.kie.kogito.explainability.model.domain.FeatureDomain in project kogito-apps by kiegroup.
the class CounterfactualScoreCalculatorTest method testGoalSizeSmaller.
/**
* Using a smaller number of features in the goals (1) than the model's output (2) should
* throw an {@link IllegalArgumentException} with the appropriate message.
*/
@Test
void testGoalSizeSmaller() throws ExecutionException, InterruptedException {
final CounterFactualScoreCalculator scoreCalculator = new CounterFactualScoreCalculator();
PredictionProvider model = TestUtils.getFeatureSkipModel(0);
List<Feature> features = new ArrayList<>();
List<FeatureDomain> featureDomains = new ArrayList<>();
List<Boolean> constraints = new ArrayList<>();
// f-1
features.add(FeatureFactory.newNumericalFeature("f-1", 1.0));
featureDomains.add(NumericalFeatureDomain.create(0.0, 10.0));
constraints.add(false);
// f-2
features.add(FeatureFactory.newNumericalFeature("f-2", 2.0));
featureDomains.add(NumericalFeatureDomain.create(0.0, 10.0));
constraints.add(false);
// f-3
features.add(FeatureFactory.newBooleanFeature("f-3", true));
featureDomains.add(EmptyFeatureDomain.create());
constraints.add(false);
PredictionInput input = new PredictionInput(features);
PredictionFeatureDomain domains = new PredictionFeatureDomain(featureDomains);
List<CounterfactualEntity> entities = CounterfactualEntityFactory.createEntities(input);
List<Output> goal = new ArrayList<>();
goal.add(new Output("f-2", Type.NUMBER, new Value(2.0), 0.0));
List<PredictionOutput> predictionOutputs = model.predictAsync(List.of(input)).get();
assertEquals(1, goal.size());
// A single prediction is expected
assertEquals(1, predictionOutputs.size());
// Single prediction with two features
assertEquals(2, predictionOutputs.get(0).getOutputs().size());
final CounterfactualSolution solution = new CounterfactualSolution(entities, features, model, goal, UUID.randomUUID(), UUID.randomUUID(), 0.0);
IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> {
scoreCalculator.calculateScore(solution);
});
assertEquals("Prediction size must be equal to goal size", exception.getMessage());
}
use of org.kie.kogito.explainability.model.domain.FeatureDomain in project kogito-apps by kiegroup.
the class ConversionUtilsTest method testToFeatureDomain_UnitFixedNumber.
@Test
void testToFeatureDomain_UnitFixedNumber() {
FeatureDomain featureDomain = ConversionUtils.toFeatureDomain(new CounterfactualSearchDomainUnitValue("integer", "integer", true, null));
assertTrue(featureDomain instanceof EmptyFeatureDomain);
}
use of org.kie.kogito.explainability.model.domain.FeatureDomain in project kogito-apps by kiegroup.
the class ConversionUtilsTest method testToFeatureDomain_UnitRangeDouble.
@Test
void testToFeatureDomain_UnitRangeDouble() {
FeatureDomain featureDomain = ConversionUtils.toFeatureDomain(new CounterfactualSearchDomainUnitValue("double", "double", true, new CounterfactualDomainRange(DoubleNode.valueOf(-273.15), DoubleNode.valueOf(Double.MAX_VALUE))));
assertTrue(featureDomain instanceof NumericalFeatureDomain);
NumericalFeatureDomain numericalFeatureDomain = (NumericalFeatureDomain) featureDomain;
assertEquals(-273.15, numericalFeatureDomain.getLowerBound());
assertEquals(Double.MAX_VALUE, numericalFeatureDomain.getUpperBound());
assertNull(numericalFeatureDomain.getCategories());
}
Aggregations