Search in sources :

Example 1 with LocalSaliencyStability

use of org.kie.kogito.explainability.utils.LocalSaliencyStability in project kogito-apps by kiegroup.

the class LimeStabilityTest method testStabilityDeterministic.

@ParameterizedTest
@ValueSource(longs = { 0, 1, 2, 3, 4 })
void testStabilityDeterministic(long seed) throws Exception {
    List<LocalSaliencyStability> stabilities = new ArrayList<>();
    for (int j = 0; j < 2; j++) {
        Random random = new Random();
        PredictionProvider model = TestUtils.getSumSkipModel(0);
        List<Feature> featureList = new LinkedList<>();
        for (int i = 0; i < 5; i++) {
            featureList.add(TestUtils.getMockedNumericFeature(i));
        }
        PredictionInput input = new PredictionInput(featureList);
        List<PredictionOutput> predictionOutputs = model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        Prediction prediction = new SimplePrediction(input, predictionOutputs.get(0));
        LimeConfig limeConfig = new LimeConfig().withSamples(10).withPerturbationContext(new PerturbationContext(seed, random, 1));
        LimeExplainer explainer = new LimeExplainer(limeConfig);
        LocalSaliencyStability stability = ExplainabilityMetrics.getLocalSaliencyStability(model, prediction, explainer, 2, 10);
        stabilities.add(stability);
    }
    LocalSaliencyStability first = stabilities.get(0);
    LocalSaliencyStability second = stabilities.get(1);
    String decisionName = "sum-but0";
    assertThat(first.getNegativeStabilityScore(decisionName, 1)).isEqualTo(second.getNegativeStabilityScore(decisionName, 1));
    assertThat(first.getPositiveStabilityScore(decisionName, 1)).isEqualTo(second.getPositiveStabilityScore(decisionName, 1));
    assertThat(first.getNegativeStabilityScore(decisionName, 2)).isEqualTo(second.getNegativeStabilityScore(decisionName, 2));
    assertThat(first.getPositiveStabilityScore(decisionName, 2)).isEqualTo(second.getPositiveStabilityScore(decisionName, 2));
}
Also used : SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) Prediction(org.kie.kogito.explainability.model.Prediction) ArrayList(java.util.ArrayList) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Feature(org.kie.kogito.explainability.model.Feature) LinkedList(java.util.LinkedList) Random(java.util.Random) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) LocalSaliencyStability(org.kie.kogito.explainability.utils.LocalSaliencyStability) ValueSource(org.junit.jupiter.params.provider.ValueSource) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Example 2 with LocalSaliencyStability

use of org.kie.kogito.explainability.utils.LocalSaliencyStability in project kogito-apps by kiegroup.

the class LimeStabilityScoreCalculator method getStabilityScore.

private BigDecimal getStabilityScore(PredictionProvider model, LimeConfig config, List<Prediction> predictions) {
    double succeededEvaluations = 0;
    int topK = 2;
    BigDecimal stabilityScore = BigDecimal.ZERO;
    LimeExplainer limeExplainer = new LimeExplainer(config);
    for (Prediction prediction : predictions) {
        try {
            LocalSaliencyStability stability = ExplainabilityMetrics.getLocalSaliencyStability(model, prediction, limeExplainer, topK, NUM_RUNS);
            for (String decision : stability.getDecisions()) {
                BigDecimal decisionMarginalScore = getDecisionMarginalScore(stability, decision, topK);
                stabilityScore = stabilityScore.add(decisionMarginalScore);
                succeededEvaluations++;
            }
        } catch (ExecutionException e) {
            LOGGER.error("Saliency stability calculation returned an error {}", e.getMessage());
        } catch (InterruptedException e) {
            LOGGER.error("Interrupted while waiting for saliency stability calculation {}", e.getMessage());
            Thread.currentThread().interrupt();
        } catch (TimeoutException e) {
            LOGGER.error("Timed out while waiting for saliency stability calculation", e);
        }
    }
    if (succeededEvaluations > 0) {
        stabilityScore = stabilityScore.divide(BigDecimal.valueOf(succeededEvaluations), RoundingMode.CEILING);
    }
    return stabilityScore;
}
Also used : LimeExplainer(org.kie.kogito.explainability.local.lime.LimeExplainer) Prediction(org.kie.kogito.explainability.model.Prediction) LocalSaliencyStability(org.kie.kogito.explainability.utils.LocalSaliencyStability) ExecutionException(java.util.concurrent.ExecutionException) BigDecimal(java.math.BigDecimal) TimeoutException(java.util.concurrent.TimeoutException)

Aggregations

Prediction (org.kie.kogito.explainability.model.Prediction)2 LocalSaliencyStability (org.kie.kogito.explainability.utils.LocalSaliencyStability)2 BigDecimal (java.math.BigDecimal)1 ArrayList (java.util.ArrayList)1 LinkedList (java.util.LinkedList)1 Random (java.util.Random)1 ExecutionException (java.util.concurrent.ExecutionException)1 TimeoutException (java.util.concurrent.TimeoutException)1 ParameterizedTest (org.junit.jupiter.params.ParameterizedTest)1 ValueSource (org.junit.jupiter.params.provider.ValueSource)1 LimeExplainer (org.kie.kogito.explainability.local.lime.LimeExplainer)1 Feature (org.kie.kogito.explainability.model.Feature)1 PerturbationContext (org.kie.kogito.explainability.model.PerturbationContext)1 PredictionInput (org.kie.kogito.explainability.model.PredictionInput)1 PredictionOutput (org.kie.kogito.explainability.model.PredictionOutput)1 PredictionProvider (org.kie.kogito.explainability.model.PredictionProvider)1 SimplePrediction (org.kie.kogito.explainability.model.SimplePrediction)1