use of org.apache.ignite.ml.math.functions.IgniteFunction in project ignite by apache.
the class GroupTrainer method train.
/**
* {@inheritDoc}
*/
@Override
public final M train(T data) {
UUID trainingUUID = UUID.randomUUID();
LC locCtx = initialLocalContext(data, trainingUUID);
GroupTrainingContext<K, V, LC> ctx = new GroupTrainingContext<>(locCtx, cache, ignite);
ComputationsChain<LC, K, V, T, T> chain = (i, c) -> i;
IgniteFunction<GroupTrainerCacheKey<K>, ResultAndUpdates<IN>> distributedInitializer = distributedInitializer(data);
init(data, trainingUUID);
M res = chain.thenDistributedForKeys(distributedInitializer, (t, lc) -> data.initialKeys(trainingUUID), reduceDistributedInitData()).thenLocally(this::locallyProcessInitData).thenWhile(this::shouldContinue, trainingLoopStep()).thenDistributedForEntries(this::extractContextForFinalResultCreation, finalResultsExtractor(), this::finalResultKeys, finalResultsReducer()).thenLocally(this::mapFinalResult).process(data, ctx);
cleanup(locCtx);
return res;
}
use of org.apache.ignite.ml.math.functions.IgniteFunction in project ignite by apache.
the class SplitDataGenerator method testByGen.
/**
*/
<D extends ContinuousRegionInfo> void testByGen(int totalPts, IgniteFunction<ColumnDecisionTreeTrainerInput, ? extends ContinuousSplitCalculator<D>> calc, IgniteFunction<ColumnDecisionTreeTrainerInput, IgniteFunction<DoubleStream, Double>> catImpCalc, IgniteFunction<DoubleStream, Double> regCalc, Ignite ignite) {
List<IgniteBiTuple<Integer, V>> lst = points(totalPts, (i, rn) -> i).collect(Collectors.toList());
Collections.shuffle(lst, rnd);
SparseDistributedMatrix m = new SparseDistributedMatrix(totalPts, featCnt + 1, StorageConstants.COLUMN_STORAGE_MODE, StorageConstants.RANDOM_ACCESS_MODE);
Map<Integer, List<LabeledVectorDouble>> byRegion = new HashMap<>();
int i = 0;
for (IgniteBiTuple<Integer, V> bt : lst) {
byRegion.putIfAbsent(bt.get1(), new LinkedList<>());
byRegion.get(bt.get1()).add(asLabeledVector(bt.get2().getStorage().data()));
m.setRow(i, bt.get2().getStorage().data());
i++;
}
ColumnDecisionTreeTrainer<D> trainer = new ColumnDecisionTreeTrainer<>(3, calc, catImpCalc, regCalc, ignite);
DecisionTreeModel mdl = trainer.train(new MatrixColumnDecisionTreeTrainerInput(m, catFeaturesInfo));
byRegion.keySet().forEach(k -> mdl.apply(byRegion.get(k).get(0).features()));
}
use of org.apache.ignite.ml.math.functions.IgniteFunction in project ignite by apache.
the class GradientDescent method getLossGradientFunction.
/**
* Makes carrying of the gradient function and fixes data matrix.
*/
private IgniteFunction<Vector, Vector> getLossGradientFunction(Matrix data) {
if (data instanceof SparseDistributedMatrix) {
SparseDistributedMatrix distributedMatrix = (SparseDistributedMatrix) data;
if (distributedMatrix.getStorage().storageMode() == StorageConstants.ROW_STORAGE_MODE)
return weights -> calculateDistributedGradient(distributedMatrix, weights);
}
Matrix inputs = extractInputs(data);
Vector groundTruth = extractGroundTruth(data);
return weights -> lossGradient.compute(inputs, groundTruth, weights);
}
use of org.apache.ignite.ml.math.functions.IgniteFunction in project ignite by apache.
the class DistributedWorkersChainTest method testDistributed.
/**
*/
public void testDistributed() {
ComputationsChain<TestLocalContext, Double, Integer, Integer, Integer> chain = Chains.create();
IgniteCache<GroupTrainerCacheKey<Double>, Integer> cache = TestGroupTrainingCache.getOrCreate(ignite);
int init = 1;
UUID trainingUUID = UUID.randomUUID();
TestLocalContext locCtx = new TestLocalContext(0, trainingUUID);
Map<GroupTrainerCacheKey<Double>, Integer> m = new HashMap<>();
m.put(new GroupTrainerCacheKey<>(0L, 1.0, trainingUUID), 1);
m.put(new GroupTrainerCacheKey<>(1L, 2.0, trainingUUID), 2);
m.put(new GroupTrainerCacheKey<>(2L, 3.0, trainingUUID), 3);
m.put(new GroupTrainerCacheKey<>(3L, 4.0, trainingUUID), 4);
Stream<GroupTrainerCacheKey<Double>> keys = m.keySet().stream();
cache.putAll(m);
IgniteBiFunction<Integer, TestLocalContext, IgniteSupplier<Stream<GroupTrainerCacheKey<Double>>>> function = (o, l) -> () -> keys;
IgniteFunction<List<Integer>, Integer> max = ints -> ints.stream().mapToInt(x -> x).max().orElse(Integer.MIN_VALUE);
Integer res = chain.thenDistributedForEntries((integer, context) -> () -> null, this::readAndIncrement, function, max).process(init, new GroupTrainingContext<>(locCtx, cache, ignite));
int localMax = m.values().stream().max(Comparator.comparingInt(i -> i)).orElse(Integer.MIN_VALUE);
assertEquals((long) localMax, (long) res);
for (GroupTrainerCacheKey<Double> key : m.keySet()) m.compute(key, (k, v) -> v + 1);
assertMapEqualsCache(m, cache);
}
use of org.apache.ignite.ml.math.functions.IgniteFunction in project ignite by apache.
the class ColumnDecisionTreeTrainerBenchmark method testByGenStreamerLoad.
/**
*/
private void testByGenStreamerLoad(int ptsPerReg, HashMap<Integer, Integer> catsInfo, SplitDataGenerator<DenseLocalOnHeapVector> gen, Random rnd) {
List<IgniteBiTuple<Integer, DenseLocalOnHeapVector>> lst = gen.points(ptsPerReg, (i, rn) -> i).collect(Collectors.toList());
int featCnt = gen.featuresCnt();
Collections.shuffle(lst, rnd);
int numRegs = gen.regsCount();
SparseDistributedMatrix m = new SparseDistributedMatrix(numRegs * ptsPerReg, featCnt + 1, StorageConstants.COLUMN_STORAGE_MODE, StorageConstants.RANDOM_ACCESS_MODE);
IgniteFunction<DoubleStream, Double> regCalc = s -> s.average().orElse(0.0);
Map<Integer, List<LabeledVectorDouble>> byRegion = new HashMap<>();
SparseDistributedMatrixStorage sto = (SparseDistributedMatrixStorage) m.getStorage();
long before = System.currentTimeMillis();
X.println("Batch loading started...");
loadVectorsIntoSparseDistributedMatrixCache(sto.cache().getName(), sto.getUUID(), gen.points(ptsPerReg, (i, rn) -> i).map(IgniteBiTuple::get2).iterator(), featCnt + 1);
X.println("Batch loading took " + (System.currentTimeMillis() - before) + " ms.");
for (IgniteBiTuple<Integer, DenseLocalOnHeapVector> bt : lst) {
byRegion.putIfAbsent(bt.get1(), new LinkedList<>());
byRegion.get(bt.get1()).add(asLabeledVector(bt.get2().getStorage().data()));
}
ColumnDecisionTreeTrainer<VarianceSplitCalculator.VarianceData> trainer = new ColumnDecisionTreeTrainer<>(2, ContinuousSplitCalculators.VARIANCE, RegionCalculators.VARIANCE, regCalc, ignite);
before = System.currentTimeMillis();
DecisionTreeModel mdl = trainer.train(new MatrixColumnDecisionTreeTrainerInput(m, catsInfo));
X.println("Training took: " + (System.currentTimeMillis() - before) + " ms.");
byRegion.keySet().forEach(k -> {
LabeledVectorDouble sp = byRegion.get(k).get(0);
Tracer.showAscii(sp.features());
X.println("Predicted value and label [pred=" + mdl.apply(sp.features()) + ", label=" + sp.doubleLabel() + "]");
assert mdl.apply(sp.features()) == sp.doubleLabel();
});
}
Aggregations