Search in sources :

Example 11 with Prediction

use of org.kie.kogito.explainability.model.Prediction in project kogito-apps by kiegroup.

the class LimeImpactScoreCalculator method calculateScore.

@Override
public SimpleBigDecimalScore calculateScore(LimeConfigSolution solution) {
    LimeConfig config = LimeConfigEntityFactory.toLimeConfig(solution);
    BigDecimal impactScore = BigDecimal.ZERO;
    List<Prediction> predictions = solution.getPredictions();
    if (!predictions.isEmpty()) {
        impactScore = getImpactScore(solution, config, predictions);
    }
    return SimpleBigDecimalScore.of(impactScore);
}
Also used : Prediction(org.kie.kogito.explainability.model.Prediction) LimeConfig(org.kie.kogito.explainability.local.lime.LimeConfig) BigDecimal(java.math.BigDecimal)

Example 12 with Prediction

use of org.kie.kogito.explainability.model.Prediction in project kogito-apps by kiegroup.

the class LimeStabilityScoreCalculator method calculateScore.

@Override
public SimpleBigDecimalScore calculateScore(LimeConfigSolution solution) {
    LimeConfig config = LimeConfigEntityFactory.toLimeConfig(solution);
    BigDecimal stabilityScore = BigDecimal.ZERO;
    List<Prediction> predictions = solution.getPredictions();
    if (!predictions.isEmpty()) {
        stabilityScore = getStabilityScore(solution.getModel(), config, predictions);
    }
    return SimpleBigDecimalScore.of(stabilityScore);
}
Also used : Prediction(org.kie.kogito.explainability.model.Prediction) LimeConfig(org.kie.kogito.explainability.local.lime.LimeConfig) BigDecimal(java.math.BigDecimal)

Example 13 with Prediction

use of org.kie.kogito.explainability.model.Prediction in project kogito-apps by kiegroup.

the class HighScoreNumericFeatureZonesProvider method getHighScoreFeatureZones.

/**
 * Get a map of feature-name -> high score feature zones. Predictions in data distribution are sorted by (descending)
 * score, then the (aggregated) mean score is calculated and all the data points that are associated with a prediction
 * having a score between the mean and the maximum are selected (feature-wise), with an associated tolerance
 * (the stdDev of the high score feature points).
 *
 * @param dataDistribution a data distribution
 * @param predictionProvider the model used to score the inputs
 * @param features the list of features to associate high score points with
 * @param maxNoOfSamples max no. of inputs used for discovering high score zones
 * @return a map feature name -> high score numeric feature zones
 */
