Search in sources :

Example 6 with SolverConfig

use of org.optaplanner.core.config.solver.SolverConfig in project kogito-apps by kiegroup.

the class LimeConfigOptimizer method optimize.

public LimeConfig optimize(LimeConfig config, List<Prediction> predictions, PredictionProvider model) {
    List<LimeConfigEntity> entities = new ArrayList<>();
    if (samplingEntities) {
        entities.addAll(LimeConfigEntityFactory.createSamplingEntities(config));
    }
    if (proximityEntities) {
        entities.addAll(LimeConfigEntityFactory.createProximityEntities(config));
    }
    if (encodingEntities) {
        entities.addAll(LimeConfigEntityFactory.createEncodingEntities(config));
    }
    if (weightingEntities) {
        entities.addAll(LimeConfigEntityFactory.createWeightingEntities(config));
    }
    if (entities.isEmpty()) {
        return config;
    }
    LimeConfigSolution initialSolution = new LimeConfigSolution(config, predictions, entities, model);
    SolverConfig solverConfig = new SolverConfig().withEntityClasses(NumericLimeConfigEntity.class, BooleanLimeConfigEntity.class).withSolutionClass(LimeConfigSolution.class);
    ScoreDirectorFactoryConfig scoreDirectorFactoryConfig = new ScoreDirectorFactoryConfig();
    scoreDirectorFactoryConfig.setEasyScoreCalculatorClass(scoreCalculator.getClass());
    solverConfig.setScoreDirectorFactoryConfig(scoreDirectorFactoryConfig);
    TerminationConfig terminationConfig = new TerminationConfig();
    if (timeLimit > 0) {
        terminationConfig.setSecondsSpentLimit(timeLimit);
    }
    solverConfig.setTerminationConfig(terminationConfig);
    LocalSearchPhaseConfig localSearchPhaseConfig = new LocalSearchPhaseConfig();
    if (deterministic) {
        Optional<Long> seed = config.getPerturbationContext().getSeed();
        seed.ifPresent(solverConfig::setRandomSeed);
        solverConfig.setEnvironmentMode(EnvironmentMode.REPRODUCIBLE);
    } else {
        logger.debug("non reproducible execution, set the seed inside initial LimeConfig's PerturbationContext and enable deterministic execution to fix this");
    }
    localSearchPhaseConfig.setLocalSearchType(LocalSearchType.LATE_ACCEPTANCE);
    if (stepCountLimit > 0) {
        localSearchPhaseConfig.setTerminationConfig(new TerminationConfig().withStepCountLimit(stepCountLimit).withBestScoreLimit("1.0"));
    }
    @SuppressWarnings("rawtypes") List<PhaseConfig> phaseConfigs = new ArrayList<>();
    phaseConfigs.add(localSearchPhaseConfig);
    solverConfig.setPhaseConfigList(phaseConfigs);
    try (SolverManager<LimeConfigSolution, UUID> solverManager = SolverManager.create(solverConfig, new SolverManagerConfig())) {
        UUID executionId = UUID.randomUUID();
        SolverJob<LimeConfigSolution, UUID> solverJob = solverManager.solve(executionId, initialSolution);
        try {
            // Wait until the solving ends
            LimeConfigSolution finalBestSolution = solverJob.getFinalBestSolution();
            LimeConfig finalConfig = LimeConfigEntityFactory.toLimeConfig(finalBestSolution);
            BigDecimal score = finalBestSolution.getScore().getScore();
            logger.info("final best solution score {} with config {}", score, finalConfig);
            return finalConfig;
        } catch (ExecutionException e) {
            logger.error("Solving failed: {}", e.getMessage());
            throw new IllegalStateException("Prediction returned an error", e);
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            throw new IllegalStateException("Solving failed (Thread interrupted)", e);
        }
    }
}
Also used : ArrayList(java.util.ArrayList) SolverManagerConfig(org.optaplanner.core.config.solver.SolverManagerConfig) UUID(java.util.UUID) ExecutionException(java.util.concurrent.ExecutionException) SolverConfig(org.optaplanner.core.config.solver.SolverConfig) ScoreDirectorFactoryConfig(org.optaplanner.core.config.score.director.ScoreDirectorFactoryConfig) PhaseConfig(org.optaplanner.core.config.phase.PhaseConfig) LocalSearchPhaseConfig(org.optaplanner.core.config.localsearch.LocalSearchPhaseConfig) LimeConfig(org.kie.kogito.explainability.local.lime.LimeConfig) BigDecimal(java.math.BigDecimal) TerminationConfig(org.optaplanner.core.config.solver.termination.TerminationConfig) LocalSearchPhaseConfig(org.optaplanner.core.config.localsearch.LocalSearchPhaseConfig)

