use of org.kie.kogito.explainability.model.Saliency in project kogito-apps by kiegroup.
the class LimeExplainerTest method testZeroSampleSize.
@Test
void testZeroSampleSize() throws ExecutionException, InterruptedException, TimeoutException {
LimeConfig limeConfig = new LimeConfig().withSamples(0);
LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
List<Feature> features = new ArrayList<>();
for (int i = 0; i < 4; i++) {
features.add(TestUtils.getMockedNumericFeature(i));
}
PredictionInput input = new PredictionInput(features);
PredictionProvider model = TestUtils.getSumSkipModel(0);
PredictionOutput output = model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit()).get(0);
Prediction prediction = new SimplePrediction(input, output);
Map<String, Saliency> saliencyMap = limeExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
assertNotNull(saliencyMap);
}
use of org.kie.kogito.explainability.model.Saliency in project kogito-apps by kiegroup.
the class LimeExplainerTest method testDeterministic.
@ParameterizedTest
@ValueSource(longs = { 0, 1, 2, 3, 4 })
void testDeterministic(long seed) throws ExecutionException, InterruptedException, TimeoutException {
List<Saliency> saliencies = new ArrayList<>();
for (int j = 0; j < 2; j++) {
Random random = new Random();
LimeConfig limeConfig = new LimeConfig().withPerturbationContext(new PerturbationContext(seed, random, DEFAULT_NO_OF_PERTURBATIONS)).withSamples(10);
LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
List<Feature> features = new ArrayList<>();
for (int i = 0; i < 4; i++) {
features.add(TestUtils.getMockedNumericFeature(i));
}
PredictionInput input = new PredictionInput(features);
PredictionProvider model = TestUtils.getSumSkipModel(0);
PredictionOutput output = model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit()).get(0);
Prediction prediction = new SimplePrediction(input, output);
Map<String, Saliency> saliencyMap = limeExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
saliencies.add(saliencyMap.get("sum-but0"));
}
assertThat(saliencies.get(0).getPerFeatureImportance().stream().map(FeatureImportance::getScore).collect(Collectors.toList())).isEqualTo(saliencies.get(1).getPerFeatureImportance().stream().map(FeatureImportance::getScore).collect(Collectors.toList()));
}
use of org.kie.kogito.explainability.model.Saliency in project kogito-apps by kiegroup.
the class LimeExplainerTest method testNormalizedWeights.
@Test
void testNormalizedWeights() throws InterruptedException, ExecutionException, TimeoutException {
Random random = new Random();
LimeConfig limeConfig = new LimeConfig().withNormalizeWeights(true).withPerturbationContext(new PerturbationContext(4L, random, 2)).withSamples(10);
LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
int nf = 4;
List<Feature> features = new ArrayList<>();
for (int i = 0; i < nf; i++) {
features.add(TestUtils.getMockedNumericFeature(i));
}
PredictionInput input = new PredictionInput(features);
PredictionProvider model = TestUtils.getSumSkipModel(0);
PredictionOutput output = model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit()).get(0);
Prediction prediction = new SimplePrediction(input, output);
Map<String, Saliency> saliencyMap = limeExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
assertThat(saliencyMap).isNotNull();
String decisionName = "sum-but0";
Saliency saliency = saliencyMap.get(decisionName);
List<FeatureImportance> perFeatureImportance = saliency.getPerFeatureImportance();
for (FeatureImportance featureImportance : perFeatureImportance) {
assertThat(featureImportance.getScore()).isBetween(0d, 1d);
}
}
use of org.kie.kogito.explainability.model.Saliency in project kogito-apps by kiegroup.
the class LimeStabilityTest method assertStable.
private void assertStable(LimeExplainer limeExplainer, PredictionProvider model, List<Feature> featureList) throws Exception {
PredictionInput input = new PredictionInput(featureList);
List<PredictionOutput> predictionOutputs = model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
for (PredictionOutput predictionOutput : predictionOutputs) {
Prediction prediction = new SimplePrediction(input, predictionOutput);
List<Saliency> saliencies = new LinkedList<>();
for (int i = 0; i < 100; i++) {
Map<String, Saliency> saliencyMap = limeExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
saliencies.addAll(saliencyMap.values());
}
// check that the topmost important feature is stable
List<String> names = new LinkedList<>();
saliencies.stream().map(s -> s.getPositiveFeatures(1)).filter(f -> !f.isEmpty()).forEach(f -> names.add(f.get(0).getFeature().getName()));
Map<String, Long> frequencyMap = names.stream().collect(Collectors.groupingBy(Function.identity(), Collectors.counting()));
boolean topFeature = false;
for (Map.Entry<String, Long> entry : frequencyMap.entrySet()) {
if (entry.getValue() >= TOP_FEATURE_THRESHOLD) {
topFeature = true;
break;
}
}
assertTrue(topFeature);
// check that the impact is stable
List<Double> impacts = new ArrayList<>(saliencies.size());
for (Saliency saliency : saliencies) {
double v = ExplainabilityMetrics.impactScore(model, prediction, saliency.getTopFeatures(2));
impacts.add(v);
}
Map<Double, Long> impactMap = impacts.stream().collect(Collectors.groupingBy(Function.identity(), Collectors.counting()));
boolean topImpact = false;
for (Map.Entry<Double, Long> entry : impactMap.entrySet()) {
if (entry.getValue() >= TOP_FEATURE_THRESHOLD) {
topImpact = true;
break;
}
}
assertTrue(topImpact);
}
}
use of org.kie.kogito.explainability.model.Saliency in project kogito-apps by kiegroup.
the class ExplainabilityMetrics method classificationFidelity.
/**
* Calculate fidelity (accuracy) of boolean classification outputs using saliency predictor function = sign(sum(saliency.scores))
* See papers:
* - Guidotti Riccardo, et al. "A survey of methods for explaining black box models." ACM computing surveys (2018).
* - Bodria, Francesco, et al. "Explainability Methods for Natural Language Processing: Applications to Sentiment Analysis (Discussion Paper)."
*
* @param pairs pairs composed by the saliency and the related prediction
* @return the fidelity accuracy
*/
public static double classificationFidelity(List<Pair<Saliency, Prediction>> pairs) {
double acc = 0;
double evals = 0;
for (Pair<Saliency, Prediction> pair : pairs) {
Saliency saliency = pair.getLeft();
Prediction prediction = pair.getRight();
for (Output output : prediction.getOutput().getOutputs()) {
Type type = output.getType();
if (Type.BOOLEAN.equals(type)) {
double predictorOutput = saliency.getPerFeatureImportance().stream().map(FeatureImportance::getScore).mapToDouble(d -> d).sum();
double v = output.getValue().asNumber();
if ((v >= 0 && predictorOutput >= 0) || (v < 0 && predictorOutput < 0)) {
acc++;
}
evals++;
}
}
}
return evals == 0 ? 0 : acc / evals;
}
Aggregations