Search in sources :

Example 1 with LimeConfig

use of org.kie.kogito.explainability.local.lime.LimeConfig in project kogito-apps by kiegroup.

the class LimeConfigOptimizer method optimize.

public LimeConfig optimize(LimeConfig config, List<Prediction> predictions, PredictionProvider model) {
    List<LimeConfigEntity> entities = new ArrayList<>();
    if (samplingEntities) {
        entities.addAll(LimeConfigEntityFactory.createSamplingEntities(config));
    }
    if (proximityEntities) {
        entities.addAll(LimeConfigEntityFactory.createProximityEntities(config));
    }
    if (encodingEntities) {
        entities.addAll(LimeConfigEntityFactory.createEncodingEntities(config));
    }
    if (weightingEntities) {
        entities.addAll(LimeConfigEntityFactory.createWeightingEntities(config));
    }
    if (entities.isEmpty()) {
        return config;
    }
    LimeConfigSolution initialSolution = new LimeConfigSolution(config, predictions, entities, model);
    SolverConfig solverConfig = new SolverConfig().withEntityClasses(NumericLimeConfigEntity.class, BooleanLimeConfigEntity.class).withSolutionClass(LimeConfigSolution.class);
    ScoreDirectorFactoryConfig scoreDirectorFactoryConfig = new ScoreDirectorFactoryConfig();
    scoreDirectorFactoryConfig.setEasyScoreCalculatorClass(scoreCalculator.getClass());
    solverConfig.setScoreDirectorFactoryConfig(scoreDirectorFactoryConfig);
    TerminationConfig terminationConfig = new TerminationConfig();
    if (timeLimit > 0) {
        terminationConfig.setSecondsSpentLimit(timeLimit);
    }
    solverConfig.setTerminationConfig(terminationConfig);
    LocalSearchPhaseConfig localSearchPhaseConfig = new LocalSearchPhaseConfig();
    if (deterministic) {
        Optional<Long> seed = config.getPerturbationContext().getSeed();
        seed.ifPresent(solverConfig::setRandomSeed);
        solverConfig.setEnvironmentMode(EnvironmentMode.REPRODUCIBLE);
    } else {
        logger.debug("non reproducible execution, set the seed inside initial LimeConfig's PerturbationContext and enable deterministic execution to fix this");
    }
    localSearchPhaseConfig.setLocalSearchType(LocalSearchType.LATE_ACCEPTANCE);
    if (stepCountLimit > 0) {
        localSearchPhaseConfig.setTerminationConfig(new TerminationConfig().withStepCountLimit(stepCountLimit).withBestScoreLimit("1.0"));
    }
    @SuppressWarnings("rawtypes") List<PhaseConfig> phaseConfigs = new ArrayList<>();
    phaseConfigs.add(localSearchPhaseConfig);
    solverConfig.setPhaseConfigList(phaseConfigs);
    try (SolverManager<LimeConfigSolution, UUID> solverManager = SolverManager.create(solverConfig, new SolverManagerConfig())) {
        UUID executionId = UUID.randomUUID();
        SolverJob<LimeConfigSolution, UUID> solverJob = solverManager.solve(executionId, initialSolution);
        try {
            // Wait until the solving ends
            LimeConfigSolution finalBestSolution = solverJob.getFinalBestSolution();
            LimeConfig finalConfig = LimeConfigEntityFactory.toLimeConfig(finalBestSolution);
            BigDecimal score = finalBestSolution.getScore().getScore();
            logger.info("final best solution score {} with config {}", score, finalConfig);
            return finalConfig;
        } catch (ExecutionException e) {
            logger.error("Solving failed: {}", e.getMessage());
            throw new IllegalStateException("Prediction returned an error", e);
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            throw new IllegalStateException("Solving failed (Thread interrupted)", e);
        }
    }
}
Also used : ArrayList(java.util.ArrayList) SolverManagerConfig(org.optaplanner.core.config.solver.SolverManagerConfig) UUID(java.util.UUID) ExecutionException(java.util.concurrent.ExecutionException) SolverConfig(org.optaplanner.core.config.solver.SolverConfig) ScoreDirectorFactoryConfig(org.optaplanner.core.config.score.director.ScoreDirectorFactoryConfig) PhaseConfig(org.optaplanner.core.config.phase.PhaseConfig) LocalSearchPhaseConfig(org.optaplanner.core.config.localsearch.LocalSearchPhaseConfig) LimeConfig(org.kie.kogito.explainability.local.lime.LimeConfig) BigDecimal(java.math.BigDecimal) TerminationConfig(org.optaplanner.core.config.solver.termination.TerminationConfig) LocalSearchPhaseConfig(org.optaplanner.core.config.localsearch.LocalSearchPhaseConfig)

