use of org.apache.ignite.ml.util.generators.primitives.scalar.GaussRandomProducer in project ignite by apache.
the class RingsDataStream method labeled.
/**
* {@inheritDoc}
*/
@Override
public Stream<LabeledVector<Double>> labeled() {
VectorGeneratorsFamily.Builder builder = new VectorGeneratorsFamily.Builder();
for (int i = 0; i < cntOfRings; i++) {
final double radius = minRadius + distanceBetweenRings * i;
final double variance = 0.1 * (i + 1);
GaussRandomProducer gauss = new GaussRandomProducer(0, variance, seed);
builder = builder.add(ring(radius, 0, 2 * Math.PI).noisify(gauss));
seed *= 2;
}
return builder.build().asDataStream().labeled();
}
use of org.apache.ignite.ml.util.generators.primitives.scalar.GaussRandomProducer in project ignite by apache.
the class GmmClusterizationExample method main.
/**
* Runs example.
*
* @param args Command line arguments.
*/
public static void main(String[] args) {
System.out.println();
System.out.println(">>> GMM clustering algorithm over cached dataset usage example started.");
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
long seed = 0;
IgniteCache<Integer, LabeledVector<Double>> dataCache = null;
try {
dataCache = ignite.createCache(new CacheConfiguration<Integer, LabeledVector<Double>>("GMM_EXAMPLE_CACHE").setAffinity(new RendezvousAffinityFunction(false, 10)));
// Dataset consists of three gaussians where two from them are rotated onto PI/4.
DataStreamGenerator dataStream = new VectorGeneratorsFamily.Builder().add(RandomProducer.vectorize(new GaussRandomProducer(0, 2., seed++), new GaussRandomProducer(0, 3., seed++)).rotate(Math.PI / 4).move(VectorUtils.of(10., 10.))).add(RandomProducer.vectorize(new GaussRandomProducer(0, 1., seed++), new GaussRandomProducer(0, 2., seed++)).rotate(-Math.PI / 4).move(VectorUtils.of(-10., 10.))).add(RandomProducer.vectorize(new GaussRandomProducer(0, 3., seed++), new GaussRandomProducer(0, 3., seed++)).move(VectorUtils.of(0., -10.))).build(seed++).asDataStream();
AtomicInteger keyGen = new AtomicInteger();
dataStream.fillCacheWithCustomKey(50000, dataCache, v -> keyGen.getAndIncrement());
GmmTrainer trainer = new GmmTrainer(1);
GmmModel mdl = trainer.withMaxCountIterations(10).withMaxCountOfClusters(4).withEnvironmentBuilder(LearningEnvironmentBuilder.defaultBuilder().withRNGSeed(seed)).fit(ignite, dataCache, new LabeledDummyVectorizer<>());
System.out.println(">>> GMM means and covariances");
for (int i = 0; i < mdl.countOfComponents(); i++) {
MultivariateGaussianDistribution distribution = mdl.distributions().get(i);
System.out.println();
System.out.println("============");
System.out.println("Component #" + i);
System.out.println("============");
System.out.println("Mean vector = ");
Tracer.showAscii(distribution.mean());
System.out.println();
System.out.println("Covariance matrix = ");
Tracer.showAscii(distribution.covariance());
}
System.out.println(">>>");
} finally {
if (dataCache != null)
dataCache.destroy();
}
} finally {
System.out.flush();
}
}
use of org.apache.ignite.ml.util.generators.primitives.scalar.GaussRandomProducer in project ignite by apache.
the class DataStreamGeneratorFillCacheTest method testCacheFilling.
/**
*/
@Test
public void testCacheFilling() {
IgniteConfiguration configuration = new IgniteConfiguration().setDiscoverySpi(new TcpDiscoverySpi().setIpFinder(new TcpDiscoveryVmIpFinder().setAddresses(Arrays.asList("127.0.0.1:47500..47509"))));
String cacheName = "TEST_CACHE";
CacheConfiguration<UUID, LabeledVector<Double>> cacheConfiguration = new CacheConfiguration<UUID, LabeledVector<Double>>(cacheName).setAffinity(new RendezvousAffinityFunction(false, 10));
int datasetSize = 5000;
try (Ignite ignite = Ignition.start(configuration)) {
IgniteCache<UUID, LabeledVector<Double>> cache = ignite.getOrCreateCache(cacheConfiguration);
DataStreamGenerator generator = new GaussRandomProducer(0).vectorize(1).asDataStream();
generator.fillCacheWithVecUUIDAsKey(datasetSize, cache);
LabeledDummyVectorizer<UUID, Double> vectorizer = new LabeledDummyVectorizer<>();
CacheBasedDatasetBuilder<UUID, LabeledVector<Double>> datasetBuilder = new CacheBasedDatasetBuilder<>(ignite, cache);
IgniteFunction<SimpleDatasetData, StatPair> map = data -> new StatPair(DoubleStream.of(data.getFeatures()).sum(), data.getRows());
LearningEnvironment env = LearningEnvironmentBuilder.defaultBuilder().buildForTrainer();
env.deployingContext().initByClientObject(map);
try (CacheBasedDataset<UUID, LabeledVector<Double>, EmptyContext, SimpleDatasetData> dataset = datasetBuilder.build(LearningEnvironmentBuilder.defaultBuilder(), new EmptyContextBuilder<>(), new SimpleDatasetDataBuilder<>(vectorizer), env)) {
StatPair res = dataset.compute(map, StatPair::sum);
assertEquals(datasetSize, res.cntOfRows);
assertEquals(0.0, res.elementsSum / res.cntOfRows, 1e-2);
}
ignite.destroyCache(cacheName);
}
}
Aggregations