Search in sources :

Example 36 with Output

use of org.kie.kogito.explainability.model.Output in project kogito-apps by kiegroup.

the class CounterfactualExplainer method explainAsync.

@Override
public CompletableFuture<CounterfactualResult> explainAsync(Prediction prediction, PredictionProvider model, Consumer<CounterfactualResult> intermediateResultsConsumer) {
    final AtomicLong sequenceId = new AtomicLong(0);
    final CounterfactualPrediction cfPrediction = (CounterfactualPrediction) prediction;
    final UUID executionId = cfPrediction.getExecutionId();
    final Long maxRunningTimeSeconds = cfPrediction.getMaxRunningTimeSeconds();
    final List<CounterfactualEntity> entities = CounterfactualEntityFactory.createEntities(prediction.getInput());
    final List<Output> goal = prediction.getOutput().getOutputs();
    // Original features kept as structural reference to re-assemble composite features
    final List<Feature> originalFeatures = prediction.getInput().getFeatures();
    Function<UUID, CounterfactualSolution> initial = uuid -> new CounterfactualSolution(entities, originalFeatures, model, goal, UUID.randomUUID(), executionId, this.counterfactualConfig.getGoalThreshold());
    final CompletableFuture<CounterfactualSolution> cfSolution = CompletableFuture.supplyAsync(() -> {
        SolverConfig solverConfig = this.counterfactualConfig.getSolverConfig();
        if (Objects.nonNull(maxRunningTimeSeconds)) {
            solverConfig.withTerminationSpentLimit(Duration.ofSeconds(maxRunningTimeSeconds));
        }
        try (SolverManager<CounterfactualSolution, UUID> solverManager = this.counterfactualConfig.getSolverManagerFactory().apply(solverConfig)) {
            SolverJob<CounterfactualSolution, UUID> solverJob = solverManager.solveAndListen(executionId, initial, assignSolutionId.andThen(createSolutionConsumer(intermediateResultsConsumer, sequenceId)), null);
            try {
                // Wait until the solving ends
                return solverJob.getFinalBestSolution();
            } catch (ExecutionException e) {
                logger.error("Solving failed: {}", e.getMessage());
                throw new IllegalStateException("Prediction returned an error", e);
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                throw new IllegalStateException("Solving failed (Thread interrupted)", e);
            }
        }
    }, this.counterfactualConfig.getExecutor());
    final CompletableFuture<List<PredictionOutput>> cfOutputs = cfSolution.thenCompose(s -> model.predictAsync(buildInput(s.getEntities())));
    return CompletableFuture.allOf(cfOutputs, cfSolution).thenApply(v -> {
        CounterfactualSolution solution = cfSolution.join();
        return new CounterfactualResult(solution.getEntities(), solution.getOriginalFeatures(), cfOutputs.join(), solution.getScore().isFeasible(), UUID.randomUUID(), solution.getExecutionId(), sequenceId.incrementAndGet());
    });
}
Also used : SolverConfig(org.optaplanner.core.config.solver.SolverConfig) Feature(org.kie.kogito.explainability.model.Feature) Prediction(org.kie.kogito.explainability.model.Prediction) LoggerFactory(org.slf4j.LoggerFactory) CompletableFuture(java.util.concurrent.CompletableFuture) CounterfactualEntity(org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity) SolverManager(org.optaplanner.core.api.solver.SolverManager) Function(java.util.function.Function) CompositeFeatureUtils(org.kie.kogito.explainability.utils.CompositeFeatureUtils) Duration(java.time.Duration) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Logger(org.slf4j.Logger) Executor(java.util.concurrent.Executor) LocalExplainer(org.kie.kogito.explainability.local.LocalExplainer) UUID(java.util.UUID) Collectors(java.util.stream.Collectors) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Objects(java.util.Objects) ExecutionException(java.util.concurrent.ExecutionException) Consumer(java.util.function.Consumer) AtomicLong(java.util.concurrent.atomic.AtomicLong) CounterfactualEntityFactory(org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntityFactory) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) List(java.util.List) Output(org.kie.kogito.explainability.model.Output) SolverJob(org.optaplanner.core.api.solver.SolverJob) Feature(org.kie.kogito.explainability.model.Feature) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) CounterfactualEntity(org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity) AtomicLong(java.util.concurrent.atomic.AtomicLong) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) AtomicLong(java.util.concurrent.atomic.AtomicLong) List(java.util.List) UUID(java.util.UUID) ExecutionException(java.util.concurrent.ExecutionException) SolverConfig(org.optaplanner.core.config.solver.SolverConfig)

