use of org.kie.kogito.explainability.utils.LocalSaliencyStability in project kogito-apps by kiegroup.
the class LimeStabilityTest method testStabilityDeterministic.
@ParameterizedTest
@ValueSource(longs = { 0, 1, 2, 3, 4 })
void testStabilityDeterministic(long seed) throws Exception {
List<LocalSaliencyStability> stabilities = new ArrayList<>();
for (int j = 0; j < 2; j++) {
Random random = new Random();
PredictionProvider model = TestUtils.getSumSkipModel(0);
List<Feature> featureList = new LinkedList<>();
for (int i = 0; i < 5; i++) {
featureList.add(TestUtils.getMockedNumericFeature(i));
}
PredictionInput input = new PredictionInput(featureList);
List<PredictionOutput> predictionOutputs = model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
Prediction prediction = new SimplePrediction(input, predictionOutputs.get(0));
LimeConfig limeConfig = new LimeConfig().withSamples(10).withPerturbationContext(new PerturbationContext(seed, random, 1));
LimeExplainer explainer = new LimeExplainer(limeConfig);
LocalSaliencyStability stability = ExplainabilityMetrics.getLocalSaliencyStability(model, prediction, explainer, 2, 10);
stabilities.add(stability);
}
LocalSaliencyStability first = stabilities.get(0);
LocalSaliencyStability second = stabilities.get(1);
String decisionName = "sum-but0";
assertThat(first.getNegativeStabilityScore(decisionName, 1)).isEqualTo(second.getNegativeStabilityScore(decisionName, 1));
assertThat(first.getPositiveStabilityScore(decisionName, 1)).isEqualTo(second.getPositiveStabilityScore(decisionName, 1));
assertThat(first.getNegativeStabilityScore(decisionName, 2)).isEqualTo(second.getNegativeStabilityScore(decisionName, 2));
assertThat(first.getPositiveStabilityScore(decisionName, 2)).isEqualTo(second.getPositiveStabilityScore(decisionName, 2));
}
use of org.kie.kogito.explainability.utils.LocalSaliencyStability in project kogito-apps by kiegroup.
the class LimeStabilityScoreCalculator method getStabilityScore.
private BigDecimal getStabilityScore(PredictionProvider model, LimeConfig config, List<Prediction> predictions) {
double succeededEvaluations = 0;
int topK = 2;
BigDecimal stabilityScore = BigDecimal.ZERO;
LimeExplainer limeExplainer = new LimeExplainer(config);
for (Prediction prediction : predictions) {
try {
LocalSaliencyStability stability = ExplainabilityMetrics.getLocalSaliencyStability(model, prediction, limeExplainer, topK, NUM_RUNS);
for (String decision : stability.getDecisions()) {
BigDecimal decisionMarginalScore = getDecisionMarginalScore(stability, decision, topK);
stabilityScore = stabilityScore.add(decisionMarginalScore);
succeededEvaluations++;
}
} catch (ExecutionException e) {
LOGGER.error("Saliency stability calculation returned an error {}", e.getMessage());
} catch (InterruptedException e) {
LOGGER.error("Interrupted while waiting for saliency stability calculation {}", e.getMessage());
Thread.currentThread().interrupt();
} catch (TimeoutException e) {
LOGGER.error("Timed out while waiting for saliency stability calculation", e);
}
}
if (succeededEvaluations > 0) {
stabilityScore = stabilityScore.divide(BigDecimal.valueOf(succeededEvaluations), RoundingMode.CEILING);
}
return stabilityScore;
}
Aggregations