Search in sources :

Example 1 with GmmTrainer

use of org.apache.ignite.ml.clustering.gmm.GmmTrainer in project ignite by apache.

the class GmmClusterizationExample method main.

/**
 * Runs example.
 *
 * @param args Command line arguments.
 */
public static void main(String[] args) {
    System.out.println();
    System.out.println(">>> GMM clustering algorithm over cached dataset usage example started.");
    // Start ignite grid.
    try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
        System.out.println(">>> Ignite grid started.");
        long seed = 0;
        IgniteCache<Integer, LabeledVector<Double>> dataCache = null;
        try {
            dataCache = ignite.createCache(new CacheConfiguration<Integer, LabeledVector<Double>>("GMM_EXAMPLE_CACHE").setAffinity(new RendezvousAffinityFunction(false, 10)));
            // Dataset consists of three gaussians where two from them are rotated onto PI/4.
            DataStreamGenerator dataStream = new VectorGeneratorsFamily.Builder().add(RandomProducer.vectorize(new GaussRandomProducer(0, 2., seed++), new GaussRandomProducer(0, 3., seed++)).rotate(Math.PI / 4).move(VectorUtils.of(10., 10.))).add(RandomProducer.vectorize(new GaussRandomProducer(0, 1., seed++), new GaussRandomProducer(0, 2., seed++)).rotate(-Math.PI / 4).move(VectorUtils.of(-10., 10.))).add(RandomProducer.vectorize(new GaussRandomProducer(0, 3., seed++), new GaussRandomProducer(0, 3., seed++)).move(VectorUtils.of(0., -10.))).build(seed++).asDataStream();
            AtomicInteger keyGen = new AtomicInteger();
            dataStream.fillCacheWithCustomKey(50000, dataCache, v -> keyGen.getAndIncrement());
            GmmTrainer trainer = new GmmTrainer(1);
            GmmModel mdl = trainer.withMaxCountIterations(10).withMaxCountOfClusters(4).withEnvironmentBuilder(LearningEnvironmentBuilder.defaultBuilder().withRNGSeed(seed)).fit(ignite, dataCache, new LabeledDummyVectorizer<>());
            System.out.println(">>> GMM means and covariances");
            for (int i = 0; i < mdl.countOfComponents(); i++) {
                MultivariateGaussianDistribution distribution = mdl.distributions().get(i);
                System.out.println();
                System.out.println("============");
                System.out.println("Component #" + i);
                System.out.println("============");
                System.out.println("Mean vector = ");
                Tracer.showAscii(distribution.mean());
                System.out.println();
                System.out.println("Covariance matrix = ");
                Tracer.showAscii(distribution.covariance());
            }
            System.out.println(">>>");
        } finally {
            if (dataCache != null)
                dataCache.destroy();
        }
    } finally {
        System.out.flush();
    }
}
Also used : LabeledVector(org.apache.ignite.ml.structures.LabeledVector) GmmModel(org.apache.ignite.ml.clustering.gmm.GmmModel) GaussRandomProducer(org.apache.ignite.ml.util.generators.primitives.scalar.GaussRandomProducer) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) GmmTrainer(org.apache.ignite.ml.clustering.gmm.GmmTrainer) MultivariateGaussianDistribution(org.apache.ignite.ml.math.stat.MultivariateGaussianDistribution) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) DataStreamGenerator(org.apache.ignite.ml.util.generators.DataStreamGenerator) VectorGeneratorsFamily(org.apache.ignite.ml.util.generators.primitives.vector.VectorGeneratorsFamily) Ignite(org.apache.ignite.Ignite) RendezvousAffinityFunction(org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction)

Aggregations

AtomicInteger (java.util.concurrent.atomic.AtomicInteger)1 Ignite (org.apache.ignite.Ignite)1 RendezvousAffinityFunction (org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction)1 GmmModel (org.apache.ignite.ml.clustering.gmm.GmmModel)1 GmmTrainer (org.apache.ignite.ml.clustering.gmm.GmmTrainer)1 MultivariateGaussianDistribution (org.apache.ignite.ml.math.stat.MultivariateGaussianDistribution)1 LabeledVector (org.apache.ignite.ml.structures.LabeledVector)1 DataStreamGenerator (org.apache.ignite.ml.util.generators.DataStreamGenerator)1 GaussRandomProducer (org.apache.ignite.ml.util.generators.primitives.scalar.GaussRandomProducer)1 VectorGeneratorsFamily (org.apache.ignite.ml.util.generators.primitives.vector.VectorGeneratorsFamily)1