Search in sources :

Example 11 with Value

use of org.kie.kogito.explainability.model.Value in project kogito-apps by kiegroup.

the class CounterfactualExplainerTest method testSparsity.

/**
 * The test rationale is to find the solution to (f-num1 + f-num2 = 10), for f-num1 with an initial
 * value of 0 and f-num2 with an initial value of 5 and both varying in [0, 10].
 * All the possible solutions will have the same distance, but the sparsity
 * criteria will select the ones which leave one of the inputs (either f-num1 or f-num2) unchanged.
 *
 * @param seed
 * @throws ExecutionException
 * @throws InterruptedException
 * @throws TimeoutException
 */
@ParameterizedTest
@ValueSource(ints = { 0, 1, 2, 3, 4 })
void testSparsity(int seed) throws ExecutionException, InterruptedException, TimeoutException {
    Random random = new Random();
    random.setSeed(seed);
    final List<Output> goal = List.of(new Output("inside", Type.BOOLEAN, new Value(true), 0.0));
    List<Feature> features = new ArrayList<>();
    features.add(FeatureFactory.newNumericalFeature("f-num1", 0, NumericalFeatureDomain.create(0, 10)));
    features.add(FeatureFactory.newNumericalFeature("f-num2", 5, NumericalFeatureDomain.create(0, 10)));
    final double center = 10.0;
    final double epsilon = 0.1;
    final CounterfactualResult result = runCounterfactualSearch((long) seed, goal, features, TestUtils.getSumThresholdModel(center, epsilon), DEFAULT_GOAL_THRESHOLD);
    assertTrue(!result.getEntities().get(0).isChanged() || !result.getEntities().get(1).isChanged());
    assertTrue(result.isValid());
}
Also used : Random(java.util.Random) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) Value(org.kie.kogito.explainability.model.Value) ArrayList(java.util.ArrayList) Feature(org.kie.kogito.explainability.model.Feature) ValueSource(org.junit.jupiter.params.provider.ValueSource) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Example 12 with Value

use of org.kie.kogito.explainability.model.Value in project kogito-apps by kiegroup.

the class CounterfactualScoreCalculatorTest method testGoalSizeSmaller.

/**
 * Using a smaller number of features in the goals (1) than the model's output (2) should
 * throw an {@link IllegalArgumentException} with the appropriate message.
 */
@Test
void testGoalSizeSmaller() throws ExecutionException, InterruptedException {
    final CounterFactualScoreCalculator scoreCalculator = new CounterFactualScoreCalculator();
    PredictionProvider model = TestUtils.getFeatureSkipModel(0);
    List<Feature> features = new ArrayList<>();
    List<FeatureDomain> featureDomains = new ArrayList<>();
    List<Boolean> constraints = new ArrayList<>();
    // f-1
    features.add(FeatureFactory.newNumericalFeature("f-1", 1.0));
    featureDomains.add(NumericalFeatureDomain.create(0.0, 10.0));
    constraints.add(false);
    // f-2
    features.add(FeatureFactory.newNumericalFeature("f-2", 2.0));
    featureDomains.add(NumericalFeatureDomain.create(0.0, 10.0));
    constraints.add(false);
    // f-3
    features.add(FeatureFactory.newBooleanFeature("f-3", true));
    featureDomains.add(EmptyFeatureDomain.create());
    constraints.add(false);
    PredictionInput input = new PredictionInput(features);
    PredictionFeatureDomain domains = new PredictionFeatureDomain(featureDomains);
    List<CounterfactualEntity> entities = CounterfactualEntityFactory.createEntities(input);
    List<Output> goal = new ArrayList<>();
    goal.add(new Output("f-2", Type.NUMBER, new Value(2.0), 0.0));
    List<PredictionOutput> predictionOutputs = model.predictAsync(List.of(input)).get();
    assertEquals(1, goal.size());
    // A single prediction is expected
    assertEquals(1, predictionOutputs.size());
    // Single prediction with two features
    assertEquals(2, predictionOutputs.get(0).getOutputs().size());
    final CounterfactualSolution solution = new CounterfactualSolution(entities, features, model, goal, UUID.randomUUID(), UUID.randomUUID(), 0.0);
    IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> {
        scoreCalculator.calculateScore(solution);
    });
    assertEquals("Prediction size must be equal to goal size", exception.getMessage());
}
Also used : PredictionInput(org.kie.kogito.explainability.model.PredictionInput) ArrayList(java.util.ArrayList) EmptyFeatureDomain(org.kie.kogito.explainability.model.domain.EmptyFeatureDomain) PredictionFeatureDomain(org.kie.kogito.explainability.model.PredictionFeatureDomain) NumericalFeatureDomain(org.kie.kogito.explainability.model.domain.NumericalFeatureDomain) FeatureDomain(org.kie.kogito.explainability.model.domain.FeatureDomain) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Feature(org.kie.kogito.explainability.model.Feature) CounterfactualEntity(org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity) PredictionFeatureDomain(org.kie.kogito.explainability.model.PredictionFeatureDomain) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) Value(org.kie.kogito.explainability.model.Value) Test(org.junit.jupiter.api.Test) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Example 13 with Value

