Search in sources :

Example 71 with Value

use of org.kie.kogito.explainability.model.Value in project kogito-apps by kiegroup.

the class PartialDependencePlotExplainer method explainFromDataDistribution.

private List<PartialDependenceGraph> explainFromDataDistribution(PredictionProvider model, int outputSize, DataDistribution dataDistribution) throws InterruptedException, ExecutionException, TimeoutException {
    long start = System.currentTimeMillis();
    List<PartialDependenceGraph> pdps = new ArrayList<>();
    List<FeatureDistribution> featureDistributions = dataDistribution.asFeatureDistributions();
    // fetch entire data distributions for all features
    List<PredictionInput> trainingData = dataDistribution.sample(config.getSeriesLength());
    // create a PDP for each feature
    for (FeatureDistribution featureDistribution : featureDistributions) {
        // generate (further) samples for the feature under analysis
        // TBD: maybe just reuse trainingData
        List<Value> xsValues = featureDistribution.sample(config.getSeriesLength()).stream().sorted(// sort alphanumerically (if Value#asNumber is NaN)
        Comparator.comparing(Value::asString)).sorted(// sort by natural order
        (v1, v2) -> Comparator.comparingDouble(Value::asNumber).compare(v1, v2)).distinct().collect(Collectors.toList());
        List<Feature> featureXSvalues = // transform sampled Values into Features
        xsValues.stream().map(v -> FeatureFactory.copyOf(featureDistribution.getFeature(), v)).collect(Collectors.toList());
        // create a PDP for each feature and each output
        for (int outputIndex = 0; outputIndex < outputSize; outputIndex++) {
            PartialDependenceGraph partialDependenceGraph = getPartialDependenceGraph(model, trainingData, xsValues, featureXSvalues, outputIndex);
            pdps.add(partialDependenceGraph);
        }
    }
    long end = System.currentTimeMillis();
    LOGGER.debug("explanation time: {}ms", (end - start));
    return pdps;
}
Also used : FeatureFactory(org.kie.kogito.explainability.model.FeatureFactory) PredictionInputsDataDistribution(org.kie.kogito.explainability.model.PredictionInputsDataDistribution) Feature(org.kie.kogito.explainability.model.Feature) Prediction(org.kie.kogito.explainability.model.Prediction) LoggerFactory(org.slf4j.LoggerFactory) TimeoutException(java.util.concurrent.TimeoutException) HashMap(java.util.HashMap) Value(org.kie.kogito.explainability.model.Value) DataDistribution(org.kie.kogito.explainability.model.DataDistribution) ArrayList(java.util.ArrayList) PartialDependenceGraph(org.kie.kogito.explainability.model.PartialDependenceGraph) Map(java.util.Map) FeatureDistribution(org.kie.kogito.explainability.model.FeatureDistribution) GlobalExplainer(org.kie.kogito.explainability.global.GlobalExplainer) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Logger(org.slf4j.Logger) Collection(java.util.Collection) Collectors(java.util.stream.Collectors) Type(org.kie.kogito.explainability.model.Type) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) PredictionProviderMetadata(org.kie.kogito.explainability.model.PredictionProviderMetadata) ExecutionException(java.util.concurrent.ExecutionException) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) List(java.util.List) Output(org.kie.kogito.explainability.model.Output) Comparator(java.util.Comparator) Config(org.kie.kogito.explainability.Config) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) ArrayList(java.util.ArrayList) Feature(org.kie.kogito.explainability.model.Feature) FeatureDistribution(org.kie.kogito.explainability.model.FeatureDistribution) Value(org.kie.kogito.explainability.model.Value) PartialDependenceGraph(org.kie.kogito.explainability.model.PartialDependenceGraph)

Example 72 with Value

use of org.kie.kogito.explainability.model.Value in project kogito-apps by kiegroup.

the class PartialDependencePlotExplainer method collapseMarginalImpacts.

/**
 * Collapse value counts into marginal impacts.
 * For numbers ({@code Type.NUMBER.equals(type))} this is just the average of each value at each feature value.
 * For all other types the final {@link Value} is just the most frequent.
 *
 * @param valueCounts the frequency of each value at each position
 * @param type the type of the output
 * @return the marginal impacts
 */