Example 37 with Output

use of org.kie.kogito.explainability.model.Output in project kogito-apps by kiegroup.

the class HighScoreNumericFeatureZonesProvider method getHighScoreFeatureZones.

/**
 * Get a map of feature-name -> high score feature zones. Predictions in data distribution are sorted by (descending)
 * score, then the (aggregated) mean score is calculated and all the data points that are associated with a prediction
 * having a score between the mean and the maximum are selected (feature-wise), with an associated tolerance
 * (the stdDev of the high score feature points).
 *
 * @param dataDistribution a data distribution
 * @param predictionProvider the model used to score the inputs
 * @param features the list of features to associate high score points with
 * @param maxNoOfSamples max no. of inputs used for discovering high score zones
 * @return a map feature name -> high score numeric feature zones
 */
public static Map<String, HighScoreNumericFeatureZones> getHighScoreFeatureZones(DataDistribution dataDistribution, PredictionProvider predictionProvider, List<Feature> features, int maxNoOfSamples) {
    Map<String, HighScoreNumericFeatureZones> numericFeatureZonesMap = new HashMap<>();
    List<Prediction> scoreSortedPredictions = new ArrayList<>();
    try {
        scoreSortedPredictions.addAll(DataUtils.getScoreSortedPredictions(predictionProvider, new PredictionInputsDataDistribution(dataDistribution.sample(maxNoOfSamples))));
    } catch (ExecutionException e) {
        LOGGER.error("Could not sort predictions by score {}", e.getMessage());
    } catch (InterruptedException e) {
        LOGGER.error("Interrupted while waiting for sorting predictions by score {}", e.getMessage());
        Thread.currentThread().interrupt();
    } catch (TimeoutException e) {
        LOGGER.error("Timed out while waiting for sorting predictions by score", e);
    }
    if (!scoreSortedPredictions.isEmpty()) {
        // calculate min, max and mean scores
        double max = scoreSortedPredictions.get(0).getOutput().getOutputs().stream().mapToDouble(Output::getScore).sum();
        double min = scoreSortedPredictions.get(scoreSortedPredictions.size() - 1).getOutput().getOutputs().stream().mapToDouble(Output::getScore).sum();
        if (max != min) {
            double threshold = scoreSortedPredictions.stream().map(p -> p.getOutput().getOutputs().stream().mapToDouble(Output::getScore).sum()).mapToDouble(d -> d).average().orElse((max + min) / 2);
            // filter out predictions whose score is in [min, threshold]
            scoreSortedPredictions = scoreSortedPredictions.stream().filter(p -> p.getOutput().getOutputs().stream().mapToDouble(Output::getScore).sum() > threshold).collect(Collectors.toList());
            for (int j = 0; j < features.size(); j++) {
                Feature feature = features.get(j);
                if (Type.NUMBER.equals(feature.getType())) {
                    int finalJ = j;
                    // get feature values associated with high score inputs
                    List<Double> topValues = scoreSortedPredictions.stream().map(prediction -> prediction.getInput().getFeatures().get(finalJ).getValue().asNumber()).distinct().collect(Collectors.toList());
                    // get high score points and tolerance
                    double[] highScoreFeaturePoints = topValues.stream().flatMapToDouble(DoubleStream::of).toArray();
                    double center = DataUtils.getMean(highScoreFeaturePoints);
                    double tolerance = DataUtils.getStdDev(highScoreFeaturePoints, center) / 2;
                    HighScoreNumericFeatureZones highScoreNumericFeatureZones = new HighScoreNumericFeatureZones(highScoreFeaturePoints, tolerance);
                    numericFeatureZonesMap.put(feature.getName(), highScoreNumericFeatureZones);
                }
            }
        }
    }
    return numericFeatureZonesMap;
}
Also used : DataUtils(org.kie.kogito.explainability.utils.DataUtils) Logger(org.slf4j.Logger) PredictionInputsDataDistribution(org.kie.kogito.explainability.model.PredictionInputsDataDistribution) PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) Feature(org.kie.kogito.explainability.model.Feature) Prediction(org.kie.kogito.explainability.model.Prediction) LoggerFactory(org.slf4j.LoggerFactory) TimeoutException(java.util.concurrent.TimeoutException) HashMap(java.util.HashMap) Collectors(java.util.stream.Collectors) DataDistribution(org.kie.kogito.explainability.model.DataDistribution) Type(org.kie.kogito.explainability.model.Type) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) ArrayList(java.util.ArrayList) DoubleStream(java.util.stream.DoubleStream) ExecutionException(java.util.concurrent.ExecutionException) List(java.util.List) Map(java.util.Map) Output(org.kie.kogito.explainability.model.Output) HashMap(java.util.HashMap) Prediction(org.kie.kogito.explainability.model.Prediction) ArrayList(java.util.ArrayList) Feature(org.kie.kogito.explainability.model.Feature) ExecutionException(java.util.concurrent.ExecutionException) PredictionInputsDataDistribution(org.kie.kogito.explainability.model.PredictionInputsDataDistribution) TimeoutException(java.util.concurrent.TimeoutException)