Example 2 with LimeConfig

use of org.kie.kogito.explainability.local.lime.LimeConfig in project kogito-apps by kiegroup.

the class LimeImpactScoreCalculator method calculateScore.

@Override
public SimpleBigDecimalScore calculateScore(LimeConfigSolution solution) {
    LimeConfig config = LimeConfigEntityFactory.toLimeConfig(solution);
    BigDecimal impactScore = BigDecimal.ZERO;
    List<Prediction> predictions = solution.getPredictions();
    if (!predictions.isEmpty()) {
        impactScore = getImpactScore(solution, config, predictions);
    }
    return SimpleBigDecimalScore.of(impactScore);
}
Also used : Prediction(org.kie.kogito.explainability.model.Prediction) LimeConfig(org.kie.kogito.explainability.local.lime.LimeConfig) BigDecimal(java.math.BigDecimal)

Example 3 with LimeConfig

use of org.kie.kogito.explainability.local.lime.LimeConfig in project kogito-apps by kiegroup.

the class LimeStabilityScoreCalculator method calculateScore.

@Override
public SimpleBigDecimalScore calculateScore(LimeConfigSolution solution) {
    LimeConfig config = LimeConfigEntityFactory.toLimeConfig(solution);
    BigDecimal stabilityScore = BigDecimal.ZERO;
    List<Prediction> predictions = solution.getPredictions();
    if (!predictions.isEmpty()) {
        stabilityScore = getStabilityScore(solution.getModel(), config, predictions);
    }
    return SimpleBigDecimalScore.of(stabilityScore);
}
Also used : Prediction(org.kie.kogito.explainability.model.Prediction) LimeConfig(org.kie.kogito.explainability.local.lime.LimeConfig) BigDecimal(java.math.BigDecimal)

Example 4 with LimeConfig

use of org.kie.kogito.explainability.local.lime.LimeConfig in project kogito-apps by kiegroup.

the class JITDMNServiceImpl method evaluateModelAndExplain.

public DMNResultWithExplanation evaluateModelAndExplain(DMNEvaluator dmnEvaluator, Map<String, Object> context) {
    LocalDMNPredictionProvider localDMNPredictionProvider = new LocalDMNPredictionProvider(dmnEvaluator);
    DMNResult dmnResult = dmnEvaluator.evaluate(context);
    Prediction prediction = new SimplePrediction(LocalDMNPredictionProvider.toPredictionInput(context), LocalDMNPredictionProvider.toPredictionOutput(dmnResult));
    LimeConfig limeConfig = new LimeConfig().withSamples(explainabilityLimeSampleSize).withPerturbationContext(new PerturbationContext(new Random(), explainabilityLimeNoOfPerturbation));
    LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
    Map<String, Saliency> saliencyMap;
    try {
        saliencyMap = limeExplainer.explainAsync(prediction, localDMNPredictionProvider).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    } catch (TimeoutException | InterruptedException | ExecutionException e) {
        if (e instanceof InterruptedException) {
            LOGGER.error("Critical InterruptedException occurred", e);
            Thread.currentThread().interrupt();
        }
        return new DMNResultWithExplanation(new JITDMNResult(dmnEvaluator.getNamespace(), dmnEvaluator.getName(), dmnResult), new SalienciesResponse(EXPLAINABILITY_FAILED, EXPLAINABILITY_FAILED_MESSAGE, null));
    }
    List<SaliencyResponse> saliencyModelResponse = buildSalienciesResponse(dmnEvaluator.getDmnModel(), saliencyMap);
    return new DMNResultWithExplanation(new JITDMNResult(dmnEvaluator.getNamespace(), dmnEvaluator.getName(), dmnResult), new SalienciesResponse(EXPLAINABILITY_SUCCEEDED, null, saliencyModelResponse));
}
Also used : SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) SalienciesResponse(org.kie.kogito.trusty.service.common.responses.SalienciesResponse) DMNResult(org.kie.dmn.api.core.DMNResult) JITDMNResult(org.kie.kogito.jitexecutor.dmn.responses.JITDMNResult) PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) SaliencyResponse(org.kie.kogito.trusty.service.common.responses.SaliencyResponse) LimeExplainer(org.kie.kogito.explainability.local.lime.LimeExplainer) Prediction(org.kie.kogito.explainability.model.Prediction) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) DMNResultWithExplanation(org.kie.kogito.jitexecutor.dmn.responses.DMNResultWithExplanation) JITDMNResult(org.kie.kogito.jitexecutor.dmn.responses.JITDMNResult) Saliency(org.kie.kogito.explainability.model.Saliency) LimeConfig(org.kie.kogito.explainability.local.lime.LimeConfig) Random(java.util.Random) ExecutionException(java.util.concurrent.ExecutionException) TimeoutException(java.util.concurrent.TimeoutException)