use of org.kie.kogito.explainability.model.Value in project kogito-apps by kiegroup.

the class DataUtils method boostrapFeatureDistributions.

/**
 * Generate feature distributions from an existing (evantually small) {@link DataDistribution} for each {@link Feature}.
 * Each feature intervals (min, max) and density information (mean, stdDev) are generated using bootstrap, then
 * data points are sampled from a normal distribution (see {@link #generateData(double, double, int, Random)}).
 *
 * @param dataDistribution data distribution to take feature values from
 * @param perturbationContext perturbation context
 * @param featureDistributionSize desired size of generated feature distributions
 * @param draws number of times sampling from feature values is performed
 * @param sampleSize size of each sample draw
 * @param numericFeatureZonesMap high feature score zones
 * @return a map feature name -> generated feature distribution
 */
public static Map<String, FeatureDistribution> boostrapFeatureDistributions(DataDistribution dataDistribution, PerturbationContext perturbationContext, int featureDistributionSize, int draws, int sampleSize, Map<String, HighScoreNumericFeatureZones> numericFeatureZonesMap) {
    Map<String, FeatureDistribution> featureDistributions = new HashMap<>();
    for (FeatureDistribution featureDistribution : dataDistribution.asFeatureDistributions()) {
        Feature feature = featureDistribution.getFeature();
        if (Type.NUMBER.equals(feature.getType())) {
            List<Value> values = featureDistribution.getAllSamples();
            double[] means = new double[draws];
            double[] stdDevs = new double[draws];
            double[] mins = new double[draws];
            double[] maxs = new double[draws];
            for (int i = 0; i < draws; i++) {
                List<Value> sampledValues = DataUtils.sampleWithReplacement(values, sampleSize, perturbationContext.getRandom());
                double[] data = sampledValues.stream().mapToDouble(Value::asNumber).toArray();
                double mean = DataUtils.getMean(data);
                double stdDev = Math.pow(DataUtils.getStdDev(data, mean), 2);
                double min = Arrays.stream(data).min().orElse(Double.MIN_VALUE);
                double max = Arrays.stream(data).max().orElse(Double.MAX_VALUE);
                means[i] = mean;
                stdDevs[i] = stdDev;
                mins[i] = min;
                maxs[i] = max;
            }
            double finalMean = DataUtils.getMean(means);
            double finalStdDev = Math.sqrt(DataUtils.getMean(stdDevs));
            double finalMin = DataUtils.getMean(mins);
            double finalMax = DataUtils.getMean(maxs);
            double[] doubles = DataUtils.generateData(finalMean, finalStdDev, featureDistributionSize, perturbationContext.getRandom());
            double[] boundedData = Arrays.stream(doubles).map(d -> Math.min(Math.max(d, finalMin), finalMax)).toArray();
            HighScoreNumericFeatureZones highScoreNumericFeatureZones = numericFeatureZonesMap.get(feature.getName());
            double[] finaldata;
            if (highScoreNumericFeatureZones != null) {
                double[] filteredData = DoubleStream.of(boundedData).filter(highScoreNumericFeatureZones::test).toArray();
                // only use the filtered data if it's not discarding more than 50% of the points
                if (filteredData.length > featureDistributionSize / 2) {
                    finaldata = filteredData;
                } else {
                    finaldata = boundedData;
                }
            } else {
                finaldata = boundedData;
            }
            NumericFeatureDistribution numericFeatureDistribution = new NumericFeatureDistribution(feature, finaldata);
            featureDistributions.put(feature.getName(), numericFeatureDistribution);
        }
    }
    return featureDistributions;
}
Also used : IntStream(java.util.stream.IntStream) FeatureFactory(org.kie.kogito.explainability.model.FeatureFactory) Arrays(java.util.Arrays) MalformedInputException(java.nio.charset.MalformedInputException) PredictionInputsDataDistribution(org.kie.kogito.explainability.model.PredictionInputsDataDistribution) PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) Feature(org.kie.kogito.explainability.model.Feature) Prediction(org.kie.kogito.explainability.model.Prediction) CSVRecord(org.apache.commons.csv.CSVRecord) TimeoutException(java.util.concurrent.TimeoutException) HashMap(java.util.HashMap) Random(java.util.Random) Value(org.kie.kogito.explainability.model.Value) DataDistribution(org.kie.kogito.explainability.model.DataDistribution) ArrayList(java.util.ArrayList) CSVFormat(org.apache.commons.csv.CSVFormat) NumericFeatureDistribution(org.kie.kogito.explainability.model.NumericFeatureDistribution) PartialDependenceGraph(org.kie.kogito.explainability.model.PartialDependenceGraph) Map(java.util.Map) FeatureDistribution(org.kie.kogito.explainability.model.FeatureDistribution) LinkedList(java.util.LinkedList) Path(java.nio.file.Path) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) IndependentFeaturesDataDistribution(org.kie.kogito.explainability.model.IndependentFeaturesDataDistribution) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) Files(java.nio.file.Files) IOException(java.io.IOException) Collectors(java.util.stream.Collectors) Type(org.kie.kogito.explainability.model.Type) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) DoubleStream(java.util.stream.DoubleStream) ExecutionException(java.util.concurrent.ExecutionException) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) List(java.util.List) Output(org.kie.kogito.explainability.model.Output) Writer(java.io.Writer) Optional(java.util.Optional) HighScoreNumericFeatureZones(org.kie.kogito.explainability.local.lime.HighScoreNumericFeatureZones) BufferedReader(java.io.BufferedReader) Config(org.kie.kogito.explainability.Config) Collections(java.util.Collections) CSVPrinter(org.apache.commons.csv.CSVPrinter) HashMap(java.util.HashMap) Feature(org.kie.kogito.explainability.model.Feature) HighScoreNumericFeatureZones(org.kie.kogito.explainability.local.lime.HighScoreNumericFeatureZones) NumericFeatureDistribution(org.kie.kogito.explainability.model.NumericFeatureDistribution) FeatureDistribution(org.kie.kogito.explainability.model.FeatureDistribution) Value(org.kie.kogito.explainability.model.Value) NumericFeatureDistribution(org.kie.kogito.explainability.model.NumericFeatureDistribution)