private List<Value> collapseMarginalImpacts(List<Map<Value, Long>> valueCounts, Type type) {
    List<Value> yValues = new ArrayList<>();
    if (Type.NUMBER.equals(type)) {
        List<Double> doubles = valueCounts.stream().map(v -> v.entrySet().stream().map(e -> e.getKey().asNumber() * e.getValue() / config.getSeriesLength()).mapToDouble(d -> d).sum()).collect(Collectors.toList());
        yValues = doubles.stream().map(Value::new).collect(Collectors.toList());
    } else {
        for (Map<Value, Long> item : valueCounts) {
            long max = 0;
            String output = null;
            for (Map.Entry<Value, Long> entry : item.entrySet()) {
                if (entry.getValue() > max) {
                    max = entry.getValue();
                    output = entry.getKey().asString();
                }
            }
            yValues.add(new Value(output));
        }
    }
    return yValues;
}
Also used : FeatureFactory(org.kie.kogito.explainability.model.FeatureFactory) PredictionInputsDataDistribution(org.kie.kogito.explainability.model.PredictionInputsDataDistribution) Feature(org.kie.kogito.explainability.model.Feature) Prediction(org.kie.kogito.explainability.model.Prediction) LoggerFactory(org.slf4j.LoggerFactory) TimeoutException(java.util.concurrent.TimeoutException) HashMap(java.util.HashMap) Value(org.kie.kogito.explainability.model.Value) DataDistribution(org.kie.kogito.explainability.model.DataDistribution) ArrayList(java.util.ArrayList) PartialDependenceGraph(org.kie.kogito.explainability.model.PartialDependenceGraph) Map(java.util.Map) FeatureDistribution(org.kie.kogito.explainability.model.FeatureDistribution) GlobalExplainer(org.kie.kogito.explainability.global.GlobalExplainer) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Logger(org.slf4j.Logger) Collection(java.util.Collection) Collectors(java.util.stream.Collectors) Type(org.kie.kogito.explainability.model.Type) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) PredictionProviderMetadata(org.kie.kogito.explainability.model.PredictionProviderMetadata) ExecutionException(java.util.concurrent.ExecutionException) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) List(java.util.List) Output(org.kie.kogito.explainability.model.Output) Comparator(java.util.Comparator) Config(org.kie.kogito.explainability.Config) Value(org.kie.kogito.explainability.model.Value) ArrayList(java.util.ArrayList) HashMap(java.util.HashMap) Map(java.util.Map)

Example 73 with Value

use of org.kie.kogito.explainability.model.Value in project kogito-apps by kiegroup.

the class PmmlRegressionCategoricalLimeExplainerTest method getModel.

private PredictionProvider getModel() {
    return inputs -> CompletableFuture.supplyAsync(() -> {
        List<PredictionOutput> outputs = new ArrayList<>();
        for (PredictionInput input1 : inputs) {
            List<Feature> features1 = input1.getFeatures();
            CategoricalVariablesRegressionExecutor pmmlModel = new CategoricalVariablesRegressionExecutor(features1.get(0).getValue().asString(), features1.get(1).getValue().asString());
            PMML4Result result = pmmlModel.execute(categoricalVariableRegressionRuntime);
            String score = result.getResultVariables().get("result").toString();
            PredictionOutput predictionOutput = new PredictionOutput(List.of(new Output("result", Type.NUMBER, new Value(score), 1d)));
            outputs.add(predictionOutput);
        }
        return outputs;
    });
}
Also used : FeatureFactory(org.kie.kogito.explainability.model.FeatureFactory) PredictionInputsDataDistribution(org.kie.kogito.explainability.model.PredictionInputsDataDistribution) PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) PMMLRuntime(org.kie.pmml.api.runtime.PMMLRuntime) Feature(org.kie.kogito.explainability.model.Feature) Prediction(org.kie.kogito.explainability.model.Prediction) URISyntaxException(java.net.URISyntaxException) Assertions.assertThat(org.assertj.core.api.Assertions.assertThat) AssertionsForClassTypes(org.assertj.core.api.AssertionsForClassTypes) TimeoutException(java.util.concurrent.TimeoutException) Random(java.util.Random) CompletableFuture(java.util.concurrent.CompletableFuture) PMML4Result(org.kie.api.pmml.PMML4Result) Disabled(org.junit.jupiter.api.Disabled) Value(org.kie.kogito.explainability.model.Value) DataDistribution(org.kie.kogito.explainability.model.DataDistribution) Saliency(org.kie.kogito.explainability.model.Saliency) PMMLRuntimeFactoryInternal.getPMMLRuntime(org.kie.pmml.evaluator.assembler.factories.PMMLRuntimeFactoryInternal.getPMMLRuntime) ArrayList(java.util.ArrayList) BeforeAll(org.junit.jupiter.api.BeforeAll) Map(java.util.Map) LimeConfig(org.kie.kogito.explainability.local.lime.LimeConfig) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) LimeConfigOptimizer(org.kie.kogito.explainability.local.lime.optim.LimeConfigOptimizer) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) LimeExplainer(org.kie.kogito.explainability.local.lime.LimeExplainer) DataUtils(org.kie.kogito.explainability.utils.DataUtils) Type(org.kie.kogito.explainability.model.Type) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) ExecutionException(java.util.concurrent.ExecutionException) Test(org.junit.jupiter.api.Test) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) List(java.util.List) ExplainabilityMetrics(org.kie.kogito.explainability.utils.ExplainabilityMetrics) Output(org.kie.kogito.explainability.model.Output) ValidationUtils(org.kie.kogito.explainability.utils.ValidationUtils) Config(org.kie.kogito.explainability.Config) Assertions.assertDoesNotThrow(org.junit.jupiter.api.Assertions.assertDoesNotThrow) PMML4Result(org.kie.api.pmml.PMML4Result) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) ArrayList(java.util.ArrayList) Value(org.kie.kogito.explainability.model.Value) Feature(org.kie.kogito.explainability.model.Feature)

