Search in sources :

Example 1 with GaussRandomProducer

use of org.apache.ignite.ml.util.generators.primitives.scalar.GaussRandomProducer in project ignite by apache.

the class RingsDataStream method labeled.

/**
 * {@inheritDoc}
 */
@Override
public Stream<LabeledVector<Double>> labeled() {
    VectorGeneratorsFamily.Builder builder = new VectorGeneratorsFamily.Builder();
    for (int i = 0; i < cntOfRings; i++) {
        final double radius = minRadius + distanceBetweenRings * i;
        final double variance = 0.1 * (i + 1);
        GaussRandomProducer gauss = new GaussRandomProducer(0, variance, seed);
        builder = builder.add(ring(radius, 0, 2 * Math.PI).noisify(gauss));
        seed *= 2;
    }
    return builder.build().asDataStream().labeled();
}
Also used : VectorGeneratorsFamily(org.apache.ignite.ml.util.generators.primitives.vector.VectorGeneratorsFamily) GaussRandomProducer(org.apache.ignite.ml.util.generators.primitives.scalar.GaussRandomProducer)

Example 2 with GaussRandomProducer

use of org.apache.ignite.ml.util.generators.primitives.scalar.GaussRandomProducer in project ignite by apache.

the class GmmClusterizationExample method main.

/**
 * Runs example.
 *
 * @param args Command line arguments.
 */
public static void main(String[] args) {
    System.out.println();
    System.out.println(">>> GMM clustering algorithm over cached dataset usage example started.");
    // Start ignite grid.
    try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
        System.out.println(">>> Ignite grid started.");
        long seed = 0;
        IgniteCache<Integer, LabeledVector<Double>> dataCache = null;
        try {
            dataCache = ignite.createCache(new CacheConfiguration<Integer, LabeledVector<Double>>("GMM_EXAMPLE_CACHE").setAffinity(new RendezvousAffinityFunction(false, 10)));
            // Dataset consists of three gaussians where two from them are rotated onto PI/4.
            DataStreamGenerator dataStream = new VectorGeneratorsFamily.Builder().add(RandomProducer.vectorize(new GaussRandomProducer(0, 2., seed++), new GaussRandomProducer(0, 3., seed++)).rotate(Math.PI / 4).move(VectorUtils.of(10., 10.))).add(RandomProducer.vectorize(new GaussRandomProducer(0, 1., seed++), new GaussRandomProducer(0, 2., seed++)).rotate(-Math.PI / 4).move(VectorUtils.of(-10., 10.))).add(RandomProducer.vectorize(new GaussRandomProducer(0, 3., seed++), new GaussRandomProducer(0, 3., seed++)).move(VectorUtils.of(0., -10.))).build(seed++).asDataStream();
            AtomicInteger keyGen = new AtomicInteger();
            dataStream.fillCacheWithCustomKey(50000, dataCache, v -> keyGen.getAndIncrement());
            GmmTrainer trainer = new GmmTrainer(1);
            GmmModel mdl = trainer.withMaxCountIterations(10).withMaxCountOfClusters(4).withEnvironmentBuilder(LearningEnvironmentBuilder.defaultBuilder().withRNGSeed(seed)).fit(ignite, dataCache, new LabeledDummyVectorizer<>());
            System.out.println(">>> GMM means and covariances");
            for (int i = 0; i < mdl.countOfComponents(); i++) {
                MultivariateGaussianDistribution distribution = mdl.distributions().get(i);
                System.out.println();
                System.out.println("============");
                System.out.println("Component #" + i);
                System.out.println("============");
                System.out.println("Mean vector = ");
                Tracer.showAscii(distribution.mean());
                System.out.println();
                System.out.println("Covariance matrix = ");
                Tracer.showAscii(distribution.covariance());
            }
            System.out.println(">>>");
        } finally {
            if (dataCache != null)
                dataCache.destroy();
        }
    } finally {
        System.out.flush();
    }
}
Also used : LabeledVector(org.apache.ignite.ml.structures.LabeledVector) GmmModel(org.apache.ignite.ml.clustering.gmm.GmmModel) GaussRandomProducer(org.apache.ignite.ml.util.generators.primitives.scalar.GaussRandomProducer) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) GmmTrainer(org.apache.ignite.ml.clustering.gmm.GmmTrainer) MultivariateGaussianDistribution(org.apache.ignite.ml.math.stat.MultivariateGaussianDistribution) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) DataStreamGenerator(org.apache.ignite.ml.util.generators.DataStreamGenerator) VectorGeneratorsFamily(org.apache.ignite.ml.util.generators.primitives.vector.VectorGeneratorsFamily) Ignite(org.apache.ignite.Ignite) RendezvousAffinityFunction(org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction)

