use of org.kie.kogito.explainability.model.domain.FeatureDomain in project kogito-apps by kiegroup.
the class ConversionUtils method toFeatureList.
// //////////////////////////
// TO EXPLAINABILITY MODEL
// //////////////////////////
/*
* ---------------------------------------
* Feature conversion
* ---------------------------------------
*/
public static List<Feature> toFeatureList(Collection<? extends HasNameValue<TypedValue>> values, Collection<CounterfactualSearchDomain> searchDomains) {
if (searchDomains.isEmpty()) {
return toFeatureList(values);
} else {
AtomicInteger index = new AtomicInteger();
final List<FeatureDomain> featureDomains = toFeatureDomainList(searchDomains);
final List<Boolean> featureConstraints = toFeatureConstraintList(searchDomains);
return values.stream().map(hnv -> {
final String name = hnv.getName();
final TypedValue value = hnv.getValue();
final int i = index.getAndIncrement();
return toFeature(name, value, featureDomains.get(i), featureConstraints.get(i));
}).collect(Collectors.toList());
}
}
use of org.kie.kogito.explainability.model.domain.FeatureDomain in project kogito-apps by kiegroup.
the class CounterfactualScoreCalculatorTest method testGoalSizeMatch.
/**
* If the goal and the model's output is the same, the distances should all be zero.
*/
@Test
void testGoalSizeMatch() throws ExecutionException, InterruptedException {
final CounterFactualScoreCalculator scoreCalculator = new CounterFactualScoreCalculator();
PredictionProvider model = TestUtils.getFeatureSkipModel(0);
List<Feature> features = new ArrayList<>();
List<FeatureDomain> featureDomains = new ArrayList<>();
List<Boolean> constraints = new ArrayList<>();
// f-1
features.add(FeatureFactory.newNumericalFeature("f-1", 1.0));
featureDomains.add(NumericalFeatureDomain.create(0.0, 10.0));
constraints.add(false);
// f-2
features.add(FeatureFactory.newNumericalFeature("f-2", 2.0));
featureDomains.add(NumericalFeatureDomain.create(0.0, 10.0));
constraints.add(false);
// f-3
features.add(FeatureFactory.newBooleanFeature("f-3", true));
featureDomains.add(EmptyFeatureDomain.create());
constraints.add(false);
PredictionInput input = new PredictionInput(features);
PredictionFeatureDomain domains = new PredictionFeatureDomain(featureDomains);
List<CounterfactualEntity> entities = CounterfactualEntityFactory.createEntities(input);
List<Output> goal = new ArrayList<>();
goal.add(new Output("f-2", Type.NUMBER, new Value(2.0), 0.0));
goal.add(new Output("f-3", Type.BOOLEAN, new Value(true), 0.0));
final CounterfactualSolution solution = new CounterfactualSolution(entities, features, model, goal, UUID.randomUUID(), UUID.randomUUID(), 0.0);
BendableBigDecimalScore score = scoreCalculator.calculateScore(solution);
List<PredictionOutput> predictionOutputs = model.predictAsync(List.of(input)).get();
assertTrue(score.isFeasible());
assertEquals(2, goal.size());
// A single prediction is expected
assertEquals(1, predictionOutputs.size());
// Single prediction with two features
assertEquals(2, predictionOutputs.get(0).getOutputs().size());
assertEquals(0, score.getHardScore(0).compareTo(BigDecimal.ZERO));
assertEquals(0, score.getHardScore(1).compareTo(BigDecimal.ZERO));
assertEquals(0, score.getHardScore(2).compareTo(BigDecimal.ZERO));
assertEquals(0, score.getSoftScore(0).compareTo(BigDecimal.ZERO));
assertEquals(0, score.getSoftScore(1).compareTo(BigDecimal.ZERO));
assertEquals(3, score.getHardLevelsSize());
assertEquals(2, score.getSoftLevelsSize());
}
use of org.kie.kogito.explainability.model.domain.FeatureDomain in project kogito-apps by kiegroup.
the class DoubleEntityTest method distanceUnscaled.
@Test
void distanceUnscaled() {
final FeatureDomain featureDomain = NumericalFeatureDomain.create(0.0, 40.0);
final Feature doubleFeature = FeatureFactory.newNumericalFeature("feature-double", 20.0, featureDomain);
DoubleEntity entity = (DoubleEntity) CounterfactualEntityFactory.from(doubleFeature);
entity.proposedValue = 30.0;
assertEquals(10.0, entity.distance());
}
use of org.kie.kogito.explainability.model.domain.FeatureDomain in project kogito-apps by kiegroup.
the class IntegerEntityTest method distanceUnscaled.
@Test
void distanceUnscaled() {
final FeatureDomain featureDomain = NumericalFeatureDomain.create(0, 100);
final Feature integerFeature = FeatureFactory.newNumericalFeature("feature-integer", 20, featureDomain);
IntegerEntity entity = (IntegerEntity) CounterfactualEntityFactory.from(integerFeature);
entity.proposedValue = 40;
assertEquals(20.0, entity.distance());
}
use of org.kie.kogito.explainability.model.domain.FeatureDomain in project kogito-apps by kiegroup.
the class LongEntityTest method distanceUnscaled.
@Test
void distanceUnscaled() {
final FeatureDomain featureDomain = NumericalFeatureDomain.create(0, 100);
final Feature feature = FeatureFactory.newNumericalFeature("feature-long", 20L, featureDomain);
LongEntity entity = (LongEntity) CounterfactualEntityFactory.from(feature);
entity.proposedValue = 40L;
assertEquals(20.0, entity.distance());
}
Aggregations