Search in sources :

Example 1 with ParameterSetGenerator

use of org.apache.ignite.ml.selection.paramgrid.ParameterSetGenerator in project ignite by apache.

the class AbstractCrossValidation method scoreBruteForceHyperparameterOptimization.

/**
 * Finds the best set of hyperparameters based on brute force approach .
 */
private CrossValidationResult scoreBruteForceHyperparameterOptimization() {
    List<Double[]> paramSets = new ParameterSetGenerator(paramGrid.getParamValuesByParamIdx()).generate();
    CrossValidationResult cvRes = new CrossValidationResult();
    List<IgniteSupplier<TaskResult>> tasks = paramSets.stream().map(paramSet -> (IgniteSupplier<TaskResult>) (() -> calculateScoresForFixedParamSet(paramSet))).collect(Collectors.toList());
    List<TaskResult> taskResults = environment.parallelismStrategy().submit(tasks).stream().map(Promise::unsafeGet).collect(Collectors.toList());
    taskResults.forEach(tr -> cvRes.addScores(tr.locScores, tr.paramMap));
    return cvRes;
}
Also used : Metric(org.apache.ignite.ml.selection.scoring.metric.Metric) SHA256UniformMapper(org.apache.ignite.ml.selection.split.mapper.SHA256UniformMapper) Arrays(java.util.Arrays) IgniteBiPredicate(org.apache.ignite.lang.IgniteBiPredicate) GeneticAlgorithm(org.apache.ignite.ml.util.genetic.GeneticAlgorithm) PipelineMdl(org.apache.ignite.ml.pipeline.PipelineMdl) Evaluator(org.apache.ignite.ml.selection.scoring.evaluator.Evaluator) BiFunction(java.util.function.BiFunction) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) Preprocessor(org.apache.ignite.ml.preprocessing.Preprocessor) HashMap(java.util.HashMap) Random(java.util.Random) Function(java.util.function.Function) DatasetTrainer(org.apache.ignite.ml.trainers.DatasetTrainer) ArrayList(java.util.ArrayList) HyperParameterTuningStrategy(org.apache.ignite.ml.selection.paramgrid.HyperParameterTuningStrategy) LearningEnvironment(org.apache.ignite.ml.environment.LearningEnvironment) EvolutionOptimizationStrategy(org.apache.ignite.ml.selection.paramgrid.EvolutionOptimizationStrategy) MetricName(org.apache.ignite.ml.selection.scoring.metric.MetricName) Map(java.util.Map) ParamGrid(org.apache.ignite.ml.selection.paramgrid.ParamGrid) BruteForceStrategy(org.apache.ignite.ml.selection.paramgrid.BruteForceStrategy) LearningEnvironmentBuilder(org.apache.ignite.ml.environment.LearningEnvironmentBuilder) IgniteSupplier(org.apache.ignite.ml.math.functions.IgniteSupplier) IgniteModel(org.apache.ignite.ml.IgniteModel) DatasetBuilder(org.apache.ignite.ml.dataset.DatasetBuilder) Collectors(java.util.stream.Collectors) Serializable(java.io.Serializable) ParameterSetGenerator(org.apache.ignite.ml.selection.paramgrid.ParameterSetGenerator) List(java.util.List) Promise(org.apache.ignite.ml.environment.parallelism.Promise) Chromosome(org.apache.ignite.ml.util.genetic.Chromosome) RandomStrategy(org.apache.ignite.ml.selection.paramgrid.RandomStrategy) UniformMapper(org.apache.ignite.ml.selection.split.mapper.UniformMapper) NotNull(org.jetbrains.annotations.NotNull) Collections(java.util.Collections) Pipeline(org.apache.ignite.ml.pipeline.Pipeline) IgniteDoubleConsumer(org.apache.ignite.ml.math.functions.IgniteDoubleConsumer) IgniteSupplier(org.apache.ignite.ml.math.functions.IgniteSupplier) ParameterSetGenerator(org.apache.ignite.ml.selection.paramgrid.ParameterSetGenerator)

Example 2 with ParameterSetGenerator

use of org.apache.ignite.ml.selection.paramgrid.ParameterSetGenerator in project ignite by apache.

the class AbstractCrossValidation method scoreRandomSearchHyperparameterOptimization.

/**
 * Finds the best set of hyperparameters based on Random Serach.
 */
