Search in sources :

Example 21 with FeatureImportance

use of org.kie.kogito.explainability.model.FeatureImportance in project kogito-apps by kiegroup.

the class LimeExplainerTest method testNormalizedWeights.

@Test
void testNormalizedWeights() throws InterruptedException, ExecutionException, TimeoutException {
    Random random = new Random();
    LimeConfig limeConfig = new LimeConfig().withNormalizeWeights(true).withPerturbationContext(new PerturbationContext(4L, random, 2)).withSamples(10);
    LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
    int nf = 4;
    List<Feature> features = new ArrayList<>();
    for (int i = 0; i < nf; i++) {
        features.add(TestUtils.getMockedNumericFeature(i));
    }
    PredictionInput input = new PredictionInput(features);
    PredictionProvider model = TestUtils.getSumSkipModel(0);
    PredictionOutput output = model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit()).get(0);
    Prediction prediction = new SimplePrediction(input, output);
    Map<String, Saliency> saliencyMap = limeExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    assertThat(saliencyMap).isNotNull();
    String decisionName = "sum-but0";
    Saliency saliency = saliencyMap.get(decisionName);
    List<FeatureImportance> perFeatureImportance = saliency.getPerFeatureImportance();
    for (FeatureImportance featureImportance : perFeatureImportance) {
        assertThat(featureImportance.getScore()).isBetween(0d, 1d);
    }
}
Also used : SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) Prediction(org.kie.kogito.explainability.model.Prediction) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) ArrayList(java.util.ArrayList) Saliency(org.kie.kogito.explainability.model.Saliency) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Feature(org.kie.kogito.explainability.model.Feature) Random(java.util.Random) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Test(org.junit.jupiter.api.Test) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Example 22 with FeatureImportance

use of org.kie.kogito.explainability.model.FeatureImportance in project kogito-apps by kiegroup.

the class ExplainabilityMetrics method classificationFidelity.

/**
 * Calculate fidelity (accuracy) of boolean classification outputs using saliency predictor function = sign(sum(saliency.scores))
 * See papers:
 * - Guidotti Riccardo, et al. "A survey of methods for explaining black box models." ACM computing surveys (2018).
 * - Bodria, Francesco, et al. "Explainability Methods for Natural Language Processing: Applications to Sentiment Analysis (Discussion Paper)."
 *
 * @param pairs pairs composed by the saliency and the related prediction
 * @return the fidelity accuracy
 */
public static double classificationFidelity(List<Pair<Saliency, Prediction>> pairs) {
    double acc = 0;
    double evals = 0;
    for (Pair<Saliency, Prediction> pair : pairs) {
        Saliency saliency = pair.getLeft();
        Prediction prediction = pair.getRight();
        for (Output output : prediction.getOutput().getOutputs()) {
            Type type = output.getType();
            if (Type.BOOLEAN.equals(type)) {
                double predictorOutput = saliency.getPerFeatureImportance().stream().map(FeatureImportance::getScore).mapToDouble(d -> d).sum();
                double v = output.getValue().asNumber();
                if ((v >= 0 && predictorOutput >= 0) || (v < 0 && predictorOutput < 0)) {
                    acc++;
                }
                evals++;
            }
        }
    }
    return evals == 0 ? 0 : acc / evals;
}
Also used : Feature(org.kie.kogito.explainability.model.Feature) Prediction(org.kie.kogito.explainability.model.Prediction) LoggerFactory(org.slf4j.LoggerFactory) TimeoutException(java.util.concurrent.TimeoutException) HashMap(java.util.HashMap) Function(java.util.function.Function) DataDistribution(org.kie.kogito.explainability.model.DataDistribution) Saliency(org.kie.kogito.explainability.model.Saliency) ArrayList(java.util.ArrayList) Pair(org.apache.commons.lang3.tuple.Pair) Map(java.util.Map) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Logger(org.slf4j.Logger) LocalExplainer(org.kie.kogito.explainability.local.LocalExplainer) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) Collectors(java.util.stream.Collectors) Type(org.kie.kogito.explainability.model.Type) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) ExecutionException(java.util.concurrent.ExecutionException) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) List(java.util.List) Output(org.kie.kogito.explainability.model.Output) Optional(java.util.Optional) Comparator(java.util.Comparator) Config(org.kie.kogito.explainability.Config) Collections(java.util.Collections) Type(org.kie.kogito.explainability.model.Type) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) Prediction(org.kie.kogito.explainability.model.Prediction) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) Saliency(org.kie.kogito.explainability.model.Saliency)