Example 5 with LimeConfig

use of org.kie.kogito.explainability.local.lime.LimeConfig in project kogito-apps by kiegroup.

the class DummyDmnModelsLimeExplainerTest method testAllTypesDMNExplanation.

@Test
void testAllTypesDMNExplanation() throws ExecutionException, InterruptedException, TimeoutException {
    DMNRuntime dmnRuntime = DMNKogito.createGenericDMNRuntime(new InputStreamReader(getClass().getResourceAsStream("/dmn/allTypes.dmn")));
    assertThat(dmnRuntime.getModels().size()).isEqualTo(1);
    final String namespace = "https://kiegroup.org/dmn/_24B9EC8C-2F02-40EB-B6BB-E8CDE82FBF08";
    final String name = "new-file";
    DecisionModel decisionModel = new DmnDecisionModel(dmnRuntime, namespace, name);
    PredictionProvider model = new DecisionModelWrapper(decisionModel);
    Map<String, Object> context = new HashMap<>();
    context.put("stringInput", "test");
    context.put("listOfStringInput", Collections.singletonList("test"));
    context.put("numberInput", 1);
    context.put("listOfNumbersInput", Collections.singletonList(1));
    context.put("booleanInput", true);
    context.put("listOfBooleansInput", Collections.singletonList(true));
    context.put("timeInput", "h09:00");
    context.put("dateInput", "2020-04-02");
    context.put("dateAndTimeInput", "2020-04-02T09:00:00");
    context.put("daysAndTimeDurationInput", "P1DT1H");
    context.put("yearsAndMonthDurationInput", "P1Y1M");
    Map<String, Object> complexInput = new HashMap<>();
    complexInput.put("aNestedListOfNumbers", Collections.singletonList(1));
    complexInput.put("aNestedString", "test");
    complexInput.put("aNestedComplexInput", Collections.singletonMap("doubleNestedNumber", 1));
    context.put("complexInput", complexInput);
    context.put("listOfComplexInput", Collections.singletonList(complexInput));
    List<Feature> features = new ArrayList<>();
    features.add(FeatureFactory.newCompositeFeature("context", context));
    PredictionInput predictionInput = new PredictionInput(features);
    List<PredictionOutput> predictionOutputs = model.predictAsync(List.of(predictionInput)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    Prediction prediction = new SimplePrediction(predictionInput, predictionOutputs.get(0));
    Random random = new Random();
    PerturbationContext perturbationContext = new PerturbationContext(0L, random, 3);
    LimeConfig limeConfig = new LimeConfig().withSamples(10).withPerturbationContext(perturbationContext);
    LimeExplainer limeExplainer = new LimeExplainer(limeConfig);
    Map<String, Saliency> saliencyMap = limeExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    for (Saliency saliency : saliencyMap.values()) {
        assertThat(saliency).isNotNull();
    }
    assertThatCode(() -> ValidationUtils.validateLocalSaliencyStability(model, prediction, limeExplainer, 1, 0.5, 0.2)).doesNotThrowAnyException();
    String decision = "myDecision";
    List<PredictionInput> inputs = new ArrayList<>();
    for (int n = 0; n < 10; n++) {
        inputs.add(new PredictionInput(DataUtils.perturbFeatures(features, perturbationContext)));
    }
    DataDistribution distribution = new PredictionInputsDataDistribution(inputs);
    int k = 2;
    int chunkSize = 5;
    double precision = ExplainabilityMetrics.getLocalSaliencyPrecision(decision, model, limeExplainer, distribution, k, chunkSize);
    assertThat(precision).isBetween(0d, 1d);
    double recall = ExplainabilityMetrics.getLocalSaliencyRecall(decision, model, limeExplainer, distribution, k, chunkSize);
    assertThat(recall).isBetween(0d, 1d);
    double f1 = ExplainabilityMetrics.getLocalSaliencyF1(decision, model, limeExplainer, distribution, k, chunkSize);
    assertThat(f1).isBetween(0d, 1d);
}
Also used : SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) HashMap(java.util.HashMap) LimeExplainer(org.kie.kogito.explainability.local.lime.LimeExplainer) ArrayList(java.util.ArrayList) DecisionModel(org.kie.kogito.decision.DecisionModel) DmnDecisionModel(org.kie.kogito.dmn.DmnDecisionModel) Saliency(org.kie.kogito.explainability.model.Saliency) Feature(org.kie.kogito.explainability.model.Feature) Random(java.util.Random) PredictionInputsDataDistribution(org.kie.kogito.explainability.model.PredictionInputsDataDistribution) InputStreamReader(java.io.InputStreamReader) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) Prediction(org.kie.kogito.explainability.model.Prediction) SimplePrediction(org.kie.kogito.explainability.model.SimplePrediction) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) DMNRuntime(org.kie.dmn.api.core.DMNRuntime) LimeConfig(org.kie.kogito.explainability.local.lime.LimeConfig) PredictionInputsDataDistribution(org.kie.kogito.explainability.model.PredictionInputsDataDistribution) DataDistribution(org.kie.kogito.explainability.model.DataDistribution) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) DmnDecisionModel(org.kie.kogito.dmn.DmnDecisionModel) Test(org.junit.jupiter.api.Test)

