Search in sources :

Example 1 with ComputationsChain

use of org.apache.ignite.ml.trainers.group.chain.ComputationsChain in project ignite by apache.

the class GroupTrainer method train.

/**
 * {@inheritDoc}
 */
@Override
public final M train(T data) {
    UUID trainingUUID = UUID.randomUUID();
    LC locCtx = initialLocalContext(data, trainingUUID);
    GroupTrainingContext<K, V, LC> ctx = new GroupTrainingContext<>(locCtx, cache, ignite);
    ComputationsChain<LC, K, V, T, T> chain = (i, c) -> i;
    IgniteFunction<GroupTrainerCacheKey<K>, ResultAndUpdates<IN>> distributedInitializer = distributedInitializer(data);
    init(data, trainingUUID);
    M res = chain.thenDistributedForKeys(distributedInitializer, (t, lc) -> data.initialKeys(trainingUUID), reduceDistributedInitData()).thenLocally(this::locallyProcessInitData).thenWhile(this::shouldContinue, trainingLoopStep()).thenDistributedForEntries(this::extractContextForFinalResultCreation, finalResultsExtractor(), this::finalResultKeys, finalResultsReducer()).thenLocally(this::mapFinalResult).process(data, ctx);
    cleanup(locCtx);
    return res;
}
Also used : IgniteFunction(org.apache.ignite.ml.math.functions.IgniteFunction) IgniteSupplier(org.apache.ignite.ml.math.functions.IgniteSupplier) Model(org.apache.ignite.ml.Model) ComputationsChain(org.apache.ignite.ml.trainers.group.chain.ComputationsChain) UUID(java.util.UUID) Ignite(org.apache.ignite.Ignite) IgniteCache(org.apache.ignite.IgniteCache) Serializable(java.io.Serializable) Trainer(org.apache.ignite.ml.trainers.Trainer) EntryAndContext(org.apache.ignite.ml.trainers.group.chain.EntryAndContext) List(java.util.List) Stream(java.util.stream.Stream) HasTrainingUUID(org.apache.ignite.ml.trainers.group.chain.HasTrainingUUID) UUID(java.util.UUID) HasTrainingUUID(org.apache.ignite.ml.trainers.group.chain.HasTrainingUUID)

Example 2 with ComputationsChain

use of org.apache.ignite.ml.trainers.group.chain.ComputationsChain in project ignite by apache.

the class DistributedWorkersChainTest method testDistributed.

/**
 */
public void testDistributed() {
    ComputationsChain<TestLocalContext, Double, Integer, Integer, Integer> chain = Chains.create();
    IgniteCache<GroupTrainerCacheKey<Double>, Integer> cache = TestGroupTrainingCache.getOrCreate(ignite);
    int init = 1;
    UUID trainingUUID = UUID.randomUUID();
    TestLocalContext locCtx = new TestLocalContext(0, trainingUUID);
    Map<GroupTrainerCacheKey<Double>, Integer> m = new HashMap<>();
    m.put(new GroupTrainerCacheKey<>(0L, 1.0, trainingUUID), 1);
    m.put(new GroupTrainerCacheKey<>(1L, 2.0, trainingUUID), 2);
    m.put(new GroupTrainerCacheKey<>(2L, 3.0, trainingUUID), 3);
    m.put(new GroupTrainerCacheKey<>(3L, 4.0, trainingUUID), 4);
    Stream<GroupTrainerCacheKey<Double>> keys = m.keySet().stream();
    cache.putAll(m);
    IgniteBiFunction<Integer, TestLocalContext, IgniteSupplier<Stream<GroupTrainerCacheKey<Double>>>> function = (o, l) -> () -> keys;
    IgniteFunction<List<Integer>, Integer> max = ints -> ints.stream().mapToInt(x -> x).max().orElse(Integer.MIN_VALUE);
    Integer res = chain.thenDistributedForEntries((integer, context) -> () -> null, this::readAndIncrement, function, max).process(init, new GroupTrainingContext<>(locCtx, cache, ignite));
    int localMax = m.values().stream().max(Comparator.comparingInt(i -> i)).orElse(Integer.MIN_VALUE);
    assertEquals((long) localMax, (long) res);
    for (GroupTrainerCacheKey<Double> key : m.keySet()) m.compute(key, (k, v) -> v + 1);
    assertMapEqualsCache(m, cache);
}
Also used : GridCommonAbstractTest(org.apache.ignite.testframework.junits.common.GridCommonAbstractTest) IgniteFunction(org.apache.ignite.ml.math.functions.IgniteFunction) IgniteSupplier(org.apache.ignite.ml.math.functions.IgniteSupplier) Chains(org.apache.ignite.ml.trainers.group.chain.Chains) ComputationsChain(org.apache.ignite.ml.trainers.group.chain.ComputationsChain) HashMap(java.util.HashMap) UUID(java.util.UUID) Ignite(org.apache.ignite.Ignite) IgniteCache(org.apache.ignite.IgniteCache) EntryAndContext(org.apache.ignite.ml.trainers.group.chain.EntryAndContext) List(java.util.List) Stream(java.util.stream.Stream) Ignition(org.apache.ignite.Ignition) IgniteBiFunction(org.apache.ignite.ml.math.functions.IgniteBiFunction) Map(java.util.Map) Comparator(java.util.Comparator) Assert(org.junit.Assert) IgniteSupplier(org.apache.ignite.ml.math.functions.IgniteSupplier) HashMap(java.util.HashMap) List(java.util.List) UUID(java.util.UUID)

