use of org.apache.ignite.ml.structures.LabeledVectorSet in project ignite by apache.
the class NNClassificationModel method getDistances.
/**
* Computes distances between given vector and each vector in training dataset.
*
* @param v The given vector.
* @param trainingData The training dataset.
* @return Key - distanceMeasure from given features before features with idx stored in value. Value is presented
* with Set because there can be a few vectors with the same distance.
*/
@NotNull
protected TreeMap<Double, Set<Integer>> getDistances(Vector v, LabeledVectorSet<LabeledVector> trainingData) {
TreeMap<Double, Set<Integer>> distanceIdxPairs = new TreeMap<>();
for (int i = 0; i < trainingData.rowSize(); i++) {
LabeledVector labeledVector = trainingData.getRow(i);
if (labeledVector != null) {
double distance = distanceMeasure.compute(v, labeledVector.features());
putDistanceIdxPair(distanceIdxPairs, i, distance);
}
}
return distanceIdxPairs;
}
use of org.apache.ignite.ml.structures.LabeledVectorSet in project ignite by apache.
the class ANNClassificationTrainer method getCentroidStat.
/**
*/
private <K, V> CentroidStat getCentroidStat(DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> vectorizer, List<Vector> centers) {
PartitionDataBuilder<K, V, EmptyContext, LabeledVectorSet<LabeledVector>> partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap<>(vectorizer);
try (Dataset<EmptyContext, LabeledVectorSet<LabeledVector>> dataset = datasetBuilder.build(envBuilder, (env, upstream, upstreamSize) -> new EmptyContext(), partDataBuilder, learningEnvironment())) {
return dataset.compute(data -> {
CentroidStat res = new CentroidStat();
for (int i = 0; i < data.rowSize(); i++) {
final IgniteBiTuple<Integer, Double> closestCentroid = findClosestCentroid(centers, data.getRow(i));
int centroidIdx = closestCentroid.get1();
double lb = data.label(i);
// add new label to label set
res.labels().add(lb);
ConcurrentHashMap<Double, Integer> centroidStat = res.centroidStat.get(centroidIdx);
if (centroidStat == null) {
centroidStat = new ConcurrentHashMap<>();
centroidStat.put(lb, 1);
res.centroidStat.put(centroidIdx, centroidStat);
} else {
int cnt = centroidStat.getOrDefault(lb, 0);
centroidStat.put(lb, cnt + 1);
}
res.counts.merge(centroidIdx, 1, (IgniteBiFunction<Integer, Integer, Integer>) (i1, i2) -> i1 + i2);
}
return res;
}, (a, b) -> {
if (a == null)
return b == null ? new CentroidStat() : b;
if (b == null)
return a;
return a.merge(b);
});
} catch (Exception e) {
throw new RuntimeException(e);
}
}
use of org.apache.ignite.ml.structures.LabeledVectorSet in project ignite by apache.
the class KNNUtils method buildDataset.
/**
* Builds dataset.
*
* @param envBuilder Learning environment builder.
* @param datasetBuilder Dataset builder.
* @param vectorizer Upstream vectorizer.
* @return Dataset.
*/
@Nullable
public static <K, V, C extends Serializable> Dataset<EmptyContext, LabeledVectorSet<LabeledVector>> buildDataset(LearningEnvironmentBuilder envBuilder, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> vectorizer) {
LearningEnvironment environment = envBuilder.buildForTrainer();
environment.initDeployingContext(vectorizer);
PartitionDataBuilder<K, V, EmptyContext, LabeledVectorSet<LabeledVector>> partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap<>(vectorizer);
Dataset<EmptyContext, LabeledVectorSet<LabeledVector>> dataset = null;
if (datasetBuilder != null) {
dataset = datasetBuilder.build(envBuilder, (env, upstream, upstreamSize) -> new EmptyContext(), partDataBuilder, environment);
}
return dataset;
}
use of org.apache.ignite.ml.structures.LabeledVectorSet in project ignite by apache.
the class LabeledDatasetPartitionDataBuilderOnHeap method build.
/**
* {@inheritDoc}
*/
@Override
public LabeledVectorSet<LabeledVector> build(LearningEnvironment env, Iterator<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize, C ctx) {
int xCols = -1;
double[][] x = null;
double[] y = new double[Math.toIntExact(upstreamDataSize)];
int ptr = 0;
while (upstreamData.hasNext()) {
UpstreamEntry<K, V> entry = upstreamData.next();
LabeledVector<Double> labeledVector = preprocessor.apply(entry.getKey(), entry.getValue());
Vector row = labeledVector.features();
if (xCols < 0) {
xCols = row.size();
x = new double[Math.toIntExact(upstreamDataSize)][xCols];
} else
assert row.size() == xCols : "X extractor must return exactly " + xCols + " columns";
x[ptr] = row.asArray();
y[ptr] = labeledVector.label();
ptr++;
}
return new LabeledVectorSet<>(x, y);
}
use of org.apache.ignite.ml.structures.LabeledVectorSet in project ignite by apache.
the class KMeansTrainer method updateModel.
/**
* {@inheritDoc}
*/
@Override
protected <K, V> KMeansModel updateModel(KMeansModel mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
assert datasetBuilder != null;
PartitionDataBuilder<K, V, EmptyContext, LabeledVectorSet<LabeledVector>> partDataBuilder = new LabeledDatasetPartitionDataBuilderOnHeap<>(preprocessor);
Vector[] centers;
try (Dataset<EmptyContext, LabeledVectorSet<LabeledVector>> dataset = datasetBuilder.build(envBuilder, (env, upstream, upstreamSize) -> new EmptyContext(), partDataBuilder, learningEnvironment())) {
final Integer cols = dataset.compute(org.apache.ignite.ml.structures.Dataset::colSize, (a, b) -> {
if (a == null)
return b == null ? 0 : b;
if (b == null)
return a;
return b;
});
if (cols == null)
return getLastTrainedModelOrThrowEmptyDatasetException(mdl);
centers = Optional.ofNullable(mdl).map(KMeansModel::centers).orElseGet(() -> initClusterCentersRandomly(dataset, k));
boolean converged = false;
int iteration = 0;
while (iteration < maxIterations && !converged) {
Vector[] newCentroids = new DenseVector[k];
TotalCostAndCounts totalRes = calcDataForNewCentroids(centers, dataset, cols);
converged = true;
for (Map.Entry<Integer, Vector> entry : totalRes.sums.entrySet()) {
Vector massCenter = entry.getValue().times(1.0 / totalRes.counts.get(entry.getKey()));
if (converged && distance.compute(massCenter, centers[entry.getKey()]) > epsilon * epsilon)
converged = false;
newCentroids[entry.getKey()] = massCenter;
}
iteration++;
for (int i = 0; i < centers.length; i++) {
if (newCentroids[i] != null)
centers[i] = newCentroids[i];
}
}
} catch (Exception e) {
throw new RuntimeException(e);
}
return new KMeansModel(centers, distance);
}
Aggregations