Search in sources :

Example 11 with Type

use of org.kie.kogito.explainability.model.Type in project kogito-apps by kiegroup.

the class CounterFactualScoreCalculator method outputDistance.

public static Double outputDistance(Output prediction, Output goal, double threshold) throws IllegalArgumentException {
    final Type predictionType = prediction.getType();
    final Type goalType = goal.getType();
    // there could be a type difference (e.g. a numerical feature is predicted as a textual "null")
    if (predictionType != goalType) {
        if (Objects.nonNull(prediction.getValue().getUnderlyingObject())) {
            String message = String.format("Features must have the same type. Feature '%s', has type '%s' and '%s'", prediction.getName(), predictionType.toString(), goalType.toString());
            logger.error(message);
            throw new IllegalArgumentException(message);
        } else {
            return DEFAULT_DISTANCE;
        }
    }
    if (predictionType == Type.NUMBER) {
        final double predictionValue = prediction.getValue().asNumber();
        final double goalValue = goal.getValue().asNumber();
        final double difference = Math.abs(predictionValue - goalValue);
        // If any of the values is zero use the difference instead of change
        // If neither of the values is zero use the change rate
        double distance;
        if (Double.isNaN(predictionValue) || Double.isNaN(goalValue)) {
            String message = String.format("Unsupported NaN or NULL for numeric feature '%s'", prediction.getName());
            logger.error(message);
            throw new IllegalArgumentException(message);
        }
        if (predictionValue == 0 || goalValue == 0) {
            distance = difference;
        } else {
            distance = difference / Math.max(predictionValue, goalValue);
        }
        if (distance < threshold) {
            return 0d;
        } else {
            return distance;
        }
    } else if (predictionType == Type.DURATION) {
        final Duration predictionValue = (Duration) prediction.getValue().getUnderlyingObject();
        final Duration goalValue = (Duration) goal.getValue().getUnderlyingObject();
        if (Objects.isNull(predictionValue) || Objects.isNull(goalValue)) {
            return 1.0;
        }
        // Duration distances calculated from value in seconds
        final double difference = predictionValue.minus(goalValue).abs().getSeconds();
        // If any of the values is zero use the difference instead of change
        // If neither of the values is zero use the change rate
        double distance;
        if (predictionValue.isZero() || goalValue.isZero()) {
            distance = difference;
        } else {
            distance = difference / Math.max(predictionValue.getSeconds(), goalValue.getSeconds());
        }
        if (distance < threshold) {
            return 0d;
        } else {
            return distance;
        }
    } else if (predictionType == Type.TIME) {
        final LocalTime predictionValue = (LocalTime) prediction.getValue().getUnderlyingObject();
        final LocalTime goalValue = (LocalTime) goal.getValue().getUnderlyingObject();
        if (Objects.isNull(predictionValue) || Objects.isNull(goalValue)) {
            return 1.0;
        }
        final double interval = LocalTime.MIN.until(LocalTime.MAX, ChronoUnit.SECONDS);
        // Time distances calculated from value in seconds
        final double distance = Math.abs(predictionValue.until(goalValue, ChronoUnit.SECONDS)) / interval;
        if (distance < threshold) {
            return 0d;
        } else {
            return distance;
        }
    } else if (SUPPORTED_CATEGORICAL_TYPES.contains(predictionType)) {
        final Object goalValueObject = goal.getValue().getUnderlyingObject();
        final Object predictionValueObject = prediction.getValue().getUnderlyingObject();
        return Objects.equals(goalValueObject, predictionValueObject) ? 0.0 : DEFAULT_DISTANCE;
    } else {
        String message = String.format("Feature '%s' has unsupported type '%s'", prediction.getName(), predictionType.toString());
        logger.error(message);
        throw new IllegalArgumentException(message);
    }
}
Also used : Type(org.kie.kogito.explainability.model.Type) LocalTime(java.time.LocalTime) Duration(java.time.Duration)

Example 12 with Type

