Search in sources :

Example 1 with SolverManagerConfig

use of org.optaplanner.core.config.solver.SolverManagerConfig in project kogito-apps by kiegroup.

the class LimeConfigOptimizer method optimize.

public LimeConfig optimize(LimeConfig config, List<Prediction> predictions, PredictionProvider model) {
    List<LimeConfigEntity> entities = new ArrayList<>();
    if (samplingEntities) {
        entities.addAll(LimeConfigEntityFactory.createSamplingEntities(config));
    }
    if (proximityEntities) {
        entities.addAll(LimeConfigEntityFactory.createProximityEntities(config));
    }
    if (encodingEntities) {
        entities.addAll(LimeConfigEntityFactory.createEncodingEntities(config));
    }
    if (weightingEntities) {
        entities.addAll(LimeConfigEntityFactory.createWeightingEntities(config));
    }
    if (entities.isEmpty()) {
        return config;
    }
    LimeConfigSolution initialSolution = new LimeConfigSolution(config, predictions, entities, model);
    SolverConfig solverConfig = new SolverConfig().withEntityClasses(NumericLimeConfigEntity.class, BooleanLimeConfigEntity.class).withSolutionClass(LimeConfigSolution.class);
    ScoreDirectorFactoryConfig scoreDirectorFactoryConfig = new ScoreDirectorFactoryConfig();
    scoreDirectorFactoryConfig.setEasyScoreCalculatorClass(scoreCalculator.getClass());
    solverConfig.setScoreDirectorFactoryConfig(scoreDirectorFactoryConfig);
    TerminationConfig terminationConfig = new TerminationConfig();
    if (timeLimit > 0) {
        terminationConfig.setSecondsSpentLimit(timeLimit);
    }
    solverConfig.setTerminationConfig(terminationConfig);
    LocalSearchPhaseConfig localSearchPhaseConfig = new LocalSearchPhaseConfig();
    if (deterministic) {
        Optional<Long> seed = config.getPerturbationContext().getSeed();
        seed.ifPresent(solverConfig::setRandomSeed);
        solverConfig.setEnvironmentMode(EnvironmentMode.REPRODUCIBLE);
    } else {
        logger.debug("non reproducible execution, set the seed inside initial LimeConfig's PerturbationContext and enable deterministic execution to fix this");
    }
    localSearchPhaseConfig.setLocalSearchType(LocalSearchType.LATE_ACCEPTANCE);
    if (stepCountLimit > 0) {
        localSearchPhaseConfig.setTerminationConfig(new TerminationConfig().withStepCountLimit(stepCountLimit).withBestScoreLimit("1.0"));
    }
    @SuppressWarnings("rawtypes") List<PhaseConfig> phaseConfigs = new ArrayList<>();
    phaseConfigs.add(localSearchPhaseConfig);
    solverConfig.setPhaseConfigList(phaseConfigs);
    try (SolverManager<LimeConfigSolution, UUID> solverManager = SolverManager.create(solverConfig, new SolverManagerConfig())) {
        UUID executionId = UUID.randomUUID();
        SolverJob<LimeConfigSolution, UUID> solverJob = solverManager.solve(executionId, initialSolution);
        try {
            // Wait until the solving ends
            LimeConfigSolution finalBestSolution = solverJob.getFinalBestSolution();
            LimeConfig finalConfig = LimeConfigEntityFactory.toLimeConfig(finalBestSolution);
            BigDecimal score = finalBestSolution.getScore().getScore();
            logger.info("final best solution score {} with config {}", score, finalConfig);
            return finalConfig;
        } catch (ExecutionException e) {
            logger.error("Solving failed: {}", e.getMessage());
            throw new IllegalStateException("Prediction returned an error", e);
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            throw new IllegalStateException("Solving failed (Thread interrupted)", e);
        }
    }
}
Also used : ArrayList(java.util.ArrayList) SolverManagerConfig(org.optaplanner.core.config.solver.SolverManagerConfig) UUID(java.util.UUID) ExecutionException(java.util.concurrent.ExecutionException) SolverConfig(org.optaplanner.core.config.solver.SolverConfig) ScoreDirectorFactoryConfig(org.optaplanner.core.config.score.director.ScoreDirectorFactoryConfig) PhaseConfig(org.optaplanner.core.config.phase.PhaseConfig) LocalSearchPhaseConfig(org.optaplanner.core.config.localsearch.LocalSearchPhaseConfig) LimeConfig(org.kie.kogito.explainability.local.lime.LimeConfig) BigDecimal(java.math.BigDecimal) TerminationConfig(org.optaplanner.core.config.solver.termination.TerminationConfig) LocalSearchPhaseConfig(org.optaplanner.core.config.localsearch.LocalSearchPhaseConfig)

Aggregations

BigDecimal (java.math.BigDecimal)1 ArrayList (java.util.ArrayList)1 UUID (java.util.UUID)1 ExecutionException (java.util.concurrent.ExecutionException)1 LimeConfig (org.kie.kogito.explainability.local.lime.LimeConfig)1 LocalSearchPhaseConfig (org.optaplanner.core.config.localsearch.LocalSearchPhaseConfig)1 PhaseConfig (org.optaplanner.core.config.phase.PhaseConfig)1 ScoreDirectorFactoryConfig (org.optaplanner.core.config.score.director.ScoreDirectorFactoryConfig)1 SolverConfig (org.optaplanner.core.config.solver.SolverConfig)1 SolverManagerConfig (org.optaplanner.core.config.solver.SolverManagerConfig)1 TerminationConfig (org.optaplanner.core.config.solver.termination.TerminationConfig)1