use of org.optaplanner.core.config.solver.termination.TerminationConfig in project kogito-apps by kiegroup.
the class CounterfactualExplainerTest method testFinalUniqueIds.
@ParameterizedTest
@ValueSource(ints = { 0, 1, 2 })
void testFinalUniqueIds(int seed) throws ExecutionException, InterruptedException, TimeoutException {
Random random = new Random();
random.setSeed(seed);
final List<Output> goal = new ArrayList<>();
List<Feature> features = List.of(FeatureFactory.newNumericalFeature("f-num1", 10.0, NumericalFeatureDomain.create(0, 20)));
PredictionProvider model = TestUtils.getFeaturePassModel(0);
final TerminationConfig terminationConfig = new TerminationConfig().withScoreCalculationCountLimit(100_000L);
final SolverConfig solverConfig = SolverConfigBuilder.builder().withTerminationConfig(terminationConfig).build();
solverConfig.setRandomSeed((long) seed);
solverConfig.setEnvironmentMode(EnvironmentMode.REPRODUCIBLE);
final List<UUID> intermediateIds = new ArrayList<>();
final List<UUID> executionIds = new ArrayList<>();
final Consumer<CounterfactualResult> captureIntermediateIds = counterfactual -> {
intermediateIds.add(counterfactual.getSolutionId());
};
final Consumer<CounterfactualResult> captureExecutionIds = counterfactual -> {
executionIds.add(counterfactual.getExecutionId());
};
final CounterfactualConfig counterfactualConfig = new CounterfactualConfig().withSolverConfig(solverConfig);
solverConfig.withEasyScoreCalculatorClass(MockCounterFactualScoreCalculator.class);
final CounterfactualExplainer counterfactualExplainer = new CounterfactualExplainer(counterfactualConfig);
PredictionInput input = new PredictionInput(features);
PredictionOutput output = new PredictionOutput(goal);
final UUID executionId = UUID.randomUUID();
Prediction prediction = new CounterfactualPrediction(input, output, null, executionId, null);
final CounterfactualResult counterfactualResult = counterfactualExplainer.explainAsync(prediction, model, captureIntermediateIds.andThen(captureExecutionIds)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
for (CounterfactualEntity entity : counterfactualResult.getEntities()) {
logger.debug("Entity: {}", entity);
}
// All intermediate ids should be unique
assertEquals((int) intermediateIds.stream().distinct().count(), intermediateIds.size());
// There should be at least one intermediate id
assertTrue(intermediateIds.size() > 0);
// There should be at least one execution id
assertTrue(executionIds.size() > 0);
// We should have the same number of execution ids as intermediate ids (captured from intermediate results)
assertEquals(executionIds.size(), intermediateIds.size());
// All execution ids should be the same
assertEquals(1, (int) executionIds.stream().distinct().count());
// The last intermediate id must be different from the final result id
assertNotEquals(intermediateIds.get(intermediateIds.size() - 1), counterfactualResult.getSolutionId());
// Captured execution ids should be the same as the one provided
assertEquals(executionIds.get(0), executionId);
}
use of org.optaplanner.core.config.solver.termination.TerminationConfig in project kogito-apps by kiegroup.
the class CounterfactualExplainerTest method testIntermediateUniqueIds.
@ParameterizedTest
@ValueSource(ints = { 0, 1, 2 })
void testIntermediateUniqueIds(int seed) throws ExecutionException, InterruptedException, TimeoutException {
Random random = new Random();
random.setSeed(seed);
final List<Output> goal = new ArrayList<>();
List<Feature> features = List.of(FeatureFactory.newNumericalFeature("f-num1", 10.0, NumericalFeatureDomain.create(0, 20)));
PredictionProvider model = TestUtils.getFeaturePassModel(0);
final TerminationConfig terminationConfig = new TerminationConfig().withScoreCalculationCountLimit(100_000L);
final SolverConfig solverConfig = SolverConfigBuilder.builder().withTerminationConfig(terminationConfig).build();
solverConfig.setRandomSeed((long) seed);
solverConfig.setEnvironmentMode(EnvironmentMode.REPRODUCIBLE);
final List<UUID> intermediateIds = new ArrayList<>();
final List<UUID> executionIds = new ArrayList<>();
final Consumer<CounterfactualResult> captureIntermediateIds = counterfactual -> {
intermediateIds.add(counterfactual.getSolutionId());
};
final Consumer<CounterfactualResult> captureExecutionIds = counterfactual -> {
executionIds.add(counterfactual.getExecutionId());
};
final CounterfactualConfig counterfactualConfig = new CounterfactualConfig().withSolverConfig(solverConfig);
solverConfig.withEasyScoreCalculatorClass(MockCounterFactualScoreCalculator.class);
final CounterfactualExplainer counterfactualExplainer = new CounterfactualExplainer(counterfactualConfig);
PredictionInput input = new PredictionInput(features);
PredictionOutput output = new PredictionOutput(goal);
final UUID executionId = UUID.randomUUID();
Prediction prediction = new CounterfactualPrediction(input, output, null, executionId, null);
final CounterfactualResult counterfactualResult = counterfactualExplainer.explainAsync(prediction, model, captureIntermediateIds.andThen(captureExecutionIds)).get(Config.INSTANCE.getAsyncTimeout(), Config.INSTANCE.getAsyncTimeUnit());
for (CounterfactualEntity entity : counterfactualResult.getEntities()) {
logger.debug("Entity: {}", entity);
}
// all intermediate Ids must be distinct
assertEquals((int) intermediateIds.stream().distinct().count(), intermediateIds.size());
assertEquals(1, (int) executionIds.stream().distinct().count());
assertEquals(executionIds.get(0), executionId);
}
use of org.optaplanner.core.config.solver.termination.TerminationConfig in project kogito-apps by kiegroup.
the class CounterfactualExplainerTest method runCounterfactualSearch.
private CounterfactualResult runCounterfactualSearch(Long randomSeed, List<Output> goal, List<Feature> features, PredictionProvider model, double goalThresold) throws InterruptedException, ExecutionException, TimeoutException {
final TerminationConfig terminationConfig = new TerminationConfig().withScoreCalculationCountLimit(steps);
final SolverConfig solverConfig = SolverConfigBuilder.builder().withTerminationConfig(terminationConfig).build();
solverConfig.setRandomSeed(randomSeed);
solverConfig.setEnvironmentMode(EnvironmentMode.REPRODUCIBLE);
final CounterfactualConfig counterfactualConfig = new CounterfactualConfig();
counterfactualConfig.withSolverConfig(solverConfig).withGoalThreshold(goalThresold);
final CounterfactualExplainer explainer = new CounterfactualExplainer(counterfactualConfig);
final PredictionInput input = new PredictionInput(features);
PredictionOutput output = new PredictionOutput(goal);
Prediction prediction = new CounterfactualPrediction(input, output, null, UUID.randomUUID(), null);
return explainer.explainAsync(prediction, model).get(predictionTimeOut, predictionTimeUnit);
}
use of org.optaplanner.core.config.solver.termination.TerminationConfig in project kogito-apps by kiegroup.
the class LimeConfigOptimizer method optimize.
public LimeConfig optimize(LimeConfig config, List<Prediction> predictions, PredictionProvider model) {
List<LimeConfigEntity> entities = new ArrayList<>();
if (samplingEntities) {
entities.addAll(LimeConfigEntityFactory.createSamplingEntities(config));
}
if (proximityEntities) {
entities.addAll(LimeConfigEntityFactory.createProximityEntities(config));
}
if (encodingEntities) {
entities.addAll(LimeConfigEntityFactory.createEncodingEntities(config));
}
if (weightingEntities) {
entities.addAll(LimeConfigEntityFactory.createWeightingEntities(config));
}
if (entities.isEmpty()) {
return config;
}
LimeConfigSolution initialSolution = new LimeConfigSolution(config, predictions, entities, model);
SolverConfig solverConfig = new SolverConfig().withEntityClasses(NumericLimeConfigEntity.class, BooleanLimeConfigEntity.class).withSolutionClass(LimeConfigSolution.class);
ScoreDirectorFactoryConfig scoreDirectorFactoryConfig = new ScoreDirectorFactoryConfig();
scoreDirectorFactoryConfig.setEasyScoreCalculatorClass(scoreCalculator.getClass());
solverConfig.setScoreDirectorFactoryConfig(scoreDirectorFactoryConfig);
TerminationConfig terminationConfig = new TerminationConfig();
if (timeLimit > 0) {
terminationConfig.setSecondsSpentLimit(timeLimit);
}
solverConfig.setTerminationConfig(terminationConfig);
LocalSearchPhaseConfig localSearchPhaseConfig = new LocalSearchPhaseConfig();
if (deterministic) {
Optional<Long> seed = config.getPerturbationContext().getSeed();
seed.ifPresent(solverConfig::setRandomSeed);
solverConfig.setEnvironmentMode(EnvironmentMode.REPRODUCIBLE);
} else {
logger.debug("non reproducible execution, set the seed inside initial LimeConfig's PerturbationContext and enable deterministic execution to fix this");
}
localSearchPhaseConfig.setLocalSearchType(LocalSearchType.LATE_ACCEPTANCE);
if (stepCountLimit > 0) {
localSearchPhaseConfig.setTerminationConfig(new TerminationConfig().withStepCountLimit(stepCountLimit).withBestScoreLimit("1.0"));
}
@SuppressWarnings("rawtypes") List<PhaseConfig> phaseConfigs = new ArrayList<>();
phaseConfigs.add(localSearchPhaseConfig);
solverConfig.setPhaseConfigList(phaseConfigs);
try (SolverManager<LimeConfigSolution, UUID> solverManager = SolverManager.create(solverConfig, new SolverManagerConfig())) {
UUID executionId = UUID.randomUUID();
SolverJob<LimeConfigSolution, UUID> solverJob = solverManager.solve(executionId, initialSolution);
try {
// Wait until the solving ends
LimeConfigSolution finalBestSolution = solverJob.getFinalBestSolution();
LimeConfig finalConfig = LimeConfigEntityFactory.toLimeConfig(finalBestSolution);
BigDecimal score = finalBestSolution.getScore().getScore();
logger.info("final best solution score {} with config {}", score, finalConfig);
return finalConfig;
} catch (ExecutionException e) {
logger.error("Solving failed: {}", e.getMessage());
throw new IllegalStateException("Prediction returned an error", e);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new IllegalStateException("Solving failed (Thread interrupted)", e);
}
}
}
use of org.optaplanner.core.config.solver.termination.TerminationConfig in project kogito-apps by kiegroup.
the class AbstractTaskAssigningCoreTest method createNonDaemonSolver.
protected Solver<TaskAssigningSolution> createNonDaemonSolver(int stepCountLimit) {
SolverConfig config = createBaseConfig();
ConstructionHeuristicPhaseConfig constructionHeuristicPhaseConfig = new ConstructionHeuristicPhaseConfig();
constructionHeuristicPhaseConfig.setConstructionHeuristicType(ConstructionHeuristicType.FIRST_FIT);
LocalSearchPhaseConfig phaseConfig = new LocalSearchPhaseConfig();
phaseConfig.setTerminationConfig(new TerminationConfig().withStepCountLimit(stepCountLimit));
config.setPhaseConfigList(Arrays.asList(constructionHeuristicPhaseConfig, phaseConfig));
SolverFactory<TaskAssigningSolution> solverFactory = SolverFactory.create(config);
return solverFactory.buildSolver();
}
Aggregations