Search in sources :

Example 51 with Feature

use of org.kie.kogito.explainability.model.Feature in project kogito-apps by kiegroup.

the class PartialDependencePlotExplainer method prepareInputs.

/**
 * Generate inputs for a particular feature, using 1) a specific discrete value from the data distribution of the
 * feature under analysis for that particular feature and 2) values from a training data distribution (which we sample)
 * for all the other feature values.
 * The resulting list of prediction inputs will have the very same value for the feature under analysis, and values
 * from the training data for all other features.
 *
 * @param featureXs specific value of the feature under analysis
 * @param trainingData training data
 * @return a list of prediction inputs
 */
private List<PredictionInput> prepareInputs(Feature featureXs, List<PredictionInput> trainingData) {
    List<PredictionInput> predictionInputs = new ArrayList<>(config.getSeriesLength());
    for (PredictionInput trainingSample : trainingData) {
        List<Feature> features = trainingSample.getFeatures();
        List<Feature> newFeatures = replaceFeatures(featureXs, features);
        predictionInputs.add(new PredictionInput(newFeatures));
    }
    return predictionInputs;
}
Also used : PredictionInput(org.kie.kogito.explainability.model.PredictionInput) ArrayList(java.util.ArrayList) Feature(org.kie.kogito.explainability.model.Feature)

Example 52 with Feature

use of org.kie.kogito.explainability.model.Feature in project kogito-apps by kiegroup.

the class CounterFactualScoreCalculator method calculateInputScore.

private BendableBigDecimalScore calculateInputScore(CounterfactualSolution solution) {
    StringBuilder builder = new StringBuilder();
    int secondarySoftScore = 0;
    int secondaryHardScore = 0;
    // Calculate similarities between original inputs and proposed inputs
    double inputSimilarities = 0.0;
    final int numberOfEntities = solution.getEntities().size();
    for (CounterfactualEntity entity : solution.getEntities()) {
        final double entitySimilarity = entity.similarity();
        inputSimilarities += entitySimilarity / numberOfEntities;
        final Feature f = entity.asFeature();
        builder.append(String.format("%s=%s (d:%f)", f.getName(), f.getValue().getUnderlyingObject(), entitySimilarity));
        if (entity.isChanged()) {
            secondarySoftScore -= 1;
            if (entity.isConstrained()) {
                secondaryHardScore -= 1;
            }
        }
    }
    logger.debug("Current solution: {}", builder);
    // Calculate Gower distance from the similarities
    final double primarySoftScore = -Math.sqrt(Math.abs(1.0 - inputSimilarities));
    logger.debug("Changed constraints penalty: {}", secondaryHardScore);
    logger.debug("Feature distance: {}", -Math.abs(primarySoftScore));
    return BendableBigDecimalScore.of(new BigDecimal[] { BigDecimal.ZERO, BigDecimal.valueOf(secondaryHardScore), BigDecimal.ZERO }, new BigDecimal[] { BigDecimal.valueOf(-Math.abs(primarySoftScore)), BigDecimal.valueOf(secondarySoftScore) });
}
Also used : CounterfactualEntity(org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity) Feature(org.kie.kogito.explainability.model.Feature)

Example 53 with Feature

use of org.kie.kogito.explainability.model.Feature in project kogito-apps by kiegroup.

the class CounterFactualScoreCalculator method calculateScore.

/**
 * Calculates the counterfactual score for each proposed solution.
 * This method assumes that each model used as {@link org.kie.kogito.explainability.model.PredictionProvider} is
 * consistent, in the sense that for repeated operations, the size of the returned collection of
 * {@link PredictionOutput} is the same, if the size of {@link PredictionInput} doesn't change.
 *
 * @param solution Proposed solution
 * @return A {@link BendableBigDecimalScore} with three "hard" levels and one "soft" level
 */
