Search in sources :

Example 1 with Promise

use of org.apache.ignite.ml.environment.parallelism.Promise in project ignite by apache.

the class TrainerTransformers method runOnEnsemble.

/**
 * This method accepts function which for given dataset builder and index of model in ensemble generates
 * task of training this model.
 *
 * @param trainingTaskGenerator Training test generator.
 * @param datasetBuilder Dataset builder.
 * @param ensembleSize Size of ensemble.
 * @param subsampleRatio Ratio (subsample size) / (initial dataset size).
 * @param featuresVectorSize Dimensionality of feature vector.
 * @param featureSubspaceDim Dimensionality of feature subspace.
 * @param aggregator Aggregator of models.
 * @param environment Environment.
 * @param <K> Type of keys in dataset builder.
 * @param <V> Type of values in dataset builder.
 * @param <M> Type of model.
 * @return Composition of models trained on bagged dataset.
 */
private static <K, V, M extends IgniteModel<Vector, Double>> ModelsComposition runOnEnsemble(IgniteTriFunction<DatasetBuilder<K, V>, Integer, IgniteBiFunction<K, V, Vector>, IgniteSupplier<M>> trainingTaskGenerator, DatasetBuilder<K, V> datasetBuilder, int ensembleSize, double subsampleRatio, int featuresVectorSize, int featureSubspaceDim, IgniteBiFunction<K, V, Vector> extractor, PredictionsAggregator aggregator, LearningEnvironment environment) {
    MLLogger log = environment.logger(datasetBuilder.getClass());
    log.log(MLLogger.VerboseLevel.LOW, "Start learning.");
    List<int[]> mappings = null;
    if (featuresVectorSize > 0 && featureSubspaceDim != featuresVectorSize) {
        mappings = IntStream.range(0, ensembleSize).mapToObj(modelIdx -> getMapping(featuresVectorSize, featureSubspaceDim, environment.randomNumbersGenerator().nextLong() + modelIdx)).collect(Collectors.toList());
    }
    Long startTs = System.currentTimeMillis();
    List<IgniteSupplier<M>> tasks = new ArrayList<>();
    List<IgniteBiFunction<K, V, Vector>> extractors = new ArrayList<>();
    if (mappings != null) {
        for (int[] mapping : mappings) extractors.add(wrapExtractor(extractor, mapping));
    }
    for (int i = 0; i < ensembleSize; i++) {
        DatasetBuilder<K, V> newBuilder = datasetBuilder.withUpstreamTransformer(BaggingUpstreamTransformer.builder(subsampleRatio, i));
        tasks.add(trainingTaskGenerator.apply(newBuilder, i, mappings != null ? extractors.get(i) : extractor));
    }
    List<ModelWithMapping<Vector, Double, M>> models = environment.parallelismStrategy().submit(tasks).stream().map(Promise::unsafeGet).map(ModelWithMapping<Vector, Double, M>::new).collect(Collectors.toList());
    // If we need to do projection, do it.
    if (mappings != null) {
        for (int i = 0; i < models.size(); i++) models.get(i).setMapping(VectorUtils.getProjector(mappings.get(i)));
    }
    double learningTime = (double) (System.currentTimeMillis() - startTs) / 1000.0;
    log.log(MLLogger.VerboseLevel.LOW, "The training time was %.2fs.", learningTime);
    log.log(MLLogger.VerboseLevel.LOW, "Learning finished.");
    return new ModelsComposition(models, aggregator);
}
Also used : IgniteSupplier(org.apache.ignite.ml.math.functions.IgniteSupplier) IgniteBiFunction(org.apache.ignite.ml.math.functions.IgniteBiFunction) ArrayList(java.util.ArrayList) ModelsComposition(org.apache.ignite.ml.composition.ModelsComposition) Promise(org.apache.ignite.ml.environment.parallelism.Promise) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) MLLogger(org.apache.ignite.ml.environment.logging.MLLogger)

Example 2 with Promise

use of org.apache.ignite.ml.environment.parallelism.Promise in project ignite by apache.

the class AbstractCrossValidation method scoreBruteForceHyperparameterOptimization.

/**
 * Finds the best set of hyperparameters based on brute force approach .
 */
