use of org.apache.ignite.ml.selection.paramgrid.ParameterSetGenerator in project ignite by apache.
the class AbstractCrossValidation method scoreBruteForceHyperparameterOptimization.
/**
* Finds the best set of hyperparameters based on brute force approach .
*/
private CrossValidationResult scoreBruteForceHyperparameterOptimization() {
List<Double[]> paramSets = new ParameterSetGenerator(paramGrid.getParamValuesByParamIdx()).generate();
CrossValidationResult cvRes = new CrossValidationResult();
List<IgniteSupplier<TaskResult>> tasks = paramSets.stream().map(paramSet -> (IgniteSupplier<TaskResult>) (() -> calculateScoresForFixedParamSet(paramSet))).collect(Collectors.toList());
List<TaskResult> taskResults = environment.parallelismStrategy().submit(tasks).stream().map(Promise::unsafeGet).collect(Collectors.toList());
taskResults.forEach(tr -> cvRes.addScores(tr.locScores, tr.paramMap));
return cvRes;
}
use of org.apache.ignite.ml.selection.paramgrid.ParameterSetGenerator in project ignite by apache.
the class AbstractCrossValidation method scoreRandomSearchHyperparameterOptimization.
/**
* Finds the best set of hyperparameters based on Random Serach.
*/
private CrossValidationResult scoreRandomSearchHyperparameterOptimization() {
RandomStrategy stgy = (RandomStrategy) paramGrid.getHyperParameterTuningStrategy();
List<Double[]> paramSets = new ParameterSetGenerator(paramGrid.getParamValuesByParamIdx()).generate();
List<Double[]> paramSetsCp = new ArrayList<>(paramSets);
Collections.shuffle(paramSetsCp, new Random(stgy.getSeed()));
CrossValidationResult cvRes = new CrossValidationResult();
List<Double[]> rndParamSets = paramSetsCp.subList(0, stgy.getMaxTries());
List<IgniteSupplier<TaskResult>> tasks = rndParamSets.stream().map(paramSet -> (IgniteSupplier<TaskResult>) (() -> calculateScoresForFixedParamSet(paramSet))).collect(Collectors.toList());
List<TaskResult> taskResults = environment.parallelismStrategy().submit(tasks).stream().map(Promise::unsafeGet).collect(Collectors.toList());
taskResults.forEach(tr -> cvRes.addScores(tr.locScores, tr.paramMap));
return cvRes;
}
use of org.apache.ignite.ml.selection.paramgrid.ParameterSetGenerator in project ignite by apache.
the class AbstractCrossValidation method scoreEvolutionAlgorithmSearchHyperparameterOptimization.
/**
* Finds the best set of hyper-parameters based on Genetic Programming approach.
*/
private CrossValidationResult scoreEvolutionAlgorithmSearchHyperparameterOptimization() {
EvolutionOptimizationStrategy stgy = (EvolutionOptimizationStrategy) paramGrid.getHyperParameterTuningStrategy();
List<Double[]> paramSets = new ParameterSetGenerator(paramGrid.getParamValuesByParamIdx()).generate();
// initialization
List<Double[]> paramSetsCp = new ArrayList<>(paramSets);
Collections.shuffle(paramSetsCp, new Random(stgy.getSeed()));
int sizeOfPopulation = 20;
List<Double[]> rndParamSets = paramSetsCp.subList(0, sizeOfPopulation);
CrossValidationResult cvRes = new CrossValidationResult();
Function<Chromosome, Double> fitnessFunction = (Chromosome chromosome) -> {
TaskResult tr = calculateScoresForFixedParamSet(chromosome.toDoubleArray());
cvRes.addScores(tr.locScores, tr.paramMap);
return Arrays.stream(tr.locScores).average().orElse(Double.MIN_VALUE);
};
// TODO: common seed for shared lambdas can produce the same value on each function call? or sequent?
Random rnd = new Random(stgy.getSeed());
BiFunction<Integer, Double, Double> mutator = (Integer geneIdx, Double geneValue) -> {
Double newGeneVal;
Double[] possibleGeneValues = paramGrid.getParamRawData().get(geneIdx);
newGeneVal = possibleGeneValues[rnd.nextInt(possibleGeneValues.length)];
return newGeneVal;
};
GeneticAlgorithm ga = new GeneticAlgorithm(rndParamSets);
ga.withFitnessFunction(fitnessFunction).withMutationOperator(mutator).withAmountOfEliteChromosomes(stgy.getNumberOfEliteChromosomes()).withCrossingoverProbability(stgy.getCrossingoverProbability()).withCrossoverStgy(stgy.getCrossoverStgy()).withAmountOfGenerations(stgy.getNumberOfGenerations()).withSelectionStgy(stgy.getSelectionStgy()).withMutationProbability(stgy.getMutationProbability());
if (environment.parallelismStrategy().getParallelism() > 1)
ga.runParallel(environment);
else
ga.run();
return cvRes;
}
Aggregations