Example 14 with Value

use of org.kie.kogito.explainability.model.Value in project kogito-apps by kiegroup.

the class DataUtils method perturbFeatures.

/**
 * Perform perturbations on a fixed number of features in the given input.
 * A map of feature distributions to draw (all, none or some of them) is given.
 * Which feature will be perturbed is non deterministic.
 *
 * @param originalFeatures the input features that need to be perturbed
 * @param perturbationContext the perturbation context
 * @param featureDistributionsMap the map of feature distributions
 * @return a perturbed copy of the input features
 */
public static List<Feature> perturbFeatures(List<Feature> originalFeatures, PerturbationContext perturbationContext, Map<String, FeatureDistribution> featureDistributionsMap) {
    List<Feature> newFeatures = new ArrayList<>(originalFeatures);
    if (!newFeatures.isEmpty()) {
        // perturb at most in the range [|features|/2), noOfPerturbations]
        int lowerBound = (int) Math.min(perturbationContext.getNoOfPerturbations(), 0.5d * newFeatures.size());
        int upperBound = (int) Math.max(perturbationContext.getNoOfPerturbations(), 0.5d * newFeatures.size());
        upperBound = Math.min(upperBound, newFeatures.size());
        // lower bound should always be greater than zero (not ok to not perturb)
        lowerBound = Math.max(1, lowerBound);
        int perturbationSize = 0;
        if (lowerBound == upperBound) {
            perturbationSize = lowerBound;
        } else if (upperBound > lowerBound) {
            perturbationSize = perturbationContext.getRandom().ints(1, lowerBound, 1 + upperBound).findFirst().orElse(1);
        }
        if (perturbationSize > 0) {
            int[] indexesToBePerturbed = perturbationContext.getRandom().ints(0, newFeatures.size()).distinct().limit(perturbationSize).toArray();
            for (int index : indexesToBePerturbed) {
                Feature feature = newFeatures.get(index);
                Value newValue;
                if (featureDistributionsMap.containsKey(feature.getName())) {
                    newValue = featureDistributionsMap.get(feature.getName()).sample();
                } else {
                    newValue = feature.getType().perturb(feature.getValue(), perturbationContext);
                }
                Feature perturbedFeature = FeatureFactory.copyOf(feature, newValue);
                newFeatures.set(index, perturbedFeature);
            }
        }
    }
    return newFeatures;
}
Also used : ArrayList(java.util.ArrayList) Value(org.kie.kogito.explainability.model.Value) Feature(org.kie.kogito.explainability.model.Feature)