Example 38 with Output

use of org.kie.kogito.explainability.model.Output in project kogito-apps by kiegroup.

the class LimeExplainer method prepareInputs.

/**
 * Check the perturbed inputs so that the dataset of perturbed input / outputs contains more than just one output
 * class, otherwise it would be impossible to linearly separate it, and hence learn meaningful weights to be used as
 * feature importance scores.
 * The check can be {@code strict} or not, if so it will throw a {@code DatasetNotSeparableException} when the dataset
 * for a given output is not separable.
 */
private LimeInputs prepareInputs(List<PredictionInput> perturbedInputs, List<PredictionOutput> perturbedOutputs, List<Feature> linearizedTargetInputFeatures, int o, Output currentOutput, boolean strict) {
    if (currentOutput.getValue() != null && currentOutput.getValue().getUnderlyingObject() != null) {
        Map<Double, Long> rawClassesBalance;
        // calculate the no. of samples belonging to each output class
        Value fv = currentOutput.getValue();
        rawClassesBalance = getClassBalance(perturbedOutputs, fv, o);
        Long max = rawClassesBalance.values().stream().max(Long::compareTo).orElse(1L);
        double separationRatio = (double) max / (double) perturbedInputs.size();
        List<Output> outputs = perturbedOutputs.stream().map(po -> po.getOutputs().get(o)).collect(Collectors.toList());
        boolean classification = rawClassesBalance.size() == 2;
        if (strict) {
            // check if the dataset is separable and also if the linear model should fit a regressor or a classifier
            if (rawClassesBalance.size() > 1 && separationRatio < limeConfig.getSeparableDatasetRatio()) {
                // if dataset creation process succeeds use it to train the linear model
                return new LimeInputs(classification, linearizedTargetInputFeatures, currentOutput, perturbedInputs, outputs);
            } else {
                throw new DatasetNotSeparableException(currentOutput, rawClassesBalance);
            }
        } else {
            LOGGER.warn("Using an hardly separable dataset for output '{}' of type '{}' with value '{}' ({})", currentOutput.getName(), currentOutput.getType(), currentOutput.getValue(), rawClassesBalance);
            return new LimeInputs(classification, linearizedTargetInputFeatures, currentOutput, perturbedInputs, outputs);
        }
    } else {
        return new LimeInputs(false, linearizedTargetInputFeatures, currentOutput, emptyList(), emptyList());
    }
}
Also used : Arrays(java.util.Arrays) PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) Feature(org.kie.kogito.explainability.model.Feature) Prediction(org.kie.kogito.explainability.model.Prediction) CompletableFuture.completedFuture(java.util.concurrent.CompletableFuture.completedFuture) LoggerFactory(org.slf4j.LoggerFactory) HashMap(java.util.HashMap) CompletableFuture(java.util.concurrent.CompletableFuture) Value(org.kie.kogito.explainability.model.Value) DataDistribution(org.kie.kogito.explainability.model.DataDistribution) Saliency(org.kie.kogito.explainability.model.Saliency) ArrayList(java.util.ArrayList) LinearModel(org.kie.kogito.explainability.utils.LinearModel) Pair(org.apache.commons.lang3.tuple.Pair) Map(java.util.Map) FeatureDistribution(org.kie.kogito.explainability.model.FeatureDistribution) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) DataUtils(org.kie.kogito.explainability.utils.DataUtils) Logger(org.slf4j.Logger) LocalExplainer(org.kie.kogito.explainability.local.LocalExplainer) Collections.emptyList(java.util.Collections.emptyList) FeatureImportance(org.kie.kogito.explainability.model.FeatureImportance) LocalExplanationException(org.kie.kogito.explainability.local.LocalExplanationException) Collectors(java.util.stream.Collectors) Type(org.kie.kogito.explainability.model.Type) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Objects(java.util.Objects) Consumer(java.util.function.Consumer) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) List(java.util.List) Output(org.kie.kogito.explainability.model.Output) Optional(java.util.Optional) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) Value(org.kie.kogito.explainability.model.Value)

