Search in sources :

Example 1 with DatasetTrainer

use of org.apache.ignite.ml.trainers.DatasetTrainer in project ignite by apache.

the class RandomForestClassifierTrainerTest method testUpdate.

/**
 */
@Test
public void testUpdate() {
    int sampleSize = 1000;
    Map<Integer, LabeledVector<Double>> sample = new HashMap<>();
    for (int i = 0; i < sampleSize; i++) {
        double x1 = i;
        double x2 = x1 / 10.0;
        double x3 = x2 / 10.0;
        double x4 = x3 / 10.0;
        sample.put(i, VectorUtils.of(x1, x2, x3, x4).labeled((double) i % 2));
    }
    ArrayList<FeatureMeta> meta = new ArrayList<>();
    for (int i = 0; i < 4; i++) meta.add(new FeatureMeta("", i, false));
    DatasetTrainer<RandomForestModel, Double> trainer = new RandomForestClassifierTrainer(meta).withAmountOfTrees(100).withFeaturesCountSelectionStrgy(x -> 2).withEnvironmentBuilder(TestUtils.testEnvBuilder());
    RandomForestModel originalMdl = trainer.fit(sample, parts, new LabeledDummyVectorizer<>());
    RandomForestModel updatedOnSameDS = trainer.update(originalMdl, sample, parts, new LabeledDummyVectorizer<>());
    RandomForestModel updatedOnEmptyDS = trainer.update(originalMdl, new HashMap<Integer, LabeledVector<Double>>(), parts, new LabeledDummyVectorizer<>());
    Vector v = VectorUtils.of(5, 0.5, 0.05, 0.005);
    assertEquals(originalMdl.predict(v), updatedOnSameDS.predict(v), 0.01);
    assertEquals(originalMdl.predict(v), updatedOnEmptyDS.predict(v), 0.01);
}
Also used : TrainerTest(org.apache.ignite.ml.common.TrainerTest) OnMajorityPredictionsAggregator(org.apache.ignite.ml.composition.predictionsaggregator.OnMajorityPredictionsAggregator) TestUtils(org.apache.ignite.ml.TestUtils) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) Assert.assertTrue(org.junit.Assert.assertTrue) HashMap(java.util.HashMap) Test(org.junit.Test) DatasetTrainer(org.apache.ignite.ml.trainers.DatasetTrainer) LabeledDummyVectorizer(org.apache.ignite.ml.dataset.feature.extractor.impl.LabeledDummyVectorizer) ArrayList(java.util.ArrayList) FeatureMeta(org.apache.ignite.ml.dataset.feature.FeatureMeta) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) VectorUtils(org.apache.ignite.ml.math.primitives.vector.VectorUtils) Map(java.util.Map) Assert.assertEquals(org.junit.Assert.assertEquals) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) FeatureMeta(org.apache.ignite.ml.dataset.feature.FeatureMeta) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) TrainerTest(org.apache.ignite.ml.common.TrainerTest) Test(org.junit.Test)

Example 2 with DatasetTrainer

use of org.apache.ignite.ml.trainers.DatasetTrainer in project ignite by apache.

the class LearningEnvironmentTest method testRandomNumbersGenerator.

/**
 * Test random number generator provided by  {@link LearningEnvironment}.
 * We test that:
 * 1. Correct random generator is returned for each partition.
 * 2. Its state is saved between compute calls (for this we do several iterations of compute).
 */
