use of org.optaplanner.core.api.solver.SolverManager in project kogito-apps by kiegroup.
the class CounterfactualExplainerTest method testSequenceIds.
@ParameterizedTest
@ValueSource(ints = { 1, 2, 3, 5, 8 })
@SuppressWarnings("unchecked")
void testSequenceIds(int numberOfIntermediateSolutions) throws ExecutionException, InterruptedException, TimeoutException {
final List<Long> sequenceIds = new ArrayList<>();
final Consumer<CounterfactualResult> captureSequenceIds = counterfactual -> {
sequenceIds.add(counterfactual.getSequenceId());
};
ArgumentCaptor<Consumer<CounterfactualSolution>> intermediateSolutionConsumerCaptor = ArgumentCaptor.forClass(Consumer.class);
CounterfactualResult result = mockExplainerInvocation(captureSequenceIds, null);
verify(solverManager).solveAndListen(any(), any(), intermediateSolutionConsumerCaptor.capture(), any());
Consumer<CounterfactualSolution> intermediateSolutionConsumer = intermediateSolutionConsumerCaptor.getValue();
// Mock the intermediate Solution callback being invoked
IntStream.range(0, numberOfIntermediateSolutions).forEach(i -> {
CounterfactualSolution intermediate = mock(CounterfactualSolution.class);
BendableBigDecimalScore intermediateScore = BendableBigDecimalScore.zero(0, 0);
when(intermediate.getScore()).thenReturn(intermediateScore);
intermediateSolutionConsumer.accept(intermediate);
});
// The final and intermediate Solutions should all have unique Sequence Ids.
sequenceIds.add(result.getSequenceId());
assertEquals(numberOfIntermediateSolutions + 1, sequenceIds.size());
assertEquals(numberOfIntermediateSolutions + 1, (int) sequenceIds.stream().distinct().count());
}
use of org.optaplanner.core.api.solver.SolverManager in project kogito-apps by kiegroup.
the class CounterfactualExplainer method explainAsync.
@Override
public CompletableFuture<CounterfactualResult> explainAsync(Prediction prediction, PredictionProvider model, Consumer<CounterfactualResult> intermediateResultsConsumer) {
final AtomicLong sequenceId = new AtomicLong(0);
final CounterfactualPrediction cfPrediction = (CounterfactualPrediction) prediction;
final UUID executionId = cfPrediction.getExecutionId();
final Long maxRunningTimeSeconds = cfPrediction.getMaxRunningTimeSeconds();
final List<CounterfactualEntity> entities = CounterfactualEntityFactory.createEntities(prediction.getInput());
final List<Output> goal = prediction.getOutput().getOutputs();
// Original features kept as structural reference to re-assemble composite features
final List<Feature> originalFeatures = prediction.getInput().getFeatures();
Function<UUID, CounterfactualSolution> initial = uuid -> new CounterfactualSolution(entities, originalFeatures, model, goal, UUID.randomUUID(), executionId, this.counterfactualConfig.getGoalThreshold());
final CompletableFuture<CounterfactualSolution> cfSolution = CompletableFuture.supplyAsync(() -> {
SolverConfig solverConfig = this.counterfactualConfig.getSolverConfig();
if (Objects.nonNull(maxRunningTimeSeconds)) {
solverConfig.withTerminationSpentLimit(Duration.ofSeconds(maxRunningTimeSeconds));
}
try (SolverManager<CounterfactualSolution, UUID> solverManager = this.counterfactualConfig.getSolverManagerFactory().apply(solverConfig)) {
SolverJob<CounterfactualSolution, UUID> solverJob = solverManager.solveAndListen(executionId, initial, assignSolutionId.andThen(createSolutionConsumer(intermediateResultsConsumer, sequenceId)), null);
try {
// Wait until the solving ends
return solverJob.getFinalBestSolution();
} catch (ExecutionException e) {
logger.error("Solving failed: {}", e.getMessage());
throw new IllegalStateException("Prediction returned an error", e);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new IllegalStateException("Solving failed (Thread interrupted)", e);
}
}
}, this.counterfactualConfig.getExecutor());
final CompletableFuture<List<PredictionOutput>> cfOutputs = cfSolution.thenCompose(s -> model.predictAsync(buildInput(s.getEntities())));
return CompletableFuture.allOf(cfOutputs, cfSolution).thenApply(v -> {
CounterfactualSolution solution = cfSolution.join();
return new CounterfactualResult(solution.getEntities(), solution.getOriginalFeatures(), cfOutputs.join(), solution.getScore().isFeasible(), UUID.randomUUID(), solution.getExecutionId(), sequenceId.incrementAndGet());
});
}
Aggregations