use of org.kie.kogito.explainability.model.Type in project kogito-apps by kiegroup.

the class CounterfactualEntityFactory method from.

public static CounterfactualEntity from(Feature feature, FeatureDistribution featureDistribution) {
    CounterfactualEntity entity = null;
    validateFeature(feature);
    final Type type = feature.getType();
    final FeatureDomain featureDomain = feature.getDomain();
    final boolean isConstrained = feature.isConstrained();
    final Object valueObject = feature.getValue().getUnderlyingObject();
    if (type == Type.NUMBER) {
        if (valueObject instanceof Double) {
            if (isConstrained) {
                entity = FixedDoubleEntity.from(feature);
            } else {
                entity = DoubleEntity.from(feature, featureDomain.getLowerBound(), featureDomain.getUpperBound(), featureDistribution, isConstrained);
            }
        } else if (valueObject instanceof Long) {
            if (isConstrained) {
                entity = FixedLongEntity.from(feature);
            } else {
                entity = LongEntity.from(feature, featureDomain.getLowerBound().intValue(), featureDomain.getUpperBound().intValue(), featureDistribution, isConstrained);
            }
        } else if (valueObject instanceof Integer) {
            if (isConstrained) {
                entity = FixedIntegerEntity.from(feature);
            } else {
                entity = IntegerEntity.from(feature, featureDomain.getLowerBound().intValue(), featureDomain.getUpperBound().intValue(), featureDistribution, isConstrained);
            }
        }
    } else if (feature.getType() == Type.BOOLEAN) {
        if (isConstrained) {
            entity = FixedBooleanEntity.from(feature);
        } else {
            entity = BooleanEntity.from(feature, isConstrained);
        }
    } else if (feature.getType() == Type.TEXT) {
        if (isConstrained) {
            entity = FixedTextEntity.from(feature);
        } else {
            throw new IllegalArgumentException("Unsupported feature type: " + feature.getType());
        }
    } else if (feature.getType() == Type.BINARY) {
        if (isConstrained) {
            entity = FixedBinaryEntity.from(feature);
        } else {
            entity = BinaryEntity.from(feature, ((BinaryFeatureDomain) featureDomain).getCategories(), isConstrained);
        }
    } else if (feature.getType() == Type.URI) {
        if (isConstrained) {
            entity = FixedURIEntity.from(feature);
        } else {
            entity = URIEntity.from(feature, ((URIFeatureDomain) featureDomain).getCategories(), isConstrained);
        }
    } else if (feature.getType() == Type.TIME) {
        if (isConstrained) {
            entity = FixedTimeEntity.from(feature);
        } else {
            final LocalTime lowerBound = LocalTime.MIN.plusSeconds(featureDomain.getLowerBound().longValue());
            final LocalTime upperBound = LocalTime.MIN.plusSeconds(featureDomain.getUpperBound().longValue());
            entity = TimeEntity.from(feature, lowerBound, upperBound, isConstrained);
        }
    } else if (feature.getType() == Type.DURATION) {
        if (isConstrained) {
            entity = FixedDurationEntity.from(feature);
        } else {
            DurationFeatureDomain domain = (DurationFeatureDomain) featureDomain;
            entity = DurationEntity.from(feature, Duration.of(domain.getLowerBound().longValue(), domain.getUnit()), Duration.of(domain.getUpperBound().longValue(), domain.getUnit()), featureDistribution, isConstrained);
        }
    } else if (feature.getType() == Type.VECTOR) {
        if (isConstrained) {
            entity = FixedVectorEntity.from(feature);
        } else {
            throw new IllegalArgumentException("Unsupported feature type: " + feature.getType());
        }
    } else if (feature.getType() == Type.COMPOSITE) {
        if (isConstrained) {
            entity = FixedCompositeEntity.from(feature);
        } else {
            throw new IllegalArgumentException("Unsupported feature type: " + feature.getType());
        }
    } else if (feature.getType() == Type.CURRENCY) {
        if (isConstrained) {
            entity = FixedCurrencyEntity.from(feature);
        } else {
            entity = CurrencyEntity.from(feature, ((CurrencyFeatureDomain) featureDomain).getCategories(), isConstrained);
        }
    } else if (feature.getType() == Type.CATEGORICAL) {
        if (isConstrained) {
            entity = FixedCategoricalEntity.from(feature);
        } else {
            entity = CategoricalEntity.from(feature, ((CategoricalFeatureDomain) featureDomain).getCategories(), isConstrained);
        }
    } else if (feature.getType() == Type.UNDEFINED) {
        if (isConstrained) {
            entity = FixedObjectEntity.from(feature);
        } else {
            entity = ObjectEntity.from(feature, ((ObjectFeatureDomain) featureDomain).getCategories(), isConstrained);
        }
    } else {
        throw new IllegalArgumentException("Unsupported feature type: " + feature.getType());
    }
    return entity;
}
Also used : LocalTime(java.time.LocalTime) URIFeatureDomain(org.kie.kogito.explainability.model.domain.URIFeatureDomain) DurationFeatureDomain(org.kie.kogito.explainability.model.domain.DurationFeatureDomain) ObjectFeatureDomain(org.kie.kogito.explainability.model.domain.ObjectFeatureDomain) URIFeatureDomain(org.kie.kogito.explainability.model.domain.URIFeatureDomain) CategoricalFeatureDomain(org.kie.kogito.explainability.model.domain.CategoricalFeatureDomain) DurationFeatureDomain(org.kie.kogito.explainability.model.domain.DurationFeatureDomain) BinaryFeatureDomain(org.kie.kogito.explainability.model.domain.BinaryFeatureDomain) CurrencyFeatureDomain(org.kie.kogito.explainability.model.domain.CurrencyFeatureDomain) FeatureDomain(org.kie.kogito.explainability.model.domain.FeatureDomain) CategoricalFeatureDomain(org.kie.kogito.explainability.model.domain.CategoricalFeatureDomain) Type(org.kie.kogito.explainability.model.Type)

