use of org.apache.ignite.ml.dataset.Dataset in project ignite by apache.
the class ANNClassificationTrainer method getCentroidStat.
/**
*/
private <K, V> CentroidStat getCentroidStat(DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> vectorizer, List<Vector> centers) {
PartitionDataBuilder<K, V, EmptyContext, LabeledVectorSet<LabeledVector>> partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap<>(vectorizer);
try (Dataset<EmptyContext, LabeledVectorSet<LabeledVector>> dataset = datasetBuilder.build(envBuilder, (env, upstream, upstreamSize) -> new EmptyContext(), partDataBuilder, learningEnvironment())) {
return dataset.compute(data -> {
CentroidStat res = new CentroidStat();
for (int i = 0; i < data.rowSize(); i++) {
final IgniteBiTuple<Integer, Double> closestCentroid = findClosestCentroid(centers, data.getRow(i));
int centroidIdx = closestCentroid.get1();
double lb = data.label(i);
// add new label to label set
res.labels().add(lb);
ConcurrentHashMap<Double, Integer> centroidStat = res.centroidStat.get(centroidIdx);
if (centroidStat == null) {
centroidStat = new ConcurrentHashMap<>();
centroidStat.put(lb, 1);
res.centroidStat.put(centroidIdx, centroidStat);
} else {
int cnt = centroidStat.getOrDefault(lb, 0);
centroidStat.put(lb, cnt + 1);
}
res.counts.merge(centroidIdx, 1, (IgniteBiFunction<Integer, Integer, Integer>) (i1, i2) -> i1 + i2);
}
return res;
}, (a, b) -> {
if (a == null)
return b == null ? new CentroidStat() : b;
if (b == null)
return a;
return a.merge(b);
});
} catch (Exception e) {
throw new RuntimeException(e);
}
}
use of org.apache.ignite.ml.dataset.Dataset in project ignite by apache.
the class KNNUtils method buildDataset.
/**
* Builds dataset.
*
* @param envBuilder Learning environment builder.
* @param datasetBuilder Dataset builder.
* @param vectorizer Upstream vectorizer.
* @return Dataset.
*/
@Nullable
public static <K, V, C extends Serializable> Dataset<EmptyContext, LabeledVectorSet<LabeledVector>> buildDataset(LearningEnvironmentBuilder envBuilder, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> vectorizer) {
LearningEnvironment environment = envBuilder.buildForTrainer();
environment.initDeployingContext(vectorizer);
PartitionDataBuilder<K, V, EmptyContext, LabeledVectorSet<LabeledVector>> partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap<>(vectorizer);
Dataset<EmptyContext, LabeledVectorSet<LabeledVector>> dataset = null;
if (datasetBuilder != null) {
dataset = datasetBuilder.build(envBuilder, (env, upstream, upstreamSize) -> new EmptyContext(), partDataBuilder, environment);
}
return dataset;
}
use of org.apache.ignite.ml.dataset.Dataset in project ignite by apache.
the class LinearRegressionSGDTrainer method updateModel.
/**
* {@inheritDoc}
*/
@Override
protected <K, V> LinearRegressionModel updateModel(LinearRegressionModel mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> extractor) {
assert updatesStgy != null;
IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, MLPArchitecture> archSupplier = dataset -> {
int cols = dataset.compute(data -> {
if (data.getFeatures() == null)
return null;
return data.getFeatures().length / data.getRows();
}, (a, b) -> {
if (a == null)
return b == null ? 0 : b;
if (b == null)
return a;
return b;
});
MLPArchitecture architecture = new MLPArchitecture(cols);
architecture = architecture.withAddedLayer(1, true, Activators.LINEAR);
return architecture;
};
MLPTrainer<?> trainer = new MLPTrainer<>(archSupplier, LossFunctions.MSE, updatesStgy, maxIterations, batchSize, locIterations, seed);
IgniteFunction<LabeledVector<Double>, LabeledVector<double[]>> func = lv -> new LabeledVector<>(lv.features(), new double[] { lv.label() });
PatchedPreprocessor<K, V, Double, double[]> patchedPreprocessor = new PatchedPreprocessor<>(func, extractor);
MultilayerPerceptron mlp = Optional.ofNullable(mdl).map(this::restoreMLPState).map(m -> trainer.update(m, datasetBuilder, patchedPreprocessor)).orElseGet(() -> trainer.fit(datasetBuilder, patchedPreprocessor));
double[] p = mlp.parameters().getStorage().data();
return new LinearRegressionModel(new DenseVector(Arrays.copyOf(p, p.length - 1)), p[p.length - 1]);
}
use of org.apache.ignite.ml.dataset.Dataset in project ignite by apache.
the class Evaluator method initEvaluationContexts.
/**
* Inits evaluation contexts for metrics.
*
* @param dataset Dataset.
* @param metrics Metrics.
* @return Computed contexts.
*/
@SuppressWarnings("unchecked")
private static Map<Class, EvaluationContext> initEvaluationContexts(Dataset<EmptyContext, FeatureMatrixWithLabelsOnHeapData> dataset, Metric... metrics) {
long nonEmptyCtxsCnt = Arrays.stream(metrics).map(x -> x.makeAggregator().createInitializedContext()).filter(x -> ((EvaluationContext) x).needToCompute()).count();
if (nonEmptyCtxsCnt == 0) {
HashMap<Class, EvaluationContext> res = new HashMap<>();
for (Metric m : metrics) {
MetricStatsAggregator<Double, ?, ?> aggregator = m.makeAggregator();
res.put(aggregator.getClass(), (EvaluationContext) m.makeAggregator().createInitializedContext());
return res;
}
}
return dataset.compute(data -> {
Map<Class, MetricStatsAggregator> aggrs = new HashMap<>();
for (Metric m : metrics) {
MetricStatsAggregator<Double, ?, ?> aggregator = m.makeAggregator();
if (!aggrs.containsKey(aggregator.getClass()))
aggrs.put(aggregator.getClass(), aggregator);
}
Map<Class, EvaluationContext> aggrToEvCtx = new HashMap<>();
aggrs.forEach((clazz, aggr) -> aggrToEvCtx.put(clazz, (EvaluationContext) aggr.createInitializedContext()));
for (int i = 0; i < data.getLabels().length; i++) {
LabeledVector<Double> vector = VectorUtils.of(data.getFeatures()[i]).labeled(data.getLabels()[i]);
aggrToEvCtx.values().forEach(ctx -> ctx.aggregate(vector));
}
return aggrToEvCtx;
}, (left, right) -> {
if (left == null && right == null)
return new HashMap<>();
if (left == null)
return right;
if (right == null)
return left;
HashMap<Class, EvaluationContext> res = new HashMap<>();
for (Class key : left.keySet()) {
EvaluationContext ctx1 = left.get(key);
EvaluationContext ctx2 = right.get(key);
A.ensure(ctx1 != null && ctx2 != null, "ctx1 != null && ctx2 != null");
res.put(key, ctx1.mergeWith(ctx2));
}
return res;
});
}
use of org.apache.ignite.ml.dataset.Dataset in project ignite by apache.
the class KMeansTrainer method updateModel.
/**
* {@inheritDoc}
*/
@Override
protected <K, V> KMeansModel updateModel(KMeansModel mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
assert datasetBuilder != null;
PartitionDataBuilder<K, V, EmptyContext, LabeledVectorSet<LabeledVector>> partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap<>(preprocessor);
Vector[] centers;
try (Dataset<EmptyContext, LabeledVectorSet<LabeledVector>> dataset = datasetBuilder.build(envBuilder, (env, upstream, upstreamSize) -> new EmptyContext(), partDataBuilder, learningEnvironment())) {
final Integer cols = dataset.compute(org.apache.ignite.ml.structures.Dataset::colSize, (a, b) -> {
if (a == null)
return b == null ? 0 : b;
if (b == null)
return a;
return b;
});
if (cols == null)
return getLastTrainedModelOrThrowEmptyDatasetException(mdl);
centers = Optional.ofNullable(mdl).map(KMeansModel::centers).orElseGet(() -> initClusterCentersRandomly(dataset, k));
boolean converged = false;
int iteration = 0;
while (iteration < maxIterations && !converged) {
Vector[] newCentroids = new DenseVector[k];
TotalCostAndCounts totalRes = calcDataForNewCentroids(centers, dataset, cols);
converged = true;
for (Map.Entry<Integer, Vector> entry : totalRes.sums.entrySet()) {
Vector massCenter = entry.getValue().times(1.0 / totalRes.counts.get(entry.getKey()));
if (converged && distance.compute(massCenter, centers[entry.getKey()]) > epsilon * epsilon)
converged = false;
newCentroids[entry.getKey()] = massCenter;
}
iteration++;
for (int i = 0; i < centers.length; i++) {
if (newCentroids[i] != null)
centers[i] = newCentroids[i];
}
}
} catch (Exception e) {
throw new RuntimeException(e);
}
return new KMeansModel(centers, distance);
}
Aggregations