Search in sources :

Example 66 with Value

use of org.kie.kogito.explainability.model.Value in project kogito-apps by kiegroup.

the class CompositeFeatureUtils method unravel.

private static Feature unravel(List<Feature> flattenedFeatures, AtomicInteger tracker, Feature f) {
    Feature extractedFeature;
    final Object featureObject = f.getValue().getUnderlyingObject();
    switch(f.getType()) {
        case UNDEFINED:
            if (featureObject instanceof Feature) {
                extractedFeature = FeatureFactory.copyOf(f, new Value(unravel(flattenedFeatures, tracker, (Feature) featureObject)));
            } else {
                extractedFeature = FeatureFactory.copyOf(f, flattenedFeatures.get(tracker.getAndIncrement()).getValue());
            }
            break;
        case COMPOSITE:
            if (featureObject instanceof List) {
                extractedFeature = FeatureFactory.newCompositeFeature(f.getName(), ((List<Feature>) featureObject).stream().map(feature -> unravel(flattenedFeatures, tracker, feature)).collect(Collectors.toList()));
            } else {
                extractedFeature = FeatureFactory.copyOf(f, flattenedFeatures.get(tracker.getAndIncrement()).getValue());
            }
            break;
        default:
            extractedFeature = FeatureFactory.copyOf(f, flattenedFeatures.get(tracker.getAndIncrement()).getValue());
    }
    return extractedFeature;
}
Also used : FeatureFactory(org.kie.kogito.explainability.model.FeatureFactory) List(java.util.List) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) Feature(org.kie.kogito.explainability.model.Feature) Value(org.kie.kogito.explainability.model.Value) Collectors(java.util.stream.Collectors) Value(org.kie.kogito.explainability.model.Value) List(java.util.List) Feature(org.kie.kogito.explainability.model.Feature)

Example 67 with Value

use of org.kie.kogito.explainability.model.Value in project kogito-apps by kiegroup.

the class DataUtils method toCSV.

/**
 * Persist a {@link PartialDependenceGraph} into a CSV file.
 *
 * @param partialDependenceGraph the PDP to persist
 * @param path the path to the CSV file to be created
 * @throws IOException whether any IO error occurs while writing the CSV
 */
public static void toCSV(PartialDependenceGraph partialDependenceGraph, Path path) throws IOException {
    try (Writer writer = Files.newBufferedWriter(path)) {
        List<Value> xAxis = partialDependenceGraph.getX();
        List<Value> yAxis = partialDependenceGraph.getY();
        CSVFormat format = CSVFormat.DEFAULT.withHeader(partialDependenceGraph.getFeature().getName(), partialDependenceGraph.getOutput().getName());
        CSVPrinter printer = new CSVPrinter(writer, format);
        for (int i = 0; i < xAxis.size(); i++) {
            printer.printRecord(xAxis.get(i).asString(), yAxis.get(i).asString());
        }
    }
}
Also used : CSVPrinter(org.apache.commons.csv.CSVPrinter) Value(org.kie.kogito.explainability.model.Value) CSVFormat(org.apache.commons.csv.CSVFormat) Writer(java.io.Writer)

Example 68 with Value

use of org.kie.kogito.explainability.model.Value in project kogito-apps by kiegroup.

the class DataUtils method dropFeature.

/**
 * Drop a given feature from a list of existing features.
 *
 * @param features the existing features
 * @param target the feature to drop
 * @return a new list of features having the target feature dropped
 */
public static List<Feature> dropFeature(List<Feature> features, Feature target) {
    List<Feature> newList = new ArrayList<>(features.size());
    for (Feature sourceFeature : features) {
        String sourceFeatureName = sourceFeature.getName();
        Type sourceFeatureType = sourceFeature.getType();
        Value sourceFeatureValue = sourceFeature.getValue();
        Feature f;
        if (target.getName().equals(sourceFeatureName)) {
            if (target.getType().equals(sourceFeatureType) && target.getValue().equals(sourceFeatureValue)) {
                Value droppedValue = sourceFeatureType.drop(sourceFeatureValue);
                f = FeatureFactory.copyOf(sourceFeature, droppedValue);
            } else {
                f = dropOnLinearizedFeatures(target, sourceFeature);
            }
        } else if (Type.COMPOSITE.equals(sourceFeatureType)) {
            List<Feature> nestedFeatures = (List<Feature>) sourceFeatureValue.getUnderlyingObject();
            f = FeatureFactory.newCompositeFeature(sourceFeatureName, dropFeature(nestedFeatures, target));
        } else {
            // not found
            f = FeatureFactory.copyOf(sourceFeature, sourceFeatureValue);
        }
        newList.add(f);
    }
    return newList;
}
Also used : Type(org.kie.kogito.explainability.model.Type) ArrayList(java.util.ArrayList) Value(org.kie.kogito.explainability.model.Value) ArrayList(java.util.ArrayList) LinkedList(java.util.LinkedList) List(java.util.List) Feature(org.kie.kogito.explainability.model.Feature)