private CrossValidationResult scoreBruteForceHyperparameterOptimization() {
    List<Double[]> paramSets = new ParameterSetGenerator(paramGrid.getParamValuesByParamIdx()).generate();
    CrossValidationResult cvRes = new CrossValidationResult();
    List<IgniteSupplier<TaskResult>> tasks = paramSets.stream().map(paramSet -> (IgniteSupplier<TaskResult>) (() -> calculateScoresForFixedParamSet(paramSet))).collect(Collectors.toList());
    List<TaskResult> taskResults = environment.parallelismStrategy().submit(tasks).stream().map(Promise::unsafeGet).collect(Collectors.toList());
    taskResults.forEach(tr -> cvRes.addScores(tr.locScores, tr.paramMap));
    return cvRes;
}
Also used : Metric(org.apache.ignite.ml.selection.scoring.metric.Metric) SHA256UniformMapper(org.apache.ignite.ml.selection.split.mapper.SHA256UniformMapper) Arrays(java.util.Arrays) IgniteBiPredicate(org.apache.ignite.lang.IgniteBiPredicate) GeneticAlgorithm(org.apache.ignite.ml.util.genetic.GeneticAlgorithm) PipelineMdl(org.apache.ignite.ml.pipeline.PipelineMdl) Evaluator(org.apache.ignite.ml.selection.scoring.evaluator.Evaluator) BiFunction(java.util.function.BiFunction) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) Preprocessor(org.apache.ignite.ml.preprocessing.Preprocessor) HashMap(java.util.HashMap) Random(java.util.Random) Function(java.util.function.Function) DatasetTrainer(org.apache.ignite.ml.trainers.DatasetTrainer) ArrayList(java.util.ArrayList) HyperParameterTuningStrategy(org.apache.ignite.ml.selection.paramgrid.HyperParameterTuningStrategy) LearningEnvironment(org.apache.ignite.ml.environment.LearningEnvironment) EvolutionOptimizationStrategy(org.apache.ignite.ml.selection.paramgrid.EvolutionOptimizationStrategy) MetricName(org.apache.ignite.ml.selection.scoring.metric.MetricName) Map(java.util.Map) ParamGrid(org.apache.ignite.ml.selection.paramgrid.ParamGrid) BruteForceStrategy(org.apache.ignite.ml.selection.paramgrid.BruteForceStrategy) LearningEnvironmentBuilder(org.apache.ignite.ml.environment.LearningEnvironmentBuilder) IgniteSupplier(org.apache.ignite.ml.math.functions.IgniteSupplier) IgniteModel(org.apache.ignite.ml.IgniteModel) DatasetBuilder(org.apache.ignite.ml.dataset.DatasetBuilder) Collectors(java.util.stream.Collectors) Serializable(java.io.Serializable) ParameterSetGenerator(org.apache.ignite.ml.selection.paramgrid.ParameterSetGenerator) List(java.util.List) Promise(org.apache.ignite.ml.environment.parallelism.Promise) Chromosome(org.apache.ignite.ml.util.genetic.Chromosome) RandomStrategy(org.apache.ignite.ml.selection.paramgrid.RandomStrategy) UniformMapper(org.apache.ignite.ml.selection.split.mapper.UniformMapper) NotNull(org.jetbrains.annotations.NotNull) Collections(java.util.Collections) Pipeline(org.apache.ignite.ml.pipeline.Pipeline) IgniteDoubleConsumer(org.apache.ignite.ml.math.functions.IgniteDoubleConsumer) IgniteSupplier(org.apache.ignite.ml.math.functions.IgniteSupplier) ParameterSetGenerator(org.apache.ignite.ml.selection.paramgrid.ParameterSetGenerator)

Example 3 with Promise

use of org.apache.ignite.ml.environment.parallelism.Promise in project ignite by apache.

the class AbstractCrossValidation method scoreRandomSearchHyperparameterOptimization.

/**
 * Finds the best set of hyperparameters based on Random Serach.
 */
