Search in sources :

Example 46 with Feature

use of org.kie.kogito.explainability.model.Feature in project kogito-apps by kiegroup.

the class DataUtils method boostrapFeatureDistributions.

/**
 * Generate feature distributions from an existing (evantually small) {@link DataDistribution} for each {@link Feature}.
 * Each feature intervals (min, max) and density information (mean, stdDev) are generated using bootstrap, then
 * data points are sampled from a normal distribution (see {@link #generateData(double, double, int, Random)}).
 *
 * @param dataDistribution data distribution to take feature values from
 * @param perturbationContext perturbation context
 * @param featureDistributionSize desired size of generated feature distributions
 * @param draws number of times sampling from feature values is performed
 * @param sampleSize size of each sample draw
 * @param numericFeatureZonesMap high feature score zones
 * @return a map feature name -> generated feature distribution
 */
public static Map<String, FeatureDistribution> boostrapFeatureDistributions(DataDistribution dataDistribution, PerturbationContext perturbationContext, int featureDistributionSize, int draws, int sampleSize, Map<String, HighScoreNumericFeatureZones> numericFeatureZonesMap) {
    Map<String, FeatureDistribution> featureDistributions = new HashMap<>();
    for (FeatureDistribution featureDistribution : dataDistribution.asFeatureDistributions()) {
        Feature feature = featureDistribution.getFeature();
        if (Type.NUMBER.equals(feature.getType())) {
            List<Value> values = featureDistribution.getAllSamples();
            double[] means = new double[draws];
            double[] stdDevs = new double[draws];
            double[] mins = new double[draws];
            double[] maxs = new double[draws];
            for (int i = 0; i < draws; i++) {
                List<Value> sampledValues = DataUtils.sampleWithReplacement(values, sampleSize, perturbationContext.getRandom());
                double[] data = sampledValues.stream().mapToDouble(Value::asNumber).toArray();
                double mean = DataUtils.getMean(data);
                double stdDev = Math.pow(DataUtils.getStdDev(data, mean), 2);
                double min = Arrays.stream(data).min().orElse(Double.MIN_VALUE);
                double max = Arrays.stream(data).max().orElse(Double.MAX_VALUE);
                means[i] = mean;
                stdDevs[i] = stdDev;
                mins[i] = min;
                maxs[i] = max;
            }
            double finalMean = DataUtils.getMean(means);
            double finalStdDev = Math.sqrt(DataUtils.getMean(stdDevs));
            double finalMin = DataUtils.getMean(mins);
            double finalMax = DataUtils.getMean(maxs);
            double[] doubles = DataUtils.generateData(finalMean, finalStdDev, featureDistributionSize, perturbationContext.getRandom());
            double[] boundedData = Arrays.stream(doubles).map(d -> Math.min(Math.max(d, finalMin), finalMax)).toArray();
            HighScoreNumericFeatureZones highScoreNumericFeatureZones = numericFeatureZonesMap.get(feature.getName());
            double[] finaldata;
            if (highScoreNumericFeatureZones != null) {
                double[] filteredData = DoubleStream.of(boundedData).filter(highScoreNumericFeatureZones::test).toArray();
                // only use the filtered data if it's not discarding more than 50% of the points
                if (filteredData.length > featureDistributionSize / 2) {
                    finaldata = filteredData;
                } else {
                    finaldata = boundedData;
                }
            } else {
                finaldata = boundedData;
            }
            NumericFeatureDistribution numericFeatureDistribution = new NumericFeatureDistribution(feature, finaldata);
            featureDistributions.put(feature.getName(), numericFeatureDistribution);
        }
    }
    return featureDistributions;
}
Also used : IntStream(java.util.stream.IntStream) FeatureFactory(org.kie.kogito.explainability.model.FeatureFactory) Arrays(java.util.Arrays) MalformedInputException(java.nio.charset.MalformedInputException) PredictionInputsDataDistribution(org.kie.kogito.explainability.model.PredictionInputsDataDistribution) PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) Feature(org.kie.kogito.explainability.model.Feature) Prediction(org.kie.kogito.explainability.model.Prediction) CSVRecord(org.apache.commons.csv.CSVRecord) TimeoutException(java.util.concurrent.TimeoutException) HashMap(java.util.HashMap) Random(java.util.Random) Value(org.kie.kogito.explainability.model.Value) DataDistribution(org.kie.kogito.explainability.model.DataDistribution) ArrayList(java.util.ArrayList) CSVFormat(org.apache.commons.csv.CSVFormat) NumericFeatureDistribution(org.kie.kogito.explainability.model.NumericFeatureDistribution) PartialDependenceGraph(org.kie.kogito.explainability.model.PartialDependenceGraph) Map(java.util.Map) FeatureDistribution(org.kie.kogito.explainability.model.FeatureDistribution) LinkedList(java.util.LinkedList) Path(java.nio.file.Path) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) IndependentFeaturesDataDistribution(org.kie.kogito.explainability.model.IndependentFeaturesDataDistribution) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) Files(java.nio.file.Files) IOException(java.io.IOException) Collectors(java.util.stream.Collectors) Type(org.kie.kogito.explainability.model.Type) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) DoubleStream(java.util.stream.DoubleStream) ExecutionException(java.util.concurrent.ExecutionException) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) List(java.util.List) Output(org.kie.kogito.explainability.model.Output) Writer(java.io.Writer) Optional(java.util.Optional) HighScoreNumericFeatureZones(org.kie.kogito.explainability.local.lime.HighScoreNumericFeatureZones) BufferedReader(java.io.BufferedReader) Config(org.kie.kogito.explainability.Config) Collections(java.util.Collections) CSVPrinter(org.apache.commons.csv.CSVPrinter) HashMap(java.util.HashMap) Feature(org.kie.kogito.explainability.model.Feature) HighScoreNumericFeatureZones(org.kie.kogito.explainability.local.lime.HighScoreNumericFeatureZones) NumericFeatureDistribution(org.kie.kogito.explainability.model.NumericFeatureDistribution) FeatureDistribution(org.kie.kogito.explainability.model.FeatureDistribution) Value(org.kie.kogito.explainability.model.Value) NumericFeatureDistribution(org.kie.kogito.explainability.model.NumericFeatureDistribution)

