use of org.apache.ignite.ml.dataset.PartitionDataBuilder in project ignite by apache.
the class LearningEnvironmentTest method testRandomNumbersGenerator.
/**
* Test random number generator provided by {@link LearningEnvironment}.
* We test that:
* 1. Correct random generator is returned for each partition.
* 2. Its state is saved between compute calls (for this we do several iterations of compute).
*/
@Test
public void testRandomNumbersGenerator() {
// We make such builders that provide as functions returning partition index * iteration as random number generator nextInt
LearningEnvironmentBuilder envBuilder = TestUtils.testEnvBuilder().withRandomDependency(MockRandom::new);
int partitions = 10;
int iterations = 2;
DatasetTrainer<IgniteModel<Object, Vector>, Void> trainer = new DatasetTrainer<IgniteModel<Object, Vector>, Void>() {
/**
* {@inheritDoc}
*/
@Override
public <K, V> IgniteModel<Object, Vector> fitWithInitializedDeployingContext(DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
Dataset<EmptyContext, TestUtils.DataWrapper<Integer>> ds = datasetBuilder.build(envBuilder, new EmptyContextBuilder<>(), (PartitionDataBuilder<K, V, EmptyContext, TestUtils.DataWrapper<Integer>>) (env, upstreamData, upstreamDataSize, ctx) -> TestUtils.DataWrapper.of(env.partition()), envBuilder.buildForTrainer());
Vector v = null;
for (int iter = 0; iter < iterations; iter++) {
v = ds.compute((dw, env) -> VectorUtils.fill(-1, partitions).set(env.partition(), env.randomNumbersGenerator().nextInt()), (v1, v2) -> zipOverridingEmpty(v1, v2, -1));
}
return constantModel(v);
}
/**
* {@inheritDoc}
*/
@Override
public boolean isUpdateable(IgniteModel<Object, Vector> mdl) {
return false;
}
/**
* {@inheritDoc}
*/
@Override
protected <K, V> IgniteModel<Object, Vector> updateModel(IgniteModel<Object, Vector> mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
return null;
}
};
trainer.withEnvironmentBuilder(envBuilder);
IgniteModel<Object, Vector> mdl = trainer.fit(getCacheMock(partitions), partitions, null);
Vector exp = VectorUtils.zeroes(partitions);
for (int i = 0; i < partitions; i++) exp.set(i, i * iterations);
Vector res = mdl.predict(null);
assertEquals(exp, res);
}
Aggregations