Aggregations

LimeConfig (org.kie.kogito.explainability.local.lime.LimeConfig)65 Test (org.junit.jupiter.api.Test)56 Prediction (org.kie.kogito.explainability.model.Prediction)56 PredictionProvider (org.kie.kogito.explainability.model.PredictionProvider)53 SimplePrediction (org.kie.kogito.explainability.model.SimplePrediction)49 PredictionInput (org.kie.kogito.explainability.model.PredictionInput)48 PredictionOutput (org.kie.kogito.explainability.model.PredictionOutput)47 Random (java.util.Random)45 PerturbationContext (org.kie.kogito.explainability.model.PerturbationContext)45 LimeExplainer (org.kie.kogito.explainability.local.lime.LimeExplainer)36 LimeConfigOptimizer (org.kie.kogito.explainability.local.lime.optim.LimeConfigOptimizer)28 Saliency (org.kie.kogito.explainability.model.Saliency)16 DataDistribution (org.kie.kogito.explainability.model.DataDistribution)14 PredictionInputsDataDistribution (org.kie.kogito.explainability.model.PredictionInputsDataDistribution)13 ArrayList (java.util.ArrayList)11 Feature (org.kie.kogito.explainability.model.Feature)10 ParameterizedTest (org.junit.jupiter.params.ParameterizedTest)6 LinkedList (java.util.LinkedList)5 SimpleBigDecimalScore (org.optaplanner.core.api.score.buildin.simplebigdecimal.SimpleBigDecimalScore)5 InputStreamReader (java.io.InputStreamReader)4