Example 47 with Feature

use of org.kie.kogito.explainability.model.Feature in project kogito-apps by kiegroup.

the class DataUtils method linearizeInputs.

/**
 * Transform a list of prediction inputs into another list of the same prediction inputs but having linearized features.
 *
 * @param predictionInputs a list of prediction inputs
 * @return a list of prediction inputs with linearized features
 */
public static List<PredictionInput> linearizeInputs(List<PredictionInput> predictionInputs) {
    List<PredictionInput> newInputs = new LinkedList<>();
    for (PredictionInput predictionInput : predictionInputs) {
        List<Feature> originalFeatures = predictionInput.getFeatures();
        List<Feature> flattenedFeatures = getLinearizedFeatures(originalFeatures);
        newInputs.add(new PredictionInput(flattenedFeatures));
    }
    return newInputs;
}
Also used : PredictionInput(org.kie.kogito.explainability.model.PredictionInput) Feature(org.kie.kogito.explainability.model.Feature) LinkedList(java.util.LinkedList)

Example 48 with Feature

use of org.kie.kogito.explainability.model.Feature in project kogito-apps by kiegroup.

the class DataUtils method perturbFeatures.

/**
 * Perform perturbations on a fixed number of features in the given input.
 * A map of feature distributions to draw (all, none or some of them) is given.
 * Which feature will be perturbed is non deterministic.
 *
 * @param originalFeatures the input features that need to be perturbed
 * @param perturbationContext the perturbation context
 * @param featureDistributionsMap the map of feature distributions
 * @return a perturbed copy of the input features
 */
public static List<Feature> perturbFeatures(List<Feature> originalFeatures, PerturbationContext perturbationContext, Map<String, FeatureDistribution> featureDistributionsMap) {
    List<Feature> newFeatures = new ArrayList<>(originalFeatures);
    if (!newFeatures.isEmpty()) {
        // perturb at most in the range [|features|/2), noOfPerturbations]
        int lowerBound = (int) Math.min(perturbationContext.getNoOfPerturbations(), 0.5d * newFeatures.size());
        int upperBound = (int) Math.max(perturbationContext.getNoOfPerturbations(), 0.5d * newFeatures.size());
        upperBound = Math.min(upperBound, newFeatures.size());
        // lower bound should always be greater than zero (not ok to not perturb)
        lowerBound = Math.max(1, lowerBound);
        int perturbationSize = 0;
        if (lowerBound == upperBound) {
            perturbationSize = lowerBound;
        } else if (upperBound > lowerBound) {
            perturbationSize = perturbationContext.getRandom().ints(1, lowerBound, 1 + upperBound).findFirst().orElse(1);
        }
        if (perturbationSize > 0) {
            int[] indexesToBePerturbed = perturbationContext.getRandom().ints(0, newFeatures.size()).distinct().limit(perturbationSize).toArray();
            for (int index : indexesToBePerturbed) {
                Feature feature = newFeatures.get(index);
                Value newValue;
                if (featureDistributionsMap.containsKey(feature.getName())) {
                    newValue = featureDistributionsMap.get(feature.getName()).sample();
                } else {
                    newValue = feature.getType().perturb(feature.getValue(), perturbationContext);
                }
                Feature perturbedFeature = FeatureFactory.copyOf(feature, newValue);
                newFeatures.set(index, perturbedFeature);
            }
        }
    }
    return newFeatures;
}
Also used : ArrayList(java.util.ArrayList) Value(org.kie.kogito.explainability.model.Value) Feature(org.kie.kogito.explainability.model.Feature)