Example 13 with Type

use of org.kie.kogito.explainability.model.Type in project kogito-apps by kiegroup.

the class CounterfactualEntityFactory method validateFeature.

/**
 * Validation of features for counterfactual entity construction
 *
 * @param feature {@link Feature} to be validated
 */
public static void validateFeature(Feature feature) {
    final Type type = feature.getType();
    final Object object = feature.getValue().getUnderlyingObject();
    if (type == Type.NUMBER) {
        if (object == null) {
            throw new IllegalArgumentException("Null numeric features are not supported in counterfactuals");
        }
    }
}
Also used : Type(org.kie.kogito.explainability.model.Type)

Example 14 with Type

use of org.kie.kogito.explainability.model.Type in project kogito-apps by kiegroup.

the class RemotePredictionProvider method toPredictionOutput.

protected PredictionOutput toPredictionOutput(JsonObject mainObj) {
    if (mainObj == null || !mainObj.containsKey("result")) {
        LOG.error("Malformed json {}", mainObj);
        return null;
    }
    List<Output> resultOutputs = toOutputList(mainObj.getJsonObject("result"));
    List<String> resultOutputNames = resultOutputs.stream().map(Output::getName).collect(toList());
    Map<String, TypedValue> mappedOutputs = predictionOutputs.stream().collect(Collectors.toMap(HasNameValue::getName, HasNameValue::getValue));
    // It's possible that some outputs are missing in the response from the prediction service
    // (e.g. when the generated perturbed inputs don't make sense and a decision is skipped).
    // The explainer, however, may throw exceptions if it can't find all the inputs that were
    // specified in the execution request.
    // Here we take the outputs received from the prediction service and we fill (only if needed)
    // the missing ones with Output objects containing "null" values of type UNDEFINED, to make
    // the explainer happy.
    List<Output> outputs = Stream.concat(resultOutputs.stream().filter(output -> mappedOutputs.containsKey(output.getName())), mappedOutputs.keySet().stream().filter(key -> !resultOutputNames.contains(key)).map(key -> new Output(key, Type.UNDEFINED, new Value(null), 1d))).collect(toList());
    return new PredictionOutput(outputs);
}
Also used : WebClientOptions(io.vertx.ext.web.client.WebClientOptions) Feature(org.kie.kogito.explainability.model.Feature) LoggerFactory(org.slf4j.LoggerFactory) HashMap(java.util.HashMap) CompletableFuture(java.util.concurrent.CompletableFuture) Value(org.kie.kogito.explainability.model.Value) Map(java.util.Map) HasNameValue(org.kie.kogito.explainability.api.HasNameValue) JsonObject(io.vertx.core.json.JsonObject) URI(java.net.URI) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) PredictInput(org.kie.kogito.explainability.models.PredictInput) Logger(org.slf4j.Logger) Executor(java.util.concurrent.Executor) Collection(java.util.Collection) ThreadContext(org.eclipse.microprofile.context.ThreadContext) ConversionUtils.toOutputList(org.kie.kogito.explainability.ConversionUtils.toOutputList) Collectors(java.util.stream.Collectors) Type(org.kie.kogito.explainability.model.Type) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) TypedValue(org.kie.kogito.tracing.typedvalue.TypedValue) Objects(java.util.Objects) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) JsonArray(io.vertx.core.json.JsonArray) List(java.util.List) Collectors.toList(java.util.stream.Collectors.toList) Stream(java.util.stream.Stream) Output(org.kie.kogito.explainability.model.Output) Vertx(io.vertx.mutiny.core.Vertx) WebClient(io.vertx.mutiny.ext.web.client.WebClient) ModelIdentifier(org.kie.kogito.explainability.api.ModelIdentifier) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) Value(org.kie.kogito.explainability.model.Value) HasNameValue(org.kie.kogito.explainability.api.HasNameValue) TypedValue(org.kie.kogito.tracing.typedvalue.TypedValue) TypedValue(org.kie.kogito.tracing.typedvalue.TypedValue)

