use of org.kie.kogito.explainability.model.Feature in project kogito-apps by kiegroup.
the class PartialDependencePlotExplainer method prepareInputs.
/**
* Generate inputs for a particular feature, using 1) a specific discrete value from the data distribution of the
* feature under analysis for that particular feature and 2) values from a training data distribution (which we sample)
* for all the other feature values.
* The resulting list of prediction inputs will have the very same value for the feature under analysis, and values
* from the training data for all other features.
*
* @param featureXs specific value of the feature under analysis
* @param trainingData training data
* @return a list of prediction inputs
*/
private List<PredictionInput> prepareInputs(Feature featureXs, List<PredictionInput> trainingData) {
List<PredictionInput> predictionInputs = new ArrayList<>(config.getSeriesLength());
for (PredictionInput trainingSample : trainingData) {
List<Feature> features = trainingSample.getFeatures();
List<Feature> newFeatures = replaceFeatures(featureXs, features);
predictionInputs.add(new PredictionInput(newFeatures));
}
return predictionInputs;
}
use of org.kie.kogito.explainability.model.Feature in project kogito-apps by kiegroup.
the class CounterFactualScoreCalculator method calculateInputScore.
private BendableBigDecimalScore calculateInputScore(CounterfactualSolution solution) {
StringBuilder builder = new StringBuilder();
int secondarySoftScore = 0;
int secondaryHardScore = 0;
// Calculate similarities between original inputs and proposed inputs
double inputSimilarities = 0.0;
final int numberOfEntities = solution.getEntities().size();
for (CounterfactualEntity entity : solution.getEntities()) {
final double entitySimilarity = entity.similarity();
inputSimilarities += entitySimilarity / numberOfEntities;
final Feature f = entity.asFeature();
builder.append(String.format("%s=%s (d:%f)", f.getName(), f.getValue().getUnderlyingObject(), entitySimilarity));
if (entity.isChanged()) {
secondarySoftScore -= 1;
if (entity.isConstrained()) {
secondaryHardScore -= 1;
}
}
}
logger.debug("Current solution: {}", builder);
// Calculate Gower distance from the similarities
final double primarySoftScore = -Math.sqrt(Math.abs(1.0 - inputSimilarities));
logger.debug("Changed constraints penalty: {}", secondaryHardScore);
logger.debug("Feature distance: {}", -Math.abs(primarySoftScore));
return BendableBigDecimalScore.of(new BigDecimal[] { BigDecimal.ZERO, BigDecimal.valueOf(secondaryHardScore), BigDecimal.ZERO }, new BigDecimal[] { BigDecimal.valueOf(-Math.abs(primarySoftScore)), BigDecimal.valueOf(secondarySoftScore) });
}
use of org.kie.kogito.explainability.model.Feature in project kogito-apps by kiegroup.
the class CounterFactualScoreCalculator method calculateScore.
/**
* Calculates the counterfactual score for each proposed solution.
* This method assumes that each model used as {@link org.kie.kogito.explainability.model.PredictionProvider} is
* consistent, in the sense that for repeated operations, the size of the returned collection of
* {@link PredictionOutput} is the same, if the size of {@link PredictionInput} doesn't change.
*
* @param solution Proposed solution
* @return A {@link BendableBigDecimalScore} with three "hard" levels and one "soft" level
*/
@Override
public BendableBigDecimalScore calculateScore(CounterfactualSolution solution) {
BendableBigDecimalScore currentScore = calculateInputScore(solution);
final List<Feature> flattenedFeatures = solution.getEntities().stream().map(CounterfactualEntity::asFeature).collect(Collectors.toList());
final List<Feature> input = CompositeFeatureUtils.unflattenFeatures(flattenedFeatures, solution.getOriginalFeatures());
final List<PredictionInput> inputs = List.of(new PredictionInput(input));
final CompletableFuture<List<PredictionOutput>> predictionAsync = solution.getModel().predictAsync(inputs);
try {
List<PredictionOutput> predictions = predictionAsync.get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
solution.setPredictionOutputs(predictions);
final BendableBigDecimalScore outputScore = calculateOutputScore(solution);
currentScore = currentScore.add(outputScore);
} catch (ExecutionException e) {
logger.error("Prediction returned an error {}", e.getMessage());
} catch (InterruptedException e) {
logger.error("Interrupted while waiting for prediction {}", e.getMessage());
Thread.currentThread().interrupt();
} catch (TimeoutException e) {
logger.error("Timed out while waiting for prediction");
}
return currentScore;
}
use of org.kie.kogito.explainability.model.Feature in project kogito-apps by kiegroup.
the class CounterfactualExplainer method explainAsync.
@Override
public CompletableFuture<CounterfactualResult> explainAsync(Prediction prediction, PredictionProvider model, Consumer<CounterfactualResult> intermediateResultsConsumer) {
final AtomicLong sequenceId = new AtomicLong(0);
final CounterfactualPrediction cfPrediction = (CounterfactualPrediction) prediction;
final UUID executionId = cfPrediction.getExecutionId();
final Long maxRunningTimeSeconds = cfPrediction.getMaxRunningTimeSeconds();
final List<CounterfactualEntity> entities = CounterfactualEntityFactory.createEntities(prediction.getInput());
final List<Output> goal = prediction.getOutput().getOutputs();
// Original features kept as structural reference to re-assemble composite features
final List<Feature> originalFeatures = prediction.getInput().getFeatures();
Function<UUID, CounterfactualSolution> initial = uuid -> new CounterfactualSolution(entities, originalFeatures, model, goal, UUID.randomUUID(), executionId, this.counterfactualConfig.getGoalThreshold());
final CompletableFuture<CounterfactualSolution> cfSolution = CompletableFuture.supplyAsync(() -> {
SolverConfig solverConfig = this.counterfactualConfig.getSolverConfig();
if (Objects.nonNull(maxRunningTimeSeconds)) {
solverConfig.withTerminationSpentLimit(Duration.ofSeconds(maxRunningTimeSeconds));
}
try (SolverManager<CounterfactualSolution, UUID> solverManager = this.counterfactualConfig.getSolverManagerFactory().apply(solverConfig)) {
SolverJob<CounterfactualSolution, UUID> solverJob = solverManager.solveAndListen(executionId, initial, assignSolutionId.andThen(createSolutionConsumer(intermediateResultsConsumer, sequenceId)), null);
try {
// Wait until the solving ends
return solverJob.getFinalBestSolution();
} catch (ExecutionException e) {
logger.error("Solving failed: {}", e.getMessage());
throw new IllegalStateException("Prediction returned an error", e);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new IllegalStateException("Solving failed (Thread interrupted)", e);
}
}
}, this.counterfactualConfig.getExecutor());
final CompletableFuture<List<PredictionOutput>> cfOutputs = cfSolution.thenCompose(s -> model.predictAsync(buildInput(s.getEntities())));
return CompletableFuture.allOf(cfOutputs, cfSolution).thenApply(v -> {
CounterfactualSolution solution = cfSolution.join();
return new CounterfactualResult(solution.getEntities(), solution.getOriginalFeatures(), cfOutputs.join(), solution.getScore().isFeasible(), UUID.randomUUID(), solution.getExecutionId(), sequenceId.incrementAndGet());
});
}
use of org.kie.kogito.explainability.model.Feature in project kogito-apps by kiegroup.
the class ShapSyntheticDataSample method createSyntheticData.
/**
* Create synthetic data for this particular sample,
* according to the conditions set up in the ShapSyntheticDataSample initialization.
*
* @return Synthetic data for this particular sample
*/
private List<PredictionInput> createSyntheticData() {
List<Feature> piFeatures = this.x.getFeatures();
List<PredictionInput> synthData = new ArrayList<>();
for (int i = 0; i < this.background.getRowDimension(); i++) {
List<Feature> maskedFeatures = new ArrayList<>();
for (int j = 0; j < this.mask.length; j++) {
Feature oldFeature = piFeatures.get(j);
if (this.mask[j]) {
maskedFeatures.add(oldFeature);
} else {
maskedFeatures.add(FeatureFactory.newNumericalFeature(oldFeature.getName(), this.background.getEntry(i, j)));
}
}
synthData.add(new PredictionInput(maskedFeatures));
}
return synthData;
}
Aggregations