use of org.apache.ignite.ml.preprocessing.Preprocessor in project ignite by apache.
the class ImputerTrainer method fit.
/**
* {@inheritDoc}
*/
@Override
public ImputerPreprocessor<K, V> fit(LearningEnvironmentBuilder envBuilder, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> basePreprocessor) {
PartitionContextBuilder<K, V, EmptyContext> builder = (env, upstream, upstreamSize) -> new EmptyContext();
try (Dataset<EmptyContext, ImputerPartitionData> dataset = datasetBuilder.build(envBuilder, builder, (env, upstream, upstreamSize, ctx) -> {
double[] sums = null;
int[] counts = null;
double[] maxs = null;
double[] mins = null;
Map<Double, Integer>[] valuesByFreq = null;
while (upstream.hasNext()) {
UpstreamEntry<K, V> entity = upstream.next();
LabeledVector row = basePreprocessor.apply(entity.getKey(), entity.getValue());
switch(imputingStgy) {
case MEAN:
sums = updateTheSums(row, sums);
counts = updateTheCounts(row, counts);
break;
case MOST_FREQUENT:
valuesByFreq = updateFrequenciesByGivenRow(row, valuesByFreq);
break;
case LEAST_FREQUENT:
valuesByFreq = updateFrequenciesByGivenRow(row, valuesByFreq);
break;
case MAX:
maxs = updateTheMaxs(row, maxs);
break;
case MIN:
mins = updateTheMins(row, mins);
break;
case COUNT:
counts = updateTheCounts(row, counts);
break;
default:
throw new UnsupportedOperationException("The chosen strategy is not supported");
}
}
ImputerPartitionData partData;
switch(imputingStgy) {
case MEAN:
partData = new ImputerPartitionData().withSums(sums).withCounts(counts);
break;
case MOST_FREQUENT:
partData = new ImputerPartitionData().withValuesByFrequency(valuesByFreq);
break;
case LEAST_FREQUENT:
partData = new ImputerPartitionData().withValuesByFrequency(valuesByFreq);
break;
case MAX:
partData = new ImputerPartitionData().withMaxs(maxs);
break;
case MIN:
partData = new ImputerPartitionData().withMins(mins);
break;
case COUNT:
partData = new ImputerPartitionData().withCounts(counts);
break;
default:
throw new UnsupportedOperationException("The chosen strategy is not supported");
}
return partData;
}, learningEnvironment(basePreprocessor))) {
Vector imputingValues;
switch(imputingStgy) {
case MEAN:
imputingValues = VectorUtils.of(calculateImputingValuesBySumsAndCounts(dataset));
break;
case MOST_FREQUENT:
imputingValues = VectorUtils.of(calculateImputingValuesByTheMostFrequentValues(dataset));
break;
case LEAST_FREQUENT:
imputingValues = VectorUtils.of(calculateImputingValuesByTheLeastFrequentValues(dataset));
break;
case MAX:
imputingValues = VectorUtils.of(calculateImputingValuesByMaxValues(dataset));
break;
case MIN:
imputingValues = VectorUtils.of(calculateImputingValuesByMinValues(dataset));
break;
case COUNT:
imputingValues = VectorUtils.of(calculateImputingValuesByCounts(dataset));
break;
default:
throw new UnsupportedOperationException("The chosen strategy is not supported");
}
return new ImputerPreprocessor<>(imputingValues, basePreprocessor);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
use of org.apache.ignite.ml.preprocessing.Preprocessor in project ignite by apache.
the class MLDeployingTest method createPreprocessor.
/**
*/
private Preprocessor<Integer, Vector> createPreprocessor(Preprocessor<Integer, Vector> basePreprocessor, String clsName) throws Exception {
ClassLoader ldr = getExternalClassLoader();
Class<?> clazz = ldr.loadClass(clsName);
Constructor ctor = clazz.getConstructor(Preprocessor.class);
return (Preprocessor<Integer, Vector>) ctor.newInstance(basePreprocessor);
}
use of org.apache.ignite.ml.preprocessing.Preprocessor in project ignite by apache.
the class LearningEnvironmentTest method testRandomNumbersGenerator.
/**
* Test random number generator provided by {@link LearningEnvironment}.
* We test that:
* 1. Correct random generator is returned for each partition.
* 2. Its state is saved between compute calls (for this we do several iterations of compute).
*/
@Test
public void testRandomNumbersGenerator() {
// We make such builders that provide as functions returning partition index * iteration as random number generator nextInt
LearningEnvironmentBuilder envBuilder = TestUtils.testEnvBuilder().withRandomDependency(MockRandom::new);
int partitions = 10;
int iterations = 2;
DatasetTrainer<IgniteModel<Object, Vector>, Void> trainer = new DatasetTrainer<IgniteModel<Object, Vector>, Void>() {
/**
* {@inheritDoc}
*/
@Override
public <K, V> IgniteModel<Object, Vector> fitWithInitializedDeployingContext(DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
Dataset<EmptyContext, TestUtils.DataWrapper<Integer>> ds = datasetBuilder.build(envBuilder, new EmptyContextBuilder<>(), (PartitionDataBuilder<K, V, EmptyContext, TestUtils.DataWrapper<Integer>>) (env, upstreamData, upstreamDataSize, ctx) -> TestUtils.DataWrapper.of(env.partition()), envBuilder.buildForTrainer());
Vector v = null;
for (int iter = 0; iter < iterations; iter++) {
v = ds.compute((dw, env) -> VectorUtils.fill(-1, partitions).set(env.partition(), env.randomNumbersGenerator().nextInt()), (v1, v2) -> zipOverridingEmpty(v1, v2, -1));
}
return constantModel(v);
}
/**
* {@inheritDoc}
*/
@Override
public boolean isUpdateable(IgniteModel<Object, Vector> mdl) {
return false;
}
/**
* {@inheritDoc}
*/
@Override
protected <K, V> IgniteModel<Object, Vector> updateModel(IgniteModel<Object, Vector> mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
return null;
}
};
trainer.withEnvironmentBuilder(envBuilder);
IgniteModel<Object, Vector> mdl = trainer.fit(getCacheMock(partitions), partitions, null);
Vector exp = VectorUtils.zeroes(partitions);
for (int i = 0; i < partitions; i++) exp.set(i, i * iterations);
Vector res = mdl.predict(null);
assertEquals(exp, res);
}
Aggregations