Search in sources :

Example 1 with EmptyContextBuilder

use of org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder in project ignite by apache.

the class DataStreamGeneratorFillCacheTest method testCacheFilling.

/**
 */
@Test
public void testCacheFilling() {
    IgniteConfiguration configuration = new IgniteConfiguration().setDiscoverySpi(new TcpDiscoverySpi().setIpFinder(new TcpDiscoveryVmIpFinder().setAddresses(Arrays.asList("127.0.0.1:47500..47509"))));
    String cacheName = "TEST_CACHE";
    CacheConfiguration<UUID, LabeledVector<Double>> cacheConfiguration = new CacheConfiguration<UUID, LabeledVector<Double>>(cacheName).setAffinity(new RendezvousAffinityFunction(false, 10));
    int datasetSize = 5000;
    try (Ignite ignite = Ignition.start(configuration)) {
        IgniteCache<UUID, LabeledVector<Double>> cache = ignite.getOrCreateCache(cacheConfiguration);
        DataStreamGenerator generator = new GaussRandomProducer(0).vectorize(1).asDataStream();
        generator.fillCacheWithVecUUIDAsKey(datasetSize, cache);
        LabeledDummyVectorizer<UUID, Double> vectorizer = new LabeledDummyVectorizer<>();
        CacheBasedDatasetBuilder<UUID, LabeledVector<Double>> datasetBuilder = new CacheBasedDatasetBuilder<>(ignite, cache);
        IgniteFunction<SimpleDatasetData, StatPair> map = data -> new StatPair(DoubleStream.of(data.getFeatures()).sum(), data.getRows());
        LearningEnvironment env = LearningEnvironmentBuilder.defaultBuilder().buildForTrainer();
        env.deployingContext().initByClientObject(map);
        try (CacheBasedDataset<UUID, LabeledVector<Double>, EmptyContext, SimpleDatasetData> dataset = datasetBuilder.build(LearningEnvironmentBuilder.defaultBuilder(), new EmptyContextBuilder<>(), new SimpleDatasetDataBuilder<>(vectorizer), env)) {
            StatPair res = dataset.compute(map, StatPair::sum);
            assertEquals(datasetSize, res.cntOfRows);
            assertEquals(0.0, res.elementsSum / res.cntOfRows, 1e-2);
        }
        ignite.destroyCache(cacheName);
    }
}
Also used : Arrays(java.util.Arrays) IgniteFunction(org.apache.ignite.ml.math.functions.IgniteFunction) RendezvousAffinityFunction(org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) LearningEnvironment(org.apache.ignite.ml.environment.LearningEnvironment) EmptyContextBuilder(org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder) GaussRandomProducer(org.apache.ignite.ml.util.generators.primitives.scalar.GaussRandomProducer) LearningEnvironmentBuilder(org.apache.ignite.ml.environment.LearningEnvironmentBuilder) EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) SimpleDatasetData(org.apache.ignite.ml.dataset.primitive.data.SimpleDatasetData) SimpleDatasetDataBuilder(org.apache.ignite.ml.dataset.primitive.builder.data.SimpleDatasetDataBuilder) GridCommonAbstractTest(org.apache.ignite.testframework.junits.common.GridCommonAbstractTest) TcpDiscoveryVmIpFinder(org.apache.ignite.spi.discovery.tcp.ipfinder.vm.TcpDiscoveryVmIpFinder) Test(org.junit.Test) UUID(java.util.UUID) Ignite(org.apache.ignite.Ignite) CacheBasedDatasetBuilder(org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder) IgniteCache(org.apache.ignite.IgniteCache) LabeledDummyVectorizer(org.apache.ignite.ml.dataset.feature.extractor.impl.LabeledDummyVectorizer) DoubleStream(java.util.stream.DoubleStream) IgniteConfiguration(org.apache.ignite.configuration.IgniteConfiguration) Ignition(org.apache.ignite.Ignition) CacheConfiguration(org.apache.ignite.configuration.CacheConfiguration) CacheBasedDataset(org.apache.ignite.ml.dataset.impl.cache.CacheBasedDataset) TcpDiscoverySpi(org.apache.ignite.spi.discovery.tcp.TcpDiscoverySpi) EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) TcpDiscoveryVmIpFinder(org.apache.ignite.spi.discovery.tcp.ipfinder.vm.TcpDiscoveryVmIpFinder) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) LabeledDummyVectorizer(org.apache.ignite.ml.dataset.feature.extractor.impl.LabeledDummyVectorizer) RendezvousAffinityFunction(org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction) Ignite(org.apache.ignite.Ignite) UUID(java.util.UUID) TcpDiscoverySpi(org.apache.ignite.spi.discovery.tcp.TcpDiscoverySpi) GaussRandomProducer(org.apache.ignite.ml.util.generators.primitives.scalar.GaussRandomProducer) CacheBasedDatasetBuilder(org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder) LearningEnvironment(org.apache.ignite.ml.environment.LearningEnvironment) IgniteConfiguration(org.apache.ignite.configuration.IgniteConfiguration) SimpleDatasetData(org.apache.ignite.ml.dataset.primitive.data.SimpleDatasetData) GridCommonAbstractTest(org.apache.ignite.testframework.junits.common.GridCommonAbstractTest) Test(org.junit.Test)

Example 2 with EmptyContextBuilder

use of org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder in project ignite by apache.

the class LearningEnvironmentTest method testRandomNumbersGenerator.