Example 7 with SolverConfig

use of org.optaplanner.core.config.solver.SolverConfig in project kogito-apps by kiegroup.

the class AbstractTaskAssigningCoreTest method createNonDaemonSolver.

protected Solver<TaskAssigningSolution> createNonDaemonSolver(int stepCountLimit) {
    SolverConfig config = createBaseConfig();
    ConstructionHeuristicPhaseConfig constructionHeuristicPhaseConfig = new ConstructionHeuristicPhaseConfig();
    constructionHeuristicPhaseConfig.setConstructionHeuristicType(ConstructionHeuristicType.FIRST_FIT);
    LocalSearchPhaseConfig phaseConfig = new LocalSearchPhaseConfig();
    phaseConfig.setTerminationConfig(new TerminationConfig().withStepCountLimit(stepCountLimit));
    config.setPhaseConfigList(Arrays.asList(constructionHeuristicPhaseConfig, phaseConfig));
    SolverFactory<TaskAssigningSolution> solverFactory = SolverFactory.create(config);
    return solverFactory.buildSolver();
}
Also used : TerminationConfig(org.optaplanner.core.config.solver.termination.TerminationConfig) TaskAssigningSolution(org.kie.kogito.taskassigning.core.model.TaskAssigningSolution) LocalSearchPhaseConfig(org.optaplanner.core.config.localsearch.LocalSearchPhaseConfig) ConstructionHeuristicPhaseConfig(org.optaplanner.core.config.constructionheuristic.ConstructionHeuristicPhaseConfig) SolverConfig(org.optaplanner.core.config.solver.SolverConfig)

Example 8 with SolverConfig

use of org.optaplanner.core.config.solver.SolverConfig in project kogito-apps by kiegroup.

the class AbstractTaskAssigningCoreTest method createBaseConfig.

protected SolverConfig createBaseConfig() {
    SolverConfig config = new SolverConfig();
    config.setSolutionClass(TaskAssigningSolution.class);
    config.setEntityClassList(Arrays.asList(ChainElement.class, TaskAssignment.class));
    config.setScoreDirectorFactoryConfig(new ScoreDirectorFactoryConfig().withConstraintProviderClass(DefaultTaskAssigningConstraintProvider.class));
    return config;
}
Also used : ScoreDirectorFactoryConfig(org.optaplanner.core.config.score.director.ScoreDirectorFactoryConfig) ChainElement(org.kie.kogito.taskassigning.core.model.ChainElement) DefaultTaskAssigningConstraintProvider(org.kie.kogito.taskassigning.core.model.solver.DefaultTaskAssigningConstraintProvider) SolverConfig(org.optaplanner.core.config.solver.SolverConfig) TaskAssignment(org.kie.kogito.taskassigning.core.model.TaskAssignment)

Example 9 with SolverConfig

use of org.optaplanner.core.config.solver.SolverConfig in project kogito-apps by kiegroup.

the class LoanEligibilityDmnCounterfactualExplainerTest method testLoanEligibilityDMNExplanation.