Example 74 with Value

use of org.kie.kogito.explainability.model.Value in project kogito-apps by kiegroup.

the class RemotePredictionProvider method toPredictionOutput.

protected PredictionOutput toPredictionOutput(JsonObject mainObj) {
    if (mainObj == null || !mainObj.containsKey("result")) {
        LOG.error("Malformed json {}", mainObj);
        return null;
    }
    List<Output> resultOutputs = toOutputList(mainObj.getJsonObject("result"));
    List<String> resultOutputNames = resultOutputs.stream().map(Output::getName).collect(toList());
    Map<String, TypedValue> mappedOutputs = predictionOutputs.stream().collect(Collectors.toMap(HasNameValue::getName, HasNameValue::getValue));
    // It's possible that some outputs are missing in the response from the prediction service
    // (e.g. when the generated perturbed inputs don't make sense and a decision is skipped).
    // The explainer, however, may throw exceptions if it can't find all the inputs that were
    // specified in the execution request.
    // Here we take the outputs received from the prediction service and we fill (only if needed)
    // the missing ones with Output objects containing "null" values of type UNDEFINED, to make
    // the explainer happy.
    List<Output> outputs = Stream.concat(resultOutputs.stream().filter(output -> mappedOutputs.containsKey(output.getName())), mappedOutputs.keySet().stream().filter(key -> !resultOutputNames.contains(key)).map(key -> new Output(key, Type.UNDEFINED, new Value(null), 1d))).collect(toList());
    return new PredictionOutput(outputs);
}
Also used : WebClientOptions(io.vertx.ext.web.client.WebClientOptions) Feature(org.kie.kogito.explainability.model.Feature) LoggerFactory(org.slf4j.LoggerFactory) HashMap(java.util.HashMap) CompletableFuture(java.util.concurrent.CompletableFuture) Value(org.kie.kogito.explainability.model.Value) Map(java.util.Map) HasNameValue(org.kie.kogito.explainability.api.HasNameValue) JsonObject(io.vertx.core.json.JsonObject) URI(java.net.URI) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) PredictInput(org.kie.kogito.explainability.models.PredictInput) Logger(org.slf4j.Logger) Executor(java.util.concurrent.Executor) Collection(java.util.Collection) ThreadContext(org.eclipse.microprofile.context.ThreadContext) ConversionUtils.toOutputList(org.kie.kogito.explainability.ConversionUtils.toOutputList) Collectors(java.util.stream.Collectors) Type(org.kie.kogito.explainability.model.Type) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) TypedValue(org.kie.kogito.tracing.typedvalue.TypedValue) Objects(java.util.Objects) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) JsonArray(io.vertx.core.json.JsonArray) List(java.util.List) Collectors.toList(java.util.stream.Collectors.toList) Stream(java.util.stream.Stream) Output(org.kie.kogito.explainability.model.Output) Vertx(io.vertx.mutiny.core.Vertx) WebClient(io.vertx.mutiny.ext.web.client.WebClient) ModelIdentifier(org.kie.kogito.explainability.api.ModelIdentifier) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) Value(org.kie.kogito.explainability.model.Value) HasNameValue(org.kie.kogito.explainability.api.HasNameValue) TypedValue(org.kie.kogito.tracing.typedvalue.TypedValue) TypedValue(org.kie.kogito.tracing.typedvalue.TypedValue)

Example 75 with Value

use of org.kie.kogito.explainability.model.Value in project kogito-apps by kiegroup.

the class PmmlRegressionLimeExplainerTest method getModel.

