use of org.kie.kogito.explainability.model.Output in project kogito-apps by kiegroup.
the class CounterfactualExplainer method explainAsync.
@Override
public CompletableFuture<CounterfactualResult> explainAsync(Prediction prediction, PredictionProvider model, Consumer<CounterfactualResult> intermediateResultsConsumer) {
final AtomicLong sequenceId = new AtomicLong(0);
final CounterfactualPrediction cfPrediction = (CounterfactualPrediction) prediction;
final UUID executionId = cfPrediction.getExecutionId();
final Long maxRunningTimeSeconds = cfPrediction.getMaxRunningTimeSeconds();
final List<CounterfactualEntity> entities = CounterfactualEntityFactory.createEntities(prediction.getInput());
final List<Output> goal = prediction.getOutput().getOutputs();
// Original features kept as structural reference to re-assemble composite features
final List<Feature> originalFeatures = prediction.getInput().getFeatures();
Function<UUID, CounterfactualSolution> initial = uuid -> new CounterfactualSolution(entities, originalFeatures, model, goal, UUID.randomUUID(), executionId, this.counterfactualConfig.getGoalThreshold());
final CompletableFuture<CounterfactualSolution> cfSolution = CompletableFuture.supplyAsync(() -> {
SolverConfig solverConfig = this.counterfactualConfig.getSolverConfig();
if (Objects.nonNull(maxRunningTimeSeconds)) {
solverConfig.withTerminationSpentLimit(Duration.ofSeconds(maxRunningTimeSeconds));
}
try (SolverManager<CounterfactualSolution, UUID> solverManager = this.counterfactualConfig.getSolverManagerFactory().apply(solverConfig)) {
SolverJob<CounterfactualSolution, UUID> solverJob = solverManager.solveAndListen(executionId, initial, assignSolutionId.andThen(createSolutionConsumer(intermediateResultsConsumer, sequenceId)), null);
try {
// Wait until the solving ends
return solverJob.getFinalBestSolution();
} catch (ExecutionException e) {
logger.error("Solving failed: {}", e.getMessage());
throw new IllegalStateException("Prediction returned an error", e);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new IllegalStateException("Solving failed (Thread interrupted)", e);
}
}
}, this.counterfactualConfig.getExecutor());
final CompletableFuture<List<PredictionOutput>> cfOutputs = cfSolution.thenCompose(s -> model.predictAsync(buildInput(s.getEntities())));
return CompletableFuture.allOf(cfOutputs, cfSolution).thenApply(v -> {
CounterfactualSolution solution = cfSolution.join();
return new CounterfactualResult(solution.getEntities(), solution.getOriginalFeatures(), cfOutputs.join(), solution.getScore().isFeasible(), UUID.randomUUID(), solution.getExecutionId(), sequenceId.incrementAndGet());
});
}
use of org.kie.kogito.explainability.model.Output in project kogito-apps by kiegroup.
the class HighScoreNumericFeatureZonesProvider method getHighScoreFeatureZones.
/**
* Get a map of feature-name -> high score feature zones. Predictions in data distribution are sorted by (descending)
* score, then the (aggregated) mean score is calculated and all the data points that are associated with a prediction
* having a score between the mean and the maximum are selected (feature-wise), with an associated tolerance
* (the stdDev of the high score feature points).
*
* @param dataDistribution a data distribution
* @param predictionProvider the model used to score the inputs
* @param features the list of features to associate high score points with
* @param maxNoOfSamples max no. of inputs used for discovering high score zones
* @return a map feature name -> high score numeric feature zones
*/
public static Map<String, HighScoreNumericFeatureZones> getHighScoreFeatureZones(DataDistribution dataDistribution, PredictionProvider predictionProvider, List<Feature> features, int maxNoOfSamples) {
Map<String, HighScoreNumericFeatureZones> numericFeatureZonesMap = new HashMap<>();
List<Prediction> scoreSortedPredictions = new ArrayList<>();
try {
scoreSortedPredictions.addAll(DataUtils.getScoreSortedPredictions(predictionProvider, new PredictionInputsDataDistribution(dataDistribution.sample(maxNoOfSamples))));
} catch (ExecutionException e) {
LOGGER.error("Could not sort predictions by score {}", e.getMessage());
} catch (InterruptedException e) {
LOGGER.error("Interrupted while waiting for sorting predictions by score {}", e.getMessage());
Thread.currentThread().interrupt();
} catch (TimeoutException e) {
LOGGER.error("Timed out while waiting for sorting predictions by score", e);
}
if (!scoreSortedPredictions.isEmpty()) {
// calculate min, max and mean scores
double max = scoreSortedPredictions.get(0).getOutput().getOutputs().stream().mapToDouble(Output::getScore).sum();
double min = scoreSortedPredictions.get(scoreSortedPredictions.size() - 1).getOutput().getOutputs().stream().mapToDouble(Output::getScore).sum();
if (max != min) {
double threshold = scoreSortedPredictions.stream().map(p -> p.getOutput().getOutputs().stream().mapToDouble(Output::getScore).sum()).mapToDouble(d -> d).average().orElse((max + min) / 2);
// filter out predictions whose score is in [min, threshold]
scoreSortedPredictions = scoreSortedPredictions.stream().filter(p -> p.getOutput().getOutputs().stream().mapToDouble(Output::getScore).sum() > threshold).collect(Collectors.toList());
for (int j = 0; j < features.size(); j++) {
Feature feature = features.get(j);
if (Type.NUMBER.equals(feature.getType())) {
int finalJ = j;
// get feature values associated with high score inputs
List<Double> topValues = scoreSortedPredictions.stream().map(prediction -> prediction.getInput().getFeatures().get(finalJ).getValue().asNumber()).distinct().collect(Collectors.toList());
// get high score points and tolerance
double[] highScoreFeaturePoints = topValues.stream().flatMapToDouble(DoubleStream::of).toArray();
double center = DataUtils.getMean(highScoreFeaturePoints);
double tolerance = DataUtils.getStdDev(highScoreFeaturePoints, center) / 2;
HighScoreNumericFeatureZones highScoreNumericFeatureZones = new HighScoreNumericFeatureZones(highScoreFeaturePoints, tolerance);
numericFeatureZonesMap.put(feature.getName(), highScoreNumericFeatureZones);
}
}
}
}
return numericFeatureZonesMap;
}
use of org.kie.kogito.explainability.model.Output in project kogito-apps by kiegroup.
the class LimeExplainer method prepareInputs.
/**
* Check the perturbed inputs so that the dataset of perturbed input / outputs contains more than just one output
* class, otherwise it would be impossible to linearly separate it, and hence learn meaningful weights to be used as
* feature importance scores.
* The check can be {@code strict} or not, if so it will throw a {@code DatasetNotSeparableException} when the dataset
* for a given output is not separable.
*/
private LimeInputs prepareInputs(List<PredictionInput> perturbedInputs, List<PredictionOutput> perturbedOutputs, List<Feature> linearizedTargetInputFeatures, int o, Output currentOutput, boolean strict) {
if (currentOutput.getValue() != null && currentOutput.getValue().getUnderlyingObject() != null) {
Map<Double, Long> rawClassesBalance;
// calculate the no. of samples belonging to each output class
Value fv = currentOutput.getValue();
rawClassesBalance = getClassBalance(perturbedOutputs, fv, o);
Long max = rawClassesBalance.values().stream().max(Long::compareTo).orElse(1L);
double separationRatio = (double) max / (double) perturbedInputs.size();
List<Output> outputs = perturbedOutputs.stream().map(po -> po.getOutputs().get(o)).collect(Collectors.toList());
boolean classification = rawClassesBalance.size() == 2;
if (strict) {
// check if the dataset is separable and also if the linear model should fit a regressor or a classifier
if (rawClassesBalance.size() > 1 && separationRatio < limeConfig.getSeparableDatasetRatio()) {
// if dataset creation process succeeds use it to train the linear model
return new LimeInputs(classification, linearizedTargetInputFeatures, currentOutput, perturbedInputs, outputs);
} else {
throw new DatasetNotSeparableException(currentOutput, rawClassesBalance);
}
} else {
LOGGER.warn("Using an hardly separable dataset for output '{}' of type '{}' with value '{}' ({})", currentOutput.getName(), currentOutput.getType(), currentOutput.getValue(), rawClassesBalance);
return new LimeInputs(classification, linearizedTargetInputFeatures, currentOutput, perturbedInputs, outputs);
}
} else {
return new LimeInputs(false, linearizedTargetInputFeatures, currentOutput, emptyList(), emptyList());
}
}
use of org.kie.kogito.explainability.model.Output in project kogito-apps by kiegroup.
the class LimeExplainer method explainAsync.
@Override
public CompletableFuture<Map<String, Saliency>> explainAsync(Prediction prediction, PredictionProvider model, Consumer<Map<String, Saliency>> intermediateResultsConsumer) {
PredictionInput originalInput = prediction.getInput();
if (originalInput == null || originalInput.getFeatures() == null || (originalInput.getFeatures() != null && originalInput.getFeatures().isEmpty())) {
throw new LocalExplanationException("cannot explain a prediction whose input is empty");
}
List<PredictionInput> linearizedInputs = DataUtils.linearizeInputs(List.of(originalInput));
PredictionInput targetInput = linearizedInputs.get(0);
List<Feature> linearizedTargetInputFeatures = targetInput.getFeatures();
if (linearizedTargetInputFeatures.isEmpty()) {
throw new LocalExplanationException("input features linearization failed");
}
List<Output> actualOutputs = prediction.getOutput().getOutputs();
LimeConfig executionConfig = limeConfig.copy();
return explainWithExecutionConfig(model, originalInput, linearizedTargetInputFeatures, actualOutputs, executionConfig);
}
use of org.kie.kogito.explainability.model.Output in project kogito-apps by kiegroup.
the class CounterfactualExplainerServiceHandlerTest method testCreateIntermediateResult.
@Test
public void testCreateIntermediateResult() {
CounterfactualExplainabilityRequest request = new CounterfactualExplainabilityRequest(EXECUTION_ID, SERVICE_URL, MODEL_IDENTIFIER, COUNTERFACTUAL_ID, Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), MAX_RUNNING_TIME_SECONDS);
List<CounterfactualEntity> entities = List.of(DoubleEntity.from(new Feature("input1", Type.NUMBER, new Value(123.0d)), 0, 1000));
CounterfactualResult counterfactuals = new CounterfactualResult(entities, entities.stream().map(CounterfactualEntity::asFeature).collect(Collectors.toList()), List.of(new PredictionOutput(List.of(new Output("output1", Type.NUMBER, new Value(555.0d), 1.0)))), true, UUID.fromString(SOLUTION_ID), UUID.fromString(EXECUTION_ID), 0);
BaseExplainabilityResult base = handler.createIntermediateResult(request, counterfactuals);
assertTrue(base instanceof CounterfactualExplainabilityResult);
CounterfactualExplainabilityResult result = (CounterfactualExplainabilityResult) base;
assertEquals(ExplainabilityStatus.SUCCEEDED, result.getStatus());
assertEquals(CounterfactualExplainabilityResult.Stage.INTERMEDIATE, result.getStage());
assertEquals(EXECUTION_ID, result.getExecutionId());
assertEquals(COUNTERFACTUAL_ID, result.getCounterfactualId());
assertEquals(1, result.getInputs().size());
assertTrue(result.getInputs().stream().anyMatch(i -> i.getName().equals("input1")));
NamedTypedValue input1 = result.getInputs().iterator().next();
assertEquals(Double.class.getSimpleName(), input1.getValue().getType());
assertEquals(TypedValue.Kind.UNIT, input1.getValue().getKind());
assertEquals(123.0, input1.getValue().toUnit().getValue().asDouble());
assertEquals(1, result.getOutputs().size());
assertTrue(result.getOutputs().stream().anyMatch(o -> o.getName().equals("output1")));
NamedTypedValue output1 = result.getOutputs().iterator().next();
assertEquals(Double.class.getSimpleName(), output1.getValue().getType());
assertEquals(TypedValue.Kind.UNIT, output1.getValue().getKind());
assertEquals(555.0, output1.getValue().toUnit().getValue().asDouble());
}
Aggregations