@Override
public BendableBigDecimalScore calculateScore(CounterfactualSolution solution) {
    BendableBigDecimalScore currentScore = calculateInputScore(solution);
    final List<Feature> flattenedFeatures = solution.getEntities().stream().map(CounterfactualEntity::asFeature).collect(Collectors.toList());
    final List<Feature> input = CompositeFeatureUtils.unflattenFeatures(flattenedFeatures, solution.getOriginalFeatures());
    final List<PredictionInput> inputs = List.of(new PredictionInput(input));
    final CompletableFuture<List<PredictionOutput>> predictionAsync = solution.getModel().predictAsync(inputs);
    try {
        List<PredictionOutput> predictions = predictionAsync.get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
        solution.setPredictionOutputs(predictions);
        final BendableBigDecimalScore outputScore = calculateOutputScore(solution);
        currentScore = currentScore.add(outputScore);
    } catch (ExecutionException e) {
        logger.error("Prediction returned an error {}", e.getMessage());
    } catch (InterruptedException e) {
        logger.error("Interrupted while waiting for prediction {}", e.getMessage());
        Thread.currentThread().interrupt();
    } catch (TimeoutException e) {
        logger.error("Timed out while waiting for prediction");
    }
    return currentScore;
}
Also used : PredictionInput(org.kie.kogito.explainability.model.PredictionInput) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) BendableBigDecimalScore(org.optaplanner.core.api.score.buildin.bendablebigdecimal.BendableBigDecimalScore) List(java.util.List) ExecutionException(java.util.concurrent.ExecutionException) Feature(org.kie.kogito.explainability.model.Feature) TimeoutException(java.util.concurrent.TimeoutException)

Example 54 with Feature

use of org.kie.kogito.explainability.model.Feature in project kogito-apps by kiegroup.

the class CounterfactualExplainer method explainAsync.

@Override
public CompletableFuture<CounterfactualResult> explainAsync(Prediction prediction, PredictionProvider model, Consumer<CounterfactualResult> intermediateResultsConsumer) {
    final AtomicLong sequenceId = new AtomicLong(0);
    final CounterfactualPrediction cfPrediction = (CounterfactualPrediction) prediction;
    final UUID executionId = cfPrediction.getExecutionId();
    final Long maxRunningTimeSeconds = cfPrediction.getMaxRunningTimeSeconds();
    final List<CounterfactualEntity> entities = CounterfactualEntityFactory.createEntities(prediction.getInput());
    final List<Output> goal = prediction.getOutput().getOutputs();
    // Original features kept as structural reference to re-assemble composite features
    final List<Feature> originalFeatures = prediction.getInput().getFeatures();
    Function<UUID, CounterfactualSolution> initial = uuid -> new CounterfactualSolution(entities, originalFeatures, model, goal, UUID.randomUUID(), executionId, this.counterfactualConfig.getGoalThreshold());
    final CompletableFuture<CounterfactualSolution> cfSolution = CompletableFuture.supplyAsync(() -> {
        SolverConfig solverConfig = this.counterfactualConfig.getSolverConfig();
        if (Objects.nonNull(maxRunningTimeSeconds)) {
            solverConfig.withTerminationSpentLimit(Duration.ofSeconds(maxRunningTimeSeconds));
        }
        try (SolverManager<CounterfactualSolution, UUID> solverManager = this.counterfactualConfig.getSolverManagerFactory().apply(solverConfig)) {
            SolverJob<CounterfactualSolution, UUID> solverJob = solverManager.solveAndListen(executionId, initial, assignSolutionId.andThen(createSolutionConsumer(intermediateResultsConsumer, sequenceId)), null);
            try {
                // Wait until the solving ends
                return solverJob.getFinalBestSolution();
            } catch (ExecutionException e) {
                logger.error("Solving failed: {}", e.getMessage());
                throw new IllegalStateException("Prediction returned an error", e);
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                throw new IllegalStateException("Solving failed (Thread interrupted)", e);
            }
        }
    }, this.counterfactualConfig.getExecutor());
    final CompletableFuture<List<PredictionOutput>> cfOutputs = cfSolution.thenCompose(s -> model.predictAsync(buildInput(s.getEntities())));
    return CompletableFuture.allOf(cfOutputs, cfSolution).thenApply(v -> {
        CounterfactualSolution solution = cfSolution.join();
        return new CounterfactualResult(solution.getEntities(), solution.getOriginalFeatures(), cfOutputs.join(), solution.getScore().isFeasible(), UUID.randomUUID(), solution.getExecutionId(), sequenceId.incrementAndGet());
    });
}
Also used : SolverConfig(org.optaplanner.core.config.solver.SolverConfig) Feature(org.kie.kogito.explainability.model.Feature) Prediction(org.kie.kogito.explainability.model.Prediction) LoggerFactory(org.slf4j.LoggerFactory) CompletableFuture(java.util.concurrent.CompletableFuture) CounterfactualEntity(org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity) SolverManager(org.optaplanner.core.api.solver.SolverManager) Function(java.util.function.Function) CompositeFeatureUtils(org.kie.kogito.explainability.utils.CompositeFeatureUtils) Duration(java.time.Duration) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Logger(org.slf4j.Logger) Executor(java.util.concurrent.Executor) LocalExplainer(org.kie.kogito.explainability.local.LocalExplainer) UUID(java.util.UUID) Collectors(java.util.stream.Collectors) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Objects(java.util.Objects) ExecutionException(java.util.concurrent.ExecutionException) Consumer(java.util.function.Consumer) AtomicLong(java.util.concurrent.atomic.AtomicLong) CounterfactualEntityFactory(org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntityFactory) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) List(java.util.List) Output(org.kie.kogito.explainability.model.Output) SolverJob(org.optaplanner.core.api.solver.SolverJob) Feature(org.kie.kogito.explainability.model.Feature) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) CounterfactualEntity(org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity) AtomicLong(java.util.concurrent.atomic.AtomicLong) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) AtomicLong(java.util.concurrent.atomic.AtomicLong) List(java.util.List) UUID(java.util.UUID) ExecutionException(java.util.concurrent.ExecutionException) SolverConfig(org.optaplanner.core.config.solver.SolverConfig)