private PredictionProvider getModel() {
    return inputs -> CompletableFuture.supplyAsync(() -> {
        List<PredictionOutput> outputs = new ArrayList<>();
        for (PredictionInput input1 : inputs) {
            List<Feature> features1 = input1.getFeatures();
            LogisticRegressionIrisDataExecutor pmmlModel = new LogisticRegressionIrisDataExecutor(features1.get(0).getValue().asNumber(), features1.get(1).getValue().asNumber(), features1.get(2).getValue().asNumber(), features1.get(3).getValue().asNumber());
            PMML4Result result = pmmlModel.execute(logisticRegressionIrisRuntime);
            String species = result.getResultVariables().get("Species").toString();
            double score = Double.parseDouble(result.getResultVariables().get("Probability_" + species).toString());
            PredictionOutput predictionOutput = new PredictionOutput(List.of(new Output("species", Type.TEXT, new Value(species), score)));
            outputs.add(predictionOutput);
        }
        return outputs;
    });
}
Also used : FeatureFactory(org.kie.kogito.explainability.model.FeatureFactory) PredictionInputsDataDistribution(org.kie.kogito.explainability.model.PredictionInputsDataDistribution) PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) PMMLRuntime(org.kie.pmml.api.runtime.PMMLRuntime) Feature(org.kie.kogito.explainability.model.Feature) Prediction(org.kie.kogito.explainability.model.Prediction) URISyntaxException(java.net.URISyntaxException) Assertions.assertThat(org.assertj.core.api.Assertions.assertThat) AssertionsForClassTypes(org.assertj.core.api.AssertionsForClassTypes) TimeoutException(java.util.concurrent.TimeoutException) Random(java.util.Random) CompletableFuture(java.util.concurrent.CompletableFuture) PMML4Result(org.kie.api.pmml.PMML4Result) Value(org.kie.kogito.explainability.model.Value) DataDistribution(org.kie.kogito.explainability.model.DataDistribution) Saliency(org.kie.kogito.explainability.model.Saliency) PMMLRuntimeFactoryInternal.getPMMLRuntime(org.kie.pmml.evaluator.assembler.factories.PMMLRuntimeFactoryInternal.getPMMLRuntime) ArrayList(java.util.ArrayList) BeforeAll(org.junit.jupiter.api.BeforeAll) Map(java.util.Map) LimeConfig(org.kie.kogito.explainability.local.lime.LimeConfig) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) LimeConfigOptimizer(org.kie.kogito.explainability.local.lime.optim.LimeConfigOptimizer) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) LimeExplainer(org.kie.kogito.explainability.local.lime.LimeExplainer) DataUtils(org.kie.kogito.explainability.utils.DataUtils) Type(org.kie.kogito.explainability.model.Type) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) ExecutionException(java.util.concurrent.ExecutionException) Test(org.junit.jupiter.api.Test) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) List(java.util.List) ExplainabilityMetrics(org.kie.kogito.explainability.utils.ExplainabilityMetrics) Output(org.kie.kogito.explainability.model.Output) ValidationUtils(org.kie.kogito.explainability.utils.ValidationUtils) Config(org.kie.kogito.explainability.Config) Assertions.assertDoesNotThrow(org.junit.jupiter.api.Assertions.assertDoesNotThrow) PMML4Result(org.kie.api.pmml.PMML4Result) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) ArrayList(java.util.ArrayList) Feature(org.kie.kogito.explainability.model.Feature) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) Value(org.kie.kogito.explainability.model.Value)

Aggregations

Value (org.kie.kogito.explainability.model.Value)80 Feature (org.kie.kogito.explainability.model.Feature)69 Output (org.kie.kogito.explainability.model.Output)59 PredictionOutput (org.kie.kogito.explainability.model.PredictionOutput)54 PredictionInput (org.kie.kogito.explainability.model.PredictionInput)49 ArrayList (java.util.ArrayList)42 PredictionProvider (org.kie.kogito.explainability.model.PredictionProvider)42 LinkedList (java.util.LinkedList)36 Type (org.kie.kogito.explainability.model.Type)36 Test (org.junit.jupiter.api.Test)35 List (java.util.List)33 Prediction (org.kie.kogito.explainability.model.Prediction)33 Random (java.util.Random)31 ParameterizedTest (org.junit.jupiter.params.ParameterizedTest)23 Arrays (java.util.Arrays)16 Map (java.util.Map)16 Optional (java.util.Optional)16 CounterfactualEntity (org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity)16 FeatureFactory (org.kie.kogito.explainability.model.FeatureFactory)16 Collectors (java.util.stream.Collectors)15