use of org.kie.kogito.explainability.model.domain.FeatureDomain in project kogito-apps by kiegroup.
the class LongEntityTest method distanceScaled.
@ParameterizedTest
@ValueSource(ints = { 0, 1, 2, 3, 4 })
void distanceScaled(int seed) {
Random random = new Random();
random.setSeed(seed);
final FeatureDomain featureDomain = NumericalFeatureDomain.create(0, 100);
final Feature feature = FeatureFactory.newNumericalFeature("feature-long", 20L, featureDomain);
final FeatureDistribution featureDistribution = new NumericFeatureDistribution(feature, random.longs(5000, 10, 40).mapToDouble(x -> x).toArray());
LongEntity entity = (LongEntity) CounterfactualEntityFactory.from(feature, featureDistribution);
entity.proposedValue = 40L;
final double distance = entity.distance();
assertTrue(distance > 0.2 && distance < 0.3);
}
use of org.kie.kogito.explainability.model.domain.FeatureDomain in project kogito-apps by kiegroup.
the class FeatureDomainTest method getStart.
@Test
void getStart() {
final FeatureDomain domain = NumericalFeatureDomain.create(0.0, 10.0);
assertEquals(0.0, domain.getLowerBound());
}
use of org.kie.kogito.explainability.model.domain.FeatureDomain in project kogito-apps by kiegroup.
the class FeatureDomainTest method getEnd.
@Test
void getEnd() {
final FeatureDomain domain = NumericalFeatureDomain.create(-10, -5);
assertEquals(-5.0, domain.getUpperBound());
}
use of org.kie.kogito.explainability.model.domain.FeatureDomain in project kogito-apps by kiegroup.
the class FeatureDomainTest method getCategories.
@Test
void getCategories() {
final FeatureDomain domain = CategoricalFeatureDomain.create("foo", "bar", "foo", "bar", "bar");
assertEquals(Set.of("foo", "bar"), domain.getCategories());
}
use of org.kie.kogito.explainability.model.domain.FeatureDomain in project kogito-apps by kiegroup.
the class CounterfactualEntityFactory method from.
public static CounterfactualEntity from(Feature feature, FeatureDistribution featureDistribution) {
CounterfactualEntity entity = null;
validateFeature(feature);
final Type type = feature.getType();
final FeatureDomain featureDomain = feature.getDomain();
final boolean isConstrained = feature.isConstrained();
final Object valueObject = feature.getValue().getUnderlyingObject();
if (type == Type.NUMBER) {
if (valueObject instanceof Double) {
if (isConstrained) {
entity = FixedDoubleEntity.from(feature);
} else {
entity = DoubleEntity.from(feature, featureDomain.getLowerBound(), featureDomain.getUpperBound(), featureDistribution, isConstrained);
}
} else if (valueObject instanceof Long) {
if (isConstrained) {
entity = FixedLongEntity.from(feature);
} else {
entity = LongEntity.from(feature, featureDomain.getLowerBound().intValue(), featureDomain.getUpperBound().intValue(), featureDistribution, isConstrained);
}
} else if (valueObject instanceof Integer) {
if (isConstrained) {
entity = FixedIntegerEntity.from(feature);
} else {
entity = IntegerEntity.from(feature, featureDomain.getLowerBound().intValue(), featureDomain.getUpperBound().intValue(), featureDistribution, isConstrained);
}
}
} else if (feature.getType() == Type.BOOLEAN) {
if (isConstrained) {
entity = FixedBooleanEntity.from(feature);
} else {
entity = BooleanEntity.from(feature, isConstrained);
}
} else if (feature.getType() == Type.TEXT) {
if (isConstrained) {
entity = FixedTextEntity.from(feature);
} else {
throw new IllegalArgumentException("Unsupported feature type: " + feature.getType());
}
} else if (feature.getType() == Type.BINARY) {
if (isConstrained) {
entity = FixedBinaryEntity.from(feature);
} else {
entity = BinaryEntity.from(feature, ((BinaryFeatureDomain) featureDomain).getCategories(), isConstrained);
}
} else if (feature.getType() == Type.URI) {
if (isConstrained) {
entity = FixedURIEntity.from(feature);
} else {
entity = URIEntity.from(feature, ((URIFeatureDomain) featureDomain).getCategories(), isConstrained);
}
} else if (feature.getType() == Type.TIME) {
if (isConstrained) {
entity = FixedTimeEntity.from(feature);
} else {
final LocalTime lowerBound = LocalTime.MIN.plusSeconds(featureDomain.getLowerBound().longValue());
final LocalTime upperBound = LocalTime.MIN.plusSeconds(featureDomain.getUpperBound().longValue());
entity = TimeEntity.from(feature, lowerBound, upperBound, isConstrained);
}
} else if (feature.getType() == Type.DURATION) {
if (isConstrained) {
entity = FixedDurationEntity.from(feature);
} else {
DurationFeatureDomain domain = (DurationFeatureDomain) featureDomain;
entity = DurationEntity.from(feature, Duration.of(domain.getLowerBound().longValue(), domain.getUnit()), Duration.of(domain.getUpperBound().longValue(), domain.getUnit()), featureDistribution, isConstrained);
}
} else if (feature.getType() == Type.VECTOR) {
if (isConstrained) {
entity = FixedVectorEntity.from(feature);
} else {
throw new IllegalArgumentException("Unsupported feature type: " + feature.getType());
}
} else if (feature.getType() == Type.COMPOSITE) {
if (isConstrained) {
entity = FixedCompositeEntity.from(feature);
} else {
throw new IllegalArgumentException("Unsupported feature type: " + feature.getType());
}
} else if (feature.getType() == Type.CURRENCY) {
if (isConstrained) {
entity = FixedCurrencyEntity.from(feature);
} else {
entity = CurrencyEntity.from(feature, ((CurrencyFeatureDomain) featureDomain).getCategories(), isConstrained);
}
} else if (feature.getType() == Type.CATEGORICAL) {
if (isConstrained) {
entity = FixedCategoricalEntity.from(feature);
} else {
entity = CategoricalEntity.from(feature, ((CategoricalFeatureDomain) featureDomain).getCategories(), isConstrained);
}
} else if (feature.getType() == Type.UNDEFINED) {
if (isConstrained) {
entity = FixedObjectEntity.from(feature);
} else {
entity = ObjectEntity.from(feature, ((ObjectFeatureDomain) featureDomain).getCategories(), isConstrained);
}
} else {
throw new IllegalArgumentException("Unsupported feature type: " + feature.getType());
}
return entity;
}
Aggregations