Example 39 with Output

use of org.kie.kogito.explainability.model.Output in project kogito-apps by kiegroup.

the class LimeExplainer method explainAsync.

@Override
public CompletableFuture<Map<String, Saliency>> explainAsync(Prediction prediction, PredictionProvider model, Consumer<Map<String, Saliency>> intermediateResultsConsumer) {
    PredictionInput originalInput = prediction.getInput();
    if (originalInput == null || originalInput.getFeatures() == null || (originalInput.getFeatures() != null && originalInput.getFeatures().isEmpty())) {
        throw new LocalExplanationException("cannot explain a prediction whose input is empty");
    }
    List<PredictionInput> linearizedInputs = DataUtils.linearizeInputs(List.of(originalInput));
    PredictionInput targetInput = linearizedInputs.get(0);
    List<Feature> linearizedTargetInputFeatures = targetInput.getFeatures();
    if (linearizedTargetInputFeatures.isEmpty()) {
        throw new LocalExplanationException("input features linearization failed");
    }
    List<Output> actualOutputs = prediction.getOutput().getOutputs();
    LimeConfig executionConfig = limeConfig.copy();
    return explainWithExecutionConfig(model, originalInput, linearizedTargetInputFeatures, actualOutputs, executionConfig);
}
Also used : PredictionInput(org.kie.kogito.explainability.model.PredictionInput) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) LocalExplanationException(org.kie.kogito.explainability.local.LocalExplanationException) Feature(org.kie.kogito.explainability.model.Feature)

Example 40 with Output

use of org.kie.kogito.explainability.model.Output in project kogito-apps by kiegroup.

the class CounterfactualExplainerServiceHandlerTest method testCreateIntermediateResult.

