Search in sources :

Example 21 with PredictionInput

use of org.kie.kogito.explainability.model.PredictionInput in project kogito-apps by kiegroup.

the class LimeExplainer method prepareInputs.

/**
 * Check the perturbed inputs so that the dataset of perturbed input / outputs contains more than just one output
 * class, otherwise it would be impossible to linearly separate it, and hence learn meaningful weights to be used as
 * feature importance scores.
 * The check can be {@code strict} or not, if so it will throw a {@code DatasetNotSeparableException} when the dataset
 * for a given output is not separable.
 */
private LimeInputs prepareInputs(List<PredictionInput> perturbedInputs, List<PredictionOutput> perturbedOutputs, List<Feature> linearizedTargetInputFeatures, int o, Output currentOutput, boolean strict) {
    if (currentOutput.getValue() != null && currentOutput.getValue().getUnderlyingObject() != null) {
        Map<Double, Long> rawClassesBalance;
        // calculate the no. of samples belonging to each output class
        Value fv = currentOutput.getValue();
        rawClassesBalance = getClassBalance(perturbedOutputs, fv, o);
        Long max = rawClassesBalance.values().stream().max(Long::compareTo).orElse(1L);
        double separationRatio = (double) max / (double) perturbedInputs.size();
        List<Output> outputs = perturbedOutputs.stream().map(po -> po.getOutputs().get(o)).collect(Collectors.toList());
        boolean classification = rawClassesBalance.size() == 2;
        if (strict) {
            // check if the dataset is separable and also if the linear model should fit a regressor or a classifier
            if (rawClassesBalance.size() > 1 && separationRatio < limeConfig.getSeparableDatasetRatio()) {
                // if dataset creation process succeeds use it to train the linear model
                return new LimeInputs(classification, linearizedTargetInputFeatures, currentOutput, perturbedInputs, outputs);
            } else {
                throw new DatasetNotSeparableException(currentOutput, rawClassesBalance);
            }
        } else {
            LOGGER.warn("Using an hardly separable dataset for output '{}' of type '{}' with value '{}' ({})", currentOutput.getName(), currentOutput.getType(), currentOutput.getValue(), rawClassesBalance);
            return new LimeInputs(classification, linearizedTargetInputFeatures, currentOutput, perturbedInputs, outputs);
        }
    } else {
        return new LimeInputs(false, linearizedTargetInputFeatures, currentOutput, emptyList(), emptyList());
    }
}
Also used : Arrays(java.util.Arrays) PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) Feature(org.kie.kogito.explainability.model.Feature) Prediction(org.kie.kogito.explainability.model.Prediction) CompletableFuture.completedFuture(java.util.concurrent.CompletableFuture.completedFuture) LoggerFactory(org.slf4j.LoggerFactory) HashMap(java.util.HashMap) CompletableFuture(java.util.concurrent.CompletableFuture) Value(org.kie.kogito.explainability.model.Value) DataDistribution(org.kie.kogito.explainability.model.DataDistribution) Saliency(org.kie.kogito.explainability.model.Saliency) ArrayList(java.util.ArrayList) LinearModel(org.kie.kogito.explainability.utils.LinearModel) Pair(org.apache.commons.lang3.tuple.Pair) Map(java.util.Map) FeatureDistribution(org.kie.kogito.explainability.model.FeatureDistribution) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) DataUtils(org.kie.kogito.explainability.utils.DataUtils) Logger(org.slf4j.Logger) LocalExplainer(org.kie.kogito.explainability.local.LocalExplainer) Collections.emptyList(java.util.Collections.emptyList) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) LocalExplanationException(org.kie.kogito.explainability.local.LocalExplanationException) Collectors(java.util.stream.Collectors) Type(org.kie.kogito.explainability.model.Type) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Objects(java.util.Objects) Consumer(java.util.function.Consumer) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) List(java.util.List) Output(org.kie.kogito.explainability.model.Output) Optional(java.util.Optional) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) Value(org.kie.kogito.explainability.model.Value)

Example 22 with PredictionInput

use of org.kie.kogito.explainability.model.PredictionInput in project kogito-apps by kiegroup.

the class LimeExplainer method explainAsync.

@Override
public CompletableFuture<Map<String, Saliency>> explainAsync(Prediction prediction, PredictionProvider model, Consumer<Map<String, Saliency>> intermediateResultsConsumer) {
    PredictionInput originalInput = prediction.getInput();
    if (originalInput == null || originalInput.getFeatures() == null || (originalInput.getFeatures() != null && originalInput.getFeatures().isEmpty())) {
        throw new LocalExplanationException("cannot explain a prediction whose input is empty");
    }
    List<PredictionInput> linearizedInputs = DataUtils.linearizeInputs(List.of(originalInput));
    PredictionInput targetInput = linearizedInputs.get(0);
    List<Feature> linearizedTargetInputFeatures = targetInput.getFeatures();
    if (linearizedTargetInputFeatures.isEmpty()) {
        throw new LocalExplanationException("input features linearization failed");
    }
    List<Output> actualOutputs = prediction.getOutput().getOutputs();
    LimeConfig executionConfig = limeConfig.copy();
    return explainWithExecutionConfig(model, originalInput, linearizedTargetInputFeatures, actualOutputs, executionConfig);
}
Also used : PredictionInput(org.kie.kogito.explainability.model.PredictionInput) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) LocalExplanationException(org.kie.kogito.explainability.local.LocalExplanationException) Feature(org.kie.kogito.explainability.model.Feature)