Example 55 with Feature

use of org.kie.kogito.explainability.model.Feature in project kogito-apps by kiegroup.

the class ShapSyntheticDataSample method createSyntheticData.

/**
 * Create synthetic data for this particular sample,
 * according to the conditions set up in the ShapSyntheticDataSample initialization.
 *
 * @return Synthetic data for this particular sample
 */
private List<PredictionInput> createSyntheticData() {
    List<Feature> piFeatures = this.x.getFeatures();
    List<PredictionInput> synthData = new ArrayList<>();
    for (int i = 0; i < this.background.getRowDimension(); i++) {
        List<Feature> maskedFeatures = new ArrayList<>();
        for (int j = 0; j < this.mask.length; j++) {
            Feature oldFeature = piFeatures.get(j);
            if (this.mask[j]) {
                maskedFeatures.add(oldFeature);
            } else {
                maskedFeatures.add(FeatureFactory.newNumericalFeature(oldFeature.getName(), this.background.getEntry(i, j)));
            }
        }
        synthData.add(new PredictionInput(maskedFeatures));
    }
    return synthData;
}
Also used : PredictionInput(org.kie.kogito.explainability.model.PredictionInput) ArrayList(java.util.ArrayList) Feature(org.kie.kogito.explainability.model.Feature)

Aggregations

Feature (org.kie.kogito.explainability.model.Feature)233 PredictionOutput (org.kie.kogito.explainability.model.PredictionOutput)118 Test (org.junit.jupiter.api.Test)107 PredictionInput (org.kie.kogito.explainability.model.PredictionInput)107 ParameterizedTest (org.junit.jupiter.params.ParameterizedTest)104 Output (org.kie.kogito.explainability.model.Output)102 ArrayList (java.util.ArrayList)97 Random (java.util.Random)92 PredictionProvider (org.kie.kogito.explainability.model.PredictionProvider)78 Value (org.kie.kogito.explainability.model.Value)74 LinkedList (java.util.LinkedList)72 ValueSource (org.junit.jupiter.params.provider.ValueSource)71 Prediction (org.kie.kogito.explainability.model.Prediction)67 List (java.util.List)51 CounterfactualEntity (org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity)46 PerturbationContext (org.kie.kogito.explainability.model.PerturbationContext)42 Type (org.kie.kogito.explainability.model.Type)39 NumericalFeatureDomain (org.kie.kogito.explainability.model.domain.NumericalFeatureDomain)37 SimplePrediction (org.kie.kogito.explainability.model.SimplePrediction)35 FeatureDomain (org.kie.kogito.explainability.model.domain.FeatureDomain)33