@Test
void testLoanEligibilityDMNExplanation() throws ExecutionException, InterruptedException, TimeoutException {
    PredictionProvider model = getModel();
    final List<Output> goal = List.of(new Output("Is Enought?", Type.NUMBER, new Value(100), 0.0d), new Output("Eligibility", Type.TEXT, new Value("No"), 0.0d), new Output("Decide", Type.BOOLEAN, new Value(true), 0.0d));
    final TerminationConfig terminationConfig = new TerminationConfig().withScoreCalculationCountLimit(steps);
    final SolverConfig solverConfig = SolverConfigBuilder.builder().withTerminationConfig(terminationConfig).build();
    solverConfig.setRandomSeed(randomSeed);
    solverConfig.setEnvironmentMode(EnvironmentMode.REPRODUCIBLE);
    CounterfactualConfig config = new CounterfactualConfig();
    config.withSolverConfig(solverConfig);
    final CounterfactualExplainer explainer = new CounterfactualExplainer(config);
    PredictionInput input = getTestInput();
    PredictionOutput output = new PredictionOutput(goal);
    // test model
    List<PredictionOutput> predictionOutputs = model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    Prediction prediction = new CounterfactualPrediction(input, output, null, UUID.randomUUID(), null);
    CounterfactualResult counterfactualResult = explainer.explainAsync(prediction, model).get();
    List<Feature> cfFeatures = counterfactualResult.getEntities().stream().map(CounterfactualEntity::asFeature).collect(Collectors.toList());
    List<Feature> unflattened = CompositeFeatureUtils.unflattenFeatures(cfFeatures, input.getFeatures());
    List<PredictionOutput> outputs = model.predictAsync(List.of(new PredictionInput(unflattened))).get();
    assertTrue(counterfactualResult.isValid());
    final Output decideOutput = outputs.get(0).getOutputs().get(2);
    assertEquals("Decide", decideOutput.getName());
    assertTrue((Boolean) decideOutput.getValue().getUnderlyingObject());
}
Also used : PredictionInput(org.kie.kogito.explainability.model.PredictionInput) Prediction(org.kie.kogito.explainability.model.Prediction) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Feature(org.kie.kogito.explainability.model.Feature) CounterfactualResult(org.kie.kogito.explainability.local.counterfactual.CounterfactualResult) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) TerminationConfig(org.optaplanner.core.config.solver.termination.TerminationConfig) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) CounterfactualConfig(org.kie.kogito.explainability.local.counterfactual.CounterfactualConfig) Value(org.kie.kogito.explainability.model.Value) CounterfactualExplainer(org.kie.kogito.explainability.local.counterfactual.CounterfactualExplainer) SolverConfig(org.optaplanner.core.config.solver.SolverConfig) Test(org.junit.jupiter.api.Test)

Example 10 with SolverConfig

use of org.optaplanner.core.config.solver.SolverConfig in project kogito-apps by kiegroup.

the class ComplexEligibilityDmnCounterfactualExplainerTest method testDMNScoringFunction.