public static Map<String, HighScoreNumericFeatureZones> getHighScoreFeatureZones(DataDistribution dataDistribution, PredictionProvider predictionProvider, List<Feature> features, int maxNoOfSamples) {
    Map<String, HighScoreNumericFeatureZones> numericFeatureZonesMap = new HashMap<>();
    List<Prediction> scoreSortedPredictions = new ArrayList<>();
    try {
        scoreSortedPredictions.addAll(DataUtils.getScoreSortedPredictions(predictionProvider, new PredictionInputsDataDistribution(dataDistribution.sample(maxNoOfSamples))));
    } catch (ExecutionException e) {
        LOGGER.error("Could not sort predictions by score {}", e.getMessage());
    } catch (InterruptedException e) {
        LOGGER.error("Interrupted while waiting for sorting predictions by score {}", e.getMessage());
        Thread.currentThread().interrupt();
    } catch (TimeoutException e) {
        LOGGER.error("Timed out while waiting for sorting predictions by score", e);
    }
    if (!scoreSortedPredictions.isEmpty()) {
        // calculate min, max and mean scores
        double max = scoreSortedPredictions.get(0).getOutput().getOutputs().stream().mapToDouble(Output::getScore).sum();
        double min = scoreSortedPredictions.get(scoreSortedPredictions.size() - 1).getOutput().getOutputs().stream().mapToDouble(Output::getScore).sum();
        if (max != min) {
            double threshold = scoreSortedPredictions.stream().map(p -> p.getOutput().getOutputs().stream().mapToDouble(Output::getScore).sum()).mapToDouble(d -> d).average().orElse((max + min) / 2);
            // filter out predictions whose score is in [min, threshold]
            scoreSortedPredictions = scoreSortedPredictions.stream().filter(p -> p.getOutput().getOutputs().stream().mapToDouble(Output::getScore).sum() > threshold).collect(Collectors.toList());
            for (int j = 0; j < features.size(); j++) {
                Feature feature = features.get(j);
                if (Type.NUMBER.equals(feature.getType())) {
                    int finalJ = j;
                    // get feature values associated with high score inputs
                    List<Double> topValues = scoreSortedPredictions.stream().map(prediction -> prediction.getInput().getFeatures().get(finalJ).getValue().asNumber()).distinct().collect(Collectors.toList());
                    // get high score points and tolerance
                    double[] highScoreFeaturePoints = topValues.stream().flatMapToDouble(DoubleStream::of).toArray();
                    double center = DataUtils.getMean(highScoreFeaturePoints);
                    double tolerance = DataUtils.getStdDev(highScoreFeaturePoints, center) / 2;
                    HighScoreNumericFeatureZones highScoreNumericFeatureZones = new HighScoreNumericFeatureZones(highScoreFeaturePoints, tolerance);
                    numericFeatureZonesMap.put(feature.getName(), highScoreNumericFeatureZones);
                }
            }
        }
    }
    return numericFeatureZonesMap;
}
Also used : DataUtils(org.kie.kogito.explainability.utils.DataUtils) Logger(org.slf4j.Logger) PredictionInputsDataDistribution(org.kie.kogito.explainability.model.PredictionInputsDataDistribution) PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) Feature(org.kie.kogito.explainability.model.Feature) Prediction(org.kie.kogito.explainability.model.Prediction) LoggerFactory(org.slf4j.LoggerFactory) TimeoutException(java.util.concurrent.TimeoutException) HashMap(java.util.HashMap) Collectors(java.util.stream.Collectors) DataDistribution(org.kie.kogito.explainability.model.DataDistribution) Type(org.kie.kogito.explainability.model.Type) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) ArrayList(java.util.ArrayList) DoubleStream(java.util.stream.DoubleStream) ExecutionException(java.util.concurrent.ExecutionException) List(java.util.List) Map(java.util.Map) Output(org.kie.kogito.explainability.model.Output) HashMap(java.util.HashMap) Prediction(org.kie.kogito.explainability.model.Prediction) ArrayList(java.util.ArrayList) Feature(org.kie.kogito.explainability.model.Feature) ExecutionException(java.util.concurrent.ExecutionException) PredictionInputsDataDistribution(org.kie.kogito.explainability.model.PredictionInputsDataDistribution) TimeoutException(java.util.concurrent.TimeoutException)

Example 14 with Prediction

use of org.kie.kogito.explainability.model.Prediction in project kogito-apps by kiegroup.

the class CounterfactualExplainerServiceHandlerTest method testGetPredictionWithFlatOutputModelReordered.

@Test
public void testGetPredictionWithFlatOutputModelReordered() {
    CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, Collections.emptyList(), List.of(new NamedTypedValue("inputsAreValid", new UnitValue("boolean", BooleanNode.FALSE)), new NamedTypedValue("canRequestLoan", new UnitValue("booelan", BooleanNode.TRUE)), new NamedTypedValue("my-scoring-function", new UnitValue("number", new DoubleNode(0.85)))), Collections.emptyList(), MAX_RUNNING_TIME_SECONDS);
    Prediction prediction = handler.getPrediction(request);
    assertTrue(prediction instanceof CounterfactualPrediction);
    CounterfactualPrediction counterfactualPrediction = (CounterfactualPrediction) prediction;
    List<Output> outputs = counterfactualPrediction.getOutput().getOutputs();
    assertEquals(3, outputs.size());
    Output output1 = outputs.get(0);
    assertEquals("my-scoring-function", output1.getName());
    assertEquals(Type.NUMBER, output1.getType());
    assertEquals(0.85, output1.getValue().asNumber());
    Output output2 = outputs.get(1);
    assertEquals("inputsAreValid", output2.getName());
    assertEquals(Type.BOOLEAN, output2.getType());
    assertEquals(Boolean.FALSE, output2.getValue().getUnderlyingObject());
    Output output3 = outputs.get(2);
    assertEquals("canRequestLoan", output3.getName());
    assertEquals(Type.BOOLEAN, output3.getType());
    assertEquals(Boolean.TRUE, output3.getValue().getUnderlyingObject());
    assertTrue(counterfactualPrediction.getInput().getFeatures().isEmpty());
    assertEquals(counterfactualPrediction.getMaxRunningTimeSeconds(), request.getMaxRunningTimeSeconds());
}
Also used : CounterfactualExplainabilityRequest(org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) Prediction(org.kie.kogito.explainability.model.Prediction) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) UnitValue(org.kie.kogito.tracing.typedvalue.UnitValue) CounterfactualSearchDomainUnitValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue) DoubleNode(com.fasterxml.jackson.databind.node.DoubleNode) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) Test(org.junit.jupiter.api.Test)