Example 15 with Value

use of org.kie.kogito.explainability.model.Value in project kogito-apps by kiegroup.

the class DataUtils method readCSV.

/**
 * Read a CSV file into a {@link DataDistribution} object.
 *
 * @param file the path to the CSV file
 * @param schema an ordered list of {@link Type}s as the 'schema', used to determine
 *        the {@link Type} of each feature / column
 * @return the parsed CSV as a {@link DataDistribution}
 * @throws IOException when failing at reading the CSV file
 * @throws MalformedInputException if any record in CSV has different size with respect to the specified schema
 */
public static DataDistribution readCSV(Path file, List<Type> schema) throws IOException {
    List<PredictionInput> inputs = new ArrayList<>();
    try (BufferedReader reader = Files.newBufferedReader(file)) {
        Iterable<CSVRecord> records = CSVFormat.RFC4180.withFirstRecordAsHeader().parse(reader);
        for (CSVRecord record : records) {
            int size = record.size();
            if (schema.size() == size) {
                List<Feature> features = new ArrayList<>();
                for (int i = 0; i < size; i++) {
                    String s = record.get(i);
                    Type type = schema.get(i);
                    features.add(new Feature(record.getParser().getHeaderNames().get(i), type, new Value(s)));
                }
                inputs.add(new PredictionInput(features));
            } else {
                throw new MalformedInputException(size);
            }
        }
    }
    return new PredictionInputsDataDistribution(inputs);
}
Also used : PredictionInput(org.kie.kogito.explainability.model.PredictionInput) ArrayList(java.util.ArrayList) Feature(org.kie.kogito.explainability.model.Feature) Type(org.kie.kogito.explainability.model.Type) BufferedReader(java.io.BufferedReader) Value(org.kie.kogito.explainability.model.Value) MalformedInputException(java.nio.charset.MalformedInputException) CSVRecord(org.apache.commons.csv.CSVRecord) PredictionInputsDataDistribution(org.kie.kogito.explainability.model.PredictionInputsDataDistribution)

Aggregations

Value (org.kie.kogito.explainability.model.Value)80 Feature (org.kie.kogito.explainability.model.Feature)69 Output (org.kie.kogito.explainability.model.Output)59 PredictionOutput (org.kie.kogito.explainability.model.PredictionOutput)54 PredictionInput (org.kie.kogito.explainability.model.PredictionInput)49 ArrayList (java.util.ArrayList)42 PredictionProvider (org.kie.kogito.explainability.model.PredictionProvider)42 LinkedList (java.util.LinkedList)36 Type (org.kie.kogito.explainability.model.Type)36 Test (org.junit.jupiter.api.Test)35 List (java.util.List)33 Prediction (org.kie.kogito.explainability.model.Prediction)33 Random (java.util.Random)31 ParameterizedTest (org.junit.jupiter.params.ParameterizedTest)23 Arrays (java.util.Arrays)16 Map (java.util.Map)16 Optional (java.util.Optional)16 CounterfactualEntity (org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity)16 FeatureFactory (org.kie.kogito.explainability.model.FeatureFactory)16 Collectors (java.util.stream.Collectors)15