private CrossValidationResult scoreRandomSearchHyperparameterOptimization() {
    RandomStrategy stgy = (RandomStrategy) paramGrid.getHyperParameterTuningStrategy();
    List<Double[]> paramSets = new ParameterSetGenerator(paramGrid.getParamValuesByParamIdx()).generate();
    List<Double[]> paramSetsCp = new ArrayList<>(paramSets);
    Collections.shuffle(paramSetsCp, new Random(stgy.getSeed()));
    CrossValidationResult cvRes = new CrossValidationResult();
    List<Double[]> rndParamSets = paramSetsCp.subList(0, stgy.getMaxTries());
    List<IgniteSupplier<TaskResult>> tasks = rndParamSets.stream().map(paramSet -> (IgniteSupplier<TaskResult>) (() -> calculateScoresForFixedParamSet(paramSet))).collect(Collectors.toList());
    List<TaskResult> taskResults = environment.parallelismStrategy().submit(tasks).stream().map(Promise::unsafeGet).collect(Collectors.toList());
    taskResults.forEach(tr -> cvRes.addScores(tr.locScores, tr.paramMap));
    return cvRes;
}
Also used : Metric(org.apache.ignite.ml.selection.scoring.metric.Metric) SHA256UniformMapper(org.apache.ignite.ml.selection.split.mapper.SHA256UniformMapper) Arrays(java.util.Arrays) IgniteBiPredicate(org.apache.ignite.lang.IgniteBiPredicate) GeneticAlgorithm(org.apache.ignite.ml.util.genetic.GeneticAlgorithm) PipelineMdl(org.apache.ignite.ml.pipeline.PipelineMdl) Evaluator(org.apache.ignite.ml.selection.scoring.evaluator.Evaluator) BiFunction(java.util.function.BiFunction) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) Preprocessor(org.apache.ignite.ml.preprocessing.Preprocessor) HashMap(java.util.HashMap) Random(java.util.Random) Function(java.util.function.Function) DatasetTrainer(org.apache.ignite.ml.trainers.DatasetTrainer) ArrayList(java.util.ArrayList) HyperParameterTuningStrategy(org.apache.ignite.ml.selection.paramgrid.HyperParameterTuningStrategy) LearningEnvironment(org.apache.ignite.ml.environment.LearningEnvironment) EvolutionOptimizationStrategy(org.apache.ignite.ml.selection.paramgrid.EvolutionOptimizationStrategy) MetricName(org.apache.ignite.ml.selection.scoring.metric.MetricName) Map(java.util.Map) ParamGrid(org.apache.ignite.ml.selection.paramgrid.ParamGrid) BruteForceStrategy(org.apache.ignite.ml.selection.paramgrid.BruteForceStrategy) LearningEnvironmentBuilder(org.apache.ignite.ml.environment.LearningEnvironmentBuilder) IgniteSupplier(org.apache.ignite.ml.math.functions.IgniteSupplier) IgniteModel(org.apache.ignite.ml.IgniteModel) DatasetBuilder(org.apache.ignite.ml.dataset.DatasetBuilder) Collectors(java.util.stream.Collectors) Serializable(java.io.Serializable) ParameterSetGenerator(org.apache.ignite.ml.selection.paramgrid.ParameterSetGenerator) List(java.util.List) Promise(org.apache.ignite.ml.environment.parallelism.Promise) Chromosome(org.apache.ignite.ml.util.genetic.Chromosome) RandomStrategy(org.apache.ignite.ml.selection.paramgrid.RandomStrategy) UniformMapper(org.apache.ignite.ml.selection.split.mapper.UniformMapper) NotNull(org.jetbrains.annotations.NotNull) Collections(java.util.Collections) Pipeline(org.apache.ignite.ml.pipeline.Pipeline) IgniteDoubleConsumer(org.apache.ignite.ml.math.functions.IgniteDoubleConsumer) IgniteSupplier(org.apache.ignite.ml.math.functions.IgniteSupplier) Random(java.util.Random) RandomStrategy(org.apache.ignite.ml.selection.paramgrid.RandomStrategy) ArrayList(java.util.ArrayList) ParameterSetGenerator(org.apache.ignite.ml.selection.paramgrid.ParameterSetGenerator)

Example 3 with ParameterSetGenerator

use of org.apache.ignite.ml.selection.paramgrid.ParameterSetGenerator in project ignite by apache.

