Search in sources :

Example 1 with Type

use of org.kie.kogito.explainability.model.Type in project kogito-apps by kiegroup.

the class DataUtils method readCSV.

/**
 * Read a CSV file into a {@link DataDistribution} object.
 *
 * @param file the path to the CSV file
 * @param schema an ordered list of {@link Type}s as the 'schema', used to determine
 *        the {@link Type} of each feature / column
 * @return the parsed CSV as a {@link DataDistribution}
 * @throws IOException when failing at reading the CSV file
 * @throws MalformedInputException if any record in CSV has different size with respect to the specified schema
 */
public static DataDistribution readCSV(Path file, List<Type> schema) throws IOException {
    List<PredictionInput> inputs = new ArrayList<>();
    try (BufferedReader reader = Files.newBufferedReader(file)) {
        Iterable<CSVRecord> records = CSVFormat.RFC4180.withFirstRecordAsHeader().parse(reader);
        for (CSVRecord record : records) {
            int size = record.size();
            if (schema.size() == size) {
                List<Feature> features = new ArrayList<>();
                for (int i = 0; i < size; i++) {
                    String s = record.get(i);
                    Type type = schema.get(i);
                    features.add(new Feature(record.getParser().getHeaderNames().get(i), type, new Value(s)));
                }
                inputs.add(new PredictionInput(features));
            } else {
                throw new MalformedInputException(size);
            }
        }
    }
    return new PredictionInputsDataDistribution(inputs);
}
Also used : PredictionInput(org.kie.kogito.explainability.model.PredictionInput) ArrayList(java.util.ArrayList) Feature(org.kie.kogito.explainability.model.Feature) Type(org.kie.kogito.explainability.model.Type) BufferedReader(java.io.BufferedReader) Value(org.kie.kogito.explainability.model.Value) MalformedInputException(java.nio.charset.MalformedInputException) CSVRecord(org.apache.commons.csv.CSVRecord) PredictionInputsDataDistribution(org.kie.kogito.explainability.model.PredictionInputsDataDistribution)

Example 2 with Type

use of org.kie.kogito.explainability.model.Type in project kogito-apps by kiegroup.

the class LimeExplainer method prepareInputs.

/**
 * Check the perturbed inputs so that the dataset of perturbed input / outputs contains more than just one output
 * class, otherwise it would be impossible to linearly separate it, and hence learn meaningful weights to be used as
 * feature importance scores.
 * The check can be {@code strict} or not, if so it will throw a {@code DatasetNotSeparableException} when the dataset
 * for a given output is not separable.
 */
private LimeInputs prepareInputs(List<PredictionInput> perturbedInputs, List<PredictionOutput> perturbedOutputs, List<Feature> linearizedTargetInputFeatures, int o, Output currentOutput, boolean strict) {
    if (currentOutput.getValue() != null && currentOutput.getValue().getUnderlyingObject() != null) {
        Map<Double, Long> rawClassesBalance;
        // calculate the no. of samples belonging to each output class
        Value fv = currentOutput.getValue();
        rawClassesBalance = getClassBalance(perturbedOutputs, fv, o);
        Long max = rawClassesBalance.values().stream().max(Long::compareTo).orElse(1L);
        double separationRatio = (double) max / (double) perturbedInputs.size();
        List<Output> outputs = perturbedOutputs.stream().map(po -> po.getOutputs().get(o)).collect(Collectors.toList());
        boolean classification = rawClassesBalance.size() == 2;
        if (strict) {
            // check if the dataset is separable and also if the linear model should fit a regressor or a classifier
            if (rawClassesBalance.size() > 1 && separationRatio < limeConfig.getSeparableDatasetRatio()) {
                // if dataset creation process succeeds use it to train the linear model
                return new LimeInputs(classification, linearizedTargetInputFeatures, currentOutput, perturbedInputs, outputs);
            } else {
                throw new DatasetNotSeparableException(currentOutput, rawClassesBalance);
            }
        } else {
            LOGGER.warn("Using an hardly separable dataset for output '{}' of type '{}' with value '{}' ({})", currentOutput.getName(), currentOutput.getType(), currentOutput.getValue(), rawClassesBalance);
            return new LimeInputs(classification, linearizedTargetInputFeatures, currentOutput, perturbedInputs, outputs);
        }
    } else {
        return new LimeInputs(false, linearizedTargetInputFeatures, currentOutput, emptyList(), emptyList());
    }
}
Also used : Arrays(java.util.Arrays) PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) Feature(org.kie.kogito.explainability.model.Feature) Prediction(org.kie.kogito.explainability.model.Prediction) CompletableFuture.completedFuture(java.util.concurrent.CompletableFuture.completedFuture) LoggerFactory(org.slf4j.LoggerFactory) HashMap(java.util.HashMap) CompletableFuture(java.util.concurrent.CompletableFuture) Value(org.kie.kogito.explainability.model.Value) DataDistribution(org.kie.kogito.explainability.model.DataDistribution) Saliency(org.kie.kogito.explainability.model.Saliency) ArrayList(java.util.ArrayList) LinearModel(org.kie.kogito.explainability.utils.LinearModel) Pair(org.apache.commons.lang3.tuple.Pair) Map(java.util.Map) FeatureDistribution(org.kie.kogito.explainability.model.FeatureDistribution) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) DataUtils(org.kie.kogito.explainability.utils.DataUtils) Logger(org.slf4j.Logger) LocalExplainer(org.kie.kogito.explainability.local.LocalExplainer) Collections.emptyList(java.util.Collections.emptyList) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) LocalExplanationException(org.kie.kogito.explainability.local.LocalExplanationException) Collectors(java.util.stream.Collectors) Type(org.kie.kogito.explainability.model.Type) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Objects(java.util.Objects) Consumer(java.util.function.Consumer) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) List(java.util.List) Output(org.kie.kogito.explainability.model.Output) Optional(java.util.Optional) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) Value(org.kie.kogito.explainability.model.Value)