Example 49 with Feature

use of org.kie.kogito.explainability.model.Feature in project kogito-apps by kiegroup.

the class DataUtils method readCSV.

/**
 * Read a CSV file into a {@link DataDistribution} object.
 *
 * @param file the path to the CSV file
 * @param schema an ordered list of {@link Type}s as the 'schema', used to determine
 *        the {@link Type} of each feature / column
 * @return the parsed CSV as a {@link DataDistribution}
 * @throws IOException when failing at reading the CSV file
 * @throws MalformedInputException if any record in CSV has different size with respect to the specified schema
 */
public static DataDistribution readCSV(Path file, List<Type> schema) throws IOException {
    List<PredictionInput> inputs = new ArrayList<>();
    try (BufferedReader reader = Files.newBufferedReader(file)) {
        Iterable<CSVRecord> records = CSVFormat.RFC4180.withFirstRecordAsHeader().parse(reader);
        for (CSVRecord record : records) {
            int size = record.size();
            if (schema.size() == size) {
                List<Feature> features = new ArrayList<>();
                for (int i = 0; i < size; i++) {
                    String s = record.get(i);
                    Type type = schema.get(i);
                    features.add(new Feature(record.getParser().getHeaderNames().get(i), type, new Value(s)));
                }
                inputs.add(new PredictionInput(features));
            } else {
                throw new MalformedInputException(size);
            }
        }
    }
    return new PredictionInputsDataDistribution(inputs);
}
Also used : PredictionInput(org.kie.kogito.explainability.model.PredictionInput) ArrayList(java.util.ArrayList) Feature(org.kie.kogito.explainability.model.Feature) Type(org.kie.kogito.explainability.model.Type) BufferedReader(java.io.BufferedReader) Value(org.kie.kogito.explainability.model.Value) MalformedInputException(java.nio.charset.MalformedInputException) CSVRecord(org.apache.commons.csv.CSVRecord) PredictionInputsDataDistribution(org.kie.kogito.explainability.model.PredictionInputsDataDistribution)

Example 50 with Feature

use of org.kie.kogito.explainability.model.Feature in project kogito-apps by kiegroup.

the class DataUtils method replaceFeatures.

/**
 * Replace an existing feature in a list with another feature.
 * The feature to be replaced is the one whose name is equals to the name of the feature to use as replacement.
 *
 * @param featureToUse feature to use as replacmement
 * @param existingFeatures list of features containing the feature to be replaced
 * @return a new list of features having the "replaced" feature
 */
public static List<Feature> replaceFeatures(Feature featureToUse, List<Feature> existingFeatures) {
    List<Feature> newFeatures = new ArrayList<>();
    for (Feature f : existingFeatures) {
        Feature newFeature;
        if (f.getName().equals(featureToUse.getName())) {
            newFeature = FeatureFactory.copyOf(f, featureToUse.getValue());
        } else {
            if (Type.COMPOSITE == f.getType()) {
                List<Feature> elements = (List<Feature>) f.getValue().getUnderlyingObject();
                newFeature = FeatureFactory.newCompositeFeature(f.getName(), replaceFeatures(featureToUse, elements));
            } else {
                newFeature = FeatureFactory.copyOf(f, f.getValue());
            }
        }
        newFeatures.add(newFeature);
    }
    return newFeatures;
}
Also used : ArrayList(java.util.ArrayList) ArrayList(java.util.ArrayList) LinkedList(java.util.LinkedList) List(java.util.List) Feature(org.kie.kogito.explainability.model.Feature)

Aggregations

Feature (org.kie.kogito.explainability.model.Feature)233 PredictionOutput (org.kie.kogito.explainability.model.PredictionOutput)118 Test (org.junit.jupiter.api.Test)107 PredictionInput (org.kie.kogito.explainability.model.PredictionInput)107 ParameterizedTest (org.junit.jupiter.params.ParameterizedTest)104 Output (org.kie.kogito.explainability.model.Output)102 ArrayList (java.util.ArrayList)97 Random (java.util.Random)92 PredictionProvider (org.kie.kogito.explainability.model.PredictionProvider)78 Value (org.kie.kogito.explainability.model.Value)74 LinkedList (java.util.LinkedList)72 ValueSource (org.junit.jupiter.params.provider.ValueSource)71 Prediction (org.kie.kogito.explainability.model.Prediction)67 List (java.util.List)51 CounterfactualEntity (org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity)46 PerturbationContext (org.kie.kogito.explainability.model.PerturbationContext)42 Type (org.kie.kogito.explainability.model.Type)39 NumericalFeatureDomain (org.kie.kogito.explainability.model.domain.NumericalFeatureDomain)37 SimplePrediction (org.kie.kogito.explainability.model.SimplePrediction)35 FeatureDomain (org.kie.kogito.explainability.model.domain.FeatureDomain)33