use of org.kie.kogito.explainability.model.Prediction in project kogito-apps by kiegroup.
the class LimeImpactScoreCalculator method calculateScore.
@Override
public SimpleBigDecimalScore calculateScore(LimeConfigSolution solution) {
LimeConfig config = LimeConfigEntityFactory.toLimeConfig(solution);
BigDecimal impactScore = BigDecimal.ZERO;
List<Prediction> predictions = solution.getPredictions();
if (!predictions.isEmpty()) {
impactScore = getImpactScore(solution, config, predictions);
}
return SimpleBigDecimalScore.of(impactScore);
}
use of org.kie.kogito.explainability.model.Prediction in project kogito-apps by kiegroup.
the class LimeStabilityScoreCalculator method calculateScore.
@Override
public SimpleBigDecimalScore calculateScore(LimeConfigSolution solution) {
LimeConfig config = LimeConfigEntityFactory.toLimeConfig(solution);
BigDecimal stabilityScore = BigDecimal.ZERO;
List<Prediction> predictions = solution.getPredictions();
if (!predictions.isEmpty()) {
stabilityScore = getStabilityScore(solution.getModel(), config, predictions);
}
return SimpleBigDecimalScore.of(stabilityScore);
}
use of org.kie.kogito.explainability.model.Prediction in project kogito-apps by kiegroup.
the class HighScoreNumericFeatureZonesProvider method getHighScoreFeatureZones.
/**
* Get a map of feature-name -> high score feature zones. Predictions in data distribution are sorted by (descending)
* score, then the (aggregated) mean score is calculated and all the data points that are associated with a prediction
* having a score between the mean and the maximum are selected (feature-wise), with an associated tolerance
* (the stdDev of the high score feature points).
*
* @param dataDistribution a data distribution
* @param predictionProvider the model used to score the inputs
* @param features the list of features to associate high score points with
* @param maxNoOfSamples max no. of inputs used for discovering high score zones
* @return a map feature name -> high score numeric feature zones
*/
public static Map<String, HighScoreNumericFeatureZones> getHighScoreFeatureZones(DataDistribution dataDistribution, PredictionProvider predictionProvider, List<Feature> features, int maxNoOfSamples) {
Map<String, HighScoreNumericFeatureZones> numericFeatureZonesMap = new HashMap<>();
List<Prediction> scoreSortedPredictions = new ArrayList<>();
try {
scoreSortedPredictions.addAll(DataUtils.getScoreSortedPredictions(predictionProvider, new PredictionInputsDataDistribution(dataDistribution.sample(maxNoOfSamples))));
} catch (ExecutionException e) {
LOGGER.error("Could not sort predictions by score {}", e.getMessage());
} catch (InterruptedException e) {
LOGGER.error("Interrupted while waiting for sorting predictions by score {}", e.getMessage());
Thread.currentThread().interrupt();
} catch (TimeoutException e) {
LOGGER.error("Timed out while waiting for sorting predictions by score", e);
}
if (!scoreSortedPredictions.isEmpty()) {
// calculate min, max and mean scores
double max = scoreSortedPredictions.get(0).getOutput().getOutputs().stream().mapToDouble(Output::getScore).sum();
double min = scoreSortedPredictions.get(scoreSortedPredictions.size() - 1).getOutput().getOutputs().stream().mapToDouble(Output::getScore).sum();
if (max != min) {
double threshold = scoreSortedPredictions.stream().map(p -> p.getOutput().getOutputs().stream().mapToDouble(Output::getScore).sum()).mapToDouble(d -> d).average().orElse((max + min) / 2);
// filter out predictions whose score is in [min, threshold]
scoreSortedPredictions = scoreSortedPredictions.stream().filter(p -> p.getOutput().getOutputs().stream().mapToDouble(Output::getScore).sum() > threshold).collect(Collectors.toList());
for (int j = 0; j < features.size(); j++) {
Feature feature = features.get(j);
if (Type.NUMBER.equals(feature.getType())) {
int finalJ = j;
// get feature values associated with high score inputs
List<Double> topValues = scoreSortedPredictions.stream().map(prediction -> prediction.getInput().getFeatures().get(finalJ).getValue().asNumber()).distinct().collect(Collectors.toList());
// get high score points and tolerance
double[] highScoreFeaturePoints = topValues.stream().flatMapToDouble(DoubleStream::of).toArray();
double center = DataUtils.getMean(highScoreFeaturePoints);
double tolerance = DataUtils.getStdDev(highScoreFeaturePoints, center) / 2;
HighScoreNumericFeatureZones highScoreNumericFeatureZones = new HighScoreNumericFeatureZones(highScoreFeaturePoints, tolerance);
numericFeatureZonesMap.put(feature.getName(), highScoreNumericFeatureZones);
}
}
}
}
return numericFeatureZonesMap;
}
use of org.kie.kogito.explainability.model.Prediction in project kogito-apps by kiegroup.
the class CounterfactualExplainerServiceHandlerTest method testGetPredictionWithFlatOutputModelReordered.
@Test
public void testGetPredictionWithFlatOutputModelReordered() {
CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, Collections.emptyList(), List.of(new NamedTypedValue("inputsAreValid", new UnitValue("boolean", BooleanNode.FALSE)), new NamedTypedValue("canRequestLoan", new UnitValue("booelan", BooleanNode.TRUE)), new NamedTypedValue("my-scoring-function", new UnitValue("number", new DoubleNode(0.85)))), Collections.emptyList(), MAX_RUNNING_TIME_SECONDS);
Prediction prediction = handler.getPrediction(request);
assertTrue(prediction instanceof CounterfactualPrediction);
CounterfactualPrediction counterfactualPrediction = (CounterfactualPrediction) prediction;
List<Output> outputs = counterfactualPrediction.getOutput().getOutputs();
assertEquals(3, outputs.size());
Output output1 = outputs.get(0);
assertEquals("my-scoring-function", output1.getName());
assertEquals(Type.NUMBER, output1.getType());
assertEquals(0.85, output1.getValue().asNumber());
Output output2 = outputs.get(1);
assertEquals("inputsAreValid", output2.getName());
assertEquals(Type.BOOLEAN, output2.getType());
assertEquals(Boolean.FALSE, output2.getValue().getUnderlyingObject());
Output output3 = outputs.get(2);
assertEquals("canRequestLoan", output3.getName());
assertEquals(Type.BOOLEAN, output3.getType());
assertEquals(Boolean.TRUE, output3.getValue().getUnderlyingObject());
assertTrue(counterfactualPrediction.getInput().getFeatures().isEmpty());
assertEquals(counterfactualPrediction.getMaxRunningTimeSeconds(), request.getMaxRunningTimeSeconds());
}
use of org.kie.kogito.explainability.model.Prediction in project kogito-apps by kiegroup.
the class JITDMNServiceImpl method evaluateModelAndExplain.
public DMNResultWithExplanation evaluateModelAndExplain(DMNEvaluator dmnEvaluator, Map<String, Object> context) {
LocalDMNPredictionProvider localDMNPredictionProvider = new LocalDMNPredictionProvider(dmnEvaluator);
DMNResult dmnResult = dmnEvaluator.evaluate(context);
Prediction prediction = new SimplePrediction(LocalDMNPredictionProvider.toPredictionInput(context), LocalDMNPredictionProvider.toPredictionOutput(dmnResult));
LimeConfig limeConfig = new LimeConfig().withSamples(explainabilityLimeSampleSize).withPerturbationContext(new PerturbationContext(new Random(), explainabilityLimeNoOfPerturbation));
LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
Map<String, Saliency> saliencyMap;
try {
saliencyMap = limeExplainer.explainAsync(prediction, localDMNPredictionProvider).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
} catch (TimeoutException | InterruptedException | ExecutionException e) {
if (e instanceof InterruptedException) {
LOGGER.error("Critical InterruptedException occurred", e);
Thread.currentThread().interrupt();
}
return new DMNResultWithExplanation(new JITDMNResult(dmnEvaluator.getNamespace(), dmnEvaluator.getName(), dmnResult), new SalienciesResponse(EXPLAINABILITY_FAILED, EXPLAINABILITY_FAILED_MESSAGE, null));
}
List<SaliencyResponse> saliencyModelResponse = buildSalienciesResponse(dmnEvaluator.getDmnModel(), saliencyMap);
return new DMNResultWithExplanation(new JITDMNResult(dmnEvaluator.getNamespace(), dmnEvaluator.getName(), dmnResult), new SalienciesResponse(EXPLAINABILITY_SUCCEEDED, null, saliencyModelResponse));
}
Aggregations