Example 3 with Type

use of org.kie.kogito.explainability.model.Type in project kogito-apps by kiegroup.

the class ConversionUtils method toTypedValue.

static TypedValue toTypedValue(Output output) {
    String name = output.getName();
    Type type = output.getType();
    Value value = output.getValue();
    return toTypedValue(name, type, value);
}
Also used : Type(org.kie.kogito.explainability.model.Type) CounterfactualSearchDomainValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainValue) Value(org.kie.kogito.explainability.model.Value) HasNameValue(org.kie.kogito.explainability.api.HasNameValue) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) TypedValue(org.kie.kogito.tracing.typedvalue.TypedValue) UnitValue(org.kie.kogito.tracing.typedvalue.UnitValue) CollectionValue(org.kie.kogito.tracing.typedvalue.CollectionValue)

Example 4 with Type

use of org.kie.kogito.explainability.model.Type in project kogito-apps by kiegroup.

the class DecisionModelWrapper method predictAsync.

@Override
public CompletableFuture<List<PredictionOutput>> predictAsync(List<PredictionInput> inputs) {
    List<PredictionOutput> predictionOutputs = new LinkedList<>();
    for (PredictionInput input : inputs) {
        Map<String, Object> contextVariables = toMap(input.getFeatures());
        final DMNContext context = decisionModel.newContext(contextVariables);
        DMNResult dmnResult = decisionModel.evaluateAll(context);
        List<Output> outputs = new LinkedList<>();
        for (DMNDecisionResult decisionResult : dmnResult.getDecisionResults()) {
            String decisionName = decisionResult.getDecisionName();
            if (!skippedDecisions.contains(decisionName)) {
                Object result = decisionResult.getResult();
                Value value = new Value(result);
                Type type;
                if (result == null) {
                    type = Type.TEXT;
                } else {
                    if (result instanceof Boolean) {
                        type = Type.BOOLEAN;
                    } else if (result instanceof String) {
                        type = Type.TEXT;
                    } else {
                        type = Type.NUMBER;
                    }
                }
                Output output = new Output(decisionName, type, value, 1d);
                outputs.add(output);
            }
        }
        PredictionOutput predictionOutput = new PredictionOutput(outputs);
        predictionOutputs.add(predictionOutput);
    }
    return completedFuture(predictionOutputs);
}
Also used : DMNResult(org.kie.dmn.api.core.DMNResult) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) DMNContext(org.kie.dmn.api.core.DMNContext) LinkedList(java.util.LinkedList) Type(org.kie.kogito.explainability.model.Type) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) DMNDecisionResult(org.kie.dmn.api.core.DMNDecisionResult) Value(org.kie.kogito.explainability.model.Value)

Example 5 with Type

use of org.kie.kogito.explainability.model.Type in project kogito-apps by kiegroup.

the class DataUtilsTest method testReadCsv.

@Test
void testReadCsv() throws IOException {
    List<Type> schema = new ArrayList<>();
    schema.add(Type.CATEGORICAL);
    schema.add(Type.BOOLEAN);
    schema.add(Type.BOOLEAN);
    schema.add(Type.BOOLEAN);
    schema.add(Type.BOOLEAN);
    schema.add(Type.BOOLEAN);
    schema.add(Type.BOOLEAN);
    schema.add(Type.BOOLEAN);
    schema.add(Type.BOOLEAN);
    schema.add(Type.BOOLEAN);
    schema.add(Type.BOOLEAN);
    schema.add(Type.NUMBER);
    schema.add(Type.NUMBER);
    DataDistribution dataDistribution = DataUtils.readCSV(Paths.get(getClass().getResource("/mini-train.csv").getFile()), schema);
    assertThat(dataDistribution).isNotNull();
    assertThat(dataDistribution.getAllSamples()).hasSize(10);
}
Also used : Type(org.kie.kogito.explainability.model.Type) DataDistribution(org.kie.kogito.explainability.model.DataDistribution) IndependentFeaturesDataDistribution(org.kie.kogito.explainability.model.IndependentFeaturesDataDistribution) ArrayList(java.util.ArrayList) Test(org.junit.jupiter.api.Test) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Aggregations

Type (org.kie.kogito.explainability.model.Type)15 Value (org.kie.kogito.explainability.model.Value)10 Feature (org.kie.kogito.explainability.model.Feature)8 ArrayList (java.util.ArrayList)6 PredictionInput (org.kie.kogito.explainability.model.PredictionInput)6 List (java.util.List)5 Output (org.kie.kogito.explainability.model.Output)5 PredictionOutput (org.kie.kogito.explainability.model.PredictionOutput)5 HashMap (java.util.HashMap)4 LinkedList (java.util.LinkedList)4 Map (java.util.Map)4 Collectors (java.util.stream.Collectors)4 DataDistribution (org.kie.kogito.explainability.model.DataDistribution)4 HasNameValue (org.kie.kogito.explainability.api.HasNameValue)3 Prediction (org.kie.kogito.explainability.model.Prediction)3 PredictionProvider (org.kie.kogito.explainability.model.PredictionProvider)3 Logger (org.slf4j.Logger)3 LoggerFactory (org.slf4j.LoggerFactory)3 LocalTime (java.time.LocalTime)2 Collection (java.util.Collection)2