Search in sources :

Example 51 with SolverConfig

use of org.optaplanner.core.config.solver.SolverConfig in project kogito-apps by kiegroup.

the class CounterfactualExplainerTest method testNonEmptyInput.

@ParameterizedTest
@ValueSource(ints = { 0, 1, 2 })
void testNonEmptyInput(int seed) throws ExecutionException, InterruptedException, TimeoutException {
    Random random = new Random();
    random.setSeed(seed);
    final List<Output> goal = List.of(new Output("class", Type.NUMBER, new Value(10.0), 0.0d));
    List<Feature> features = new LinkedList<>();
    for (int i = 0; i < 4; i++) {
        features.add(FeatureFactory.newNumericalFeature("f-" + i, random.nextDouble(), NumericalFeatureDomain.create(0.0, 1000.0)));
    }
    final TerminationConfig terminationConfig = new TerminationConfig().withScoreCalculationCountLimit(10L);
    // for the purpose of this test, only a few steps are necessary
    final SolverConfig solverConfig = SolverConfigBuilder.builder().withTerminationConfig(terminationConfig).build();
    solverConfig.setRandomSeed((long) seed);
    solverConfig.setEnvironmentMode(EnvironmentMode.REPRODUCIBLE);
    final CounterfactualConfig counterfactualConfig = new CounterfactualConfig().withSolverConfig(solverConfig);
    final CounterfactualExplainer counterfactualExplainer = new CounterfactualExplainer(counterfactualConfig);
    PredictionProvider model = TestUtils.getSumSkipModel(0);
    PredictionInput input = new PredictionInput(features);
    PredictionOutput output = new PredictionOutput(goal);
    Prediction prediction = new CounterfactualPrediction(input, output, null, UUID.randomUUID(), null);
    final CounterfactualResult counterfactualResult = counterfactualExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    for (CounterfactualEntity entity : counterfactualResult.getEntities()) {
        logger.debug("Entity: {}", entity);
    }
    logger.debug("Outputs: {}", counterfactualResult.getOutput().get(0).getOutputs());
    assertNotNull(counterfactualResult);
    assertNotNull(counterfactualResult.getEntities());
}
Also used : PredictionInput(org.kie.kogito.explainability.model.PredictionInput) Prediction(org.kie.kogito.explainability.model.Prediction) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Feature(org.kie.kogito.explainability.model.Feature) LinkedList(java.util.LinkedList) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) CounterfactualEntity(org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity) TerminationConfig(org.optaplanner.core.config.solver.termination.TerminationConfig) Random(java.util.Random) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) Value(org.kie.kogito.explainability.model.Value) SolverConfig(org.optaplanner.core.config.solver.SolverConfig) ValueSource(org.junit.jupiter.params.provider.ValueSource) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Example 52 with SolverConfig

use of org.optaplanner.core.config.solver.SolverConfig in project kogito-apps by kiegroup.

the class CounterfactualExplainerTest method testTerminationSpentLimitWhenDefined.

@Test
@SuppressWarnings("unchecked")
void testTerminationSpentLimitWhenDefined() throws ExecutionException, InterruptedException, TimeoutException {
    ArgumentCaptor<SolverConfig> solverConfigArgumentCaptor = ArgumentCaptor.forClass(SolverConfig.class);
    mockExplainerInvocation(mock(Consumer.class), MAX_RUNNING_TIME_SECONDS);
    verify(solverManagerFactory).apply(solverConfigArgumentCaptor.capture());
    SolverConfig solverConfig = solverConfigArgumentCaptor.getValue();
    assertEquals(MAX_RUNNING_TIME_SECONDS, solverConfig.getTerminationConfig().getSpentLimit().getSeconds());
}
Also used : Consumer(java.util.function.Consumer) SolverConfig(org.optaplanner.core.config.solver.SolverConfig) Test(org.junit.jupiter.api.Test) ParameterizedTest(org.junit.jupiter.params.ParameterizedTest)

Example 53 with SolverConfig

use of org.optaplanner.core.config.solver.SolverConfig in project kogito-apps by kiegroup.

the class AbstractTaskAssigningCoreTest method createDaemonSolver.

protected Solver<TaskAssigningSolution> createDaemonSolver() {
    SolverConfig config = createBaseConfig();
    config.setDaemon(true);
    SolverFactory<TaskAssigningSolution> solverFactory = SolverFactory.create(config);
    return solverFactory.buildSolver();
}
Also used : TaskAssigningSolution(org.kie.kogito.taskassigning.core.model.TaskAssigningSolution) SolverConfig(org.optaplanner.core.config.solver.SolverConfig)

Example 54 with SolverConfig

use of org.optaplanner.core.config.solver.SolverConfig in project kogito-apps by kiegroup.

