use of org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder in project ignite by apache.
the class DataStreamGeneratorFillCacheTest method testCacheFilling.
/**
*/
@Test
public void testCacheFilling() {
IgniteConfiguration configuration = new IgniteConfiguration().setDiscoverySpi(new TcpDiscoverySpi().setIpFinder(new TcpDiscoveryVmIpFinder().setAddresses(Arrays.asList("127.0.0.1:47500..47509"))));
String cacheName = "TEST_CACHE";
CacheConfiguration<UUID, LabeledVector<Double>> cacheConfiguration = new CacheConfiguration<UUID, LabeledVector<Double>>(cacheName).setAffinity(new RendezvousAffinityFunction(false, 10));
int datasetSize = 5000;
try (Ignite ignite = Ignition.start(configuration)) {
IgniteCache<UUID, LabeledVector<Double>> cache = ignite.getOrCreateCache(cacheConfiguration);
DataStreamGenerator generator = new GaussRandomProducer(0).vectorize(1).asDataStream();
generator.fillCacheWithVecUUIDAsKey(datasetSize, cache);
LabeledDummyVectorizer<UUID, Double> vectorizer = new LabeledDummyVectorizer<>();
CacheBasedDatasetBuilder<UUID, LabeledVector<Double>> datasetBuilder = new CacheBasedDatasetBuilder<>(ignite, cache);
IgniteFunction<SimpleDatasetData, StatPair> map = data -> new StatPair(DoubleStream.of(data.getFeatures()).sum(), data.getRows());
LearningEnvironment env = LearningEnvironmentBuilder.defaultBuilder().buildForTrainer();
env.deployingContext().initByClientObject(map);
try (CacheBasedDataset<UUID, LabeledVector<Double>, EmptyContext, SimpleDatasetData> dataset = datasetBuilder.build(LearningEnvironmentBuilder.defaultBuilder(), new EmptyContextBuilder<>(), new SimpleDatasetDataBuilder<>(vectorizer), env)) {
StatPair res = dataset.compute(map, StatPair::sum);
assertEquals(datasetSize, res.cntOfRows);
assertEquals(0.0, res.elementsSum / res.cntOfRows, 1e-2);
}
ignite.destroyCache(cacheName);
}
}
use of org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder in project ignite by apache.
the class LearningEnvironmentTest method testRandomNumbersGenerator.
/**
* Test random number generator provided by {@link LearningEnvironment}.
* We test that:
* 1. Correct random generator is returned for each partition.
* 2. Its state is saved between compute calls (for this we do several iterations of compute).
*/
@Test
public void testRandomNumbersGenerator() {
// We make such builders that provide as functions returning partition index * iteration as random number generator nextInt
LearningEnvironmentBuilder envBuilder = TestUtils.testEnvBuilder().withRandomDependency(MockRandom::new);
int partitions = 10;
int iterations = 2;
DatasetTrainer<IgniteModel<Object, Vector>, Void> trainer = new DatasetTrainer<IgniteModel<Object, Vector>, Void>() {
/**
* {@inheritDoc}
*/
@Override
public <K, V> IgniteModel<Object, Vector> fitWithInitializedDeployingContext(DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
Dataset<EmptyContext, TestUtils.DataWrapper<Integer>> ds = datasetBuilder.build(envBuilder, new EmptyContextBuilder<>(), (PartitionDataBuilder<K, V, EmptyContext, TestUtils.DataWrapper<Integer>>) (env, upstreamData, upstreamDataSize, ctx) -> TestUtils.DataWrapper.of(env.partition()), envBuilder.buildForTrainer());
Vector v = null;
for (int iter = 0; iter < iterations; iter++) {
v = ds.compute((dw, env) -> VectorUtils.fill(-1, partitions).set(env.partition(), env.randomNumbersGenerator().nextInt()), (v1, v2) -> zipOverridingEmpty(v1, v2, -1));
}
return constantModel(v);
}
/**
* {@inheritDoc}
*/
@Override
public boolean isUpdateable(IgniteModel<Object, Vector> mdl) {
return false;
}
/**
* {@inheritDoc}
*/
@Override
protected <K, V> IgniteModel<Object, Vector> updateModel(IgniteModel<Object, Vector> mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
return null;
}
};
trainer.withEnvironmentBuilder(envBuilder);
IgniteModel<Object, Vector> mdl = trainer.fit(getCacheMock(partitions), partitions, null);
Vector exp = VectorUtils.zeroes(partitions);
for (int i = 0; i < partitions; i++) exp.set(i, i * iterations);
Vector res = mdl.predict(null);
assertEquals(exp, res);
}
Aggregations