use of org.apache.ignite.ml.environment.parallelism.Promise in project ignite by apache.
the class TrainerTransformers method runOnEnsemble.
/**
* This method accepts function which for given dataset builder and index of model in ensemble generates
* task of training this model.
*
* @param trainingTaskGenerator Training test generator.
* @param datasetBuilder Dataset builder.
* @param ensembleSize Size of ensemble.
* @param subsampleRatio Ratio (subsample size) / (initial dataset size).
* @param featuresVectorSize Dimensionality of feature vector.
* @param featureSubspaceDim Dimensionality of feature subspace.
* @param aggregator Aggregator of models.
* @param environment Environment.
* @param <K> Type of keys in dataset builder.
* @param <V> Type of values in dataset builder.
* @param <M> Type of model.
* @return Composition of models trained on bagged dataset.
*/
private static <K, V, M extends IgniteModel<Vector, Double>> ModelsComposition runOnEnsemble(IgniteTriFunction<DatasetBuilder<K, V>, Integer, IgniteBiFunction<K, V, Vector>, IgniteSupplier<M>> trainingTaskGenerator, DatasetBuilder<K, V> datasetBuilder, int ensembleSize, double subsampleRatio, int featuresVectorSize, int featureSubspaceDim, IgniteBiFunction<K, V, Vector> extractor, PredictionsAggregator aggregator, LearningEnvironment environment) {
MLLogger log = environment.logger(datasetBuilder.getClass());
log.log(MLLogger.VerboseLevel.LOW, "Start learning.");
List<int[]> mappings = null;
if (featuresVectorSize > 0 && featureSubspaceDim != featuresVectorSize) {
mappings = IntStream.range(0, ensembleSize).mapToObj(modelIdx -> getMapping(featuresVectorSize, featureSubspaceDim, environment.randomNumbersGenerator().nextLong() + modelIdx)).collect(Collectors.toList());
}
Long startTs = System.currentTimeMillis();
List<IgniteSupplier<M>> tasks = new ArrayList<>();
List<IgniteBiFunction<K, V, Vector>> extractors = new ArrayList<>();
if (mappings != null) {
for (int[] mapping : mappings) extractors.add(wrapExtractor(extractor, mapping));
}
for (int i = 0; i < ensembleSize; i++) {
DatasetBuilder<K, V> newBuilder = datasetBuilder.withUpstreamTransformer(BaggingUpstreamTransformer.builder(subsampleRatio, i));
tasks.add(trainingTaskGenerator.apply(newBuilder, i, mappings != null ? extractors.get(i) : extractor));
}
List<ModelWithMapping<Vector, Double, M>> models = environment.parallelismStrategy().submit(tasks).stream().map(Promise::unsafeGet).map(ModelWithMapping<Vector, Double, M>::new).collect(Collectors.toList());
// If we need to do projection, do it.
if (mappings != null) {
for (int i = 0; i < models.size(); i++) models.get(i).setMapping(VectorUtils.getProjector(mappings.get(i)));
}
double learningTime = (double) (System.currentTimeMillis() - startTs) / 1000.0;
log.log(MLLogger.VerboseLevel.LOW, "The training time was %.2fs.", learningTime);
log.log(MLLogger.VerboseLevel.LOW, "Learning finished.");
return new ModelsComposition(models, aggregator);
}
use of org.apache.ignite.ml.environment.parallelism.Promise in project ignite by apache.
the class AbstractCrossValidation method scoreBruteForceHyperparameterOptimization.
/**
* Finds the best set of hyperparameters based on brute force approach .
*/
private CrossValidationResult scoreBruteForceHyperparameterOptimization() {
List<Double[]> paramSets = new ParameterSetGenerator(paramGrid.getParamValuesByParamIdx()).generate();
CrossValidationResult cvRes = new CrossValidationResult();
List<IgniteSupplier<TaskResult>> tasks = paramSets.stream().map(paramSet -> (IgniteSupplier<TaskResult>) (() -> calculateScoresForFixedParamSet(paramSet))).collect(Collectors.toList());
List<TaskResult> taskResults = environment.parallelismStrategy().submit(tasks).stream().map(Promise::unsafeGet).collect(Collectors.toList());
taskResults.forEach(tr -> cvRes.addScores(tr.locScores, tr.paramMap));
return cvRes;
}
use of org.apache.ignite.ml.environment.parallelism.Promise in project ignite by apache.
the class AbstractCrossValidation method scoreRandomSearchHyperparameterOptimization.
/**
* Finds the best set of hyperparameters based on Random Serach.
*/
private CrossValidationResult scoreRandomSearchHyperparameterOptimization() {
RandomStrategy stgy = (RandomStrategy) paramGrid.getHyperParameterTuningStrategy();
List<Double[]> paramSets = new ParameterSetGenerator(paramGrid.getParamValuesByParamIdx()).generate();
List<Double[]> paramSetsCp = new ArrayList<>(paramSets);
Collections.shuffle(paramSetsCp, new Random(stgy.getSeed()));
CrossValidationResult cvRes = new CrossValidationResult();
List<Double[]> rndParamSets = paramSetsCp.subList(0, stgy.getMaxTries());
List<IgniteSupplier<TaskResult>> tasks = rndParamSets.stream().map(paramSet -> (IgniteSupplier<TaskResult>) (() -> calculateScoresForFixedParamSet(paramSet))).collect(Collectors.toList());
List<TaskResult> taskResults = environment.parallelismStrategy().submit(tasks).stream().map(Promise::unsafeGet).collect(Collectors.toList());
taskResults.forEach(tr -> cvRes.addScores(tr.locScores, tr.paramMap));
return cvRes;
}
Aggregations