Search in sources :

Example 96 with Prediction

use of org.kie.kogito.explainability.model.Prediction in project kogito-apps by kiegroup.

the class LimeCombinedScoreCalculatorTest method testScoreWithEmptyPredictions.

@Test
void testScoreWithEmptyPredictions() {
    LimeCombinedScoreCalculator scoreCalculator = new LimeCombinedScoreCalculator();
    LimeConfig config = new LimeConfig();
    List<Prediction> predictions = Collections.emptyList();
    List<LimeConfigEntity> entities = Collections.emptyList();
    PredictionProvider model = TestUtils.getDummyTextClassifier();
    LimeConfigSolution solution = new LimeConfigSolution(config, predictions, entities, model);
    SimpleBigDecimalScore score = scoreCalculator.calculateScore(solution);
    assertThat(score).isNotNull();
    assertThat(score.getScore()).isNotNull().isEqualTo(BigDecimal.valueOf(0));
}
Also used : SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) Prediction(org.kie.kogito.explainability.model.Prediction) SimpleBigDecimalScore(org.optaplanner.core.api.score.buildin.simplebigdecimal.SimpleBigDecimalScore) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) LimeConfig(org.kie.kogito.explainability.local.lime.LimeConfig) Test(org.junit.jupiter.api.Test)

Example 97 with Prediction

use of org.kie.kogito.explainability.model.Prediction in project kogito-apps by kiegroup.

the class LimeConfigEntityFactoryTest method testConversion.

@Test
void testConversion() {
    PredictionProvider model = TestUtils.getDummyTextClassifier();
    LimeConfig config = new LimeConfig();
    List<Prediction> predictions = Collections.emptyList();
    List<LimeConfigEntity> entities = Collections.emptyList();
    LimeConfigSolution solution = new LimeConfigSolution(config, predictions, entities, model);
    LimeConfig limeConfig = LimeConfigEntityFactory.toLimeConfig(solution);
    assertThat(limeConfig).isNotNull();
}
Also used : Prediction(org.kie.kogito.explainability.model.Prediction) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) LimeConfig(org.kie.kogito.explainability.local.lime.LimeConfig) Test(org.junit.jupiter.api.Test)

Example 98 with Prediction

use of org.kie.kogito.explainability.model.Prediction in project kogito-apps by kiegroup.

the class ExplainabilityMetrics method classificationFidelity.

/**
 * Calculate fidelity (accuracy) of boolean classification outputs using saliency predictor function = sign(sum(saliency.scores))
 * See papers:
 * - Guidotti Riccardo, et al. "A survey of methods for explaining black box models." ACM computing surveys (2018).
 * - Bodria, Francesco, et al. "Explainability Methods for Natural Language Processing: Applications to Sentiment Analysis (Discussion Paper)."
 *
 * @param pairs pairs composed by the saliency and the related prediction
 * @return the fidelity accuracy
 */
public static double classificationFidelity(List<Pair<Saliency, Prediction>> pairs) {
    double acc = 0;
    double evals = 0;
    for (Pair<Saliency, Prediction> pair : pairs) {
        Saliency saliency = pair.getLeft();
        Prediction prediction = pair.getRight();
        for (Output output : prediction.getOutput().getOutputs()) {
            Type type = output.getType();
            if (Type.BOOLEAN.equals(type)) {
                double predictorOutput = saliency.getPerFeatureImportance().stream().map(FeatureImportance::getScore).mapToDouble(d -> d).sum();
                double v = output.getValue().asNumber();
                if ((v >= 0 && predictorOutput >= 0) || (v < 0 && predictorOutput < 0)) {
                    acc++;
                }
                evals++;
            }
        }
    }
    return evals == 0 ? 0 : acc / evals;
}
Also used : Feature(org.kie.kogito.explainability.model.Feature) Prediction(org.kie.kogito.explainability.model.Prediction) LoggerFactory(org.slf4j.LoggerFactory) TimeoutException(java.util.concurrent.TimeoutException) HashMap(java.util.HashMap) Function(java.util.function.Function) DataDistribution(org.kie.kogito.explainability.model.DataDistribution) Saliency(org.kie.kogito.explainability.model.Saliency) ArrayList(java.util.ArrayList) Pair(org.apache.commons.lang3.tuple.Pair) Map(java.util.Map) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Logger(org.slf4j.Logger) LocalExplainer(org.kie.kogito.explainability.local.LocalExplainer) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) Collectors(java.util.stream.Collectors) Type(org.kie.kogito.explainability.model.Type) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) ExecutionException(java.util.concurrent.ExecutionException) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) List(java.util.List) Output(org.kie.kogito.explainability.model.Output) Optional(java.util.Optional) Comparator(java.util.Comparator) Config(org.kie.kogito.explainability.Config) Collections(java.util.Collections) Type(org.kie.kogito.explainability.model.Type) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) Prediction(org.kie.kogito.explainability.model.Prediction) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) Saliency(org.kie.kogito.explainability.model.Saliency)

Example 99 with Prediction

use of org.kie.kogito.explainability.model.Prediction in project kogito-apps by kiegroup.

