Search in sources :

Example 1 with SolverManager

use of org.optaplanner.core.api.solver.SolverManager in project kogito-apps by kiegroup.

the class CounterfactualExplainerTest method testSequenceIds.

@ParameterizedTest
@ValueSource(ints = { 1, 2, 3, 5, 8 })
@SuppressWarnings("unchecked")
void testSequenceIds(int numberOfIntermediateSolutions) throws ExecutionException, InterruptedException, TimeoutException {
    final List<Long> sequenceIds = new ArrayList<>();
    final Consumer<CounterfactualResult> captureSequenceIds = counterfactual -> {
        sequenceIds.add(counterfactual.getSequenceId());
    };
    ArgumentCaptor<Consumer<CounterfactualSolution>> intermediateSolutionConsumerCaptor = ArgumentCaptor.forClass(Consumer.class);
    CounterfactualResult result = mockExplainerInvocation(captureSequenceIds, null);
    verify(solverManager).solveAndListen(any(), any(), intermediateSolutionConsumerCaptor.capture(), any());
    Consumer<CounterfactualSolution> intermediateSolutionConsumer = intermediateSolutionConsumerCaptor.getValue();
    // Mock the intermediate Solution callback being invoked
    IntStream.range(0, numberOfIntermediateSolutions).forEach(i -> {
        CounterfactualSolution intermediate = mock(CounterfactualSolution.class);
        BendableBigDecimalScore intermediateScore = BendableBigDecimalScore.zero(0, 0);
        when(intermediate.getScore()).thenReturn(intermediateScore);
        intermediateSolutionConsumer.accept(intermediate);
    });
    // The final and intermediate Solutions should all have unique Sequence Ids.
    sequenceIds.add(result.getSequenceId());
    assertEquals(numberOfIntermediateSolutions + 1, sequenceIds.size());
    assertEquals(numberOfIntermediateSolutions + 1, (int) sequenceIds.stream().distinct().count());
}
Also used : BeforeEach(org.junit.jupiter.api.BeforeEach) FeatureFactory(org.kie.kogito.explainability.model.FeatureFactory) Feature(org.kie.kogito.explainability.model.Feature) LoggerFactory(org.slf4j.LoggerFactory) Assertions.assertNotEquals(org.junit.jupiter.api.Assertions.assertNotEquals) TimeoutException(java.util.concurrent.TimeoutException) Random(java.util.Random) CounterfactualEntity(org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity) Value(org.kie.kogito.explainability.model.Value) TerminationConfig(org.optaplanner.core.config.solver.termination.TerminationConfig) Assertions.assertFalse(org.junit.jupiter.api.Assertions.assertFalse) FeatureDistribution(org.kie.kogito.explainability.model.FeatureDistribution) EmptyFeatureDomain(org.kie.kogito.explainability.model.domain.EmptyFeatureDomain) Mockito.atLeast(org.mockito.Mockito.atLeast) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) CategoricalFeatureDomain(org.kie.kogito.explainability.model.domain.CategoricalFeatureDomain) DataUtils(org.kie.kogito.explainability.utils.DataUtils) UUID(java.util.UUID) Collectors(java.util.stream.Collectors) Test(org.junit.jupiter.api.Test) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) List(java.util.List) Stream(java.util.stream.Stream) NormalDistribution(org.apache.commons.math3.distribution.NormalDistribution) Output(org.kie.kogito.explainability.model.Output) Assertions.assertTrue(org.junit.jupiter.api.Assertions.assertTrue) SolverJob(org.optaplanner.core.api.solver.SolverJob) Mockito.mock(org.mockito.Mockito.mock) IntStream(java.util.stream.IntStream) ArgumentMatchers.any(org.mockito.ArgumentMatchers.any) SolverConfig(org.optaplanner.core.config.solver.SolverConfig) Assertions.assertNotNull(org.junit.jupiter.api.Assertions.assertNotNull) PerturbationContext(org.kie.kogito.explainability.model.PerturbationContext) Prediction(org.kie.kogito.explainability.model.Prediction) DataDomain(org.kie.kogito.explainability.model.DataDomain) Assertions.assertNull(org.junit.jupiter.api.Assertions.assertNull) EnvironmentMode(org.optaplanner.core.config.solver.EnvironmentMode) CompletableFuture(java.util.concurrent.CompletableFuture) SolverManager(org.optaplanner.core.api.solver.SolverManager) Function(java.util.function.Function) ArrayList(java.util.ArrayList) MockCounterFactualScoreCalculator(org.kie.kogito.explainability.local.counterfactual.score.MockCounterFactualScoreCalculator) ArgumentCaptor(org.mockito.ArgumentCaptor) NumericFeatureDistribution(org.kie.kogito.explainability.model.NumericFeatureDistribution) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) Assertions.assertEquals(org.junit.jupiter.api.Assertions.assertEquals) LinkedList(java.util.LinkedList) BendableBigDecimalScore(org.optaplanner.core.api.score.buildin.bendablebigdecimal.BendableBigDecimalScore) ValueSource(org.junit.jupiter.params.provider.ValueSource) Logger(org.slf4j.Logger) Mockito.when(org.mockito.Mockito.when) Type(org.kie.kogito.explainability.model.Type) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Mockito.verify(org.mockito.Mockito.verify) ExecutionException(java.util.concurrent.ExecutionException) TimeUnit(java.util.concurrent.TimeUnit) Consumer(java.util.function.Consumer) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest) TestUtils(org.kie.kogito.explainability.TestUtils) NumericalFeatureDomain(org.kie.kogito.explainability.model.domain.NumericalFeatureDomain) Config(org.kie.kogito.explainability.Config) Collections(java.util.Collections) FeatureDomain(org.kie.kogito.explainability.model.domain.FeatureDomain) Consumer(java.util.function.Consumer) ArrayList(java.util.ArrayList) BendableBigDecimalScore(org.optaplanner.core.api.score.buildin.bendablebigdecimal.BendableBigDecimalScore) ValueSource(org.junit.jupiter.params.provider.ValueSource) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Example 2 with SolverManager

