use of org.apache.ignite.ml.util.genetic.Chromosome in project ignite by apache.
the class AbstractCrossValidation method scoreEvolutionAlgorithmSearchHyperparameterOptimization.
/**
* Finds the best set of hyper-parameters based on Genetic Programming approach.
*/
private CrossValidationResult scoreEvolutionAlgorithmSearchHyperparameterOptimization() {
EvolutionOptimizationStrategy stgy = (EvolutionOptimizationStrategy) paramGrid.getHyperParameterTuningStrategy();
List<Double[]> paramSets = new ParameterSetGenerator(paramGrid.getParamValuesByParamIdx()).generate();
// initialization
List<Double[]> paramSetsCp = new ArrayList<>(paramSets);
Collections.shuffle(paramSetsCp, new Random(stgy.getSeed()));
int sizeOfPopulation = 20;
List<Double[]> rndParamSets = paramSetsCp.subList(0, sizeOfPopulation);
CrossValidationResult cvRes = new CrossValidationResult();
Function<Chromosome, Double> fitnessFunction = (Chromosome chromosome) -> {
TaskResult tr = calculateScoresForFixedParamSet(chromosome.toDoubleArray());
cvRes.addScores(tr.locScores, tr.paramMap);
return Arrays.stream(tr.locScores).average().orElse(Double.MIN_VALUE);
};
// TODO: common seed for shared lambdas can produce the same value on each function call? or sequent?
Random rnd = new Random(stgy.getSeed());
BiFunction<Integer, Double, Double> mutator = (Integer geneIdx, Double geneValue) -> {
Double newGeneVal;
Double[] possibleGeneValues = paramGrid.getParamRawData().get(geneIdx);
newGeneVal = possibleGeneValues[rnd.nextInt(possibleGeneValues.length)];
return newGeneVal;
};
GeneticAlgorithm ga = new GeneticAlgorithm(rndParamSets);
ga.withFitnessFunction(fitnessFunction).withMutationOperator(mutator).withAmountOfEliteChromosomes(stgy.getNumberOfEliteChromosomes()).withCrossingoverProbability(stgy.getCrossingoverProbability()).withCrossoverStgy(stgy.getCrossoverStgy()).withAmountOfGenerations(stgy.getNumberOfGenerations()).withSelectionStgy(stgy.getSelectionStgy()).withMutationProbability(stgy.getMutationProbability());
if (environment.parallelismStrategy().getParallelism() > 1)
ga.runParallel(environment);
else
ga.run();
return cvRes;
}
Aggregations