private CrossValidationResult scoreRandomSearchHyperparameterOptimization() {
    RandomStrategy stgy = (RandomStrategy) paramGrid.getHyperParameterTuningStrategy();
    List<Double[]> paramSets = new ParameterSetGenerator(paramGrid.getParamValuesByParamIdx()).generate();
    List<Double[]> paramSetsCp = new ArrayList<>(paramSets);
    Collections.shuffle(paramSetsCp, new Random(stgy.getSeed()));
    CrossValidationResult cvRes = new CrossValidationResult();
    List<Double[]> rndParamSets = paramSetsCp.subList(0, stgy.getMaxTries());
    List<IgniteSupplier<TaskResult>> tasks = rndParamSets.stream().map(paramSet -> (IgniteSupplier<TaskResult>) (() -> calculateScoresForFixedParamSet(paramSet))).collect(Collectors.toList());
    List<TaskResult> taskResults = environment.parallelismStrategy().submit(tasks).stream().map(Promise::unsafeGet).collect(Collectors.toList());
    taskResults.forEach(tr -> cvRes.addScores(tr.locScores, tr.paramMap));
    return cvRes;
}
Also used : Metric(org.apache.ignite.ml.selection.scoring.metric.Metric) SHA256UniformMapper(org.apache.ignite.ml.selection.split.mapper.SHA256UniformMapper) Arrays(java.util.Arrays) IgniteBiPredicate(org.apache.ignite.lang.IgniteBiPredicate) GeneticAlgorithm(org.apache.ignite.ml.util.genetic.GeneticAlgorithm) PipelineMdl(org.apache.ignite.ml.pipeline.PipelineMdl) Evaluator(org.apache.ignite.ml.selection.scoring.evaluator.Evaluator) BiFunction(java.util.function.BiFunction) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) Preprocessor(org.apache.ignite.ml.preprocessing.Preprocessor) HashMap(java.util.HashMap) Random(java.util.Random) Function(java.util.function.Function) DatasetTrainer(org.apache.ignite.ml.trainers.DatasetTrainer) ArrayList(java.util.ArrayList) HyperParameterTuningStrategy(org.apache.ignite.ml.selection.paramgrid.HyperParameterTuningStrategy) LearningEnvironment(org.apache.ignite.ml.environment.LearningEnvironment) EvolutionOptimizationStrategy(org.apache.ignite.ml.selection.paramgrid.EvolutionOptimizationStrategy) MetricName(org.apache.ignite.ml.selection.scoring.metric.MetricName) Map(java.util.Map) ParamGrid(org.apache.ignite.ml.selection.paramgrid.ParamGrid) BruteForceStrategy(org.apache.ignite.ml.selection.paramgrid.BruteForceStrategy) LearningEnvironmentBuilder(org.apache.ignite.ml.environment.LearningEnvironmentBuilder) IgniteSupplier(org.apache.ignite.ml.math.functions.IgniteSupplier) IgniteModel(org.apache.ignite.ml.IgniteModel) DatasetBuilder(org.apache.ignite.ml.dataset.DatasetBuilder) Collectors(java.util.stream.Collectors) Serializable(java.io.Serializable) ParameterSetGenerator(org.apache.ignite.ml.selection.paramgrid.ParameterSetGenerator) List(java.util.List) Promise(org.apache.ignite.ml.environment.parallelism.Promise) Chromosome(org.apache.ignite.ml.util.genetic.Chromosome) RandomStrategy(org.apache.ignite.ml.selection.paramgrid.RandomStrategy) UniformMapper(org.apache.ignite.ml.selection.split.mapper.UniformMapper) NotNull(org.jetbrains.annotations.NotNull) Collections(java.util.Collections) Pipeline(org.apache.ignite.ml.pipeline.Pipeline) IgniteDoubleConsumer(org.apache.ignite.ml.math.functions.IgniteDoubleConsumer) IgniteSupplier(org.apache.ignite.ml.math.functions.IgniteSupplier) Random(java.util.Random) RandomStrategy(org.apache.ignite.ml.selection.paramgrid.RandomStrategy) ArrayList(java.util.ArrayList) ParameterSetGenerator(org.apache.ignite.ml.selection.paramgrid.ParameterSetGenerator)

Aggregations

ArrayList (java.util.ArrayList)3 Promise (org.apache.ignite.ml.environment.parallelism.Promise)3 IgniteSupplier (org.apache.ignite.ml.math.functions.IgniteSupplier)3 Vector (org.apache.ignite.ml.math.primitives.vector.Vector)3 Serializable (java.io.Serializable)2 Arrays (java.util.Arrays)2 Collections (java.util.Collections)2 HashMap (java.util.HashMap)2 List (java.util.List)2 Map (java.util.Map)2 Random (java.util.Random)2 BiFunction (java.util.function.BiFunction)2 Function (java.util.function.Function)2 Collectors (java.util.stream.Collectors)2 IgniteBiPredicate (org.apache.ignite.lang.IgniteBiPredicate)2 IgniteModel (org.apache.ignite.ml.IgniteModel)2 DatasetBuilder (org.apache.ignite.ml.dataset.DatasetBuilder)2 LearningEnvironment (org.apache.ignite.ml.environment.LearningEnvironment)2 LearningEnvironmentBuilder (org.apache.ignite.ml.environment.LearningEnvironmentBuilder)2 IgniteDoubleConsumer (org.apache.ignite.ml.math.functions.IgniteDoubleConsumer)2