Search in sources :

Example 1 with MLLogger

use of org.apache.ignite.ml.environment.logging.MLLogger in project ignite by apache.

the class TrainerTransformers method runOnEnsemble.

/**
 * This method accepts function which for given dataset builder and index of model in ensemble generates
 * task of training this model.
 *
 * @param trainingTaskGenerator Training test generator.
 * @param datasetBuilder Dataset builder.
 * @param ensembleSize Size of ensemble.
 * @param subsampleRatio Ratio (subsample size) / (initial dataset size).
 * @param featuresVectorSize Dimensionality of feature vector.
 * @param featureSubspaceDim Dimensionality of feature subspace.
 * @param aggregator Aggregator of models.
 * @param environment Environment.
 * @param <K> Type of keys in dataset builder.
 * @param <V> Type of values in dataset builder.
 * @param <M> Type of model.
 * @return Composition of models trained on bagged dataset.
 */
private static <K, V, M extends IgniteModel<Vector, Double>> ModelsComposition runOnEnsemble(IgniteTriFunction<DatasetBuilder<K, V>, Integer, IgniteBiFunction<K, V, Vector>, IgniteSupplier<M>> trainingTaskGenerator, DatasetBuilder<K, V> datasetBuilder, int ensembleSize, double subsampleRatio, int featuresVectorSize, int featureSubspaceDim, IgniteBiFunction<K, V, Vector> extractor, PredictionsAggregator aggregator, LearningEnvironment environment) {
    MLLogger log = environment.logger(datasetBuilder.getClass());
    log.log(MLLogger.VerboseLevel.LOW, "Start learning.");
    List<int[]> mappings = null;
    if (featuresVectorSize > 0 && featureSubspaceDim != featuresVectorSize) {
        mappings = IntStream.range(0, ensembleSize).mapToObj(modelIdx -> getMapping(featuresVectorSize, featureSubspaceDim, environment.randomNumbersGenerator().nextLong() + modelIdx)).collect(Collectors.toList());
    }
    Long startTs = System.currentTimeMillis();
    List<IgniteSupplier<M>> tasks = new ArrayList<>();
    List<IgniteBiFunction<K, V, Vector>> extractors = new ArrayList<>();
    if (mappings != null) {
        for (int[] mapping : mappings) extractors.add(wrapExtractor(extractor, mapping));
    }
    for (int i = 0; i < ensembleSize; i++) {
        DatasetBuilder<K, V> newBuilder = datasetBuilder.withUpstreamTransformer(BaggingUpstreamTransformer.builder(subsampleRatio, i));
        tasks.add(trainingTaskGenerator.apply(newBuilder, i, mappings != null ? extractors.get(i) : extractor));
    }
    List<ModelWithMapping<Vector, Double, M>> models = environment.parallelismStrategy().submit(tasks).stream().map(Promise::unsafeGet).map(ModelWithMapping<Vector, Double, M>::new).collect(Collectors.toList());
    // If we need to do projection, do it.
    if (mappings != null) {
        for (int i = 0; i < models.size(); i++) models.get(i).setMapping(VectorUtils.getProjector(mappings.get(i)));
    }
    double learningTime = (double) (System.currentTimeMillis() - startTs) / 1000.0;
    log.log(MLLogger.VerboseLevel.LOW, "The training time was %.2fs.", learningTime);
    log.log(MLLogger.VerboseLevel.LOW, "Learning finished.");
    return new ModelsComposition(models, aggregator);
}
Also used : IgniteSupplier(org.apache.ignite.ml.math.functions.IgniteSupplier) IgniteBiFunction(org.apache.ignite.ml.math.functions.IgniteBiFunction) ArrayList(java.util.ArrayList) ModelsComposition(org.apache.ignite.ml.composition.ModelsComposition) Promise(org.apache.ignite.ml.environment.parallelism.Promise) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) MLLogger(org.apache.ignite.ml.environment.logging.MLLogger)

Aggregations

ArrayList (java.util.ArrayList)1 ModelsComposition (org.apache.ignite.ml.composition.ModelsComposition)1 MLLogger (org.apache.ignite.ml.environment.logging.MLLogger)1 Promise (org.apache.ignite.ml.environment.parallelism.Promise)1 IgniteBiFunction (org.apache.ignite.ml.math.functions.IgniteBiFunction)1 IgniteSupplier (org.apache.ignite.ml.math.functions.IgniteSupplier)1 Vector (org.apache.ignite.ml.math.primitives.vector.Vector)1