use of org.apache.ignite.ml.trainers.group.chain.ComputationsChain in project ignite by apache.
the class GroupTrainer method train.
/**
* {@inheritDoc}
*/
@Override
public final M train(T data) {
UUID trainingUUID = UUID.randomUUID();
LC locCtx = initialLocalContext(data, trainingUUID);
GroupTrainingContext<K, V, LC> ctx = new GroupTrainingContext<>(locCtx, cache, ignite);
ComputationsChain<LC, K, V, T, T> chain = (i, c) -> i;
IgniteFunction<GroupTrainerCacheKey<K>, ResultAndUpdates<IN>> distributedInitializer = distributedInitializer(data);
init(data, trainingUUID);
M res = chain.thenDistributedForKeys(distributedInitializer, (t, lc) -> data.initialKeys(trainingUUID), reduceDistributedInitData()).thenLocally(this::locallyProcessInitData).thenWhile(this::shouldContinue, trainingLoopStep()).thenDistributedForEntries(this::extractContextForFinalResultCreation, finalResultsExtractor(), this::finalResultKeys, finalResultsReducer()).thenLocally(this::mapFinalResult).process(data, ctx);
cleanup(locCtx);
return res;
}
use of org.apache.ignite.ml.trainers.group.chain.ComputationsChain in project ignite by apache.
the class DistributedWorkersChainTest method testDistributed.
/**
*/
public void testDistributed() {
ComputationsChain<TestLocalContext, Double, Integer, Integer, Integer> chain = Chains.create();
IgniteCache<GroupTrainerCacheKey<Double>, Integer> cache = TestGroupTrainingCache.getOrCreate(ignite);
int init = 1;
UUID trainingUUID = UUID.randomUUID();
TestLocalContext locCtx = new TestLocalContext(0, trainingUUID);
Map<GroupTrainerCacheKey<Double>, Integer> m = new HashMap<>();
m.put(new GroupTrainerCacheKey<>(0L, 1.0, trainingUUID), 1);
m.put(new GroupTrainerCacheKey<>(1L, 2.0, trainingUUID), 2);
m.put(new GroupTrainerCacheKey<>(2L, 3.0, trainingUUID), 3);
m.put(new GroupTrainerCacheKey<>(3L, 4.0, trainingUUID), 4);
Stream<GroupTrainerCacheKey<Double>> keys = m.keySet().stream();
cache.putAll(m);
IgniteBiFunction<Integer, TestLocalContext, IgniteSupplier<Stream<GroupTrainerCacheKey<Double>>>> function = (o, l) -> () -> keys;
IgniteFunction<List<Integer>, Integer> max = ints -> ints.stream().mapToInt(x -> x).max().orElse(Integer.MIN_VALUE);
Integer res = chain.thenDistributedForEntries((integer, context) -> () -> null, this::readAndIncrement, function, max).process(init, new GroupTrainingContext<>(locCtx, cache, ignite));
int localMax = m.values().stream().max(Comparator.comparingInt(i -> i)).orElse(Integer.MIN_VALUE);
assertEquals((long) localMax, (long) res);
for (GroupTrainerCacheKey<Double> key : m.keySet()) m.compute(key, (k, v) -> v + 1);
assertMapEqualsCache(m, cache);
}
use of org.apache.ignite.ml.trainers.group.chain.ComputationsChain in project ignite by apache.
the class DistributedWorkersChainTest method testChangeLocalContext.
/**
*/
public void testChangeLocalContext() {
ComputationsChain<TestLocalContext, Double, Integer, Integer, Integer> chain = Chains.create();
IgniteCache<GroupTrainerCacheKey<Double>, Integer> cache = TestGroupTrainingCache.getOrCreate(ignite);
int init = 1;
int newData = 10;
UUID trainingUUID = UUID.randomUUID();
TestLocalContext locCtx = new TestLocalContext(0, trainingUUID);
Integer res = chain.thenLocally((prev, lc) -> {
lc.setData(newData);
return prev;
}).process(init, new GroupTrainingContext<>(locCtx, cache, ignite));
Assert.assertEquals(newData, locCtx.data());
Assert.assertEquals(init, res.intValue());
}
use of org.apache.ignite.ml.trainers.group.chain.ComputationsChain in project ignite by apache.
the class DistributedWorkersChainTest method testChainLocal.
/**
*/
public void testChainLocal() {
ComputationsChain<TestLocalContext, Double, Integer, Integer, Integer> chain = Chains.create();
IgniteCache<GroupTrainerCacheKey<Double>, Integer> cache = TestGroupTrainingCache.getOrCreate(ignite);
int init = 1;
int initLocCtxData = 0;
UUID trainingUUID = UUID.randomUUID();
TestLocalContext locCtx = new TestLocalContext(initLocCtxData, trainingUUID);
Integer res = chain.thenLocally((prev, lc) -> prev + 1).thenLocally((prev, lc) -> prev * 5).process(init, new GroupTrainingContext<>(locCtx, cache, ignite));
Assert.assertEquals((init + 1) * 5, (long) res);
Assert.assertEquals(initLocCtxData, locCtx.data());
}
use of org.apache.ignite.ml.trainers.group.chain.ComputationsChain in project ignite by apache.
the class DistributedWorkersChainTest method testSimpleLocal.
/**
*/
public void testSimpleLocal() {
ComputationsChain<TestLocalContext, Double, Integer, Integer, Integer> chain = Chains.create();
IgniteCache<GroupTrainerCacheKey<Double>, Integer> cache = TestGroupTrainingCache.getOrCreate(ignite);
int init = 1;
int initLocCtxData = 0;
UUID trainingUUID = UUID.randomUUID();
TestLocalContext locCtx = new TestLocalContext(initLocCtxData, trainingUUID);
Integer res = chain.thenLocally((prev, lc) -> prev + 1).process(init, new GroupTrainingContext<>(locCtx, cache, ignite));
Assert.assertEquals(init + 1, (long) res);
Assert.assertEquals(initLocCtxData, locCtx.data());
}
Aggregations