use of org.apache.ignite.ml.math.primitives.vector.Vector in project ignite by apache.
the class BinarizationExample method createCache.
/**
*/
private static IgniteCache<Integer, Vector> createCache(Ignite ignite) {
CacheConfiguration<Integer, Vector> cacheConfiguration = new CacheConfiguration<>();
cacheConfiguration.setName("PERSONS");
cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 2));
IgniteCache<Integer, Vector> persons = ignite.createCache(cacheConfiguration);
persons.put(1, new DenseVector(new Serializable[] { "Mike", 42, 10000 }));
persons.put(2, new DenseVector(new Serializable[] { "John", 32, 64000 }));
persons.put(3, new DenseVector(new Serializable[] { "George", 53, 120000 }));
persons.put(4, new DenseVector(new Serializable[] { "Karl", 24, 70000 }));
return persons;
}
use of org.apache.ignite.ml.math.primitives.vector.Vector in project ignite by apache.
the class MaxAbsScalerExample method createCache.
/**
*/
private static IgniteCache<Integer, Vector> createCache(Ignite ignite) {
CacheConfiguration<Integer, Vector> cacheConfiguration = new CacheConfiguration<>();
cacheConfiguration.setName("PERSONS");
cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 2));
IgniteCache<Integer, Vector> persons = ignite.createCache(cacheConfiguration);
persons.put(1, new DenseVector(new Serializable[] { "Mike", 42, 10000 }));
persons.put(2, new DenseVector(new Serializable[] { "John", 32, 64000 }));
persons.put(3, new DenseVector(new Serializable[] { "George", 53, 120000 }));
persons.put(4, new DenseVector(new Serializable[] { "Karl", 24, 70000 }));
return persons;
}
use of org.apache.ignite.ml.math.primitives.vector.Vector in project ignite by apache.
the class GaussianNaiveBayesTrainer method updateModel.
/**
* {@inheritDoc}
*/
@Override
protected <K, V> GaussianNaiveBayesModel updateModel(GaussianNaiveBayesModel mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> extractor) {
assert datasetBuilder != null;
try (Dataset<EmptyContext, GaussianNaiveBayesSumsHolder> dataset = datasetBuilder.build(envBuilder, (env, upstream, upstreamSize) -> new EmptyContext(), (env, upstream, upstreamSize, ctx) -> {
GaussianNaiveBayesSumsHolder res = new GaussianNaiveBayesSumsHolder();
while (upstream.hasNext()) {
UpstreamEntry<K, V> entity = upstream.next();
LabeledVector lv = extractor.apply(entity.getKey(), entity.getValue());
Vector features = lv.features();
Double label = (Double) lv.label();
double[] toMeans;
double[] sqSum;
if (!res.featureSumsPerLbl.containsKey(label)) {
toMeans = new double[features.size()];
Arrays.fill(toMeans, 0.);
res.featureSumsPerLbl.put(label, toMeans);
}
if (!res.featureSquaredSumsPerLbl.containsKey(label)) {
sqSum = new double[features.size()];
res.featureSquaredSumsPerLbl.put(label, sqSum);
}
if (!res.featureCountersPerLbl.containsKey(label))
res.featureCountersPerLbl.put(label, 0);
res.featureCountersPerLbl.put(label, res.featureCountersPerLbl.get(label) + 1);
toMeans = res.featureSumsPerLbl.get(label);
sqSum = res.featureSquaredSumsPerLbl.get(label);
for (int j = 0; j < features.size(); j++) {
double x = features.get(j);
toMeans[j] += x;
sqSum[j] += x * x;
}
}
return res;
}, learningEnvironment())) {
GaussianNaiveBayesSumsHolder sumsHolder = dataset.compute(t -> t, (a, b) -> {
if (a == null)
return b;
if (b == null)
return a;
return a.merge(b);
});
if (mdl != null && mdl.getSumsHolder() != null)
sumsHolder = sumsHolder.merge(mdl.getSumsHolder());
List<Double> sortedLabels = new ArrayList<>(sumsHolder.featureCountersPerLbl.keySet());
sortedLabels.sort(Double::compareTo);
assert !sortedLabels.isEmpty() : "The dataset should contain at least one feature";
int labelCount = sortedLabels.size();
int featureCount = sumsHolder.featureSumsPerLbl.get(sortedLabels.get(0)).length;
double[][] means = new double[labelCount][featureCount];
double[][] variances = new double[labelCount][featureCount];
double[] classProbabilities = new double[labelCount];
double[] labels = new double[labelCount];
long datasetSize = sumsHolder.featureCountersPerLbl.values().stream().mapToInt(i -> i).sum();
int lbl = 0;
for (Double label : sortedLabels) {
int count = sumsHolder.featureCountersPerLbl.get(label);
double[] sum = sumsHolder.featureSumsPerLbl.get(label);
double[] sqSum = sumsHolder.featureSquaredSumsPerLbl.get(label);
for (int i = 0; i < featureCount; i++) {
means[lbl][i] = sum[i] / count;
variances[lbl][i] = (sqSum[i] - sum[i] * sum[i] / count) / count;
}
if (equiprobableClasses)
classProbabilities[lbl] = 1. / labelCount;
else if (priorProbabilities != null) {
assert classProbabilities.length == priorProbabilities.length;
classProbabilities[lbl] = priorProbabilities[lbl];
} else
classProbabilities[lbl] = (double) count / datasetSize;
labels[lbl] = label;
++lbl;
}
return new GaussianNaiveBayesModel(means, variances, classProbabilities, labels, sumsHolder);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
use of org.apache.ignite.ml.math.primitives.vector.Vector in project ignite by apache.
the class ReplicatedVectorMatrix method assignColumn.
/**
* {@inheritDoc}
*/
@Override
public Matrix assignColumn(int col, Vector vec) {
int rows = asCol ? vector.size() : replicationCnt;
int cols = asCol ? replicationCnt : vector.size();
int times = asCol ? cols : rows;
Matrix res = new DenseMatrix(rows, cols);
IgniteBiConsumer<Integer, Vector> replicantAssigner = asCol ? res::assignColumn : res::assignRow;
IgniteBiConsumer<Integer, Vector> assigner = res::assignColumn;
assign(replicantAssigner, assigner, vector, vec, times, col);
return res;
}
use of org.apache.ignite.ml.math.primitives.vector.Vector in project ignite by apache.
the class ReplicatedVectorMatrix method minus.
/**
* Specialized optimized version of minus for ReplicatedVectorMatrix.
*
* @param mtx Matrix to be subtracted.
* @return New ReplicatedVectorMatrix resulting from subtraction.
*/
public Matrix minus(ReplicatedVectorMatrix mtx) {
if (isColumnReplicated() == mtx.isColumnReplicated()) {
checkCardinality(mtx.rowSize(), mtx.columnSize());
Vector minus = vector.minus(mtx.replicant());
return new ReplicatedVectorMatrix(minus, replicationCnt, asCol);
}
throw new UnsupportedOperationException();
}
Aggregations