use of org.apache.ignite.ml.dataset.Dataset in project ignite by apache.
the class LogisticRegressionSGDTrainer method updateModel.
/**
* {@inheritDoc}
*/
@Override
protected <K, V> LogisticRegressionModel updateModel(LogisticRegressionModel mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> extractor) {
IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, MLPArchitecture> archSupplier = dataset -> {
Integer cols = dataset.compute(data -> {
if (data.getFeatures() == null)
return null;
return data.getFeatures().length / data.getRows();
}, (a, b) -> {
// If both are null then zero will be propagated, no good.
if (a == null)
return b;
return a;
});
if (cols == null)
throw new IllegalStateException("Cannot train on empty dataset");
MLPArchitecture architecture = new MLPArchitecture(cols);
architecture = architecture.withAddedLayer(1, true, Activators.SIGMOID);
return architecture;
};
MLPTrainer<?> trainer = new MLPTrainer<>(archSupplier, LossFunctions.L2, updatesStgy, maxIterations, batchSize, locIterations, seed).withEnvironmentBuilder(envBuilder);
MultilayerPerceptron mlp;
IgniteFunction<LabeledVector<Double>, LabeledVector<double[]>> func = lv -> new LabeledVector<>(lv.features(), new double[] { lv.label() });
PatchedPreprocessor<K, V, Double, double[]> patchedPreprocessor = new PatchedPreprocessor<>(func, extractor);
if (mdl != null) {
mlp = restoreMLPState(mdl);
mlp = trainer.update(mlp, datasetBuilder, patchedPreprocessor);
} else
mlp = trainer.fit(datasetBuilder, patchedPreprocessor);
double[] params = mlp.parameters().getStorage().data();
return new LogisticRegressionModel(new DenseVector(Arrays.copyOf(params, params.length - 1)), params[params.length - 1]);
}
use of org.apache.ignite.ml.dataset.Dataset in project ignite by apache.
the class Deltas method updateModel.
/**
* {@inheritDoc}
*/
@Override
protected <K, V> SVMLinearClassificationModel updateModel(SVMLinearClassificationModel mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
assert datasetBuilder != null;
IgniteFunction<Double, Double> lbTransformer = lb -> {
if (lb == 0.0)
return -1.0;
else
return lb;
};
IgniteFunction<LabeledVector<Double>, LabeledVector<Double>> func = lv -> new LabeledVector<>(lv.features(), lbTransformer.apply(lv.label()));
PatchedPreprocessor<K, V, Double, Double> patchedPreprocessor = new PatchedPreprocessor<>(func, preprocessor);
PartitionDataBuilder<K, V, EmptyContext, LabeledVectorSet<LabeledVector>> partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap<>(patchedPreprocessor);
Vector weights;
try (Dataset<EmptyContext, LabeledVectorSet<LabeledVector>> dataset = datasetBuilder.build(envBuilder, (env, upstream, upstreamSize) -> new EmptyContext(), partDataBuilder, learningEnvironment())) {
if (mdl == null) {
final int cols = dataset.compute(org.apache.ignite.ml.structures.Dataset::colSize, (a, b) -> {
if (a == null)
return b == null ? 0 : b;
if (b == null)
return a;
return b;
});
final int weightVectorSizeWithIntercept = cols + 1;
weights = initializeWeightsWithZeros(weightVectorSizeWithIntercept);
} else
weights = getStateVector(mdl);
for (int i = 0; i < this.getAmountOfIterations(); i++) {
Vector deltaWeights = calculateUpdates(weights, dataset);
if (deltaWeights == null)
return getLastTrainedModelOrThrowEmptyDatasetException(mdl);
// creates new vector
weights = weights.plus(deltaWeights);
}
} catch (Exception e) {
throw new RuntimeException(e);
}
return new SVMLinearClassificationModel(weights.copyOfRange(1, weights.size()), weights.get(0));
}
use of org.apache.ignite.ml.dataset.Dataset in project ignite by apache.
the class ImputerTrainer method fit.
/**
* {@inheritDoc}
*/
@Override
public ImputerPreprocessor<K, V> fit(LearningEnvironmentBuilder envBuilder, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> basePreprocessor) {
PartitionContextBuilder<K, V, EmptyContext> builder = (env, upstream, upstreamSize) -> new EmptyContext();
try (Dataset<EmptyContext, ImputerPartitionData> dataset = datasetBuilder.build(envBuilder, builder, (env, upstream, upstreamSize, ctx) -> {
double[] sums = null;
int[] counts = null;
double[] maxs = null;
double[] mins = null;
Map<Double, Integer>[] valuesByFreq = null;
while (upstream.hasNext()) {
UpstreamEntry<K, V> entity = upstream.next();
LabeledVector row = basePreprocessor.apply(entity.getKey(), entity.getValue());
switch(imputingStgy) {
case MEAN:
sums = updateTheSums(row, sums);
counts = updateTheCounts(row, counts);
break;
case MOST_FREQUENT:
valuesByFreq = updateFrequenciesByGivenRow(row, valuesByFreq);
break;
case LEAST_FREQUENT:
valuesByFreq = updateFrequenciesByGivenRow(row, valuesByFreq);
break;
case MAX:
maxs = updateTheMaxs(row, maxs);
break;
case MIN:
mins = updateTheMins(row, mins);
break;
case COUNT:
counts = updateTheCounts(row, counts);
break;
default:
throw new UnsupportedOperationException("The chosen strategy is not supported");
}
}
ImputerPartitionData partData;
switch(imputingStgy) {
case MEAN:
partData = new ImputerPartitionData().withSums(sums).withCounts(counts);
break;
case MOST_FREQUENT:
partData = new ImputerPartitionData().withValuesByFrequency(valuesByFreq);
break;
case LEAST_FREQUENT:
partData = new ImputerPartitionData().withValuesByFrequency(valuesByFreq);
break;
case MAX:
partData = new ImputerPartitionData().withMaxs(maxs);
break;
case MIN:
partData = new ImputerPartitionData().withMins(mins);
break;
case COUNT:
partData = new ImputerPartitionData().withCounts(counts);
break;
default:
throw new UnsupportedOperationException("The chosen strategy is not supported");
}
return partData;
}, learningEnvironment(basePreprocessor))) {
Vector imputingValues;
switch(imputingStgy) {
case MEAN:
imputingValues = VectorUtils.of(calculateImputingValuesBySumsAndCounts(dataset));
break;
case MOST_FREQUENT:
imputingValues = VectorUtils.of(calculateImputingValuesByTheMostFrequentValues(dataset));
break;
case LEAST_FREQUENT:
imputingValues = VectorUtils.of(calculateImputingValuesByTheLeastFrequentValues(dataset));
break;
case MAX:
imputingValues = VectorUtils.of(calculateImputingValuesByMaxValues(dataset));
break;
case MIN:
imputingValues = VectorUtils.of(calculateImputingValuesByMinValues(dataset));
break;
case COUNT:
imputingValues = VectorUtils.of(calculateImputingValuesByCounts(dataset));
break;
default:
throw new UnsupportedOperationException("The chosen strategy is not supported");
}
return new ImputerPreprocessor<>(imputingValues, basePreprocessor);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
use of org.apache.ignite.ml.dataset.Dataset in project ignite by apache.
the class ImpurityHistogramsComputer method aggregateImpurityStatisticsOnPartition.
/**
* Aggregates statistics for impurity computing for each corner nodes for each trees in random forest. This
* algorithm predict corner node in decision tree for learning vector and stocks it to correspond histogram.
*
* @param dataset Dataset.
* @param roots Trees.
* @param histMeta Histogram buckets meta.
* @param part Partition.
* @return Leaf statistics for impurity computing.
*/
private Map<NodeId, NodeImpurityHistograms<S>> aggregateImpurityStatisticsOnPartition(BootstrappedDatasetPartition dataset, ArrayList<RandomForestTreeModel> roots, Map<Integer, BucketMeta> histMeta, Map<NodeId, TreeNode> part) {
Map<NodeId, NodeImpurityHistograms<S>> res = part.keySet().stream().collect(Collectors.toMap(n -> n, NodeImpurityHistograms::new));
dataset.forEach(vector -> {
for (int sampleId = 0; sampleId < vector.counters().length; sampleId++) {
if (vector.counters()[sampleId] == 0)
continue;
RandomForestTreeModel root = roots.get(sampleId);
NodeId key = root.getRootNode().predictNextNodeKey(vector.features());
if (// if we didn't take all nodes from learning queue
!part.containsKey(key))
continue;
NodeImpurityHistograms<S> statistics = res.get(key);
for (Integer featureId : root.getUsedFeatures()) {
BucketMeta meta = histMeta.get(featureId);
if (!statistics.perFeatureStatistics.containsKey(featureId))
statistics.perFeatureStatistics.put(featureId, createImpurityComputerForFeature(sampleId, meta));
S impurityComputer = statistics.perFeatureStatistics.get(featureId);
impurityComputer.addElement(vector);
}
}
});
return res;
}
Aggregations