@Test
public void testRandomNumbersGenerator() {
    // We make such builders that provide as functions returning partition index * iteration as random number generator nextInt
    LearningEnvironmentBuilder envBuilder = TestUtils.testEnvBuilder().withRandomDependency(MockRandom::new);
    int partitions = 10;
    int iterations = 2;
    DatasetTrainer<IgniteModel<Object, Vector>, Void> trainer = new DatasetTrainer<IgniteModel<Object, Vector>, Void>() {

        /**
         * {@inheritDoc}
         */
        @Override
        public <K, V> IgniteModel<Object, Vector> fitWithInitializedDeployingContext(DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
            Dataset<EmptyContext, TestUtils.DataWrapper<Integer>> ds = datasetBuilder.build(envBuilder, new EmptyContextBuilder<>(), (PartitionDataBuilder<K, V, EmptyContext, TestUtils.DataWrapper<Integer>>) (env, upstreamData, upstreamDataSize, ctx) -> TestUtils.DataWrapper.of(env.partition()), envBuilder.buildForTrainer());
            Vector v = null;
            for (int iter = 0; iter < iterations; iter++) {
                v = ds.compute((dw, env) -> VectorUtils.fill(-1, partitions).set(env.partition(), env.randomNumbersGenerator().nextInt()), (v1, v2) -> zipOverridingEmpty(v1, v2, -1));
            }
            return constantModel(v);
        }

        /**
         * {@inheritDoc}
         */
        @Override
        public boolean isUpdateable(IgniteModel<Object, Vector> mdl) {
            return false;
        }

        /**
         * {@inheritDoc}
         */
        @Override
        protected <K, V> IgniteModel<Object, Vector> updateModel(IgniteModel<Object, Vector> mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
            return null;
        }
    };
    trainer.withEnvironmentBuilder(envBuilder);
    IgniteModel<Object, Vector> mdl = trainer.fit(getCacheMock(partitions), partitions, null);
    Vector exp = VectorUtils.zeroes(partitions);
    for (int i = 0; i < partitions; i++) exp.set(i, i * iterations);
    Vector res = mdl.predict(null);
    assertEquals(exp, res);
}
Also used : IntStream(java.util.stream.IntStream) TestUtils(org.apache.ignite.ml.TestUtils) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) Preprocessor(org.apache.ignite.ml.preprocessing.Preprocessor) Random(java.util.Random) DatasetTrainer(org.apache.ignite.ml.trainers.DatasetTrainer) ParallelismStrategy(org.apache.ignite.ml.environment.parallelism.ParallelismStrategy) FeatureMeta(org.apache.ignite.ml.dataset.feature.FeatureMeta) RandomForestRegressionTrainer(org.apache.ignite.ml.tree.randomforest.RandomForestRegressionTrainer) Map(java.util.Map) EmptyContextBuilder(org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder) MLLogger(org.apache.ignite.ml.environment.logging.MLLogger) PartitionDataBuilder(org.apache.ignite.ml.dataset.PartitionDataBuilder) EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) ConsoleLogger(org.apache.ignite.ml.environment.logging.ConsoleLogger) Test(org.junit.Test) FeaturesCountSelectionStrategies(org.apache.ignite.ml.tree.randomforest.data.FeaturesCountSelectionStrategies) IgniteModel(org.apache.ignite.ml.IgniteModel) DatasetBuilder(org.apache.ignite.ml.dataset.DatasetBuilder) Collectors(java.util.stream.Collectors) VectorUtils(org.apache.ignite.ml.math.primitives.vector.VectorUtils) Dataset(org.apache.ignite.ml.dataset.Dataset) DefaultParallelismStrategy(org.apache.ignite.ml.environment.parallelism.DefaultParallelismStrategy) TestUtils.constantModel(org.apache.ignite.ml.TestUtils.constantModel) Assert.assertEquals(org.junit.Assert.assertEquals) EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) DatasetTrainer(org.apache.ignite.ml.trainers.DatasetTrainer) DatasetBuilder(org.apache.ignite.ml.dataset.DatasetBuilder) Preprocessor(org.apache.ignite.ml.preprocessing.Preprocessor) IgniteModel(org.apache.ignite.ml.IgniteModel) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) Test(org.junit.Test)

Aggregations

Map (java.util.Map)2 TestUtils (org.apache.ignite.ml.TestUtils)2 FeatureMeta (org.apache.ignite.ml.dataset.feature.FeatureMeta)2 Vector (org.apache.ignite.ml.math.primitives.vector.Vector)2 VectorUtils (org.apache.ignite.ml.math.primitives.vector.VectorUtils)2 DatasetTrainer (org.apache.ignite.ml.trainers.DatasetTrainer)2 Assert.assertEquals (org.junit.Assert.assertEquals)2 Test (org.junit.Test)2 ArrayList (java.util.ArrayList)1 HashMap (java.util.HashMap)1 Random (java.util.Random)1 Collectors (java.util.stream.Collectors)1 IntStream (java.util.stream.IntStream)1 IgniteModel (org.apache.ignite.ml.IgniteModel)1 TestUtils.constantModel (org.apache.ignite.ml.TestUtils.constantModel)1 TrainerTest (org.apache.ignite.ml.common.TrainerTest)1 OnMajorityPredictionsAggregator (org.apache.ignite.ml.composition.predictionsaggregator.OnMajorityPredictionsAggregator)1 Dataset (org.apache.ignite.ml.dataset.Dataset)1 DatasetBuilder (org.apache.ignite.ml.dataset.DatasetBuilder)1 PartitionDataBuilder (org.apache.ignite.ml.dataset.PartitionDataBuilder)1