use of org.optaplanner.core.api.solver.SolverManager in project kogito-apps by kiegroup.

the class CounterfactualExplainer method explainAsync.

@Override
public CompletableFuture<CounterfactualResult> explainAsync(Prediction prediction, PredictionProvider model, Consumer<CounterfactualResult> intermediateResultsConsumer) {
    final AtomicLong sequenceId = new AtomicLong(0);
    final CounterfactualPrediction cfPrediction = (CounterfactualPrediction) prediction;
    final UUID executionId = cfPrediction.getExecutionId();
    final Long maxRunningTimeSeconds = cfPrediction.getMaxRunningTimeSeconds();
    final List<CounterfactualEntity> entities = CounterfactualEntityFactory.createEntities(prediction.getInput());
    final List<Output> goal = prediction.getOutput().getOutputs();
    // Original features kept as structural reference to re-assemble composite features
    final List<Feature> originalFeatures = prediction.getInput().getFeatures();
    Function<UUID, CounterfactualSolution> initial = uuid -> new CounterfactualSolution(entities, originalFeatures, model, goal, UUID.randomUUID(), executionId, this.counterfactualConfig.getGoalThreshold());
    final CompletableFuture<CounterfactualSolution> cfSolution = CompletableFuture.supplyAsync(() -> {
        SolverConfig solverConfig = this.counterfactualConfig.getSolverConfig();
        if (Objects.nonNull(maxRunningTimeSeconds)) {
            solverConfig.withTerminationSpentLimit(Duration.ofSeconds(maxRunningTimeSeconds));
        }
        try (SolverManager<CounterfactualSolution, UUID> solverManager = this.counterfactualConfig.getSolverManagerFactory().apply(solverConfig)) {
            SolverJob<CounterfactualSolution, UUID> solverJob = solverManager.solveAndListen(executionId, initial, assignSolutionId.andThen(createSolutionConsumer(intermediateResultsConsumer, sequenceId)), null);
            try {
                // Wait until the solving ends
                return solverJob.getFinalBestSolution();
            } catch (ExecutionException e) {
                logger.error("Solving failed: {}", e.getMessage());
                throw new IllegalStateException("Prediction returned an error", e);
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                throw new IllegalStateException("Solving failed (Thread interrupted)", e);
            }
        }
    }, this.counterfactualConfig.getExecutor());
    final CompletableFuture<List<PredictionOutput>> cfOutputs = cfSolution.thenCompose(s -> model.predictAsync(buildInput(s.getEntities())));
    return CompletableFuture.allOf(cfOutputs, cfSolution).thenApply(v -> {
        CounterfactualSolution solution = cfSolution.join();
        return new CounterfactualResult(solution.getEntities(), solution.getOriginalFeatures(), cfOutputs.join(), solution.getScore().isFeasible(), UUID.randomUUID(), solution.getExecutionId(), sequenceId.incrementAndGet());
    });
}
Also used : SolverConfig(org.optaplanner.core.config.solver.SolverConfig) Feature(org.kie.kogito.explainability.model.Feature) Prediction(org.kie.kogito.explainability.model.Prediction) LoggerFactory(org.slf4j.LoggerFactory) CompletableFuture(java.util.concurrent.CompletableFuture) CounterfactualEntity(org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity) SolverManager(org.optaplanner.core.api.solver.SolverManager) Function(java.util.function.Function) CompositeFeatureUtils(org.kie.kogito.explainability.utils.CompositeFeatureUtils) Duration(java.time.Duration) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Logger(org.slf4j.Logger) Executor(java.util.concurrent.Executor) LocalExplainer(org.kie.kogito.explainability.local.LocalExplainer) UUID(java.util.UUID) Collectors(java.util.stream.Collectors) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Objects(java.util.Objects) ExecutionException(java.util.concurrent.ExecutionException) Consumer(java.util.function.Consumer) AtomicLong(java.util.concurrent.atomic.AtomicLong) CounterfactualEntityFactory(org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntityFactory) PredictionInput(org.kie.kogito.explainability.model.PredictionInput) List(java.util.List) Output(org.kie.kogito.explainability.model.Output) SolverJob(org.optaplanner.core.api.solver.SolverJob) Feature(org.kie.kogito.explainability.model.Feature) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) CounterfactualEntity(org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity) AtomicLong(java.util.concurrent.atomic.AtomicLong) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) AtomicLong(java.util.concurrent.atomic.AtomicLong) List(java.util.List) UUID(java.util.UUID) ExecutionException(java.util.concurrent.ExecutionException) SolverConfig(org.optaplanner.core.config.solver.SolverConfig)

Aggregations

List (java.util.List)2 UUID (java.util.UUID)2 CompletableFuture (java.util.concurrent.CompletableFuture)2 ExecutionException (java.util.concurrent.ExecutionException)2 Consumer (java.util.function.Consumer)2 Function (java.util.function.Function)2 Collectors (java.util.stream.Collectors)2 CounterfactualEntity (org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity)2 CounterfactualPrediction (org.kie.kogito.explainability.model.CounterfactualPrediction)2 Feature (org.kie.kogito.explainability.model.Feature)2 Output (org.kie.kogito.explainability.model.Output)2 Prediction (org.kie.kogito.explainability.model.Prediction)2 PredictionInput (org.kie.kogito.explainability.model.PredictionInput)2 PredictionOutput (org.kie.kogito.explainability.model.PredictionOutput)2 PredictionProvider (org.kie.kogito.explainability.model.PredictionProvider)2 SolverJob (org.optaplanner.core.api.solver.SolverJob)2 SolverManager (org.optaplanner.core.api.solver.SolverManager)2 SolverConfig (org.optaplanner.core.config.solver.SolverConfig)2 Logger (org.slf4j.Logger)2 LoggerFactory (org.slf4j.LoggerFactory)2