/**
 * Test random number generator provided by  {@link LearningEnvironment}.
 * We test that:
 * 1. Correct random generator is returned for each partition.
 * 2. Its state is saved between compute calls (for this we do several iterations of compute).
 */
@Test
public void testRandomNumbersGenerator() {
    // We make such builders that provide as functions returning partition index * iteration as random number generator nextInt
    LearningEnvironmentBuilder envBuilder = TestUtils.testEnvBuilder().withRandomDependency(MockRandom::new);
    int partitions = 10;
    int iterations = 2;
    DatasetTrainer<IgniteModel<Object, Vector>, Void> trainer = new DatasetTrainer<IgniteModel<Object, Vector>, Void>() {

        /**
         * {@inheritDoc}
         */
        @Override
        public <K, V> IgniteModel<Object, Vector> fitWithInitializedDeployingContext(DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
            Dataset<EmptyContext, TestUtils.DataWrapper<Integer>> ds = datasetBuilder.build(envBuilder, new EmptyContextBuilder<>(), (PartitionDataBuilder<K, V, EmptyContext, TestUtils.DataWrapper<Integer>>) (env, upstreamData, upstreamDataSize, ctx) -> TestUtils.DataWrapper.of(env.partition()), envBuilder.buildForTrainer());
            Vector v = null;
            for (int iter = 0; iter < iterations; iter++) {
                v = ds.compute((dw, env) -> VectorUtils.fill(-1, partitions).set(env.partition(), env.randomNumbersGenerator().nextInt()), (v1, v2) -> zipOverridingEmpty(v1, v2, -1));
            }
            return constantModel(v);
        }

        /**
         * {@inheritDoc}
         */
        @Override
        public boolean isUpdateable(IgniteModel<Object, Vector> mdl) {
            return false;
        }

        /**
         * {@inheritDoc}
         */
        @Override
        protected <K, V> IgniteModel<Object, Vector> updateModel(IgniteModel<Object, Vector> mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
            return null;
        }
    };
    trainer.withEnvironmentBuilder(envBuilder);
    IgniteModel<Object, Vector> mdl = trainer.fit(getCacheMock(partitions), partitions, null);
    Vector exp = VectorUtils.zeroes(partitions);
    for (int i = 0; i < partitions; i++) exp.set(i, i * iterations);
    Vector res = mdl.predict(null);
    assertEquals(exp, res);
}
Also used : IntStream(java.util.stream.IntStream) TestUtils(org.apache.ignite.ml.TestUtils) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) Preprocessor(org.apache.ignite.ml.preprocessing.Preprocessor) Random(java.util.Random) DatasetTrainer(org.apache.ignite.ml.trainers.DatasetTrainer) ParallelismStrategy(org.apache.ignite.ml.environment.parallelism.ParallelismStrategy) FeatureMeta(org.apache.ignite.ml.dataset.feature.FeatureMeta) RandomForestRegressionTrainer(org.apache.ignite.ml.tree.randomforest.RandomForestRegressionTrainer) Map(java.util.Map) EmptyContextBuilder(org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder) MLLogger(org.apache.ignite.ml.environment.logging.MLLogger) PartitionDataBuilder(org.apache.ignite.ml.dataset.PartitionDataBuilder) EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) ConsoleLogger(org.apache.ignite.ml.environment.logging.ConsoleLogger) Test(org.junit.Test) FeaturesCountSelectionStrategies(org.apache.ignite.ml.tree.randomforest.data.FeaturesCountSelectionStrategies) IgniteModel(org.apache.ignite.ml.IgniteModel) DatasetBuilder(org.apache.ignite.ml.dataset.DatasetBuilder) Collectors(java.util.stream.Collectors) VectorUtils(org.apache.ignite.ml.math.primitives.vector.VectorUtils) Dataset(org.apache.ignite.ml.dataset.Dataset) DefaultParallelismStrategy(org.apache.ignite.ml.environment.parallelism.DefaultParallelismStrategy) TestUtils.constantModel(org.apache.ignite.ml.TestUtils.constantModel) Assert.assertEquals(org.junit.Assert.assertEquals) EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) DatasetTrainer(org.apache.ignite.ml.trainers.DatasetTrainer) DatasetBuilder(org.apache.ignite.ml.dataset.DatasetBuilder) Preprocessor(org.apache.ignite.ml.preprocessing.Preprocessor) IgniteModel(org.apache.ignite.ml.IgniteModel) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) Test(org.junit.Test)

Aggregations

EmptyContextBuilder (org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder)2 EmptyContext (org.apache.ignite.ml.dataset.primitive.context.EmptyContext)2 Test (org.junit.Test)2 Arrays (java.util.Arrays)1 Map (java.util.Map)1 Random (java.util.Random)1 UUID (java.util.UUID)1 Collectors (java.util.stream.Collectors)1 DoubleStream (java.util.stream.DoubleStream)1 IntStream (java.util.stream.IntStream)1 Ignite (org.apache.ignite.Ignite)1 IgniteCache (org.apache.ignite.IgniteCache)1 Ignition (org.apache.ignite.Ignition)1 RendezvousAffinityFunction (org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction)1 CacheConfiguration (org.apache.ignite.configuration.CacheConfiguration)1 IgniteConfiguration (org.apache.ignite.configuration.IgniteConfiguration)1 IgniteModel (org.apache.ignite.ml.IgniteModel)1 TestUtils (org.apache.ignite.ml.TestUtils)1 TestUtils.constantModel (org.apache.ignite.ml.TestUtils.constantModel)1 Dataset (org.apache.ignite.ml.dataset.Dataset)1