the class AbstractCrossValidation method scoreEvolutionAlgorithmSearchHyperparameterOptimization.

/**
 * Finds the best set of hyper-parameters based on Genetic Programming approach.
 */
private CrossValidationResult scoreEvolutionAlgorithmSearchHyperparameterOptimization() {
    EvolutionOptimizationStrategy stgy = (EvolutionOptimizationStrategy) paramGrid.getHyperParameterTuningStrategy();
    List<Double[]> paramSets = new ParameterSetGenerator(paramGrid.getParamValuesByParamIdx()).generate();
    // initialization
    List<Double[]> paramSetsCp = new ArrayList<>(paramSets);
    Collections.shuffle(paramSetsCp, new Random(stgy.getSeed()));
    int sizeOfPopulation = 20;
    List<Double[]> rndParamSets = paramSetsCp.subList(0, sizeOfPopulation);
    CrossValidationResult cvRes = new CrossValidationResult();
    Function<Chromosome, Double> fitnessFunction = (Chromosome chromosome) -> {
        TaskResult tr = calculateScoresForFixedParamSet(chromosome.toDoubleArray());
        cvRes.addScores(tr.locScores, tr.paramMap);
        return Arrays.stream(tr.locScores).average().orElse(Double.MIN_VALUE);
    };
    // TODO: common seed for shared lambdas can produce the same value on each function call? or sequent?
    Random rnd = new Random(stgy.getSeed());
    BiFunction<Integer, Double, Double> mutator = (Integer geneIdx, Double geneValue) -> {
        Double newGeneVal;
        Double[] possibleGeneValues = paramGrid.getParamRawData().get(geneIdx);
        newGeneVal = possibleGeneValues[rnd.nextInt(possibleGeneValues.length)];
        return newGeneVal;
    };
    GeneticAlgorithm ga = new GeneticAlgorithm(rndParamSets);
    ga.withFitnessFunction(fitnessFunction).withMutationOperator(mutator).withAmountOfEliteChromosomes(stgy.getNumberOfEliteChromosomes()).withCrossingoverProbability(stgy.getCrossingoverProbability()).withCrossoverStgy(stgy.getCrossoverStgy()).withAmountOfGenerations(stgy.getNumberOfGenerations()).withSelectionStgy(stgy.getSelectionStgy()).withMutationProbability(stgy.getMutationProbability());
    if (environment.parallelismStrategy().getParallelism() > 1)
        ga.runParallel(environment);
    else
        ga.run();
    return cvRes;
}
Also used : ArrayList(java.util.ArrayList) Chromosome(org.apache.ignite.ml.util.genetic.Chromosome) ParameterSetGenerator(org.apache.ignite.ml.selection.paramgrid.ParameterSetGenerator) EvolutionOptimizationStrategy(org.apache.ignite.ml.selection.paramgrid.EvolutionOptimizationStrategy) GeneticAlgorithm(org.apache.ignite.ml.util.genetic.GeneticAlgorithm) Random(java.util.Random)

Aggregations

ArrayList (java.util.ArrayList)3 Random (java.util.Random)3 EvolutionOptimizationStrategy (org.apache.ignite.ml.selection.paramgrid.EvolutionOptimizationStrategy)3 ParameterSetGenerator (org.apache.ignite.ml.selection.paramgrid.ParameterSetGenerator)3 Chromosome (org.apache.ignite.ml.util.genetic.Chromosome)3 GeneticAlgorithm (org.apache.ignite.ml.util.genetic.GeneticAlgorithm)3 Serializable (java.io.Serializable)2 Arrays (java.util.Arrays)2 Collections (java.util.Collections)2 HashMap (java.util.HashMap)2 List (java.util.List)2 Map (java.util.Map)2 BiFunction (java.util.function.BiFunction)2 Function (java.util.function.Function)2 Collectors (java.util.stream.Collectors)2 IgniteBiPredicate (org.apache.ignite.lang.IgniteBiPredicate)2 IgniteModel (org.apache.ignite.ml.IgniteModel)2 DatasetBuilder (org.apache.ignite.ml.dataset.DatasetBuilder)2 LearningEnvironment (org.apache.ignite.ml.environment.LearningEnvironment)2 LearningEnvironmentBuilder (org.apache.ignite.ml.environment.LearningEnvironmentBuilder)2