@Test
public void testCreateIntermediateResult() {
    CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), MAX_RUNNING_TIME_SECONDS);
    List<CounterfactualEntity> entities = List.of(DoubleEntity.from(new Feature("input1", Type.NUMBER, new Value(123.0d)), 0, 1000));
    CounterfactualResult counterfactuals = new CounterfactualResult(entities, entities.stream().map(CounterfactualEntity::asFeature).collect(Collectors.toList()), List.of(new PredictionOutput(List.of(new Output("output1", Type.NUMBER, new Value(555.0d), 1.0)))), true, UUID.fromString(SOLUTION_ID), UUID.fromString(EXECUTION_ID), 0);
    BaseExplainabilityResult base = handler.createIntermediateResult(request, counterfactuals);
    assertTrue(base instanceof CounterfactualExplainabilityResult);
    CounterfactualExplainabilityResult result = (CounterfactualExplainabilityResult) base;
    assertEquals(ExplainabilityStatus.SUCCEEDED, result.getStatus());
    assertEquals(CounterfactualExplainabilityResult.Stage.INTERMEDIATE, result.getStage());
    assertEquals(EXECUTION_ID, result.getExecutionId());
    assertEquals(COUNTERFACTUAL_ID, result.getCounterfactualId());
    assertEquals(1, result.getInputs().size());
    assertTrue(result.getInputs().stream().anyMatch(i -> i.getName().equals("input1")));
    NamedTypedValue input1 = result.getInputs().iterator().next();
    assertEquals(Double.class.getSimpleName(), input1.getValue().getType());
    assertEquals(TypedValue.Kind.UNIT, input1.getValue().getKind());
    assertEquals(123.0, input1.getValue().toUnit().getValue().asDouble());
    assertEquals(1, result.getOutputs().size());
    assertTrue(result.getOutputs().stream().anyMatch(o -> o.getName().equals("output1")));
    NamedTypedValue output1 = result.getOutputs().iterator().next();
    assertEquals(Double.class.getSimpleName(), output1.getValue().getType());
    assertEquals(TypedValue.Kind.UNIT, output1.getValue().getKind());
    assertEquals(555.0, output1.getValue().toUnit().getValue().asDouble());
}
Also used : CounterfactualExplainabilityRequest(org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest) BeforeEach(org.junit.jupiter.api.BeforeEach) BaseExplainabilityRequest(org.kie.kogito.explainability.api.BaseExplainabilityRequest) Feature(org.kie.kogito.explainability.model.Feature) ArgumentMatchers.eq(org.mockito.ArgumentMatchers.eq) CounterfactualDomainRange(org.kie.kogito.explainability.api.CounterfactualDomainRange) CounterfactualEntity(org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity) Value(org.kie.kogito.explainability.model.Value) CounterfactualResult(org.kie.kogito.explainability.local.counterfactual.CounterfactualResult) Assertions.assertFalse(org.junit.jupiter.api.Assertions.assertFalse) Map(java.util.Map) EmptyFeatureDomain(org.kie.kogito.explainability.model.domain.EmptyFeatureDomain) BaseExplainabilityResult(org.kie.kogito.explainability.api.BaseExplainabilityResult) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) CounterfactualExplainabilityRequest(org.kie.kogito.explainability.api.CounterfactualExplainabilityRequest) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) ExplainabilityStatus(org.kie.kogito.explainability.api.ExplainabilityStatus) UUID(java.util.UUID) Collectors(java.util.stream.Collectors) TypedValue(org.kie.kogito.tracing.typedvalue.TypedValue) UnitValue(org.kie.kogito.tracing.typedvalue.UnitValue) Test(org.junit.jupiter.api.Test) List(java.util.List) CollectionValue(org.kie.kogito.tracing.typedvalue.CollectionValue) Output(org.kie.kogito.explainability.model.Output) Assertions.assertTrue(org.junit.jupiter.api.Assertions.assertTrue) Optional(java.util.Optional) CounterfactualExplainabilityResult(org.kie.kogito.explainability.api.CounterfactualExplainabilityResult) CounterfactualSearchDomain(org.kie.kogito.explainability.api.CounterfactualSearchDomain) Mockito.mock(org.mockito.Mockito.mock) Assertions.assertThrows(org.junit.jupiter.api.Assertions.assertThrows) ArgumentMatchers.any(org.mockito.ArgumentMatchers.any) IntNode(com.fasterxml.jackson.databind.node.IntNode) Prediction(org.kie.kogito.explainability.model.Prediction) StructureValue(org.kie.kogito.tracing.typedvalue.StructureValue) CounterfactualSearchDomainStructureValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainStructureValue) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) PredictionProviderFactory(org.kie.kogito.explainability.PredictionProviderFactory) Assertions.assertEquals(org.junit.jupiter.api.Assertions.assertEquals) CounterfactualSearchDomainCollectionValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainCollectionValue) CounterfactualExplainer(org.kie.kogito.explainability.local.counterfactual.CounterfactualExplainer) Type(org.kie.kogito.explainability.model.Type) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Mockito.verify(org.mockito.Mockito.verify) Consumer(java.util.function.Consumer) DoubleEntity(org.kie.kogito.explainability.local.counterfactual.entities.DoubleEntity) CounterfactualSearchDomainUnitValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue) DoubleNode(com.fasterxml.jackson.databind.node.DoubleNode) NumericalFeatureDomain(org.kie.kogito.explainability.model.domain.NumericalFeatureDomain) BooleanNode(com.fasterxml.jackson.databind.node.BooleanNode) Collections(java.util.Collections) ModelIdentifier(org.kie.kogito.explainability.api.ModelIdentifier) CounterfactualExplainabilityResult(org.kie.kogito.explainability.api.CounterfactualExplainabilityResult) Feature(org.kie.kogito.explainability.model.Feature) CounterfactualResult(org.kie.kogito.explainability.local.counterfactual.CounterfactualResult) CounterfactualEntity(org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) BaseExplainabilityResult(org.kie.kogito.explainability.api.BaseExplainabilityResult) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) Value(org.kie.kogito.explainability.model.Value) NamedTypedValue(org.kie.kogito.explainability.api.NamedTypedValue) TypedValue(org.kie.kogito.tracing.typedvalue.TypedValue) UnitValue(org.kie.kogito.tracing.typedvalue.UnitValue) CollectionValue(org.kie.kogito.tracing.typedvalue.CollectionValue) StructureValue(org.kie.kogito.tracing.typedvalue.StructureValue) CounterfactualSearchDomainStructureValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainStructureValue) CounterfactualSearchDomainCollectionValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainCollectionValue) CounterfactualSearchDomainUnitValue(org.kie.kogito.explainability.api.CounterfactualSearchDomainUnitValue) Test(org.junit.jupiter.api.Test)

Aggregations

Output (org.kie.kogito.explainability.model.Output)120 PredictionOutput (org.kie.kogito.explainability.model.PredictionOutput)109 Feature (org.kie.kogito.explainability.model.Feature)102 Value (org.kie.kogito.explainability.model.Value)63 Random (java.util.Random)61 ParameterizedTest (org.junit.jupiter.params.ParameterizedTest)59 PredictionInput (org.kie.kogito.explainability.model.PredictionInput)57 PredictionProvider (org.kie.kogito.explainability.model.PredictionProvider)52 ArrayList (java.util.ArrayList)47 ValueSource (org.junit.jupiter.params.provider.ValueSource)47 Prediction (org.kie.kogito.explainability.model.Prediction)46 Test (org.junit.jupiter.api.Test)42 List (java.util.List)39 Type (org.kie.kogito.explainability.model.Type)36 LinkedList (java.util.LinkedList)35 CounterfactualEntity (org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity)23 Mockito.mock (org.mockito.Mockito.mock)20 Optional (java.util.Optional)19 ExecutionException (java.util.concurrent.ExecutionException)19 Collectors (java.util.stream.Collectors)18