use of org.kie.kogito.explainability.model.Type in project kogito-apps by kiegroup.
the class CounterFactualScoreCalculator method outputDistance.
public static Double outputDistance(Output prediction, Output goal, double threshold) throws IllegalArgumentException {
final Type predictionType = prediction.getType();
final Type goalType = goal.getType();
// there could be a type difference (e.g. a numerical feature is predicted as a textual "null")
if (predictionType != goalType) {
if (Objects.nonNull(prediction.getValue().getUnderlyingObject())) {
String message = String.format("Features must have the same type. Feature '%s', has type '%s' and '%s'", prediction.getName(), predictionType.toString(), goalType.toString());
logger.error(message);
throw new IllegalArgumentException(message);
} else {
return DEFAULT_DISTANCE;
}
}
if (predictionType == Type.NUMBER) {
final double predictionValue = prediction.getValue().asNumber();
final double goalValue = goal.getValue().asNumber();
final double difference = Math.abs(predictionValue - goalValue);
// If any of the values is zero use the difference instead of change
// If neither of the values is zero use the change rate
double distance;
if (Double.isNaN(predictionValue) || Double.isNaN(goalValue)) {
String message = String.format("Unsupported NaN or NULL for numeric feature '%s'", prediction.getName());
logger.error(message);
throw new IllegalArgumentException(message);
}
if (predictionValue == 0 || goalValue == 0) {
distance = difference;
} else {
distance = difference / Math.max(predictionValue, goalValue);
}
if (distance < threshold) {
return 0d;
} else {
return distance;
}
} else if (predictionType == Type.DURATION) {
final Duration predictionValue = (Duration) prediction.getValue().getUnderlyingObject();
final Duration goalValue = (Duration) goal.getValue().getUnderlyingObject();
if (Objects.isNull(predictionValue) || Objects.isNull(goalValue)) {
return 1.0;
}
// Duration distances calculated from value in seconds
final double difference = predictionValue.minus(goalValue).abs().getSeconds();
// If any of the values is zero use the difference instead of change
// If neither of the values is zero use the change rate
double distance;
if (predictionValue.isZero() || goalValue.isZero()) {
distance = difference;
} else {
distance = difference / Math.max(predictionValue.getSeconds(), goalValue.getSeconds());
}
if (distance < threshold) {
return 0d;
} else {
return distance;
}
} else if (predictionType == Type.TIME) {
final LocalTime predictionValue = (LocalTime) prediction.getValue().getUnderlyingObject();
final LocalTime goalValue = (LocalTime) goal.getValue().getUnderlyingObject();
if (Objects.isNull(predictionValue) || Objects.isNull(goalValue)) {
return 1.0;
}
final double interval = LocalTime.MIN.until(LocalTime.MAX, ChronoUnit.SECONDS);
// Time distances calculated from value in seconds
final double distance = Math.abs(predictionValue.until(goalValue, ChronoUnit.SECONDS)) / interval;
if (distance < threshold) {
return 0d;
} else {
return distance;
}
} else if (SUPPORTED_CATEGORICAL_TYPES.contains(predictionType)) {
final Object goalValueObject = goal.getValue().getUnderlyingObject();
final Object predictionValueObject = prediction.getValue().getUnderlyingObject();
return Objects.equals(goalValueObject, predictionValueObject) ? 0.0 : DEFAULT_DISTANCE;
} else {
String message = String.format("Feature '%s' has unsupported type '%s'", prediction.getName(), predictionType.toString());
logger.error(message);
throw new IllegalArgumentException(message);
}
}
use of org.kie.kogito.explainability.model.Type in project kogito-apps by kiegroup.
the class CounterfactualEntityFactory method from.
public static CounterfactualEntity from(Feature feature, FeatureDistribution featureDistribution) {
CounterfactualEntity entity = null;
validateFeature(feature);
final Type type = feature.getType();
final FeatureDomain featureDomain = feature.getDomain();
final boolean isConstrained = feature.isConstrained();
final Object valueObject = feature.getValue().getUnderlyingObject();
if (type == Type.NUMBER) {
if (valueObject instanceof Double) {
if (isConstrained) {
entity = FixedDoubleEntity.from(feature);
} else {
entity = DoubleEntity.from(feature, featureDomain.getLowerBound(), featureDomain.getUpperBound(), featureDistribution, isConstrained);
}
} else if (valueObject instanceof Long) {
if (isConstrained) {
entity = FixedLongEntity.from(feature);
} else {
entity = LongEntity.from(feature, featureDomain.getLowerBound().intValue(), featureDomain.getUpperBound().intValue(), featureDistribution, isConstrained);
}
} else if (valueObject instanceof Integer) {
if (isConstrained) {
entity = FixedIntegerEntity.from(feature);
} else {
entity = IntegerEntity.from(feature, featureDomain.getLowerBound().intValue(), featureDomain.getUpperBound().intValue(), featureDistribution, isConstrained);
}
}
} else if (feature.getType() == Type.BOOLEAN) {
if (isConstrained) {
entity = FixedBooleanEntity.from(feature);
} else {
entity = BooleanEntity.from(feature, isConstrained);
}
} else if (feature.getType() == Type.TEXT) {
if (isConstrained) {
entity = FixedTextEntity.from(feature);
} else {
throw new IllegalArgumentException("Unsupported feature type: " + feature.getType());
}
} else if (feature.getType() == Type.BINARY) {
if (isConstrained) {
entity = FixedBinaryEntity.from(feature);
} else {
entity = BinaryEntity.from(feature, ((BinaryFeatureDomain) featureDomain).getCategories(), isConstrained);
}
} else if (feature.getType() == Type.URI) {
if (isConstrained) {
entity = FixedURIEntity.from(feature);
} else {
entity = URIEntity.from(feature, ((URIFeatureDomain) featureDomain).getCategories(), isConstrained);
}
} else if (feature.getType() == Type.TIME) {
if (isConstrained) {
entity = FixedTimeEntity.from(feature);
} else {
final LocalTime lowerBound = LocalTime.MIN.plusSeconds(featureDomain.getLowerBound().longValue());
final LocalTime upperBound = LocalTime.MIN.plusSeconds(featureDomain.getUpperBound().longValue());
entity = TimeEntity.from(feature, lowerBound, upperBound, isConstrained);
}
} else if (feature.getType() == Type.DURATION) {
if (isConstrained) {
entity = FixedDurationEntity.from(feature);
} else {
DurationFeatureDomain domain = (DurationFeatureDomain) featureDomain;
entity = DurationEntity.from(feature, Duration.of(domain.getLowerBound().longValue(), domain.getUnit()), Duration.of(domain.getUpperBound().longValue(), domain.getUnit()), featureDistribution, isConstrained);
}
} else if (feature.getType() == Type.VECTOR) {
if (isConstrained) {
entity = FixedVectorEntity.from(feature);
} else {
throw new IllegalArgumentException("Unsupported feature type: " + feature.getType());
}
} else if (feature.getType() == Type.COMPOSITE) {
if (isConstrained) {
entity = FixedCompositeEntity.from(feature);
} else {
throw new IllegalArgumentException("Unsupported feature type: " + feature.getType());
}
} else if (feature.getType() == Type.CURRENCY) {
if (isConstrained) {
entity = FixedCurrencyEntity.from(feature);
} else {
entity = CurrencyEntity.from(feature, ((CurrencyFeatureDomain) featureDomain).getCategories(), isConstrained);
}
} else if (feature.getType() == Type.CATEGORICAL) {
if (isConstrained) {
entity = FixedCategoricalEntity.from(feature);
} else {
entity = CategoricalEntity.from(feature, ((CategoricalFeatureDomain) featureDomain).getCategories(), isConstrained);
}
} else if (feature.getType() == Type.UNDEFINED) {
if (isConstrained) {
entity = FixedObjectEntity.from(feature);
} else {
entity = ObjectEntity.from(feature, ((ObjectFeatureDomain) featureDomain).getCategories(), isConstrained);
}
} else {
throw new IllegalArgumentException("Unsupported feature type: " + feature.getType());
}
return entity;
}
use of org.kie.kogito.explainability.model.Type in project kogito-apps by kiegroup.
the class CounterfactualEntityFactory method validateFeature.
/**
* Validation of features for counterfactual entity construction
*
* @param feature {@link Feature} to be validated
*/
public static void validateFeature(Feature feature) {
final Type type = feature.getType();
final Object object = feature.getValue().getUnderlyingObject();
if (type == Type.NUMBER) {
if (object == null) {
throw new IllegalArgumentException("Null numeric features are not supported in counterfactuals");
}
}
}
use of org.kie.kogito.explainability.model.Type in project kogito-apps by kiegroup.
the class RemotePredictionProvider method toPredictionOutput.
protected PredictionOutput toPredictionOutput(JsonObject mainObj) {
if (mainObj == null || !mainObj.containsKey("result")) {
LOG.error("Malformed json {}", mainObj);
return null;
}
List<Output> resultOutputs = toOutputList(mainObj.getJsonObject("result"));
List<String> resultOutputNames = resultOutputs.stream().map(Output::getName).collect(toList());
Map<String, TypedValue> mappedOutputs = predictionOutputs.stream().collect(Collectors.toMap(HasNameValue::getName, HasNameValue::getValue));
// It's possible that some outputs are missing in the response from the prediction service
// (e.g. when the generated perturbed inputs don't make sense and a decision is skipped).
// The explainer, however, may throw exceptions if it can't find all the inputs that were
// specified in the execution request.
// Here we take the outputs received from the prediction service and we fill (only if needed)
// the missing ones with Output objects containing "null" values of type UNDEFINED, to make
// the explainer happy.
List<Output> outputs = Stream.concat(resultOutputs.stream().filter(output -> mappedOutputs.containsKey(output.getName())), mappedOutputs.keySet().stream().filter(key -> !resultOutputNames.contains(key)).map(key -> new Output(key, Type.UNDEFINED, new Value(null), 1d))).collect(toList());
return new PredictionOutput(outputs);
}
use of org.kie.kogito.explainability.model.Type in project kogito-apps by kiegroup.
the class ConversionUtils method toTypedValue.
static TypedValue toTypedValue(Feature feature) {
String name = feature.getName();
Type type = feature.getType();
Value value = feature.getValue();
return toTypedValue(name, type, value);
}
Aggregations