Example 23 with PredictionInput

use of org.kie.kogito.explainability.model.PredictionInput in project kogito-apps by kiegroup.

the class LocalDMNPredictionProvider method predictAsync.

@Override
@SuppressWarnings("unchecked")
public CompletableFuture<List<PredictionOutput>> predictAsync(List<PredictionInput> inputs) {
    List<PredictionOutput> predictionOutputs = new ArrayList<>();
    for (PredictionInput input : inputs) {
        Map<String, Object> contextVariables = (Map<String, Object>) toMap(input.getFeatures()).get(DUMMY_DMN_CONTEXT_KEY);
        predictionOutputs.add(toPredictionOutput(dmnEvaluator.evaluate(contextVariables)));
    }
    return completedFuture(predictionOutputs);
}
Also used : PredictionInput(org.kie.kogito.explainability.model.PredictionInput) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) ArrayList(java.util.ArrayList) HashMap(java.util.HashMap) Map(java.util.Map)

Example 24 with PredictionInput

use of org.kie.kogito.explainability.model.PredictionInput in project kogito-apps by kiegroup.

the class CounterfactualExplainerServiceHandler method getPrediction.

@Override
public Prediction getPrediction(CounterfactualExplainabilityRequest request) {
    Collection<NamedTypedValue> goals = toMapBasedSorting(request.getGoals());
    Collection<CounterfactualSearchDomain> searchDomains = request.getSearchDomains();
    Collection<NamedTypedValue> originalInputs = request.getOriginalInputs();
    Long maxRunningTimeSeconds = request.getMaxRunningTimeSeconds();
    if (Objects.nonNull(maxRunningTimeSeconds)) {
        if (maxRunningTimeSeconds > kafkaMaxRecordAgeSeconds) {
            LOGGER.info(String.format("Maximum Running Timeout set to '%d's since the provided value '%d's exceeded the Messaging sub-system configuration '%d's.", kafkaMaxRecordAgeSeconds, maxRunningTimeSeconds, kafkaMaxRecordAgeSeconds));
            maxRunningTimeSeconds = kafkaMaxRecordAgeSeconds;
        }
    }
    // See https://issues.redhat.com/browse/FAI-473 and https://issues.redhat.com/browse/FAI-474
    if (isUnsupportedModel(originalInputs, goals, searchDomains)) {
        throw new IllegalArgumentException("Counterfactual explanations only support flat models.");
    }
    PredictionInput input = new PredictionInput(toFeatureList(originalInputs, searchDomains));
    PredictionOutput output = new PredictionOutput(toOutputList(goals));
    return new CounterfactualPrediction(input, output, null, UUID.fromString(request.getExecutionId()), maxRunningTimeSeconds);
}
Also used : NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) CounterfactualSearchDomain(org.kie.kogito.explainability.api.CounterfactualSearchDomain) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction)

Example 25 with PredictionInput

use of org.kie.kogito.explainability.model.PredictionInput in project kogito-apps by kiegroup.

the class LimeExplainerServiceHandler method getPrediction.

@Override
public Prediction getPrediction(LIMEExplainabilityRequest request) {
    Collection<NamedTypedValue> inputs = request.getInputs();
    Collection<NamedTypedValue> outputs = request.getOutputs();
    PredictionInput input = new PredictionInput(toFeatureList(inputs));
    PredictionOutput output = new PredictionOutput(toOutputList(outputs));
    return new SimplePrediction(input, output);
}
Also used : SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput)

Aggregations

PredictionInput (org.kie.kogito.explainability.model.PredictionInput)187 PredictionOutput (org.kie.kogito.explainability.model.PredictionOutput)143 PredictionProvider (org.kie.kogito.explainability.model.PredictionProvider)135 Prediction (org.kie.kogito.explainability.model.Prediction)126 Feature (org.kie.kogito.explainability.model.Feature)109 Test (org.junit.jupiter.api.Test)107 ArrayList (java.util.ArrayList)97 SimplePrediction (org.kie.kogito.explainability.model.SimplePrediction)95 Random (java.util.Random)86 PerturbationContext (org.kie.kogito.explainability.model.PerturbationContext)67 ParameterizedTest (org.junit.jupiter.params.ParameterizedTest)60 Output (org.kie.kogito.explainability.model.Output)55 LimeConfig (org.kie.kogito.explainability.local.lime.LimeConfig)54 LinkedList (java.util.LinkedList)53 LimeExplainer (org.kie.kogito.explainability.local.lime.LimeExplainer)52 Value (org.kie.kogito.explainability.model.Value)52 Saliency (org.kie.kogito.explainability.model.Saliency)50 List (java.util.List)39 LimeConfigOptimizer (org.kie.kogito.explainability.local.lime.optim.LimeConfigOptimizer)33 Type (org.kie.kogito.explainability.model.Type)31