Example 3 with ComputationsChain

use of org.apache.ignite.ml.trainers.group.chain.ComputationsChain in project ignite by apache.

the class DistributedWorkersChainTest method testChangeLocalContext.

/**
 */
public void testChangeLocalContext() {
    ComputationsChain<TestLocalContext, Double, Integer, Integer, Integer> chain = Chains.create();
    IgniteCache<GroupTrainerCacheKey<Double>, Integer> cache = TestGroupTrainingCache.getOrCreate(ignite);
    int init = 1;
    int newData = 10;
    UUID trainingUUID = UUID.randomUUID();
    TestLocalContext locCtx = new TestLocalContext(0, trainingUUID);
    Integer res = chain.thenLocally((prev, lc) -> {
        lc.setData(newData);
        return prev;
    }).process(init, new GroupTrainingContext<>(locCtx, cache, ignite));
    Assert.assertEquals(newData, locCtx.data());
    Assert.assertEquals(init, res.intValue());
}
Also used : GridCommonAbstractTest(org.apache.ignite.testframework.junits.common.GridCommonAbstractTest) IgniteFunction(org.apache.ignite.ml.math.functions.IgniteFunction) IgniteSupplier(org.apache.ignite.ml.math.functions.IgniteSupplier) Chains(org.apache.ignite.ml.trainers.group.chain.Chains) ComputationsChain(org.apache.ignite.ml.trainers.group.chain.ComputationsChain) HashMap(java.util.HashMap) UUID(java.util.UUID) Ignite(org.apache.ignite.Ignite) IgniteCache(org.apache.ignite.IgniteCache) EntryAndContext(org.apache.ignite.ml.trainers.group.chain.EntryAndContext) List(java.util.List) Stream(java.util.stream.Stream) Ignition(org.apache.ignite.Ignition) IgniteBiFunction(org.apache.ignite.ml.math.functions.IgniteBiFunction) Map(java.util.Map) Comparator(java.util.Comparator) Assert(org.junit.Assert) UUID(java.util.UUID)

Example 4 with ComputationsChain

use of org.apache.ignite.ml.trainers.group.chain.ComputationsChain in project ignite by apache.

the class DistributedWorkersChainTest method testChainLocal.

/**
 */
