use of org.apache.ignite.ml.trainers.DatasetTrainer in project ignite by apache.
the class RandomForestClassifierTrainerTest method testUpdate.
/**
*/
@Test
public void testUpdate() {
int sampleSize = 1000;
Map<Integer, LabeledVector<Double>> sample = new HashMap<>();
for (int i = 0; i < sampleSize; i++) {
double x1 = i;
double x2 = x1 / 10.0;
double x3 = x2 / 10.0;
double x4 = x3 / 10.0;
sample.put(i, VectorUtils.of(x1, x2, x3, x4).labeled((double) i % 2));
}
ArrayList<FeatureMeta> meta = new ArrayList<>();
for (int i = 0; i < 4; i++) meta.add(new FeatureMeta("", i, false));
DatasetTrainer<RandomForestModel, Double> trainer = new RandomForestClassifierTrainer(meta).withAmountOfTrees(100).withFeaturesCountSelectionStrgy(x -> 2).withEnvironmentBuilder(TestUtils.testEnvBuilder());
RandomForestModel originalMdl = trainer.fit(sample, parts, new LabeledDummyVectorizer<>());
RandomForestModel updatedOnSameDS = trainer.update(originalMdl, sample, parts, new LabeledDummyVectorizer<>());
RandomForestModel updatedOnEmptyDS = trainer.update(originalMdl, new HashMap<Integer, LabeledVector<Double>>(), parts, new LabeledDummyVectorizer<>());
Vector v = VectorUtils.of(5, 0.5, 0.05, 0.005);
assertEquals(originalMdl.predict(v), updatedOnSameDS.predict(v), 0.01);
assertEquals(originalMdl.predict(v), updatedOnEmptyDS.predict(v), 0.01);
}
use of org.apache.ignite.ml.trainers.DatasetTrainer in project ignite by apache.
the class LearningEnvironmentTest method testRandomNumbersGenerator.
/**
* Test random number generator provided by {@link LearningEnvironment}.
* We test that:
* 1. Correct random generator is returned for each partition.
* 2. Its state is saved between compute calls (for this we do several iterations of compute).
*/
@Test
public void testRandomNumbersGenerator() {
// We make such builders that provide as functions returning partition index * iteration as random number generator nextInt
LearningEnvironmentBuilder envBuilder = TestUtils.testEnvBuilder().withRandomDependency(MockRandom::new);
int partitions = 10;
int iterations = 2;
DatasetTrainer<IgniteModel<Object, Vector>, Void> trainer = new DatasetTrainer<IgniteModel<Object, Vector>, Void>() {
/**
* {@inheritDoc}
*/
@Override
public <K, V> IgniteModel<Object, Vector> fitWithInitializedDeployingContext(DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
Dataset<EmptyContext, TestUtils.DataWrapper<Integer>> ds = datasetBuilder.build(envBuilder, new EmptyContextBuilder<>(), (PartitionDataBuilder<K, V, EmptyContext, TestUtils.DataWrapper<Integer>>) (env, upstreamData, upstreamDataSize, ctx) -> TestUtils.DataWrapper.of(env.partition()), envBuilder.buildForTrainer());
Vector v = null;
for (int iter = 0; iter < iterations; iter++) {
v = ds.compute((dw, env) -> VectorUtils.fill(-1, partitions).set(env.partition(), env.randomNumbersGenerator().nextInt()), (v1, v2) -> zipOverridingEmpty(v1, v2, -1));
}
return constantModel(v);
}
/**
* {@inheritDoc}
*/
@Override
public boolean isUpdateable(IgniteModel<Object, Vector> mdl) {
return false;
}
/**
* {@inheritDoc}
*/
@Override
protected <K, V> IgniteModel<Object, Vector> updateModel(IgniteModel<Object, Vector> mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
return null;
}
};
trainer.withEnvironmentBuilder(envBuilder);
IgniteModel<Object, Vector> mdl = trainer.fit(getCacheMock(partitions), partitions, null);
Vector exp = VectorUtils.zeroes(partitions);
for (int i = 0; i < partitions; i++) exp.set(i, i * iterations);
Vector res = mdl.predict(null);
assertEquals(exp, res);
}
Aggregations