Search in sources :

Example 6 with Type

use of org.kie.kogito.explainability.model.Type in project kogito-apps by kiegroup.

the class DataUtilsTest method testDropLinearizedFeature.

@Test
void testDropLinearizedFeature() {
    for (Type t : Type.values()) {
        Feature target = TestUtils.getMockedFeature(t, new Value(1d));
        List<Feature> features = new LinkedList<>();
        features.add(TestUtils.getMockedNumericFeature());
        features.add(target);
        features.add(TestUtils.getMockedTextFeature("foo bar"));
        features.add(TestUtils.getMockedNumericFeature());
        Feature source = FeatureFactory.newCompositeFeature("composite", features);
        Feature newFeature = DataUtils.dropOnLinearizedFeatures(target, source);
        assertNotEquals(source, newFeature);
    }
}
Also used : Type(org.kie.kogito.explainability.model.Type) Value(org.kie.kogito.explainability.model.Value) Feature(org.kie.kogito.explainability.model.Feature) LinkedList(java.util.LinkedList) Test(org.junit.jupiter.api.Test) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Example 7 with Type

use of org.kie.kogito.explainability.model.Type in project kogito-apps by kiegroup.

the class DataUtilsTest method testDropFeature.

@Test
void testDropFeature() {
    for (Type t : Type.values()) {
        Feature target = TestUtils.getMockedFeature(t, new Value(1d));
        List<Feature> features = new LinkedList<>();
        features.add(TestUtils.getMockedNumericFeature());
        features.add(target);
        features.add(TestUtils.getMockedTextFeature("foo bar"));
        features.add(TestUtils.getMockedNumericFeature());
        List<Feature> newFeatures = DataUtils.dropFeature(features, target);
        assertNotEquals(features, newFeatures);
    }
}
Also used : Type(org.kie.kogito.explainability.model.Type) Value(org.kie.kogito.explainability.model.Value) Feature(org.kie.kogito.explainability.model.Feature) LinkedList(java.util.LinkedList) Test(org.junit.jupiter.api.Test) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Example 8 with Type

use of org.kie.kogito.explainability.model.Type in project kogito-apps by kiegroup.

the class DataUtils method dropFeature.

/**
 * Drop a given feature from a list of existing features.
 *
 * @param features the existing features
 * @param target the feature to drop
 * @return a new list of features having the target feature dropped
 */
public static List<Feature> dropFeature(List<Feature> features, Feature target) {
    List<Feature> newList = new ArrayList<>(features.size());
    for (Feature sourceFeature : features) {
        String sourceFeatureName = sourceFeature.getName();
        Type sourceFeatureType = sourceFeature.getType();
        Value sourceFeatureValue = sourceFeature.getValue();
        Feature f;
        if (target.getName().equals(sourceFeatureName)) {
            if (target.getType().equals(sourceFeatureType) && target.getValue().equals(sourceFeatureValue)) {
                Value droppedValue = sourceFeatureType.drop(sourceFeatureValue);
                f = FeatureFactory.copyOf(sourceFeature, droppedValue);
            } else {
                f = dropOnLinearizedFeatures(target, sourceFeature);
            }
        } else if (Type.COMPOSITE.equals(sourceFeatureType)) {
            List<Feature> nestedFeatures = (List<Feature>) sourceFeatureValue.getUnderlyingObject();
            f = FeatureFactory.newCompositeFeature(sourceFeatureName, dropFeature(nestedFeatures, target));
        } else {
            // not found
            f = FeatureFactory.copyOf(sourceFeature, sourceFeatureValue);
        }
        newList.add(f);
    }
    return newList;
}
Also used : Type(org.kie.kogito.explainability.model.Type) ArrayList(java.util.ArrayList) Value(org.kie.kogito.explainability.model.Value) ArrayList(java.util.ArrayList) LinkedList(java.util.LinkedList) List(java.util.List) Feature(org.kie.kogito.explainability.model.Feature)

Example 9 with Type

use of org.kie.kogito.explainability.model.Type in project kogito-apps by kiegroup.

the class ExplainabilityMetrics method classificationFidelity.

/**
 * Calculate fidelity (accuracy) of boolean classification outputs using saliency predictor function = sign(sum(saliency.scores))
 * See papers:
 * - Guidotti Riccardo, et al. "A survey of methods for explaining black box models." ACM computing surveys (2018).
 * - Bodria, Francesco, et al. "Explainability Methods for Natural Language Processing: Applications to Sentiment Analysis (Discussion Paper)."
 *
 * @param pairs pairs composed by the saliency and the related prediction
 * @return the fidelity accuracy
 */