public void testChainLocal() {
    ComputationsChain<TestLocalContext, Double, Integer, Integer, Integer> chain = Chains.create();
    IgniteCache<GroupTrainerCacheKey<Double>, Integer> cache = TestGroupTrainingCache.getOrCreate(ignite);
    int init = 1;
    int initLocCtxData = 0;
    UUID trainingUUID = UUID.randomUUID();
    TestLocalContext locCtx = new TestLocalContext(initLocCtxData, trainingUUID);
    Integer res = chain.thenLocally((prev, lc) -> prev + 1).thenLocally((prev, lc) -> prev * 5).process(init, new GroupTrainingContext<>(locCtx, cache, ignite));
    Assert.assertEquals((init + 1) * 5, (long) res);
    Assert.assertEquals(initLocCtxData, locCtx.data());
}
Also used : GridCommonAbstractTest(org.apache.ignite.testframework.junits.common.GridCommonAbstractTest) IgniteFunction(org.apache.ignite.ml.math.functions.IgniteFunction) IgniteSupplier(org.apache.ignite.ml.math.functions.IgniteSupplier) Chains(org.apache.ignite.ml.trainers.group.chain.Chains) ComputationsChain(org.apache.ignite.ml.trainers.group.chain.ComputationsChain) HashMap(java.util.HashMap) UUID(java.util.UUID) Ignite(org.apache.ignite.Ignite) IgniteCache(org.apache.ignite.IgniteCache) EntryAndContext(org.apache.ignite.ml.trainers.group.chain.EntryAndContext) List(java.util.List) Stream(java.util.stream.Stream) Ignition(org.apache.ignite.Ignition) IgniteBiFunction(org.apache.ignite.ml.math.functions.IgniteBiFunction) Map(java.util.Map) Comparator(java.util.Comparator) Assert(org.junit.Assert) UUID(java.util.UUID)

Example 5 with ComputationsChain

use of org.apache.ignite.ml.trainers.group.chain.ComputationsChain in project ignite by apache.

the class DistributedWorkersChainTest method testSimpleLocal.

/**
 */
public void testSimpleLocal() {
    ComputationsChain<TestLocalContext, Double, Integer, Integer, Integer> chain = Chains.create();
    IgniteCache<GroupTrainerCacheKey<Double>, Integer> cache = TestGroupTrainingCache.getOrCreate(ignite);
    int init = 1;
    int initLocCtxData = 0;
    UUID trainingUUID = UUID.randomUUID();
    TestLocalContext locCtx = new TestLocalContext(initLocCtxData, trainingUUID);
    Integer res = chain.thenLocally((prev, lc) -> prev + 1).process(init, new GroupTrainingContext<>(locCtx, cache, ignite));
    Assert.assertEquals(init + 1, (long) res);
    Assert.assertEquals(initLocCtxData, locCtx.data());
}
Also used : GridCommonAbstractTest(org.apache.ignite.testframework.junits.common.GridCommonAbstractTest) IgniteFunction(org.apache.ignite.ml.math.functions.IgniteFunction) IgniteSupplier(org.apache.ignite.ml.math.functions.IgniteSupplier) Chains(org.apache.ignite.ml.trainers.group.chain.Chains) ComputationsChain(org.apache.ignite.ml.trainers.group.chain.ComputationsChain) HashMap(java.util.HashMap) UUID(java.util.UUID) Ignite(org.apache.ignite.Ignite) IgniteCache(org.apache.ignite.IgniteCache) EntryAndContext(org.apache.ignite.ml.trainers.group.chain.EntryAndContext) List(java.util.List) Stream(java.util.stream.Stream) Ignition(org.apache.ignite.Ignition) IgniteBiFunction(org.apache.ignite.ml.math.functions.IgniteBiFunction) Map(java.util.Map) Comparator(java.util.Comparator) Assert(org.junit.Assert) UUID(java.util.UUID)

Aggregations

List (java.util.List)5 UUID (java.util.UUID)5 Stream (java.util.stream.Stream)5 Ignite (org.apache.ignite.Ignite)5 IgniteCache (org.apache.ignite.IgniteCache)5 IgniteFunction (org.apache.ignite.ml.math.functions.IgniteFunction)5 IgniteSupplier (org.apache.ignite.ml.math.functions.IgniteSupplier)5 ComputationsChain (org.apache.ignite.ml.trainers.group.chain.ComputationsChain)5 EntryAndContext (org.apache.ignite.ml.trainers.group.chain.EntryAndContext)5 Comparator (java.util.Comparator)4 HashMap (java.util.HashMap)4 Map (java.util.Map)4 Ignition (org.apache.ignite.Ignition)4 IgniteBiFunction (org.apache.ignite.ml.math.functions.IgniteBiFunction)4 Chains (org.apache.ignite.ml.trainers.group.chain.Chains)4 GridCommonAbstractTest (org.apache.ignite.testframework.junits.common.GridCommonAbstractTest)4 Assert (org.junit.Assert)4 Serializable (java.io.Serializable)1 Model (org.apache.ignite.ml.Model)1 Trainer (org.apache.ignite.ml.trainers.Trainer)1