Example 3 with GaussRandomProducer

use of org.apache.ignite.ml.util.generators.primitives.scalar.GaussRandomProducer in project ignite by apache.

the class DataStreamGeneratorFillCacheTest method testCacheFilling.

/**
 */
@Test
public void testCacheFilling() {
    IgniteConfiguration configuration = new IgniteConfiguration().setDiscoverySpi(new TcpDiscoverySpi().setIpFinder(new TcpDiscoveryVmIpFinder().setAddresses(Arrays.asList("127.0.0.1:47500..47509"))));
    String cacheName = "TEST_CACHE";
    CacheConfiguration<UUID, LabeledVector<Double>> cacheConfiguration = new CacheConfiguration<UUID, LabeledVector<Double>>(cacheName).setAffinity(new RendezvousAffinityFunction(false, 10));
    int datasetSize = 5000;
    try (Ignite ignite = Ignition.start(configuration)) {
        IgniteCache<UUID, LabeledVector<Double>> cache = ignite.getOrCreateCache(cacheConfiguration);
        DataStreamGenerator generator = new GaussRandomProducer(0).vectorize(1).asDataStream();
        generator.fillCacheWithVecUUIDAsKey(datasetSize, cache);
        LabeledDummyVectorizer<UUID, Double> vectorizer = new LabeledDummyVectorizer<>();
        CacheBasedDatasetBuilder<UUID, LabeledVector<Double>> datasetBuilder = new CacheBasedDatasetBuilder<>(ignite, cache);
        IgniteFunction<SimpleDatasetData, StatPair> map = data -> new StatPair(DoubleStream.of(data.getFeatures()).sum(), data.getRows());
        LearningEnvironment env = LearningEnvironmentBuilder.defaultBuilder().buildForTrainer();
        env.deployingContext().initByClientObject(map);
        try (CacheBasedDataset<UUID, LabeledVector<Double>, EmptyContext, SimpleDatasetData> dataset = datasetBuilder.build(LearningEnvironmentBuilder.defaultBuilder(), new EmptyContextBuilder<>(), new SimpleDatasetDataBuilder<>(vectorizer), env)) {
            StatPair res = dataset.compute(map, StatPair::sum);
            assertEquals(datasetSize, res.cntOfRows);
            assertEquals(0.0, res.elementsSum / res.cntOfRows, 1e-2);
        }
        ignite.destroyCache(cacheName);
    }
}
Also used : Arrays(java.util.Arrays) IgniteFunction(org.apache.ignite.ml.math.functions.IgniteFunction) RendezvousAffinityFunction(org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) LearningEnvironment(org.apache.ignite.ml.environment.LearningEnvironment) EmptyContextBuilder(org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder) GaussRandomProducer(org.apache.ignite.ml.util.generators.primitives.scalar.GaussRandomProducer) LearningEnvironmentBuilder(org.apache.ignite.ml.environment.LearningEnvironmentBuilder) EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) SimpleDatasetData(org.apache.ignite.ml.dataset.primitive.data.SimpleDatasetData) SimpleDatasetDataBuilder(org.apache.ignite.ml.dataset.primitive.builder.data.SimpleDatasetDataBuilder) GridCommonAbstractTest(org.apache.ignite.testframework.junits.common.GridCommonAbstractTest) TcpDiscoveryVmIpFinder(org.apache.ignite.spi.discovery.tcp.ipfinder.vm.TcpDiscoveryVmIpFinder) Test(org.junit.Test) UUID(java.util.UUID) Ignite(org.apache.ignite.Ignite) CacheBasedDatasetBuilder(org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder) IgniteCache(org.apache.ignite.IgniteCache) LabeledDummyVectorizer(org.apache.ignite.ml.dataset.feature.extractor.impl.LabeledDummyVectorizer) DoubleStream(java.util.stream.DoubleStream) IgniteConfiguration(org.apache.ignite.configuration.IgniteConfiguration) Ignition(org.apache.ignite.Ignition) CacheConfiguration(org.apache.ignite.configuration.CacheConfiguration) CacheBasedDataset(org.apache.ignite.ml.dataset.impl.cache.CacheBasedDataset) TcpDiscoverySpi(org.apache.ignite.spi.discovery.tcp.TcpDiscoverySpi) EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) TcpDiscoveryVmIpFinder(org.apache.ignite.spi.discovery.tcp.ipfinder.vm.TcpDiscoveryVmIpFinder) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) LabeledDummyVectorizer(org.apache.ignite.ml.dataset.feature.extractor.impl.LabeledDummyVectorizer) RendezvousAffinityFunction(org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction) Ignite(org.apache.ignite.Ignite) UUID(java.util.UUID) TcpDiscoverySpi(org.apache.ignite.spi.discovery.tcp.TcpDiscoverySpi) GaussRandomProducer(org.apache.ignite.ml.util.generators.primitives.scalar.GaussRandomProducer) CacheBasedDatasetBuilder(org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder) LearningEnvironment(org.apache.ignite.ml.environment.LearningEnvironment) IgniteConfiguration(org.apache.ignite.configuration.IgniteConfiguration) SimpleDatasetData(org.apache.ignite.ml.dataset.primitive.data.SimpleDatasetData) GridCommonAbstractTest(org.apache.ignite.testframework.junits.common.GridCommonAbstractTest) Test(org.junit.Test)

