use of org.kie.kogito.explainability.model.PredictionProvider in project kogito-apps by kiegroup.
the class PartialDependencePlotExplainerTest method testPdpNumericClassifier.
@ParameterizedTest
@ValueSource(ints = { 0, 1, 2, 3, 4 })
void testPdpNumericClassifier(int seed) throws Exception {
Random random = new Random();
random.setSeed(seed);
PredictionProvider modelInfo = TestUtils.getSumSkipModel(0);
PartialDependencePlotExplainer partialDependencePlotProvider = new PartialDependencePlotExplainer();
List<PartialDependenceGraph> pdps = partialDependencePlotProvider.explainFromMetadata(modelInfo, getMetadata(random));
assertNotNull(pdps);
for (PartialDependenceGraph pdp : pdps) {
assertNotNull(pdp.getFeature());
assertNotNull(pdp.getX());
assertNotNull(pdp.getY());
assertEquals(pdp.getX().size(), pdp.getY().size());
assertGraph(pdp);
}
// the first feature is always skipped by the model, so the predictions are not affected, hence PDP Y values are constant
PartialDependenceGraph fixedFeatureGraph = pdps.get(0);
assertEquals(1, fixedFeatureGraph.getY().stream().distinct().count());
// the other two instead vary in Y values
assertThat(pdps.get(1).getY().stream().distinct().count()).isGreaterThan(1);
assertThat(pdps.get(2).getY().stream().distinct().count()).isGreaterThan(1);
}
use of org.kie.kogito.explainability.model.PredictionProvider in project kogito-apps by kiegroup.
the class CounterfactualExplainerTest method testFinalUniqueIds.
@ParameterizedTest
@ValueSource(ints = { 0, 1, 2 })
void testFinalUniqueIds(int seed) throws ExecutionException, InterruptedException, TimeoutException {
Random random = new Random();
random.setSeed(seed);
final List<Output> goal = new ArrayList<>();
List<Feature> features = List.of(FeatureFactory.newNumericalFeature("f-num1", 10.0, NumericalFeatureDomain.create(0, 20)));
PredictionProvider model = TestUtils.getFeaturePassModel(0);
final TerminationConfig terminationConfig = new TerminationConfig().withScoreCalculationCountLimit(100_000L);
final SolverConfig solverConfig = SolverConfigBuilder.builder().withTerminationConfig(terminationConfig).build();
solverConfig.setRandomSeed((long) seed);
solverConfig.setEnvironmentMode(EnvironmentMode.REPRODUCIBLE);
final List<UUID> intermediateIds = new ArrayList<>();
final List<UUID> executionIds = new ArrayList<>();
final Consumer<CounterfactualResult> captureIntermediateIds = counterfactual -> {
intermediateIds.add(counterfactual.getSolutionId());
};
final Consumer<CounterfactualResult> captureExecutionIds = counterfactual -> {
executionIds.add(counterfactual.getExecutionId());
};
final CounterfactualConfig counterfactualConfig = new CounterfactualConfig().withSolverConfig(solverConfig);
solverConfig.withEasyScoreCalculatorClass(MockCounterFactualScoreCalculator.class);
final CounterfactualExplainer counterfactualExplainer = new CounterfactualExplainer(counterfactualConfig);
PredictionInput input = new PredictionInput(features);
PredictionOutput output = new PredictionOutput(goal);
final UUID executionId = UUID.randomUUID();
Prediction prediction = new CounterfactualPrediction(input, output, null, executionId, null);
final CounterfactualResult counterfactualResult = counterfactualExplainer.explainAsync(prediction, model, captureIntermediateIds.andThen(captureExecutionIds)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
for (CounterfactualEntity entity : counterfactualResult.getEntities()) {
logger.debug("Entity: {}", entity);
}
// All intermediate ids should be unique
assertEquals((int) intermediateIds.stream().distinct().count(), intermediateIds.size());
// There should be at least one intermediate id
assertTrue(intermediateIds.size() > 0);
// There should be at least one execution id
assertTrue(executionIds.size() > 0);
// We should have the same number of execution ids as intermediate ids (captured from intermediate results)
assertEquals(executionIds.size(), intermediateIds.size());
// All execution ids should be the same
assertEquals(1, (int) executionIds.stream().distinct().count());
// The last intermediate id must be different from the final result id
assertNotEquals(intermediateIds.get(intermediateIds.size() - 1), counterfactualResult.getSolutionId());
// Captured execution ids should be the same as the one provided
assertEquals(executionIds.get(0), executionId);
}
use of org.kie.kogito.explainability.model.PredictionProvider in project kogito-apps by kiegroup.
the class CounterfactualExplainerTest method testIntermediateUniqueIds.
@ParameterizedTest
@ValueSource(ints = { 0, 1, 2 })
void testIntermediateUniqueIds(int seed) throws ExecutionException, InterruptedException, TimeoutException {
Random random = new Random();
random.setSeed(seed);
final List<Output> goal = new ArrayList<>();
List<Feature> features = List.of(FeatureFactory.newNumericalFeature("f-num1", 10.0, NumericalFeatureDomain.create(0, 20)));
PredictionProvider model = TestUtils.getFeaturePassModel(0);
final TerminationConfig terminationConfig = new TerminationConfig().withScoreCalculationCountLimit(100_000L);
final SolverConfig solverConfig = SolverConfigBuilder.builder().withTerminationConfig(terminationConfig).build();
solverConfig.setRandomSeed((long) seed);
solverConfig.setEnvironmentMode(EnvironmentMode.REPRODUCIBLE);
final List<UUID> intermediateIds = new ArrayList<>();
final List<UUID> executionIds = new ArrayList<>();
final Consumer<CounterfactualResult> captureIntermediateIds = counterfactual -> {
intermediateIds.add(counterfactual.getSolutionId());
};
final Consumer<CounterfactualResult> captureExecutionIds = counterfactual -> {
executionIds.add(counterfactual.getExecutionId());
};
final CounterfactualConfig counterfactualConfig = new CounterfactualConfig().withSolverConfig(solverConfig);
solverConfig.withEasyScoreCalculatorClass(MockCounterFactualScoreCalculator.class);
final CounterfactualExplainer counterfactualExplainer = new CounterfactualExplainer(counterfactualConfig);
PredictionInput input = new PredictionInput(features);
PredictionOutput output = new PredictionOutput(goal);
final UUID executionId = UUID.randomUUID();
Prediction prediction = new CounterfactualPrediction(input, output, null, executionId, null);
final CounterfactualResult counterfactualResult = counterfactualExplainer.explainAsync(prediction, model, captureIntermediateIds.andThen(captureExecutionIds)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
for (CounterfactualEntity entity : counterfactualResult.getEntities()) {
logger.debug("Entity: {}", entity);
}
// all intermediate Ids must be distinct
assertEquals((int) intermediateIds.stream().distinct().count(), intermediateIds.size());
assertEquals(1, (int) executionIds.stream().distinct().count());
assertEquals(executionIds.get(0), executionId);
}
use of org.kie.kogito.explainability.model.PredictionProvider in project kogito-apps by kiegroup.
the class CounterfactualScoreCalculatorTest method testGoalSizeSmaller.
/**
* Using a smaller number of features in the goals (1) than the model's output (2) should
* throw an {@link IllegalArgumentException} with the appropriate message.
*/
@Test
void testGoalSizeSmaller() throws ExecutionException, InterruptedException {
final CounterFactualScoreCalculator scoreCalculator = new CounterFactualScoreCalculator();
PredictionProvider model = TestUtils.getFeatureSkipModel(0);
List<Feature> features = new ArrayList<>();
List<FeatureDomain> featureDomains = new ArrayList<>();
List<Boolean> constraints = new ArrayList<>();
// f-1
features.add(FeatureFactory.newNumericalFeature("f-1", 1.0));
featureDomains.add(NumericalFeatureDomain.create(0.0, 10.0));
constraints.add(false);
// f-2
features.add(FeatureFactory.newNumericalFeature("f-2", 2.0));
featureDomains.add(NumericalFeatureDomain.create(0.0, 10.0));
constraints.add(false);
// f-3
features.add(FeatureFactory.newBooleanFeature("f-3", true));
featureDomains.add(EmptyFeatureDomain.create());
constraints.add(false);
PredictionInput input = new PredictionInput(features);
PredictionFeatureDomain domains = new PredictionFeatureDomain(featureDomains);
List<CounterfactualEntity> entities = CounterfactualEntityFactory.createEntities(input);
List<Output> goal = new ArrayList<>();
goal.add(new Output("f-2", Type.NUMBER, new Value(2.0), 0.0));
List<PredictionOutput> predictionOutputs = model.predictAsync(List.of(input)).get();
assertEquals(1, goal.size());
// A single prediction is expected
assertEquals(1, predictionOutputs.size());
// Single prediction with two features
assertEquals(2, predictionOutputs.get(0).getOutputs().size());
final CounterfactualSolution solution = new CounterfactualSolution(entities, features, model, goal, UUID.randomUUID(), UUID.randomUUID(), 0.0);
IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> {
scoreCalculator.calculateScore(solution);
});
assertEquals("Prediction size must be equal to goal size", exception.getMessage());
}
use of org.kie.kogito.explainability.model.PredictionProvider in project kogito-apps by kiegroup.
the class CounterfactualExplainer method explainAsync.
@Override
public CompletableFuture<CounterfactualResult> explainAsync(Prediction prediction, PredictionProvider model, Consumer<CounterfactualResult> intermediateResultsConsumer) {
final AtomicLong sequenceId = new AtomicLong(0);
final CounterfactualPrediction cfPrediction = (CounterfactualPrediction) prediction;
final UUID executionId = cfPrediction.getExecutionId();
final Long maxRunningTimeSeconds = cfPrediction.getMaxRunningTimeSeconds();
final List<CounterfactualEntity> entities = CounterfactualEntityFactory.createEntities(prediction.getInput());
final List<Output> goal = prediction.getOutput().getOutputs();
// Original features kept as structural reference to re-assemble composite features
final List<Feature> originalFeatures = prediction.getInput().getFeatures();
Function<UUID, CounterfactualSolution> initial = uuid -> new CounterfactualSolution(entities, originalFeatures, model, goal, UUID.randomUUID(), executionId, this.counterfactualConfig.getGoalThreshold());
final CompletableFuture<CounterfactualSolution> cfSolution = CompletableFuture.supplyAsync(() -> {
SolverConfig solverConfig = this.counterfactualConfig.getSolverConfig();
if (Objects.nonNull(maxRunningTimeSeconds)) {
solverConfig.withTerminationSpentLimit(Duration.ofSeconds(maxRunningTimeSeconds));
}
try (SolverManager<CounterfactualSolution, UUID> solverManager = this.counterfactualConfig.getSolverManagerFactory().apply(solverConfig)) {
SolverJob<CounterfactualSolution, UUID> solverJob = solverManager.solveAndListen(executionId, initial, assignSolutionId.andThen(createSolutionConsumer(intermediateResultsConsumer, sequenceId)), null);
try {
// Wait until the solving ends
return solverJob.getFinalBestSolution();
} catch (ExecutionException e) {
logger.error("Solving failed: {}", e.getMessage());
throw new IllegalStateException("Prediction returned an error", e);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new IllegalStateException("Solving failed (Thread interrupted)", e);
}
}
}, this.counterfactualConfig.getExecutor());
final CompletableFuture<List<PredictionOutput>> cfOutputs = cfSolution.thenCompose(s -> model.predictAsync(buildInput(s.getEntities())));
return CompletableFuture.allOf(cfOutputs, cfSolution).thenApply(v -> {
CounterfactualSolution solution = cfSolution.join();
return new CounterfactualResult(solution.getEntities(), solution.getOriginalFeatures(), cfOutputs.join(), solution.getScore().isFeasible(), UUID.randomUUID(), solution.getExecutionId(), sequenceId.incrementAndGet());
});
}
Aggregations