use of org.apache.ignite.ml.dataset.primitive.context.EmptyContext in project ignite by apache.
the class MinMaxScalerTrainer method fit.
/**
* {@inheritDoc}
*/
@Override
public MinMaxScalerPreprocessor<K, V> fit(LearningEnvironmentBuilder envBuilder, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> basePreprocessor) {
PartitionContextBuilder<K, V, EmptyContext> ctxBuilder = (env, upstream, upstreamSize) -> new EmptyContext();
try (Dataset<EmptyContext, MinMaxScalerPartitionData> dataset = datasetBuilder.build(envBuilder, ctxBuilder, (env, upstream, upstreamSize, ctx) -> {
double[] min = null;
double[] max = null;
while (upstream.hasNext()) {
UpstreamEntry<K, V> entity = upstream.next();
LabeledVector row = basePreprocessor.apply(entity.getKey(), entity.getValue());
if (min == null) {
min = new double[row.size()];
Arrays.fill(min, Double.MAX_VALUE);
} else
assert min.length == row.size() : "Base preprocessor must return exactly " + min.length + " features";
if (max == null) {
max = new double[row.size()];
Arrays.fill(max, -Double.MAX_VALUE);
} else
assert max.length == row.size() : "Base preprocessor must return exactly " + min.length + " features";
for (int i = 0; i < row.size(); i++) {
if (row.get(i) < min[i])
min[i] = row.get(i);
if (row.get(i) > max[i])
max[i] = row.get(i);
}
}
return new MinMaxScalerPartitionData(min, max);
}, learningEnvironment(basePreprocessor))) {
double[][] minMax = dataset.compute(data -> data.getMin() != null ? new double[][] { data.getMin(), data.getMax() } : null, (a, b) -> {
if (a == null)
return b;
if (b == null)
return a;
double[][] res = new double[2][];
res[0] = new double[a[0].length];
for (int i = 0; i < res[0].length; i++) res[0][i] = Math.min(a[0][i], b[0][i]);
res[1] = new double[a[1].length];
for (int i = 0; i < res[1].length; i++) res[1][i] = Math.max(a[1][i], b[1][i]);
return res;
});
return new MinMaxScalerPreprocessor<>(minMax[0], minMax[1], basePreprocessor);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
use of org.apache.ignite.ml.dataset.primitive.context.EmptyContext in project ignite by apache.
the class StandardScalerTrainer method computeSum.
/**
* Computes sum, squared sum and row count.
*/
private StandardScalerData computeSum(LearningEnvironmentBuilder envBuilder, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> basePreprocessor) {
try (Dataset<EmptyContext, StandardScalerData> dataset = datasetBuilder.build(envBuilder, (env, upstream, upstreamSize) -> new EmptyContext(), (env, upstream, upstreamSize, ctx) -> {
double[] sum = null;
double[] squaredSum = null;
long cnt = 0;
while (upstream.hasNext()) {
UpstreamEntry<K, V> entity = upstream.next();
Vector row = basePreprocessor.apply(entity.getKey(), entity.getValue()).features();
if (sum == null) {
sum = new double[row.size()];
squaredSum = new double[row.size()];
} else {
assert sum.length == row.size() : "Base preprocessor must return exactly " + sum.length + " features";
}
++cnt;
for (int i = 0; i < row.size(); i++) {
double x = row.get(i);
sum[i] += x;
squaredSum[i] += x * x;
}
}
return new StandardScalerData(sum, squaredSum, cnt);
}, learningEnvironment(basePreprocessor))) {
return dataset.compute(data -> data, (a, b) -> {
if (a == null)
return b;
if (b == null)
return a;
return a.merge(b);
});
} catch (Exception e) {
throw new RuntimeException(e);
}
}
use of org.apache.ignite.ml.dataset.primitive.context.EmptyContext in project ignite by apache.
the class Evaluator method initEvaluationContexts.
/**
* Inits evaluation contexts for metrics.
*
* @param dataset Dataset.
* @param metrics Metrics.
* @return Computed contexts.
*/
@SuppressWarnings("unchecked")
private static Map<Class, EvaluationContext> initEvaluationContexts(Dataset<EmptyContext, FeatureMatrixWithLabelsOnHeapData> dataset, Metric... metrics) {
long nonEmptyCtxsCnt = Arrays.stream(metrics).map(x -> x.makeAggregator().createInitializedContext()).filter(x -> ((EvaluationContext) x).needToCompute()).count();
if (nonEmptyCtxsCnt == 0) {
HashMap<Class, EvaluationContext> res = new HashMap<>();
for (Metric m : metrics) {
MetricStatsAggregator<Double, ?, ?> aggregator = m.makeAggregator();
res.put(aggregator.getClass(), (EvaluationContext) m.makeAggregator().createInitializedContext());
return res;
}
}
return dataset.compute(data -> {
Map<Class, MetricStatsAggregator> aggrs = new HashMap<>();
for (Metric m : metrics) {
MetricStatsAggregator<Double, ?, ?> aggregator = m.makeAggregator();
if (!aggrs.containsKey(aggregator.getClass()))
aggrs.put(aggregator.getClass(), aggregator);
}
Map<Class, EvaluationContext> aggrToEvCtx = new HashMap<>();
aggrs.forEach((clazz, aggr) -> aggrToEvCtx.put(clazz, (EvaluationContext) aggr.createInitializedContext()));
for (int i = 0; i < data.getLabels().length; i++) {
LabeledVector<Double> vector = VectorUtils.of(data.getFeatures()[i]).labeled(data.getLabels()[i]);
aggrToEvCtx.values().forEach(ctx -> ctx.aggregate(vector));
}
return aggrToEvCtx;
}, (left, right) -> {
if (left == null && right == null)
return new HashMap<>();
if (left == null)
return right;
if (right == null)
return left;
HashMap<Class, EvaluationContext> res = new HashMap<>();
for (Class key : left.keySet()) {
EvaluationContext ctx1 = left.get(key);
EvaluationContext ctx2 = right.get(key);
A.ensure(ctx1 != null && ctx2 != null, "ctx1 != null && ctx2 != null");
res.put(key, ctx1.mergeWith(ctx2));
}
return res;
});
}
use of org.apache.ignite.ml.dataset.primitive.context.EmptyContext in project ignite by apache.
the class KMeansTrainer method updateModel.
/**
* {@inheritDoc}
*/
@Override
protected <K, V> KMeansModel updateModel(KMeansModel mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
assert datasetBuilder != null;
PartitionDataBuilder<K, V, EmptyContext, LabeledVectorSet<LabeledVector>> partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap<>(preprocessor);
Vector[] centers;
try (Dataset<EmptyContext, LabeledVectorSet<LabeledVector>> dataset = datasetBuilder.build(envBuilder, (env, upstream, upstreamSize) -> new EmptyContext(), partDataBuilder, learningEnvironment())) {
final Integer cols = dataset.compute(org.apache.ignite.ml.structures.Dataset::colSize, (a, b) -> {
if (a == null)
return b == null ? 0 : b;
if (b == null)
return a;
return b;
});
if (cols == null)
return getLastTrainedModelOrThrowEmptyDatasetException(mdl);
centers = Optional.ofNullable(mdl).map(KMeansModel::centers).orElseGet(() -> initClusterCentersRandomly(dataset, k));
boolean converged = false;
int iteration = 0;
while (iteration < maxIterations && !converged) {
Vector[] newCentroids = new DenseVector[k];
TotalCostAndCounts totalRes = calcDataForNewCentroids(centers, dataset, cols);
converged = true;
for (Map.Entry<Integer, Vector> entry : totalRes.sums.entrySet()) {
Vector massCenter = entry.getValue().times(1.0 / totalRes.counts.get(entry.getKey()));
if (converged && distance.compute(massCenter, centers[entry.getKey()]) > epsilon * epsilon)
converged = false;
newCentroids[entry.getKey()] = massCenter;
}
iteration++;
for (int i = 0; i < centers.length; i++) {
if (newCentroids[i] != null)
centers[i] = newCentroids[i];
}
}
} catch (Exception e) {
throw new RuntimeException(e);
}
return new KMeansModel(centers, distance);
}
use of org.apache.ignite.ml.dataset.primitive.context.EmptyContext in project ignite by apache.
the class NormalizationTrainer method fit.
/**
* {@inheritDoc}
*/
@Override
public NormalizationPreprocessor<K, V> fit(DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, double[]> basePreprocessor, int cols) {
try (Dataset<EmptyContext, NormalizationPartitionData> dataset = datasetBuilder.build((upstream, upstreamSize) -> new EmptyContext(), (upstream, upstreamSize, ctx) -> {
double[] min = new double[cols];
double[] max = new double[cols];
for (int i = 0; i < cols; i++) {
min[i] = Double.MAX_VALUE;
max[i] = -Double.MAX_VALUE;
}
while (upstream.hasNext()) {
UpstreamEntry<K, V> entity = upstream.next();
double[] row = basePreprocessor.apply(entity.getKey(), entity.getValue());
for (int i = 0; i < cols; i++) {
if (row[i] < min[i])
min[i] = row[i];
if (row[i] > max[i])
max[i] = row[i];
}
}
return new NormalizationPartitionData(min, max);
})) {
double[][] minMax = dataset.compute(data -> new double[][] { data.getMin(), data.getMax() }, (a, b) -> {
if (a == null)
return b;
if (b == null)
return a;
double[][] res = new double[2][];
res[0] = new double[a[0].length];
for (int i = 0; i < res[0].length; i++) res[0][i] = Math.min(a[0][i], b[0][i]);
res[1] = new double[a[1].length];
for (int i = 0; i < res[1].length; i++) res[1][i] = Math.max(a[1][i], b[1][i]);
return res;
});
return new NormalizationPreprocessor<>(minMax[0], minMax[1], basePreprocessor);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
Aggregations