@Test
void testDMNScoringFunction() throws ExecutionException, InterruptedException, TimeoutException {
    PredictionProvider model = getModel();
    final List<Output> goal = generateGoal(true, true, 1.0);
    List<Feature> features = new LinkedList<>();
    features.add(FeatureFactory.newNumericalFeature("age", 40, NumericalFeatureDomain.create(18, 60)));
    features.add(FeatureFactory.newBooleanFeature("hasReferral", true));
    features.add(FeatureFactory.newNumericalFeature("monthlySalary", 500, NumericalFeatureDomain.create(10, 100_000)));
    final TerminationConfig terminationConfig = new TerminationConfig().withScoreCalculationCountLimit(10_000L);
    // for the purpose of this test, only a few steps are necessary
    final SolverConfig solverConfig = SolverConfigBuilder.builder().withTerminationConfig(terminationConfig).build();
    solverConfig.setRandomSeed((long) 23);
    solverConfig.setEnvironmentMode(EnvironmentMode.REPRODUCIBLE);
    final CounterfactualConfig counterfactualConfig = new CounterfactualConfig().withSolverConfig(solverConfig).withGoalThreshold(0.01);
    final CounterfactualExplainer counterfactualExplainer = new CounterfactualExplainer(counterfactualConfig);
    PredictionInput input = new PredictionInput(features);
    PredictionOutput output = new PredictionOutput(goal);
    Prediction prediction = new CounterfactualPrediction(input, output, null, UUID.randomUUID(), 60L);
    final CounterfactualResult counterfactualResult = counterfactualExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
    List<Output> cfOutputs = counterfactualResult.getOutput().get(0).getOutputs();
    assertTrue(counterfactualResult.isValid());
    assertEquals("inputsAreValid", cfOutputs.get(0).getName());
    assertTrue((Boolean) cfOutputs.get(0).getValue().getUnderlyingObject());
    assertEquals("canRequestLoan", cfOutputs.get(1).getName());
    assertTrue((Boolean) cfOutputs.get(1).getValue().getUnderlyingObject());
    assertEquals("my-scoring-function", cfOutputs.get(2).getName());
    assertEquals(1.0, ((BigDecimal) cfOutputs.get(2).getValue().getUnderlyingObject()).doubleValue(), 0.01);
    List<CounterfactualEntity> entities = counterfactualResult.getEntities();
    assertEquals("age", entities.get(0).asFeature().getName());
    assertEquals(18, entities.get(0).asFeature().getValue().asNumber());
    assertEquals("hasReferral", entities.get(1).asFeature().getName());
    assertTrue((Boolean) entities.get(1).asFeature().getValue().getUnderlyingObject());
    assertEquals("monthlySalary", entities.get(2).asFeature().getName());
    final double monthlySalary = entities.get(2).asFeature().getValue().asNumber();
    assertEquals(7900, monthlySalary, 10);
    // since the scoring function is ((0.6 * ((42 - age + 18)/42)) + (0.4 * (monthlySalary/8000)))
    // for a result of 1.0 the relation must be age = (7*monthlySalary)/2000 - 10
    assertEquals(18, (7 * monthlySalary) / 2000.0 - 10.0, 0.5);
}
Also used : PredictionInput(org.kie.kogito.explainability.model.PredictionInput) Prediction(org.kie.kogito.explainability.model.Prediction) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) PredictionProvider(org.kie.kogito.explainability.model.PredictionProvider) Feature(org.kie.kogito.explainability.model.Feature) LinkedList(java.util.LinkedList) CounterfactualResult(org.kie.kogito.explainability.local.counterfactual.CounterfactualResult) CounterfactualPrediction(org.kie.kogito.explainability.model.CounterfactualPrediction) CounterfactualEntity(org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity) TerminationConfig(org.optaplanner.core.config.solver.termination.TerminationConfig) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) PredictionOutput(org.kie.kogito.explainability.model.PredictionOutput) Output(org.kie.kogito.explainability.model.Output) CounterfactualConfig(org.kie.kogito.explainability.local.counterfactual.CounterfactualConfig) CounterfactualExplainer(org.kie.kogito.explainability.local.counterfactual.CounterfactualExplainer) SolverConfig(org.optaplanner.core.config.solver.SolverConfig) Test(org.junit.jupiter.api.Test)

Aggregations

SolverConfig (org.optaplanner.core.config.solver.SolverConfig)54 TerminationConfig (org.optaplanner.core.config.solver.termination.TerminationConfig)17 ScoreDirectorFactoryConfig (org.optaplanner.core.config.score.director.ScoreDirectorFactoryConfig)14 LocalSearchPhaseConfig (org.optaplanner.core.config.localsearch.LocalSearchPhaseConfig)12 CounterfactualPrediction (org.kie.kogito.explainability.model.CounterfactualPrediction)11 Prediction (org.kie.kogito.explainability.model.Prediction)11 PredictionInput (org.kie.kogito.explainability.model.PredictionInput)11 PredictionOutput (org.kie.kogito.explainability.model.PredictionOutput)11 Feature (org.kie.kogito.explainability.model.Feature)10 Output (org.kie.kogito.explainability.model.Output)10 PredictionProvider (org.kie.kogito.explainability.model.PredictionProvider)10 Test (org.junit.jupiter.api.Test)9 ConstructionHeuristicPhaseConfig (org.optaplanner.core.config.constructionheuristic.ConstructionHeuristicPhaseConfig)9 LinkedList (java.util.LinkedList)7 CounterfactualEntity (org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity)7 LocalSearchForagerConfig (org.optaplanner.core.config.localsearch.decider.forager.LocalSearchForagerConfig)7 VehicleRoutingSolution (org.optaplanner.examples.vehiclerouting.domain.VehicleRoutingSolution)7 ParameterizedTest (org.junit.jupiter.params.ParameterizedTest)6 Value (org.kie.kogito.explainability.model.Value)6 UnionMoveSelectorConfig (org.optaplanner.core.config.heuristic.selector.move.composite.UnionMoveSelectorConfig)6