Example 15 with Prediction

use of org.kie.kogito.explainability.model.Prediction in project kogito-apps by kiegroup.

the class JITDMNServiceImpl method evaluateModelAndExplain.

public DMNResultWithExplanation evaluateModelAndExplain(DMNEvaluator dmnEvaluator, Map<String, Object> context) {
    LocalDMNPredictionProvider localDMNPredictionProvider = new LocalDMNPredictionProvider(dmnEvaluator);
    DMNResult dmnResult = dmnEvaluator.evaluate(context);
    Prediction prediction = new SimplePrediction(LocalDMNPredictionProvider.toPredictionInput(context), LocalDMNPredictionProvider.toPredictionOutput(dmnResult));
    LimeConfig limeConfig = new LimeConfig().withSamples(explainabilityLimeSampleSize).withPerturbationContext(new PerturbationContext(new Random(), explainabilityLimeNoOfPerturbation));
    LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
    Map<String, Saliency> saliencyMap;
    try {
        saliencyMap = limeExplainer.explainAsync(prediction, localDMNPredictionProvider).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    } catch (TimeoutException | InterruptedException | ExecutionException e) {
        if (e instanceof InterruptedException) {
            LOGGER.error("Critical InterruptedException occurred", e);
            Thread.currentThread().interrupt();
        }
        return new DMNResultWithExplanation(new JITDMNResult(dmnEvaluator.getNamespace(), dmnEvaluator.getName(), dmnResult), new SalienciesResponse(EXPLAINABILITY_FAILED, EXPLAINABILITY_FAILED_MESSAGE, null));
    }
    List<SaliencyResponse> saliencyModelResponse = buildSalienciesResponse(dmnEvaluator.getDmnModel(), saliencyMap);
    return new DMNResultWithExplanation(new JITDMNResult(dmnEvaluator.getNamespace(), dmnEvaluator.getName(), dmnResult), new SalienciesResponse(EXPLAINABILITY_SUCCEEDED, null, saliencyModelResponse));
}
Also used : SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) SalienciesResponse(org.kie.kogito.trusty.service.common.responses.SalienciesResponse) DMNResult(org.kie.dmn.api.core.DMNResult) JITDMNResult(org.kie.kogito.jitexecutor.dmn.responses.JITDMNResult) PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) SaliencyResponse(org.kie.kogito.trusty.service.common.responses.SaliencyResponse) LimeExplainer(org.kie.kogito.explainability.local.lime.LimeExplainer) Prediction(org.kie.kogito.explainability.model.Prediction) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) DMNResultWithExplanation(org.kie.kogito.jitexecutor.dmn.responses.DMNResultWithExplanation) JITDMNResult(org.kie.kogito.jitexecutor.dmn.responses.JITDMNResult) Saliency(org.kie.kogito.explainability.model.Saliency) LimeConfig(org.kie.kogito.explainability.local.lime.LimeConfig) Random(java.util.Random) ExecutionException(java.util.concurrent.ExecutionException) TimeoutException(java.util.concurrent.TimeoutException)

Aggregations

Prediction (org.kie.kogito.explainability.model.Prediction)134 PredictionProvider (org.kie.kogito.explainability.model.PredictionProvider)117 PredictionOutput (org.kie.kogito.explainability.model.PredictionOutput)107 PredictionInput (org.kie.kogito.explainability.model.PredictionInput)105 SimplePrediction (org.kie.kogito.explainability.model.SimplePrediction)96 Test (org.junit.jupiter.api.Test)95 Random (java.util.Random)65 PerturbationContext (org.kie.kogito.explainability.model.PerturbationContext)61 LimeConfig (org.kie.kogito.explainability.local.lime.LimeConfig)57 ArrayList (java.util.ArrayList)51 Feature (org.kie.kogito.explainability.model.Feature)48 Saliency (org.kie.kogito.explainability.model.Saliency)48 ParameterizedTest (org.junit.jupiter.params.ParameterizedTest)42 LimeExplainer (org.kie.kogito.explainability.local.lime.LimeExplainer)40 LimeConfigOptimizer (org.kie.kogito.explainability.local.lime.optim.LimeConfigOptimizer)28 DataDistribution (org.kie.kogito.explainability.model.DataDistribution)24 ValueSource (org.junit.jupiter.params.provider.ValueSource)22 FeatureImportance (org.kie.kogito.explainability.model.FeatureImportance)22 Output (org.kie.kogito.explainability.model.Output)22 LinkedList (java.util.LinkedList)21