Example 23 with FeatureImportance

use of org.kie.kogito.explainability.model.FeatureImportance in project kogito-apps by kiegroup.

the class ExplainabilityMetrics method impactScore.

/**
 * Calculate the impact of dropping the most important features (given by {@link Saliency#getTopFeatures(int)} from the input.
 * Highly important features would have rather high impact.
 * See paper: Qiu Lin, Zhong, et al. "Do Explanations Reflect Decisions? A Machine-centric Strategy to Quantify the
 * Performance of Explainability Algorithms." 2019.
 *
 * @param model the model to be explained
 * @param prediction a prediction
 * @param topFeatures the list of important features that should be dropped
 * @return the saliency impact
 */
public static double impactScore(PredictionProvider model, Prediction prediction, List<FeatureImportance> topFeatures) throws InterruptedException, ExecutionException, TimeoutException {
    List<Feature> copy = List.copyOf(prediction.getInput().getFeatures());
    for (FeatureImportance featureImportance : topFeatures) {
        copy = DataUtils.dropFeature(copy, featureImportance.getFeature());
    }
    PredictionInput predictionInput = new PredictionInput(copy);
    List<PredictionOutput> predictionOutputs;
    try {
        predictionOutputs = model.predictAsync(List.of(predictionInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    } catch (ExecutionException | TimeoutException e) {
        LOGGER.error("Impossible to obtain prediction {}", e.getMessage());
        throw new IllegalStateException("Impossible to obtain prediction", e);
    } catch (InterruptedException e) {
        Thread.currentThread().interrupt();
        throw new IllegalStateException("Impossible to obtain prediction (Thread interrupted)", e);
    }
    double impact = 0d;
    for (PredictionOutput predictionOutput : predictionOutputs) {
        double size = predictionOutput.getOutputs().size();
        for (int i = 0; i < size; i++) {
            Output original = prediction.getOutput().getOutputs().get(i);
            Output modified = predictionOutput.getOutputs().get(i);
            impact += (!original.getValue().asString().equals(modified.getValue().asString()) || modified.getScore() < original.getScore() * CONFIDENCE_DROP_RATIO) ? 1d / size : 0d;
        }
    }
    return impact;
}
Also used : PredictionInput(org.kie.kogito.explainability.model.PredictionInput) Feature(org.kie.kogito.explainability.model.Feature) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) ExecutionException(java.util.concurrent.ExecutionException) TimeoutException(java.util.concurrent.TimeoutException)

Example 24 with FeatureImportance

use of org.kie.kogito.explainability.model.FeatureImportance in project kogito-apps by kiegroup.

the class ShapKernelExplainer method saliencyFromMatrix.

/**
 * Given an n x m matrix of n outputs and m feature importances, return an array of Saliencies
 *
 * @param m: The n x m matrix
 * @param pi: The prediction input
 * @param po: The prediction output
 *
 * @return an array of n saliencies, one for each output of the model. Each Saliency lists the feature
 *         importances of each input feature to that particular output
 */
public static Saliency[] saliencyFromMatrix(RealMatrix m, PredictionInput pi, PredictionOutput po) {
    Saliency[] saliencies = new Saliency[m.getRowDimension()];
    for (int i = 0; i < m.getRowDimension(); i++) {
        List<FeatureImportance> fis = new ArrayList<>();
        for (int j = 0; j < m.getColumnDimension(); j++) {
            fis.add(new FeatureImportance(pi.getFeatures().get(j), m.getEntry(i, j)));
        }
        saliencies[i] = new Saliency(po.getOutputs().get(i), fis);
    }
    return saliencies;
}
Also used : FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) ArrayList(java.util.ArrayList) Saliency(org.kie.kogito.explainability.model.Saliency)

Example 25 with FeatureImportance