public static double classificationFidelity(List<Pair<Saliency, Prediction>> pairs) {
    double acc = 0;
    double evals = 0;
    for (Pair<Saliency, Prediction> pair : pairs) {
        Saliency saliency = pair.getLeft();
        Prediction prediction = pair.getRight();
        for (Output output : prediction.getOutput().getOutputs()) {
            Type type = output.getType();
            if (Type.BOOLEAN.equals(type)) {
                double predictorOutput = saliency.getPerFeatureImportance().stream().map(FeatureImportance::getScore).mapToDouble(d -> d).sum();
                double v = output.getValue().asNumber();
                if ((v >= 0 && predictorOutput >= 0) || (v < 0 && predictorOutput < 0)) {
                    acc++;
                }
                evals++;
            }
        }
    }
    return evals == 0 ? 0 : acc / evals;
}
Also used : Feature(org.kie.kogito.explainability.model.Feature) Prediction(org.kie.kogito.explainability.model.Prediction) LoggerFactory(org.slf4j.LoggerFactory) TimeoutException(java.util.concurrent.TimeoutException) HashMap(java.util.HashMap) Function(java.util.function.Function) DataDistribution(org.kie.kogito.explainability.model.DataDistribution) Saliency(org.kie.kogito.explainability.model.Saliency) ArrayList(java.util.ArrayList) Pair(org.apache.commons.lang3.tuple.Pair) Map(java.util.Map) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Logger(org.slf4j.Logger) LocalExplainer(org.kie.kogito.explainability.local.LocalExplainer) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) Collectors(java.util.stream.Collectors) Type(org.kie.kogito.explainability.model.Type) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) ExecutionException(java.util.concurrent.ExecutionException) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) List(java.util.List) Output(org.kie.kogito.explainability.model.Output) Optional(java.util.Optional) Comparator(java.util.Comparator) Config(org.kie.kogito.explainability.Config) Collections(java.util.Collections) Type(org.kie.kogito.explainability.model.Type) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) Prediction(org.kie.kogito.explainability.model.Prediction) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) Saliency(org.kie.kogito.explainability.model.Saliency)

Example 10 with Type

use of org.kie.kogito.explainability.model.Type in project kogito-apps by kiegroup.

the class PartialDependencePlotExplainer method collapseMarginalImpacts.

/**
 * Collapse value counts into marginal impacts.
 * For numbers ({@code Type.NUMBER.equals(type))} this is just the average of each value at each feature value.
 * For all other types the final {@link Value} is just the most frequent.
 *
 * @param valueCounts the frequency of each value at each position
 * @param type the type of the output
 * @return the marginal impacts
 */
private List<Value> collapseMarginalImpacts(List<Map<Value, Long>> valueCounts, Type type) {
    List<Value> yValues = new ArrayList<>();
    if (Type.NUMBER.equals(type)) {
        List<Double> doubles = valueCounts.stream().map(v -> v.entrySet().stream().map(e -> e.getKey().asNumber() * e.getValue() / config.getSeriesLength()).mapToDouble(d -> d).sum()).collect(Collectors.toList());
        yValues = doubles.stream().map(Value::new).collect(Collectors.toList());
    } else {
        for (Map<Value, Long> item : valueCounts) {
            long max = 0;
            String output = null;
            for (Map.Entry<Value, Long> entry : item.entrySet()) {
                if (entry.getValue() > max) {
                    max = entry.getValue();
                    output = entry.getKey().asString();
                }
            }
            yValues.add(new Value(output));
        }
    }
    return yValues;
}
Also used : FeatureFactory(org.kie.kogito.explainability.model.FeatureFactory) PredictionInputsDataDistribution(org.kie.kogito.explainability.model.PredictionInputsDataDistribution) Feature(org.kie.kogito.explainability.model.Feature) Prediction(org.kie.kogito.explainability.model.Prediction) LoggerFactory(org.slf4j.LoggerFactory) TimeoutException(java.util.concurrent.TimeoutException) HashMap(java.util.HashMap) Value(org.kie.kogito.explainability.model.Value) DataDistribution(org.kie.kogito.explainability.model.DataDistribution) ArrayList(java.util.ArrayList) PartialDependenceGraph(org.kie.kogito.explainability.model.PartialDependenceGraph) Map(java.util.Map) FeatureDistribution(org.kie.kogito.explainability.model.FeatureDistribution) GlobalExplainer(org.kie.kogito.explainability.global.GlobalExplainer) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Logger(org.slf4j.Logger) Collection(java.util.Collection) Collectors(java.util.stream.Collectors) Type(org.kie.kogito.explainability.model.Type) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) PredictionProviderMetadata(org.kie.kogito.explainability.model.PredictionProviderMetadata) ExecutionException(java.util.concurrent.ExecutionException) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) List(java.util.List) Output(org.kie.kogito.explainability.model.Output) Comparator(java.util.Comparator) Config(org.kie.kogito.explainability.Config) Value(org.kie.kogito.explainability.model.Value) ArrayList(java.util.ArrayList) HashMap(java.util.HashMap) Map(java.util.Map)

Aggregations

Type (org.kie.kogito.explainability.model.Type)15 Value (org.kie.kogito.explainability.model.Value)10 Feature (org.kie.kogito.explainability.model.Feature)8 ArrayList (java.util.ArrayList)6 PredictionInput (org.kie.kogito.explainability.model.PredictionInput)6 List (java.util.List)5 Output (org.kie.kogito.explainability.model.Output)5 PredictionOutput (org.kie.kogito.explainability.model.PredictionOutput)5 HashMap (java.util.HashMap)4 LinkedList (java.util.LinkedList)4 Map (java.util.Map)4 Collectors (java.util.stream.Collectors)4 DataDistribution (org.kie.kogito.explainability.model.DataDistribution)4 HasNameValue (org.kie.kogito.explainability.api.HasNameValue)3 Prediction (org.kie.kogito.explainability.model.Prediction)3 PredictionProvider (org.kie.kogito.explainability.model.PredictionProvider)3 Logger (org.slf4j.Logger)3 LoggerFactory (org.slf4j.LoggerFactory)3 LocalTime (java.time.LocalTime)2 Collection (java.util.Collection)2