Example 15 with Type

use of org.kie.kogito.explainability.model.Type in project kogito-apps by kiegroup.

the class ConversionUtils method toTypedValue.

static TypedValue toTypedValue(Feature feature) {
    String name = feature.getName();
    Type type = feature.getType();
    Value value = feature.getValue();
    return toTypedValue(name, type, value);
}
Also used : Type(org.kie.kogito.explainability.model.Type) CounterfactualSearchDomainValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainValue) Value(org.kie.kogito.explainability.model.Value) HasNameValue(org.kie.kogito.explainability.api.HasNameValue) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) TypedValue(org.kie.kogito.tracing.typedvalue.TypedValue) UnitValue(org.kie.kogito.tracing.typedvalue.UnitValue) CollectionValue(org.kie.kogito.tracing.typedvalue.CollectionValue)

Aggregations

Type (org.kie.kogito.explainability.model.Type)15 Value (org.kie.kogito.explainability.model.Value)10 Feature (org.kie.kogito.explainability.model.Feature)8 ArrayList (java.util.ArrayList)6 PredictionInput (org.kie.kogito.explainability.model.PredictionInput)6 List (java.util.List)5 Output (org.kie.kogito.explainability.model.Output)5 PredictionOutput (org.kie.kogito.explainability.model.PredictionOutput)5 HashMap (java.util.HashMap)4 LinkedList (java.util.LinkedList)4 Map (java.util.Map)4 Collectors (java.util.stream.Collectors)4 DataDistribution (org.kie.kogito.explainability.model.DataDistribution)4 HasNameValue (org.kie.kogito.explainability.api.HasNameValue)3 Prediction (org.kie.kogito.explainability.model.Prediction)3 PredictionProvider (org.kie.kogito.explainability.model.PredictionProvider)3 Logger (org.slf4j.Logger)3 LoggerFactory (org.slf4j.LoggerFactory)3 LocalTime (java.time.LocalTime)2 Collection (java.util.Collection)2