Aggregations

GaussRandomProducer (org.apache.ignite.ml.util.generators.primitives.scalar.GaussRandomProducer)3 Ignite (org.apache.ignite.Ignite)2 RendezvousAffinityFunction (org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction)2 LabeledVector (org.apache.ignite.ml.structures.LabeledVector)2 VectorGeneratorsFamily (org.apache.ignite.ml.util.generators.primitives.vector.VectorGeneratorsFamily)2 Arrays (java.util.Arrays)1 UUID (java.util.UUID)1 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)1 DoubleStream (java.util.stream.DoubleStream)1 IgniteCache (org.apache.ignite.IgniteCache)1 Ignition (org.apache.ignite.Ignition)1 CacheConfiguration (org.apache.ignite.configuration.CacheConfiguration)1 IgniteConfiguration (org.apache.ignite.configuration.IgniteConfiguration)1 GmmModel (org.apache.ignite.ml.clustering.gmm.GmmModel)1 GmmTrainer (org.apache.ignite.ml.clustering.gmm.GmmTrainer)1 LabeledDummyVectorizer (org.apache.ignite.ml.dataset.feature.extractor.impl.LabeledDummyVectorizer)1 CacheBasedDataset (org.apache.ignite.ml.dataset.impl.cache.CacheBasedDataset)1 CacheBasedDatasetBuilder (org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder)1 EmptyContextBuilder (org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder)1 SimpleDatasetDataBuilder (org.apache.ignite.ml.dataset.primitive.builder.data.SimpleDatasetDataBuilder)1