the class PrequalificationDmnCounterfactualExplainerTest method testValidCounterfactual.

@Test
void testValidCounterfactual() throws ExecutionException, InterruptedException, TimeoutException {
    PredictionProvider model = getModel();
    final List<Output> goal = List.of(new Output("Qualified?", Type.BOOLEAN, new Value(true), 0.0d));
    final TerminationConfig terminationConfig = new TerminationConfig().withScoreCalculationCountLimit(steps);
    final SolverConfig solverConfig = SolverConfigBuilder.builder().withTerminationConfig(terminationConfig).build();
    solverConfig.setRandomSeed(randomSeed);
    solverConfig.setEnvironmentMode(EnvironmentMode.REPRODUCIBLE);
    CounterfactualConfig config = new CounterfactualConfig().withGoalThreshold(0.1);
    config.withSolverConfig(solverConfig);
    final CounterfactualExplainer explainer = new CounterfactualExplainer(config);
    PredictionInput input = getTestInputVariable();
    PredictionOutput output = new PredictionOutput(goal);
    // test model
    List<PredictionOutput> predictionOutputs = model.predictAsync(List.of(getTestInputFixed())).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    final Output predictionOutput = predictionOutputs.get(0).getOutputs().get(0);
    assertEquals("Qualified?", predictionOutput.getName());
    assertFalse((Boolean) predictionOutput.getValue().getUnderlyingObject());
    Prediction prediction = new CounterfactualPrediction(input, output, null, UUID.randomUUID(), null);
    CounterfactualResult counterfactualResult = explainer.explainAsync(prediction, model).get();
    List<Feature> cfFeatures = counterfactualResult.getEntities().stream().map(CounterfactualEntity::asFeature).collect(Collectors.toList());
    List<Feature> unflattened = CompositeFeatureUtils.unflattenFeatures(cfFeatures, input.getFeatures());
    List<PredictionOutput> outputs = model.predictAsync(List.of(new PredictionInput(unflattened))).get();
    assertTrue(counterfactualResult.isValid());
    final Output decideOutput = outputs.get(0).getOutputs().get(0);
    assertEquals("Qualified?", decideOutput.getName());
    assertTrue((Boolean) decideOutput.getValue().getUnderlyingObject());
}
Also used : PredictionInput(org.kie.kogito.explainability.model.PredictionInput) Prediction(org.kie.kogito.explainability.model.Prediction) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Feature(org.kie.kogito.explainability.model.Feature) CounterfactualResult(org.kie.kogito.explainability.local.counterfactual.CounterfactualResult) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) TerminationConfig(org.optaplanner.core.config.solver.termination.TerminationConfig) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) CounterfactualConfig(org.kie.kogito.explainability.local.counterfactual.CounterfactualConfig) Value(org.kie.kogito.explainability.model.Value) CounterfactualExplainer(org.kie.kogito.explainability.local.counterfactual.CounterfactualExplainer) SolverConfig(org.optaplanner.core.config.solver.SolverConfig) Test(org.junit.jupiter.api.Test)

Aggregations

SolverConfig (org.optaplanner.core.config.solver.SolverConfig)54 TerminationConfig (org.optaplanner.core.config.solver.termination.TerminationConfig)17 ScoreDirectorFactoryConfig (org.optaplanner.core.config.score.director.ScoreDirectorFactoryConfig)14 LocalSearchPhaseConfig (org.optaplanner.core.config.localsearch.LocalSearchPhaseConfig)12 CounterfactualPrediction (org.kie.kogito.explainability.model.CounterfactualPrediction)11 Prediction (org.kie.kogito.explainability.model.Prediction)11 PredictionInput (org.kie.kogito.explainability.model.PredictionInput)11 PredictionOutput (org.kie.kogito.explainability.model.PredictionOutput)11 Feature (org.kie.kogito.explainability.model.Feature)10 Output (org.kie.kogito.explainability.model.Output)10 PredictionProvider (org.kie.kogito.explainability.model.PredictionProvider)10 Test (org.junit.jupiter.api.Test)9 ConstructionHeuristicPhaseConfig (org.optaplanner.core.config.constructionheuristic.ConstructionHeuristicPhaseConfig)9 LinkedList (java.util.LinkedList)7 CounterfactualEntity (org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity)7 LocalSearchForagerConfig (org.optaplanner.core.config.localsearch.decider.forager.LocalSearchForagerConfig)7 VehicleRoutingSolution (org.optaplanner.examples.vehiclerouting.domain.VehicleRoutingSolution)7 ParameterizedTest (org.junit.jupiter.params.ParameterizedTest)6 Value (org.kie.kogito.explainability.model.Value)6 UnionMoveSelectorConfig (org.optaplanner.core.config.heuristic.selector.move.composite.UnionMoveSelectorConfig)6