use of org.apache.ignite.ml.preprocessing.Preprocessor in project ignite by apache.
the class KNNUtils method buildDataset.
/**
* Builds dataset.
*
* @param envBuilder Learning environment builder.
* @param datasetBuilder Dataset builder.
* @param vectorizer Upstream vectorizer.
* @return Dataset.
*/
@Nullable
public static <K, V, C extends Serializable> Dataset<EmptyContext, LabeledVectorSet<LabeledVector>> buildDataset(LearningEnvironmentBuilder envBuilder, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> vectorizer) {
LearningEnvironment environment = envBuilder.buildForTrainer();
environment.initDeployingContext(vectorizer);
PartitionDataBuilder<K, V, EmptyContext, LabeledVectorSet<LabeledVector>> partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap<>(vectorizer);
Dataset<EmptyContext, LabeledVectorSet<LabeledVector>> dataset = null;
if (datasetBuilder != null) {
dataset = datasetBuilder.build(envBuilder, (env, upstream, upstreamSize) -> new EmptyContext(), partDataBuilder, environment);
}
return dataset;
}
use of org.apache.ignite.ml.preprocessing.Preprocessor in project ignite by apache.
the class LinearRegressionSGDTrainer method updateModel.
/**
* {@inheritDoc}
*/
@Override
protected <K, V> LinearRegressionModel updateModel(LinearRegressionModel mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> extractor) {
assert updatesStgy != null;
IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, MLPArchitecture> archSupplier = dataset -> {
int cols = dataset.compute(data -> {
if (data.getFeatures() == null)
return null;
return data.getFeatures().length / data.getRows();
}, (a, b) -> {
if (a == null)
return b == null ? 0 : b;
if (b == null)
return a;
return b;
});
MLPArchitecture architecture = new MLPArchitecture(cols);
architecture = architecture.withAddedLayer(1, true, Activators.LINEAR);
return architecture;
};
MLPTrainer<?> trainer = new MLPTrainer<>(archSupplier, LossFunctions.MSE, updatesStgy, maxIterations, batchSize, locIterations, seed);
IgniteFunction<LabeledVector<Double>, LabeledVector<double[]>> func = lv -> new LabeledVector<>(lv.features(), new double[] { lv.label() });
PatchedPreprocessor<K, V, Double, double[]> patchedPreprocessor = new PatchedPreprocessor<>(func, extractor);
MultilayerPerceptron mlp = Optional.ofNullable(mdl).map(this::restoreMLPState).map(m -> trainer.update(m, datasetBuilder, patchedPreprocessor)).orElseGet(() -> trainer.fit(datasetBuilder, patchedPreprocessor));
double[] p = mlp.parameters().getStorage().data();
return new LinearRegressionModel(new DenseVector(Arrays.copyOf(p, p.length - 1)), p[p.length - 1]);
}
use of org.apache.ignite.ml.preprocessing.Preprocessor in project ignite by apache.
the class MinMaxScalerTrainer method fit.
/**
* {@inheritDoc}
*/
@Override
public MinMaxScalerPreprocessor<K, V> fit(LearningEnvironmentBuilder envBuilder, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> basePreprocessor) {
PartitionContextBuilder<K, V, EmptyContext> ctxBuilder = (env, upstream, upstreamSize) -> new EmptyContext();
try (Dataset<EmptyContext, MinMaxScalerPartitionData> dataset = datasetBuilder.build(envBuilder, ctxBuilder, (env, upstream, upstreamSize, ctx) -> {
double[] min = null;
double[] max = null;
while (upstream.hasNext()) {
UpstreamEntry<K, V> entity = upstream.next();
LabeledVector row = basePreprocessor.apply(entity.getKey(), entity.getValue());
if (min == null) {
min = new double[row.size()];
Arrays.fill(min, Double.MAX_VALUE);
} else
assert min.length == row.size() : "Base preprocessor must return exactly " + min.length + " features";
if (max == null) {
max = new double[row.size()];
Arrays.fill(max, -Double.MAX_VALUE);
} else
assert max.length == row.size() : "Base preprocessor must return exactly " + min.length + " features";
for (int i = 0; i < row.size(); i++) {
if (row.get(i) < min[i])
min[i] = row.get(i);
if (row.get(i) > max[i])
max[i] = row.get(i);
}
}
return new MinMaxScalerPartitionData(min, max);
}, learningEnvironment(basePreprocessor))) {
double[][] minMax = dataset.compute(data -> data.getMin() != null ? new double[][] { data.getMin(), data.getMax() } : null, (a, b) -> {
if (a == null)
return b;
if (b == null)
return a;
double[][] res = new double[2][];
res[0] = new double[a[0].length];
for (int i = 0; i < res[0].length; i++) res[0][i] = Math.min(a[0][i], b[0][i]);
res[1] = new double[a[1].length];
for (int i = 0; i < res[1].length; i++) res[1][i] = Math.max(a[1][i], b[1][i]);
return res;
});
return new MinMaxScalerPreprocessor<>(minMax[0], minMax[1], basePreprocessor);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
use of org.apache.ignite.ml.preprocessing.Preprocessor in project ignite by apache.
the class LogisticRegressionSGDTrainer method updateModel.
/**
* {@inheritDoc}
*/
@Override
protected <K, V> LogisticRegressionModel updateModel(LogisticRegressionModel mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> extractor) {
IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, MLPArchitecture> archSupplier = dataset -> {
Integer cols = dataset.compute(data -> {
if (data.getFeatures() == null)
return null;
return data.getFeatures().length / data.getRows();
}, (a, b) -> {
// If both are null then zero will be propagated, no good.
if (a == null)
return b;
return a;
});
if (cols == null)
throw new IllegalStateException("Cannot train on empty dataset");
MLPArchitecture architecture = new MLPArchitecture(cols);
architecture = architecture.withAddedLayer(1, true, Activators.SIGMOID);
return architecture;
};
MLPTrainer<?> trainer = new MLPTrainer<>(archSupplier, LossFunctions.L2, updatesStgy, maxIterations, batchSize, locIterations, seed).withEnvironmentBuilder(envBuilder);
MultilayerPerceptron mlp;
IgniteFunction<LabeledVector<Double>, LabeledVector<double[]>> func = lv -> new LabeledVector<>(lv.features(), new double[] { lv.label() });
PatchedPreprocessor<K, V, Double, double[]> patchedPreprocessor = new PatchedPreprocessor<>(func, extractor);
if (mdl != null) {
mlp = restoreMLPState(mdl);
mlp = trainer.update(mlp, datasetBuilder, patchedPreprocessor);
} else
mlp = trainer.fit(datasetBuilder, patchedPreprocessor);
double[] params = mlp.parameters().getStorage().data();
return new LogisticRegressionModel(new DenseVector(Arrays.copyOf(params, params.length - 1)), params[params.length - 1]);
}
use of org.apache.ignite.ml.preprocessing.Preprocessor in project ignite by apache.
the class Deltas method updateModel.
/**
* {@inheritDoc}
*/
@Override
protected <K, V> SVMLinearClassificationModel updateModel(SVMLinearClassificationModel mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
assert datasetBuilder != null;
IgniteFunction<Double, Double> lbTransformer = lb -> {
if (lb == 0.0)
return -1.0;
else
return lb;
};
IgniteFunction<LabeledVector<Double>, LabeledVector<Double>> func = lv -> new LabeledVector<>(lv.features(), lbTransformer.apply(lv.label()));
PatchedPreprocessor<K, V, Double, Double> patchedPreprocessor = new PatchedPreprocessor<>(func, preprocessor);
PartitionDataBuilder<K, V, EmptyContext, LabeledVectorSet<LabeledVector>> partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap<>(patchedPreprocessor);
Vector weights;
try (Dataset<EmptyContext, LabeledVectorSet<LabeledVector>> dataset = datasetBuilder.build(envBuilder, (env, upstream, upstreamSize) -> new EmptyContext(), partDataBuilder, learningEnvironment())) {
if (mdl == null) {
final int cols = dataset.compute(org.apache.ignite.ml.structures.Dataset::colSize, (a, b) -> {
if (a == null)
return b == null ? 0 : b;
if (b == null)
return a;
return b;
});
final int weightVectorSizeWithIntercept = cols + 1;
weights = initializeWeightsWithZeros(weightVectorSizeWithIntercept);
} else
weights = getStateVector(mdl);
for (int i = 0; i < this.getAmountOfIterations(); i++) {
Vector deltaWeights = calculateUpdates(weights, dataset);
if (deltaWeights == null)
return getLastTrainedModelOrThrowEmptyDatasetException(mdl);
// creates new vector
weights = weights.plus(deltaWeights);
}
} catch (Exception e) {
throw new RuntimeException(e);
}
return new SVMLinearClassificationModel(weights.copyOfRange(1, weights.size()), weights.get(0));
}
Aggregations