the class FairnessMetrics method countMatchingOutputSelector.

/**
 * Count true / false favorable and true / false unfavorable outputs with respect to a specified output selector.
 *
 * @param dataset dataset used to match predictions with labels
 * @param predictionOutputs predictions to match with the dataset labels
 * @param outputSelector selector to define positive labelled samples / predictions
 * @return a map containing counts for true positives ("tp"), true negatives ("tn"), false positives ("fp"), false negatives ("fn")
 */
private static Map<String, Integer> countMatchingOutputSelector(Dataset dataset, List<PredictionOutput> predictionOutputs, Predicate<PredictionOutput> outputSelector) {
    assert predictionOutputs.size() == dataset.getData().size() : "dataset and predictions must have same size";
    int tp = 0;
    int tn = 0;
    int fp = 0;
    int fn = 0;
    int i = 0;
    for (Prediction trainingExample : dataset.getData()) {
        if (outputSelector.test(trainingExample.getOutput())) {
            // positive
            if (outputSelector.test(predictionOutputs.get(i))) {
                tp++;
            } else {
                fn++;
            }
        } else {
            // negative
            if (outputSelector.test(predictionOutputs.get(i))) {
                fp++;
            } else {
                tn++;
            }
        }
        i++;
    }
    Map<String, Integer> map = new HashMap<>();
    map.put("tp", tp);
    map.put("tn", tn);
    map.put("fp", fp);
    map.put("fn", fn);
    return map;
}
Also used : HashMap(java.util.HashMap) Prediction(org.kie.kogito.explainability.model.Prediction)

Example 100 with Prediction

use of org.kie.kogito.explainability.model.Prediction in project kogito-apps by kiegroup.

the class LimeStabilityScoreCalculator method getStabilityScore.

private BigDecimal getStabilityScore(PredictionProvider model, LimeConfig config, List<Prediction> predictions) {
    double succeededEvaluations = 0;
    int topK = 2;
    BigDecimal stabilityScore = BigDecimal.ZERO;
    LimeExplainer limeExplainer = new LimeExplainer(config);
    for (Prediction prediction : predictions) {
        try {
            LocalSaliencyStability stability = ExplainabilityMetrics.getLocalSaliencyStability(model, prediction, limeExplainer, topK, NUM_RUNS);
            for (String decision : stability.getDecisions()) {
                BigDecimal decisionMarginalScore = getDecisionMarginalScore(stability, decision, topK);
                stabilityScore = stabilityScore.add(decisionMarginalScore);
                succeededEvaluations++;
            }
        } catch (ExecutionException e) {
            LOGGER.error("Saliency stability calculation returned an error {}", e.getMessage());
        } catch (InterruptedException e) {
            LOGGER.error("Interrupted while waiting for saliency stability calculation {}", e.getMessage());
            Thread.currentThread().interrupt();
        } catch (TimeoutException e) {
            LOGGER.error("Timed out while waiting for saliency stability calculation", e);
        }
    }
    if (succeededEvaluations > 0) {
        stabilityScore = stabilityScore.divide(BigDecimal.valueOf(succeededEvaluations), RoundingMode.CEILING);
    }
    return stabilityScore;
}
Also used : LimeExplainer(org.kie.kogito.explainability.local.lime.LimeExplainer) Prediction(org.kie.kogito.explainability.model.Prediction) LocalSaliencyStability(org.kie.kogito.explainability.utils.LocalSaliencyStability) ExecutionException(java.util.concurrent.ExecutionException) BigDecimal(java.math.BigDecimal) TimeoutException(java.util.concurrent.TimeoutException)

Aggregations

Prediction (org.kie.kogito.explainability.model.Prediction)134 PredictionProvider (org.kie.kogito.explainability.model.PredictionProvider)117 PredictionOutput (org.kie.kogito.explainability.model.PredictionOutput)107 PredictionInput (org.kie.kogito.explainability.model.PredictionInput)105 SimplePrediction (org.kie.kogito.explainability.model.SimplePrediction)96 Test (org.junit.jupiter.api.Test)95 Random (java.util.Random)65 PerturbationContext (org.kie.kogito.explainability.model.PerturbationContext)61 LimeConfig (org.kie.kogito.explainability.local.lime.LimeConfig)57 ArrayList (java.util.ArrayList)51 Feature (org.kie.kogito.explainability.model.Feature)48 Saliency (org.kie.kogito.explainability.model.Saliency)48 ParameterizedTest (org.junit.jupiter.params.ParameterizedTest)42 LimeExplainer (org.kie.kogito.explainability.local.lime.LimeExplainer)40 LimeConfigOptimizer (org.kie.kogito.explainability.local.lime.optim.LimeConfigOptimizer)28 DataDistribution (org.kie.kogito.explainability.model.DataDistribution)24 ValueSource (org.junit.jupiter.params.provider.ValueSource)22 FeatureImportance (org.kie.kogito.explainability.model.FeatureImportance)22 Output (org.kie.kogito.explainability.model.Output)22 LinkedList (java.util.LinkedList)21