Example 69 with Value

use of org.kie.kogito.explainability.model.Value in project kogito-apps by kiegroup.

the class FairnessMetrics method getFavorableLabelProbability.

private static double getFavorableLabelProbability(Predicate<PredictionInput> groupSelector, List<PredictionInput> samples, PredictionProvider model, Output favorableOutput) throws ExecutionException, InterruptedException {
    String outputName = favorableOutput.getName();
    Value outputValue = favorableOutput.getValue();
    List<PredictionOutput> selectedOutputs = getSelectedPredictionOutputs(groupSelector, samples, model);
    double numSelected = selectedOutputs.size();
    long numFavorableSelected = selectedOutputs.stream().map(po -> po.getByName(outputName)).map(Optional::get).filter(o -> o.getValue().equals(outputValue)).count();
    return numFavorableSelected / numSelected;
}
Also used : Predicate(java.util.function.Predicate) Prediction(org.kie.kogito.explainability.model.Prediction) BiFunction(java.util.function.BiFunction) HashMap(java.util.HashMap) Dataset(org.kie.kogito.explainability.model.Dataset) Value(org.kie.kogito.explainability.model.Value) Collectors(java.util.stream.Collectors) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) ExecutionException(java.util.concurrent.ExecutionException) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) List(java.util.List) Map(java.util.Map) Output(org.kie.kogito.explainability.model.Output) Optional(java.util.Optional) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Optional(java.util.Optional) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Value(org.kie.kogito.explainability.model.Value)

Example 70 with Value

use of org.kie.kogito.explainability.model.Value in project kogito-apps by kiegroup.

the class PartialDependencePlotExplainer method getPartialDependenceGraph.

private PartialDependenceGraph getPartialDependenceGraph(PredictionProvider model, List<PredictionInput> trainingData, List<Value> xsValues, List<Feature> featureXSvalues, int outputIndex) throws InterruptedException, ExecutionException, TimeoutException {
    Output outputDecision = null;
    Feature feature = null;
    // each feature value of the feature under analysis should have a corresponding output value (composed by the marginal impacts of the other features)
    List<Map<Value, Long>> valueCounts = new ArrayList<>(featureXSvalues.size());
    for (int i = 0; i < featureXSvalues.size(); i++) {
        // initialize an empty feature to use in the generated PDP
        if (feature == null) {
            feature = FeatureFactory.copyOf(featureXSvalues.get(i), new Value(null));
        }
        List<PredictionInput> predictionInputs = prepareInputs(featureXSvalues.get(i), trainingData);
        List<PredictionOutput> predictionOutputs = getOutputs(model, predictionInputs);
        // prediction requests are batched per value of feature 'Xs' under analysis
        for (PredictionOutput predictionOutput : predictionOutputs) {
            Output output = predictionOutput.getOutputs().get(outputIndex);
            if (outputDecision == null) {
                outputDecision = new Output(output.getName(), output.getType());
            }
            // update output value counts
            updateValueCounts(valueCounts, i, output);
        }
    }
    if (outputDecision != null) {
        List<Value> yValues = collapseMarginalImpacts(valueCounts, outputDecision.getType());
        return new PartialDependenceGraph(feature, outputDecision, xsValues, yValues);
    } else {
        throw new IllegalArgumentException("cannot produce PDP for null decision");
    }
}
Also used : PredictionInput(org.kie.kogito.explainability.model.PredictionInput) ArrayList(java.util.ArrayList) Feature(org.kie.kogito.explainability.model.Feature) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) Value(org.kie.kogito.explainability.model.Value) HashMap(java.util.HashMap) Map(java.util.Map) PartialDependenceGraph(org.kie.kogito.explainability.model.PartialDependenceGraph)

Aggregations

Value (org.kie.kogito.explainability.model.Value)80 Feature (org.kie.kogito.explainability.model.Feature)69 Output (org.kie.kogito.explainability.model.Output)59 PredictionOutput (org.kie.kogito.explainability.model.PredictionOutput)54 PredictionInput (org.kie.kogito.explainability.model.PredictionInput)49 ArrayList (java.util.ArrayList)42 PredictionProvider (org.kie.kogito.explainability.model.PredictionProvider)42 LinkedList (java.util.LinkedList)36 Type (org.kie.kogito.explainability.model.Type)36 Test (org.junit.jupiter.api.Test)35 List (java.util.List)33 Prediction (org.kie.kogito.explainability.model.Prediction)33 Random (java.util.Random)31 ParameterizedTest (org.junit.jupiter.params.ParameterizedTest)23 Arrays (java.util.Arrays)16 Map (java.util.Map)16 Optional (java.util.Optional)16 CounterfactualEntity (org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity)16 FeatureFactory (org.kie.kogito.explainability.model.FeatureFactory)16 Collectors (java.util.stream.Collectors)15