Search in sources :

Example 6 with WeightedPredictionsAggregator

use of org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator in project ignite by apache.

the class GDBTrainer method updateModel.

/**
 * {@inheritDoc}
 */
@Override
protected <K, V> GDBModel updateModel(GDBModel mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
    if (!learnLabels(datasetBuilder, preprocessor))
        return getLastTrainedModelOrThrowEmptyDatasetException(mdl);
    IgniteBiTuple<Double, Long> initAndSampleSize = computeInitialValue(envBuilder, datasetBuilder, preprocessor);
    if (initAndSampleSize == null)
        return getLastTrainedModelOrThrowEmptyDatasetException(mdl);
    Double mean = initAndSampleSize.get1();
    Long sampleSize = initAndSampleSize.get2();
    long learningStartTs = System.currentTimeMillis();
    GDBLearningStrategy stgy = getLearningStrategy().withBaseModelTrainerBuilder(this::buildBaseModelTrainer).withExternalLabelToInternal(this::externalLabelToInternal).withCntOfIterations(cntOfIterations).withEnvironmentBuilder(envBuilder).withLossGradient(loss).withSampleSize(sampleSize).withMeanLabelValue(mean).withDefaultGradStepSize(gradientStep).withCheckConvergenceStgyFactory(checkConvergenceStgyFactory);
    List<IgniteModel<Vector, Double>> models;
    if (mdl != null)
        models = stgy.update(mdl, datasetBuilder, preprocessor);
    else
        models = stgy.learnModels(datasetBuilder, preprocessor);
    double learningTime = (double) (System.currentTimeMillis() - learningStartTs) / 1000.0;
    environment.logger(getClass()).log(MLLogger.VerboseLevel.LOW, "The training time was %.2fs", learningTime);
    WeightedPredictionsAggregator resAggregator = new WeightedPredictionsAggregator(stgy.getCompositionWeights(), stgy.getMeanValue());
    return new GDBModel(models, resAggregator, this::internalLabelToExternal);
}
Also used : WeightedPredictionsAggregator(org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator) IgniteModel(org.apache.ignite.ml.IgniteModel)

Example 7 with WeightedPredictionsAggregator

use of org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator in project ignite by apache.

the class GDBTrainerTest method testFitRegression.

/**
 */
@Test
public void testFitRegression() {
    int size = 100;
    double[] xs = new double[size];
    double[] ys = new double[size];
    double from = -5.0;
    double to = 5.0;
    double step = Math.abs(from - to) / size;
    Map<Integer, double[]> learningSample = new HashMap<>();
    for (int i = 0; i < size; i++) {
        xs[i] = from + step * i;
        ys[i] = 2 * xs[i];
        learningSample.put(i, new double[] { xs[i], ys[i] });
    }
    GDBTrainer trainer = new GDBRegressionOnTreesTrainer(1.0, 2000, 3, 0.0).withUsingIdx(true);
    IgniteModel<Vector, Double> mdl = trainer.fit(learningSample, 1, new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.LAST));
    double mse = 0.0;
    for (int j = 0; j < size; j++) {
        double x = xs[j];
        double y = ys[j];
        double p = mdl.predict(VectorUtils.of(x));
        mse += Math.pow(y - p, 2);
    }
    mse /= size;
    assertEquals(0.0, mse, 0.0001);
    ModelsComposition composition = (ModelsComposition) mdl;
    assertTrue(!composition.toString().isEmpty());
    assertTrue(!composition.toString(true).isEmpty());
    assertTrue(!composition.toString(false).isEmpty());
    composition.getModels().forEach(m -> assertTrue(m instanceof DecisionTreeModel));
    assertEquals(2000, composition.getModels().size());
    assertTrue(composition.getPredictionsAggregator() instanceof WeightedPredictionsAggregator);
    trainer = trainer.withCheckConvergenceStgyFactory(new MeanAbsValueConvergenceCheckerFactory(0.1));
    assertTrue(trainer.fit(learningSample, 1, new DoubleArrayVectorizer<Integer>().labeled(1)).getModels().size() < 2000);
}
Also used : DoubleArrayVectorizer(org.apache.ignite.ml.dataset.feature.extractor.impl.DoubleArrayVectorizer) GDBRegressionOnTreesTrainer(org.apache.ignite.ml.tree.boosting.GDBRegressionOnTreesTrainer) HashMap(java.util.HashMap) DecisionTreeModel(org.apache.ignite.ml.tree.DecisionTreeModel) WeightedPredictionsAggregator(org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator) ModelsComposition(org.apache.ignite.ml.composition.ModelsComposition) MeanAbsValueConvergenceCheckerFactory(org.apache.ignite.ml.composition.boosting.convergence.mean.MeanAbsValueConvergenceCheckerFactory) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) TrainerTest(org.apache.ignite.ml.common.TrainerTest) Test(org.junit.Test)

Aggregations

WeightedPredictionsAggregator (org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator)7 IgniteModel (org.apache.ignite.ml.IgniteModel)5 ModelsComposition (org.apache.ignite.ml.composition.ModelsComposition)4 Vector (org.apache.ignite.ml.math.primitives.vector.Vector)4 HashMap (java.util.HashMap)3 ArrayList (java.util.ArrayList)2 MeanAbsValueConvergenceCheckerFactory (org.apache.ignite.ml.composition.boosting.convergence.mean.MeanAbsValueConvergenceCheckerFactory)2 DecisionTreeModel (org.apache.ignite.ml.tree.DecisionTreeModel)2 IOException (java.io.IOException)1 Serializable (java.io.Serializable)1 TreeMap (java.util.TreeMap)1 Configuration (org.apache.hadoop.conf.Configuration)1 Path (org.apache.hadoop.fs.Path)1 TrainerTest (org.apache.ignite.ml.common.TrainerTest)1 GDBModel (org.apache.ignite.ml.composition.boosting.GDBModel)1 ConvergenceCheckerStubFactory (org.apache.ignite.ml.composition.boosting.convergence.simple.ConvergenceCheckerStubFactory)1 DoubleArrayVectorizer (org.apache.ignite.ml.dataset.feature.extractor.impl.DoubleArrayVectorizer)1 EmptyContext (org.apache.ignite.ml.dataset.primitive.context.EmptyContext)1 LearningEnvironment (org.apache.ignite.ml.environment.LearningEnvironment)1 LabeledVector (org.apache.ignite.ml.structures.LabeledVector)1