use of org.kie.kogito.explainability.model.FeatureImportance in project kogito-apps by kiegroup.

the class PrequalificationDmnLimeExplainerTest method testPrequalificationDMNExplanation.

@Test
void testPrequalificationDMNExplanation() throws ExecutionException, InterruptedException, TimeoutException {
    PredictionProvider model = getModel();
    PredictionInput predictionInput = getTestInput();
    Random random = new Random();
    PerturbationContext perturbationContext = new PerturbationContext(0L, random, 1);
    LimeConfig limeConfig = new LimeConfig().withSamples(10).withPerturbationContext(perturbationContext);
    LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
    List<PredictionOutput> predictionOutputs = model.predictAsync(List.of(predictionInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    Prediction prediction = new SimplePrediction(predictionInput, predictionOutputs.get(0));
    Map<String, Saliency> saliencyMap = limeExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    for (Saliency saliency : saliencyMap.values()) {
        assertNotNull(saliency);
        List<FeatureImportance> topFeatures = saliency.getTopFeatures(2);
        if (!topFeatures.isEmpty()) {
            assertThat(ExplainabilityMetrics.impactScore(model, prediction, topFeatures)).isPositive();
        }
    }
    assertDoesNotThrow(() -> ValidationUtils.validateLocalSaliencyStability(model, prediction, limeExplainer, 1, 0.3, 0.3));
    String decision = "LLPA";
    List<PredictionInput> inputs = new ArrayList<>();
    for (int n = 0; n < 10; n++) {
        inputs.add(new PredictionInput(DataUtils.perturbFeatures(predictionInput.getFeatures(), perturbationContext)));
    }
    DataDistribution distribution = new PredictionInputsDataDistribution(inputs);
    int k = 2;
    int chunkSize = 2;
    double f1 = ExplainabilityMetrics.getLocalSaliencyF1(decision, model, limeExplainer, distribution, k, chunkSize);
    AssertionsForClassTypes.assertThat(f1).isBetween(0.5d, 1d);
}
Also used : SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) LimeExplainer(org.kie.kogito.explainability.local.lime.LimeExplainer) Prediction(org.kie.kogito.explainability.model.Prediction) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) ArrayList(java.util.ArrayList) Saliency(org.kie.kogito.explainability.model.Saliency) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) LimeConfig(org.kie.kogito.explainability.local.lime.LimeConfig) Random(java.util.Random) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) PredictionInputsDataDistribution(org.kie.kogito.explainability.model.PredictionInputsDataDistribution) DataDistribution(org.kie.kogito.explainability.model.DataDistribution) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) PredictionInputsDataDistribution(org.kie.kogito.explainability.model.PredictionInputsDataDistribution) Test(org.junit.jupiter.api.Test)

Aggregations

FeatureImportance (org.kie.kogito.explainability.model.FeatureImportance)25 Saliency (org.kie.kogito.explainability.model.Saliency)23 PredictionInput (org.kie.kogito.explainability.model.PredictionInput)19 PredictionOutput (org.kie.kogito.explainability.model.PredictionOutput)19 ArrayList (java.util.ArrayList)18 Prediction (org.kie.kogito.explainability.model.Prediction)18 Feature (org.kie.kogito.explainability.model.Feature)17 PredictionProvider (org.kie.kogito.explainability.model.PredictionProvider)16 Random (java.util.Random)14 SimplePrediction (org.kie.kogito.explainability.model.SimplePrediction)13 DataDistribution (org.kie.kogito.explainability.model.DataDistribution)12 PerturbationContext (org.kie.kogito.explainability.model.PerturbationContext)12 ParameterizedTest (org.junit.jupiter.params.ParameterizedTest)10 PredictionInputsDataDistribution (org.kie.kogito.explainability.model.PredictionInputsDataDistribution)10 ValueSource (org.junit.jupiter.params.provider.ValueSource)9 LinkedList (java.util.LinkedList)8 Test (org.junit.jupiter.api.Test)7 Output (org.kie.kogito.explainability.model.Output)7 LimeExplainer (org.kie.kogito.explainability.local.lime.LimeExplainer)6 LimeConfig (org.kie.kogito.explainability.local.lime.LimeConfig)5