use of org.apache.ignite.ml.clustering.gmm.GmmTrainer in project ignite by apache.
the class GmmClusterizationExample method main.
/**
* Runs example.
*
* @param args Command line arguments.
*/
public static void main(String[] args) {
System.out.println();
System.out.println(">>> GMM clustering algorithm over cached dataset usage example started.");
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
long seed = 0;
IgniteCache<Integer, LabeledVector<Double>> dataCache = null;
try {
dataCache = ignite.createCache(new CacheConfiguration<Integer, LabeledVector<Double>>("GMM_EXAMPLE_CACHE").setAffinity(new RendezvousAffinityFunction(false, 10)));
// Dataset consists of three gaussians where two from them are rotated onto PI/4.
DataStreamGenerator dataStream = new VectorGeneratorsFamily.Builder().add(RandomProducer.vectorize(new GaussRandomProducer(0, 2., seed++), new GaussRandomProducer(0, 3., seed++)).rotate(Math.PI / 4).move(VectorUtils.of(10., 10.))).add(RandomProducer.vectorize(new GaussRandomProducer(0, 1., seed++), new GaussRandomProducer(0, 2., seed++)).rotate(-Math.PI / 4).move(VectorUtils.of(-10., 10.))).add(RandomProducer.vectorize(new GaussRandomProducer(0, 3., seed++), new GaussRandomProducer(0, 3., seed++)).move(VectorUtils.of(0., -10.))).build(seed++).asDataStream();
AtomicInteger keyGen = new AtomicInteger();
dataStream.fillCacheWithCustomKey(50000, dataCache, v -> keyGen.getAndIncrement());
GmmTrainer trainer = new GmmTrainer(1);
GmmModel mdl = trainer.withMaxCountIterations(10).withMaxCountOfClusters(4).withEnvironmentBuilder(LearningEnvironmentBuilder.defaultBuilder().withRNGSeed(seed)).fit(ignite, dataCache, new LabeledDummyVectorizer<>());
System.out.println(">>> GMM means and covariances");
for (int i = 0; i < mdl.countOfComponents(); i++) {
MultivariateGaussianDistribution distribution = mdl.distributions().get(i);
System.out.println();
System.out.println("============");
System.out.println("Component #" + i);
System.out.println("============");
System.out.println("Mean vector = ");
Tracer.showAscii(distribution.mean());
System.out.println();
System.out.println("Covariance matrix = ");
Tracer.showAscii(distribution.covariance());
}
System.out.println(">>>");
} finally {
if (dataCache != null)
dataCache.destroy();
}
} finally {
System.out.flush();
}
}
Aggregations