use of org.apache.ignite.ml.dataset.primitive.context.EmptyContext in project ignite by apache.
the class EncoderTrainer method fit.
/**
* {@inheritDoc}
*/
@Override
public EncoderPreprocessor<K, V> fit(LearningEnvironmentBuilder envBuilder, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> basePreprocessor) {
if (handledIndices.isEmpty() && encoderType != EncoderType.LABEL_ENCODER)
throw new RuntimeException("Add indices of handled features");
try (Dataset<EmptyContext, EncoderPartitionData> dataset = datasetBuilder.build(envBuilder, (env, upstream, upstreamSize) -> new EmptyContext(), (env, upstream, upstreamSize, ctx) -> {
EncoderPartitionData partData = new EncoderPartitionData();
if (encoderType == EncoderType.LABEL_ENCODER) {
Map<String, Integer> lbFrequencies = null;
while (upstream.hasNext()) {
UpstreamEntry<K, V> entity = upstream.next();
LabeledVector<Double> row = basePreprocessor.apply(entity.getKey(), entity.getValue());
lbFrequencies = updateLabelFrequenciesForNextRow(row, lbFrequencies);
}
partData.withLabelFrequencies(lbFrequencies);
} else if (encoderType == EncoderType.TARGET_ENCODER) {
TargetCounter[] targetCounter = null;
while (upstream.hasNext()) {
UpstreamEntry<K, V> entity = upstream.next();
LabeledVector<Double> row = basePreprocessor.apply(entity.getKey(), entity.getValue());
targetCounter = updateTargetCountersForNextRow(row, targetCounter);
}
partData.withTargetCounters(targetCounter);
} else {
// This array will contain not null values for handled indices
Map<String, Integer>[] categoryFrequencies = null;
while (upstream.hasNext()) {
UpstreamEntry<K, V> entity = upstream.next();
LabeledVector<Double> row = basePreprocessor.apply(entity.getKey(), entity.getValue());
categoryFrequencies = updateFeatureFrequenciesForNextRow(row, categoryFrequencies);
}
partData.withCategoryFrequencies(categoryFrequencies);
}
return partData;
}, learningEnvironment(basePreprocessor))) {
switch(encoderType) {
case ONE_HOT_ENCODER:
return new OneHotEncoderPreprocessor<>(calculateEncodingValuesByFrequencies(dataset), basePreprocessor, handledIndices);
case STRING_ENCODER:
return new StringEncoderPreprocessor<>(calculateEncodingValuesByFrequencies(dataset), basePreprocessor, handledIndices);
case LABEL_ENCODER:
return new LabelEncoderPreprocessor<>(calculateEncodingValuesForLabelsByFrequencies(dataset), basePreprocessor);
case FREQUENCY_ENCODER:
return new FrequencyEncoderPreprocessor<>(calculateEncodingFrequencies(dataset), basePreprocessor, handledIndices);
case TARGET_ENCODER:
return new TargetEncoderPreprocessor<>(calculateTargetEncodingFrequencies(dataset), basePreprocessor, handledIndices);
default:
throw new IllegalStateException("Define the type of the resulting prerocessor.");
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}
use of org.apache.ignite.ml.dataset.primitive.context.EmptyContext in project ignite by apache.
the class GaussianNaiveBayesTrainer method updateModel.
/**
* {@inheritDoc}
*/
@Override
protected <K, V> GaussianNaiveBayesModel updateModel(GaussianNaiveBayesModel mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> extractor) {
assert datasetBuilder != null;
try (Dataset<EmptyContext, GaussianNaiveBayesSumsHolder> dataset = datasetBuilder.build(envBuilder, (env, upstream, upstreamSize) -> new EmptyContext(), (env, upstream, upstreamSize, ctx) -> {
GaussianNaiveBayesSumsHolder res = new GaussianNaiveBayesSumsHolder();
while (upstream.hasNext()) {
UpstreamEntry<K, V> entity = upstream.next();
LabeledVector lv = extractor.apply(entity.getKey(), entity.getValue());
Vector features = lv.features();
Double label = (Double) lv.label();
double[] toMeans;
double[] sqSum;
if (!res.featureSumsPerLbl.containsKey(label)) {
toMeans = new double[features.size()];
Arrays.fill(toMeans, 0.);
res.featureSumsPerLbl.put(label, toMeans);
}
if (!res.featureSquaredSumsPerLbl.containsKey(label)) {
sqSum = new double[features.size()];
res.featureSquaredSumsPerLbl.put(label, sqSum);
}
if (!res.featureCountersPerLbl.containsKey(label))
res.featureCountersPerLbl.put(label, 0);
res.featureCountersPerLbl.put(label, res.featureCountersPerLbl.get(label) + 1);
toMeans = res.featureSumsPerLbl.get(label);
sqSum = res.featureSquaredSumsPerLbl.get(label);
for (int j = 0; j < features.size(); j++) {
double x = features.get(j);
toMeans[j] += x;
sqSum[j] += x * x;
}
}
return res;
}, learningEnvironment())) {
GaussianNaiveBayesSumsHolder sumsHolder = dataset.compute(t -> t, (a, b) -> {
if (a == null)
return b;
if (b == null)
return a;
return a.merge(b);
});
if (mdl != null && mdl.getSumsHolder() != null)
sumsHolder = sumsHolder.merge(mdl.getSumsHolder());
List<Double> sortedLabels = new ArrayList<>(sumsHolder.featureCountersPerLbl.keySet());
sortedLabels.sort(Double::compareTo);
assert !sortedLabels.isEmpty() : "The dataset should contain at least one feature";
int labelCount = sortedLabels.size();
int featureCount = sumsHolder.featureSumsPerLbl.get(sortedLabels.get(0)).length;
double[][] means = new double[labelCount][featureCount];
double[][] variances = new double[labelCount][featureCount];
double[] classProbabilities = new double[labelCount];
double[] labels = new double[labelCount];
long datasetSize = sumsHolder.featureCountersPerLbl.values().stream().mapToInt(i -> i).sum();
int lbl = 0;
for (Double label : sortedLabels) {
int count = sumsHolder.featureCountersPerLbl.get(label);
double[] sum = sumsHolder.featureSumsPerLbl.get(label);
double[] sqSum = sumsHolder.featureSquaredSumsPerLbl.get(label);
for (int i = 0; i < featureCount; i++) {
means[lbl][i] = sum[i] / count;
variances[lbl][i] = (sqSum[i] - sum[i] * sum[i] / count) / count;
}
if (equiprobableClasses)
classProbabilities[lbl] = 1. / labelCount;
else if (priorProbabilities != null) {
assert classProbabilities.length == priorProbabilities.length;
classProbabilities[lbl] = priorProbabilities[lbl];
} else
classProbabilities[lbl] = (double) count / datasetSize;
labels[lbl] = label;
++lbl;
}
return new GaussianNaiveBayesModel(means, variances, classProbabilities, labels, sumsHolder);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
use of org.apache.ignite.ml.dataset.primitive.context.EmptyContext in project ignite by apache.
the class MLPTrainer method updateModel.
/**
* {@inheritDoc}
*/
@Override
protected <K, V> MultilayerPerceptron updateModel(MultilayerPerceptron lastLearnedMdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> extractor) {
assert archSupplier != null;
assert loss != null;
assert updatesStgy != null;
try (Dataset<EmptyContext, SimpleLabeledDatasetData> dataset = datasetBuilder.build(envBuilder, new EmptyContextBuilder<>(), new SimpleLabeledDatasetDataBuilder<>(extractor), learningEnvironment())) {
MultilayerPerceptron mdl;
if (lastLearnedMdl != null)
mdl = lastLearnedMdl;
else {
MLPArchitecture arch = archSupplier.apply(dataset);
mdl = new MultilayerPerceptron(arch, new RandomInitializer(seed));
}
ParameterUpdateCalculator<? super MultilayerPerceptron, P> updater = updatesStgy.getUpdatesCalculator();
for (int i = 0; i < maxIterations; i += locIterations) {
MultilayerPerceptron finalMdl = mdl;
int finalI = i;
List<P> totUp = dataset.compute(data -> {
P update = updater.init(finalMdl, loss);
MultilayerPerceptron mlp = Utils.copy(finalMdl);
if (data.getFeatures() != null) {
List<P> updates = new ArrayList<>();
for (int locStep = 0; locStep < locIterations; locStep++) {
int[] rows = Utils.selectKDistinct(data.getRows(), Math.min(batchSize, data.getRows()), new Random(seed ^ (finalI * locStep)));
double[] inputsBatch = batch(data.getFeatures(), rows, data.getRows());
double[] groundTruthBatch = batch(data.getLabels(), rows, data.getRows());
Matrix inputs = new DenseMatrix(inputsBatch, rows.length, 0);
Matrix groundTruth = new DenseMatrix(groundTruthBatch, rows.length, 0);
update = updater.calculateNewUpdate(mlp, update, locStep, inputs.transpose(), groundTruth.transpose());
mlp = updater.update(mlp, update);
updates.add(update);
}
List<P> res = new ArrayList<>();
res.add(updatesStgy.locStepUpdatesReducer().apply(updates));
return res;
}
return null;
}, (a, b) -> {
if (a == null)
return b;
else if (b == null)
return a;
else {
a.addAll(b);
return a;
}
});
if (totUp == null)
return getLastTrainedModelOrThrowEmptyDatasetException(lastLearnedMdl);
P update = updatesStgy.allUpdatesReducer().apply(totUp);
mdl = updater.update(mdl, update);
}
return mdl;
} catch (Exception e) {
throw new RuntimeException(e);
}
}
use of org.apache.ignite.ml.dataset.primitive.context.EmptyContext in project ignite by apache.
the class DiscreteNaiveBayesTrainer method updateModel.
/**
* {@inheritDoc}
*/
@Override
protected <K, V> DiscreteNaiveBayesModel updateModel(DiscreteNaiveBayesModel mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> extractor) {
try (Dataset<EmptyContext, DiscreteNaiveBayesSumsHolder> dataset = datasetBuilder.build(envBuilder, (env, upstream, upstreamSize) -> new EmptyContext(), (env, upstream, upstreamSize, ctx) -> {
DiscreteNaiveBayesSumsHolder res = new DiscreteNaiveBayesSumsHolder();
while (upstream.hasNext()) {
UpstreamEntry<K, V> entity = upstream.next();
LabeledVector lv = extractor.apply(entity.getKey(), entity.getValue());
Vector features = lv.features();
Double lb = (Double) lv.label();
long[][] valuesInBucket;
int size = features.size();
if (!res.valuesInBucketPerLbl.containsKey(lb)) {
valuesInBucket = new long[size][];
for (int i = 0; i < size; i++) {
valuesInBucket[i] = new long[bucketThresholds[i].length + 1];
Arrays.fill(valuesInBucket[i], 0L);
}
res.valuesInBucketPerLbl.put(lb, valuesInBucket);
}
if (!res.featureCountersPerLbl.containsKey(lb))
res.featureCountersPerLbl.put(lb, 0);
res.featureCountersPerLbl.put(lb, res.featureCountersPerLbl.get(lb) + 1);
valuesInBucket = res.valuesInBucketPerLbl.get(lb);
for (int j = 0; j < size; j++) {
double x = features.get(j);
int bucketNum = toBucketNumber(x, bucketThresholds[j]);
valuesInBucket[j][bucketNum] += 1;
}
}
return res;
}, learningEnvironment())) {
DiscreteNaiveBayesSumsHolder sumsHolder = dataset.compute(t -> t, (a, b) -> {
if (a == null)
return b;
if (b == null)
return a;
return a.merge(b);
});
if (mdl != null && isUpdateable(mdl)) {
if (checkSumsHolder(sumsHolder, mdl.getSumsHolder()))
sumsHolder = sumsHolder.merge(mdl.getSumsHolder());
}
List<Double> sortedLabels = new ArrayList<>(sumsHolder.featureCountersPerLbl.keySet());
sortedLabels.sort(Double::compareTo);
assert !sortedLabels.isEmpty() : "The dataset should contain at least one feature";
int lbCnt = sortedLabels.size();
int featureCnt = sumsHolder.valuesInBucketPerLbl.get(sortedLabels.get(0)).length;
double[][][] probabilities = new double[lbCnt][featureCnt][];
double[] classProbabilities = new double[lbCnt];
double[] labels = new double[lbCnt];
long datasetSize = sumsHolder.featureCountersPerLbl.values().stream().mapToInt(i -> i).sum();
int lbl = 0;
for (Double label : sortedLabels) {
int cnt = sumsHolder.featureCountersPerLbl.get(label);
long[][] sum = sumsHolder.valuesInBucketPerLbl.get(label);
for (int i = 0; i < featureCnt; i++) {
int bucketsCnt = sum[i].length;
probabilities[lbl][i] = new double[bucketsCnt];
for (int j = 0; j < bucketsCnt; j++) probabilities[lbl][i][j] = (double) sum[i][j] / cnt;
}
if (equiprobableClasses)
classProbabilities[lbl] = 1. / lbCnt;
else if (priorProbabilities != null) {
assert classProbabilities.length == priorProbabilities.length;
classProbabilities[lbl] = priorProbabilities[lbl];
} else
classProbabilities[lbl] = (double) cnt / datasetSize;
labels[lbl] = label;
++lbl;
}
return new DiscreteNaiveBayesModel(probabilities, classProbabilities, labels, bucketThresholds, sumsHolder);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
use of org.apache.ignite.ml.dataset.primitive.context.EmptyContext in project ignite by apache.
the class MaxAbsScalerTrainer method fit.
/**
* {@inheritDoc}
*/
@Override
public MaxAbsScalerPreprocessor<K, V> fit(LearningEnvironmentBuilder envBuilder, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> basePreprocessor) {
try (Dataset<EmptyContext, MaxAbsScalerPartitionData> dataset = datasetBuilder.build(envBuilder, (env, upstream, upstreamSize) -> new EmptyContext(), (env, upstream, upstreamSize, ctx) -> {
double[] maxAbs = null;
while (upstream.hasNext()) {
UpstreamEntry<K, V> entity = upstream.next();
LabeledVector row = basePreprocessor.apply(entity.getKey(), entity.getValue());
if (maxAbs == null) {
maxAbs = new double[row.size()];
Arrays.fill(maxAbs, .0);
} else
assert maxAbs.length == row.size() : "Base preprocessor must return exactly " + maxAbs.length + " features";
for (int i = 0; i < row.size(); i++) {
if (Math.abs(row.get(i)) > Math.abs(maxAbs[i]))
maxAbs[i] = Math.abs(row.get(i));
}
}
return new MaxAbsScalerPartitionData(maxAbs);
}, learningEnvironment(basePreprocessor))) {
double[] maxAbs = dataset.compute(MaxAbsScalerPartitionData::getMaxAbs, (a, b) -> {
if (a == null)
return b;
if (b == null)
return a;
double[] res = new double[a.length];
for (int i = 0; i < res.length; i++) res[i] = Math.max(Math.abs(a[i]), Math.abs(b[i]));
return res;
});
return new MaxAbsScalerPreprocessor<>(maxAbs, basePreprocessor);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
Aggregations