use of org.optaplanner.core.config.solver.SolverConfig in project kogito-apps by kiegroup.
the class LimeConfigOptimizer method optimize.
public LimeConfig optimize(LimeConfig config, List<Prediction> predictions, PredictionProvider model) {
List<LimeConfigEntity> entities = new ArrayList<>();
if (samplingEntities) {
entities.addAll(LimeConfigEntityFactory.createSamplingEntities(config));
}
if (proximityEntities) {
entities.addAll(LimeConfigEntityFactory.createProximityEntities(config));
}
if (encodingEntities) {
entities.addAll(LimeConfigEntityFactory.createEncodingEntities(config));
}
if (weightingEntities) {
entities.addAll(LimeConfigEntityFactory.createWeightingEntities(config));
}
if (entities.isEmpty()) {
return config;
}
LimeConfigSolution initialSolution = new LimeConfigSolution(config, predictions, entities, model);
SolverConfig solverConfig = new SolverConfig().withEntityClasses(NumericLimeConfigEntity.class, BooleanLimeConfigEntity.class).withSolutionClass(LimeConfigSolution.class);
ScoreDirectorFactoryConfig scoreDirectorFactoryConfig = new ScoreDirectorFactoryConfig();
scoreDirectorFactoryConfig.setEasyScoreCalculatorClass(scoreCalculator.getClass());
solverConfig.setScoreDirectorFactoryConfig(scoreDirectorFactoryConfig);
TerminationConfig terminationConfig = new TerminationConfig();
if (timeLimit > 0) {
terminationConfig.setSecondsSpentLimit(timeLimit);
}
solverConfig.setTerminationConfig(terminationConfig);
LocalSearchPhaseConfig localSearchPhaseConfig = new LocalSearchPhaseConfig();
if (deterministic) {
Optional<Long> seed = config.getPerturbationContext().getSeed();
seed.ifPresent(solverConfig::setRandomSeed);
solverConfig.setEnvironmentMode(EnvironmentMode.REPRODUCIBLE);
} else {
logger.debug("non reproducible execution, set the seed inside initial LimeConfig's PerturbationContext and enable deterministic execution to fix this");
}
localSearchPhaseConfig.setLocalSearchType(LocalSearchType.LATE_ACCEPTANCE);
if (stepCountLimit > 0) {
localSearchPhaseConfig.setTerminationConfig(new TerminationConfig().withStepCountLimit(stepCountLimit).withBestScoreLimit("1.0"));
}
@SuppressWarnings("rawtypes") List<PhaseConfig> phaseConfigs = new ArrayList<>();
phaseConfigs.add(localSearchPhaseConfig);
solverConfig.setPhaseConfigList(phaseConfigs);
try (SolverManager<LimeConfigSolution, UUID> solverManager = SolverManager.create(solverConfig, new SolverManagerConfig())) {
UUID executionId = UUID.randomUUID();
SolverJob<LimeConfigSolution, UUID> solverJob = solverManager.solve(executionId, initialSolution);
try {
// Wait until the solving ends
LimeConfigSolution finalBestSolution = solverJob.getFinalBestSolution();
LimeConfig finalConfig = LimeConfigEntityFactory.toLimeConfig(finalBestSolution);
BigDecimal score = finalBestSolution.getScore().getScore();
logger.info("final best solution score {} with config {}", score, finalConfig);
return finalConfig;
} catch (ExecutionException e) {
logger.error("Solving failed: {}", e.getMessage());
throw new IllegalStateException("Prediction returned an error", e);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new IllegalStateException("Solving failed (Thread interrupted)", e);
}
}
}
use of org.optaplanner.core.config.solver.SolverConfig in project kogito-apps by kiegroup.
the class AbstractTaskAssigningCoreTest method createNonDaemonSolver.
protected Solver<TaskAssigningSolution> createNonDaemonSolver(int stepCountLimit) {
SolverConfig config = createBaseConfig();
ConstructionHeuristicPhaseConfig constructionHeuristicPhaseConfig = new ConstructionHeuristicPhaseConfig();
constructionHeuristicPhaseConfig.setConstructionHeuristicType(ConstructionHeuristicType.FIRST_FIT);
LocalSearchPhaseConfig phaseConfig = new LocalSearchPhaseConfig();
phaseConfig.setTerminationConfig(new TerminationConfig().withStepCountLimit(stepCountLimit));
config.setPhaseConfigList(Arrays.asList(constructionHeuristicPhaseConfig, phaseConfig));
SolverFactory<TaskAssigningSolution> solverFactory = SolverFactory.create(config);
return solverFactory.buildSolver();
}
use of org.optaplanner.core.config.solver.SolverConfig in project kogito-apps by kiegroup.
the class AbstractTaskAssigningCoreTest method createBaseConfig.
protected SolverConfig createBaseConfig() {
SolverConfig config = new SolverConfig();
config.setSolutionClass(TaskAssigningSolution.class);
config.setEntityClassList(Arrays.asList(ChainElement.class, TaskAssignment.class));
config.setScoreDirectorFactoryConfig(new ScoreDirectorFactoryConfig().withConstraintProviderClass(DefaultTaskAssigningConstraintProvider.class));
return config;
}
use of org.optaplanner.core.config.solver.SolverConfig in project kogito-apps by kiegroup.
the class LoanEligibilityDmnCounterfactualExplainerTest method testLoanEligibilityDMNExplanation.
@Test
void testLoanEligibilityDMNExplanation() throws ExecutionException, InterruptedException, TimeoutException {
PredictionProvider model = getModel();
final List<Output> goal = List.of(new Output("Is Enought?", Type.NUMBER, new Value(100), 0.0d), new Output("Eligibility", Type.TEXT, new Value("No"), 0.0d), new Output("Decide", Type.BOOLEAN, new Value(true), 0.0d));
final TerminationConfig terminationConfig = new TerminationConfig().withScoreCalculationCountLimit(steps);
final SolverConfig solverConfig = SolverConfigBuilder.builder().withTerminationConfig(terminationConfig).build();
solverConfig.setRandomSeed(randomSeed);
solverConfig.setEnvironmentMode(EnvironmentMode.REPRODUCIBLE);
CounterfactualConfig config = new CounterfactualConfig();
config.withSolverConfig(solverConfig);
final CounterfactualExplainer explainer = new CounterfactualExplainer(config);
PredictionInput input = getTestInput();
PredictionOutput output = new PredictionOutput(goal);
// test model
List<PredictionOutput> predictionOutputs = model.predictAsync(List.of(input)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
Prediction prediction = new CounterfactualPrediction(input, output, null, UUID.randomUUID(), null);
CounterfactualResult counterfactualResult = explainer.explainAsync(prediction, model).get();
List<Feature> cfFeatures = counterfactualResult.getEntities().stream().map(CounterfactualEntity::asFeature).collect(Collectors.toList());
List<Feature> unflattened = CompositeFeatureUtils.unflattenFeatures(cfFeatures, input.getFeatures());
List<PredictionOutput> outputs = model.predictAsync(List.of(new PredictionInput(unflattened))).get();
assertTrue(counterfactualResult.isValid());
final Output decideOutput = outputs.get(0).getOutputs().get(2);
assertEquals("Decide", decideOutput.getName());
assertTrue((Boolean) decideOutput.getValue().getUnderlyingObject());
}
use of org.optaplanner.core.config.solver.SolverConfig in project kogito-apps by kiegroup.
the class ComplexEligibilityDmnCounterfactualExplainerTest method testDMNScoringFunction.
@Test
void testDMNScoringFunction() throws ExecutionException, InterruptedException, TimeoutException {
PredictionProvider model = getModel();
final List<Output> goal = generateGoal(true, true, 1.0);
List<Feature> features = new LinkedList<>();
features.add(FeatureFactory.newNumericalFeature("age", 40, NumericalFeatureDomain.create(18, 60)));
features.add(FeatureFactory.newBooleanFeature("hasReferral", true));
features.add(FeatureFactory.newNumericalFeature("monthlySalary", 500, NumericalFeatureDomain.create(10, 100_000)));
final TerminationConfig terminationConfig = new TerminationConfig().withScoreCalculationCountLimit(10_000L);
// for the purpose of this test, only a few steps are necessary
final SolverConfig solverConfig = SolverConfigBuilder.builder().withTerminationConfig(terminationConfig).build();
solverConfig.setRandomSeed((long) 23);
solverConfig.setEnvironmentMode(EnvironmentMode.REPRODUCIBLE);
final CounterfactualConfig counterfactualConfig = new CounterfactualConfig().withSolverConfig(solverConfig).withGoalThreshold(0.01);
final CounterfactualExplainer counterfactualExplainer = new CounterfactualExplainer(counterfactualConfig);
PredictionInput input = new PredictionInput(features);
PredictionOutput output = new PredictionOutput(goal);
Prediction prediction = new CounterfactualPrediction(input, output, null, UUID.randomUUID(), 60L);
final CounterfactualResult counterfactualResult = counterfactualExplainer.explainAsync(prediction, model).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
List<Output> cfOutputs = counterfactualResult.getOutput().get(0).getOutputs();
assertTrue(counterfactualResult.isValid());
assertEquals("inputsAreValid", cfOutputs.get(0).getName());
assertTrue((Boolean) cfOutputs.get(0).getValue().getUnderlyingObject());
assertEquals("canRequestLoan", cfOutputs.get(1).getName());
assertTrue((Boolean) cfOutputs.get(1).getValue().getUnderlyingObject());
assertEquals("my-scoring-function", cfOutputs.get(2).getName());
assertEquals(1.0, ((BigDecimal) cfOutputs.get(2).getValue().getUnderlyingObject()).doubleValue(), 0.01);
List<CounterfactualEntity> entities = counterfactualResult.getEntities();
assertEquals("age", entities.get(0).asFeature().getName());
assertEquals(18, entities.get(0).asFeature().getValue().asNumber());
assertEquals("hasReferral", entities.get(1).asFeature().getName());
assertTrue((Boolean) entities.get(1).asFeature().getValue().getUnderlyingObject());
assertEquals("monthlySalary", entities.get(2).asFeature().getName());
final double monthlySalary = entities.get(2).asFeature().getValue().asNumber();
assertEquals(7900, monthlySalary, 10);
// since the scoring function is ((0.6 * ((42 - age + 18)/42)) + (0.4 * (monthlySalary/8000)))
// for a result of 1.0 the relation must be age = (7*monthlySalary)/2000 - 10
assertEquals(18, (7 * monthlySalary) / 2000.0 - 10.0, 0.5);
}
Aggregations