Search in sources :

Example 1 with DistributedSgDotMessage

use of org.nd4j.parameterserver.distributed.messages.intercom.DistributedSgDotMessage in project nd4j by deeplearning4j.

the class VoidParameterServerTest method testNodeInitialization2.

/**
 * This is very important test, it covers basic messages handling over network.
 * Here we have 1 client, 1 connected Shard + 2 shards available over multicast UDP
 *
 * PLEASE NOTE: This test uses manual stepping through messages
 *
 * @throws Exception
 */
@Test
public void testNodeInitialization2() throws Exception {
    final AtomicInteger failCnt = new AtomicInteger(0);
    final AtomicInteger passCnt = new AtomicInteger(0);
    final AtomicInteger startCnt = new AtomicInteger(0);
    INDArray exp = Nd4j.create(new double[] { 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 1.00, 1.00, 1.00, 1.00, 1.00, 1.00, 1.00, 1.00, 1.00, 1.00, 2.00, 2.00, 2.00, 2.00, 2.00, 2.00, 2.00, 2.00, 2.00, 2.00 });
    final VoidConfiguration clientConf = VoidConfiguration.builder().unicastPort(34567).multicastPort(45678).numberOfShards(3).shardAddresses(localIPs).multicastNetwork("224.0.1.1").streamId(119).forcedRole(NodeRole.CLIENT).ttl(4).build();
    final VoidConfiguration shardConf1 = VoidConfiguration.builder().unicastPort(34567).multicastPort(45678).numberOfShards(3).streamId(119).shardAddresses(localIPs).multicastNetwork("224.0.1.1").ttl(4).build();
    final VoidConfiguration shardConf2 = // we'll never get anything on this port
    VoidConfiguration.builder().unicastPort(34569).multicastPort(45678).numberOfShards(3).streamId(119).shardAddresses(localIPs).multicastNetwork("224.0.1.1").ttl(4).build();
    final VoidConfiguration shardConf3 = // we'll never get anything on this port
    VoidConfiguration.builder().unicastPort(34570).multicastPort(45678).numberOfShards(3).streamId(119).shardAddresses(localIPs).multicastNetwork("224.0.1.1").ttl(4).build();
    VoidParameterServer clientNode = new VoidParameterServer(true);
    clientNode.setShardIndex((short) 0);
    clientNode.init(clientConf);
    clientNode.getTransport().launch(Transport.ThreadingModel.DEDICATED_THREADS);
    assertEquals(NodeRole.CLIENT, clientNode.getNodeRole());
    Thread[] threads = new Thread[3];
    final VoidConfiguration[] voidConfigurations = new VoidConfiguration[] { shardConf1, shardConf2, shardConf3 };
    VoidParameterServer[] shards = new VoidParameterServer[threads.length];
    for (int t = 0; t < threads.length; t++) {
        final int x = t;
        threads[t] = new Thread(() -> {
            shards[x] = new VoidParameterServer(true);
            shards[x].setShardIndex((short) x);
            shards[x].init(voidConfigurations[x]);
            shards[x].getTransport().launch(Transport.ThreadingModel.DEDICATED_THREADS);
            assertEquals(NodeRole.SHARD, shards[x].getNodeRole());
            startCnt.incrementAndGet();
            passCnt.incrementAndGet();
        });
        threads[t].setDaemon(true);
        threads[t].start();
    }
    // we block until all threads are really started before sending commands
    while (startCnt.get() < threads.length) Thread.sleep(500);
    // give additional time to start handlers
    Thread.sleep(1000);
    // now we'll send commands from Client, and we'll check how these messages will be handled
    DistributedInitializationMessage message = DistributedInitializationMessage.builder().numWords(100).columnsPerShard(10).seed(123).useHs(false).useNeg(true).vectorLength(100).build();
    log.info("MessageType: {}", message.getMessageType());
    clientNode.getTransport().sendMessage(message);
    // now we check message queue within Shards
    for (int t = 0; t < threads.length; t++) {
        VoidMessage incMessage = shards[t].getTransport().takeMessage();
        assertNotEquals("Failed for shard " + t, null, incMessage);
        assertEquals("Failed for shard " + t, message.getMessageType(), incMessage.getMessageType());
        // we should put message back to corresponding
        shards[t].getTransport().putMessage(incMessage);
    }
    for (int t = 0; t < threads.length; t++) {
        VoidMessage incMessage = shards[t].getTransport().takeMessage();
        assertNotEquals("Failed for shard " + t, null, incMessage);
        shards[t].handleMessage(message);
        /**
         * Now we're checking how data storage was initialized
         */
        assertEquals(null, shards[t].getNegTable());
        assertEquals(null, shards[t].getSyn1());
        assertNotEquals(null, shards[t].getExpTable());
        assertNotEquals(null, shards[t].getSyn0());
        assertNotEquals(null, shards[t].getSyn1Neg());
    }
    // now we'll check passing for negTable, but please note - we're not sending it right now
    INDArray negTable = Nd4j.create(100000).assign(12.0f);
    DistributedSolidMessage negMessage = new DistributedSolidMessage(WordVectorStorage.NEGATIVE_TABLE, negTable, false);
    for (int t = 0; t < threads.length; t++) {
        shards[t].handleMessage(negMessage);
        assertNotEquals(null, shards[t].getNegTable());
        assertEquals(negTable, shards[t].getNegTable());
    }
    // now we assign each row to something
    for (int t = 0; t < threads.length; t++) {
        shards[t].handleMessage(new DistributedAssignMessage(WordVectorStorage.SYN_0, 1, (double) t));
        assertEquals(Nd4j.create(message.getColumnsPerShard()).assign((double) t), shards[t].getSyn0().getRow(1));
    }
    // and now we'll request for aggregated vector for row 1
    clientNode.getVector(1);
    VoidMessage vecm = shards[0].getTransport().takeMessage();
    assertEquals(7, vecm.getMessageType());
    VectorRequestMessage vrm = (VectorRequestMessage) vecm;
    assertEquals(1, vrm.getRowIndex());
    shards[0].handleMessage(vecm);
    Thread.sleep(100);
    // at this moment all 3 shards should already have distributed message
    for (int t = 0; t < threads.length; t++) {
        VoidMessage dm = shards[t].getTransport().takeMessage();
        assertEquals(20, dm.getMessageType());
        shards[t].handleMessage(dm);
    }
    // at this moment we should have messages propagated across all shards
    Thread.sleep(100);
    for (int t = threads.length - 1; t >= 0; t--) {
        VoidMessage msg;
        while ((msg = shards[t].getTransport().takeMessage()) != null) {
            shards[t].handleMessage(msg);
        }
    }
    // and at this moment, Shard_0 should contain aggregated vector for us
    assertEquals(true, shards[0].clipboard.isTracking(0L, 1L));
    assertEquals(true, shards[0].clipboard.isReady(0L, 1L));
    INDArray jointVector = shards[0].clipboard.nextCandidate().getAccumulatedResult();
    log.info("Joint vector: {}", jointVector);
    assertEquals(exp, jointVector);
    // first, we're setting data to something predefined
    for (int t = 0; t < threads.length; t++) {
        shards[t].handleMessage(new DistributedAssignMessage(WordVectorStorage.SYN_0, 0, 0.0));
        shards[t].handleMessage(new DistributedAssignMessage(WordVectorStorage.SYN_0, 1, 1.0));
        shards[t].handleMessage(new DistributedAssignMessage(WordVectorStorage.SYN_0, 2, 2.0));
        shards[t].handleMessage(new DistributedAssignMessage(WordVectorStorage.SYN_1_NEGATIVE, 0, 0.0));
        shards[t].handleMessage(new DistributedAssignMessage(WordVectorStorage.SYN_1_NEGATIVE, 1, 1.0));
        shards[t].handleMessage(new DistributedAssignMessage(WordVectorStorage.SYN_1_NEGATIVE, 2, 2.0));
    }
    DistributedSgDotMessage ddot = new DistributedSgDotMessage(2L, new int[] { 0, 1, 2 }, new int[] { 0, 1, 2 }, 0, 1, new byte[] { 0, 1 }, true, (short) 0, 0.01f);
    for (int t = 0; t < threads.length; t++) {
        shards[t].handleMessage(ddot);
    }
    Thread.sleep(100);
    for (int t = threads.length - 1; t >= 0; t--) {
        VoidMessage msg;
        while ((msg = shards[t].getTransport().takeMessage()) != null) {
            shards[t].handleMessage(msg);
        }
    }
    // at this moment ot should be caclulated everywhere
    exp = Nd4j.create(new double[] { 0.0, 30.0, 120.0 });
    for (int t = 0; t < threads.length; t++) {
        assertEquals(true, shards[t].clipboard.isReady(0L, 2L));
        DotAggregation dot = (DotAggregation) shards[t].clipboard.unpin(0L, 2L);
        INDArray aggregated = dot.getAccumulatedResult();
        assertEquals(exp, aggregated);
    }
    for (int t = 0; t < threads.length; t++) {
        threads[t].join();
    }
    for (int t = 0; t < threads.length; t++) {
        shards[t].shutdown();
    }
    assertEquals(threads.length, passCnt.get());
    for (VoidParameterServer server : shards) {
        server.shutdown();
    }
    clientNode.shutdown();
}
Also used : VoidConfiguration(org.nd4j.parameterserver.distributed.conf.VoidConfiguration) DistributedSolidMessage(org.nd4j.parameterserver.distributed.messages.intercom.DistributedSolidMessage) VectorRequestMessage(org.nd4j.parameterserver.distributed.messages.requests.VectorRequestMessage) DistributedAssignMessage(org.nd4j.parameterserver.distributed.messages.intercom.DistributedAssignMessage) DotAggregation(org.nd4j.parameterserver.distributed.messages.aggregations.DotAggregation) INDArray(org.nd4j.linalg.api.ndarray.INDArray) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) DistributedInitializationMessage(org.nd4j.parameterserver.distributed.messages.intercom.DistributedInitializationMessage) DistributedSgDotMessage(org.nd4j.parameterserver.distributed.messages.intercom.DistributedSgDotMessage) Test(org.junit.Test)

Example 2 with DistributedSgDotMessage

use of org.nd4j.parameterserver.distributed.messages.intercom.DistributedSgDotMessage in project nd4j by deeplearning4j.

the class SkipGramTrainer method startTraining.

@Override
public void startTraining(SkipGramRequestMessage message) {
    /**
     * All we do right HERE - is dot calculation start
     */
    /**
     * If we're on HS, we know pairs in advance: it's our points.
     */
    // log.info("sI_{} adding SkipGramChain originator: {}; frame: {}; task: {}", transport.getShardIndex(), message.getOriginatorId(), message.getFrameId(), message.getTaskId());
    SkipGramChain chain = new SkipGramChain(message.getOriginatorId(), message.getTaskId(), message.getFrameId());
    chain.addElement(message);
    // log.info("Starting chain [{}]", chain.getTaskId());
    chains.put(RequestDescriptor.createDescriptor(message.getOriginatorId(), message.getTaskId()), chain);
    // we assume this is HS round
    // if (message.getPoints() != null && message.getPoints().length > 0) {
    // replicate(message.getW2(), message.getPoints().length);
    int[] row_syn0 = new int[0];
    int[] row_syn1 = message.getPoints();
    if (message.getNegSamples() > 0) {
        int rows = storage.getArray(WordVectorStorage.SYN_0).rows();
        int[] tempArray = new int[message.getNegSamples() + 1];
        tempArray[0] = message.getW1();
        for (int e = 1; e < message.getNegSamples() + 1; e++) {
            while (true) {
                int rnd = RandomUtils.nextInt(0, rows);
                if (rnd != message.getW1()) {
                    tempArray[e] = rnd;
                    break;
                }
            }
        }
        row_syn1 = ArrayUtils.addAll(row_syn1, tempArray);
        message.setNegatives(tempArray);
    }
    if (message.getPoints().length != message.getCodes().length)
        throw new RuntimeException("Mismatiching points/codes lengths here!");
    // FIXME: taskId should be real here, since it'll be used for task chain tracking
    // as result, we'll have aggregated dot as single ordered column, which might be used for gradient calculation
    DistributedSgDotMessage ddm = new DistributedSgDotMessage(message.getTaskId(), row_syn0, row_syn1, message.getW1(), message.getW2(), message.getCodes(), message.getCodes() != null && message.getCodes().length > 0, message.getNegSamples(), (float) message.getAlpha());
    ddm.setTargetId((short) -1);
    ddm.setOriginatorId(message.getOriginatorId());
    if (voidConfiguration.getExecutionMode() == ExecutionMode.AVERAGING) {
        transport.putMessage(ddm);
    } else if (voidConfiguration.getExecutionMode() == ExecutionMode.SHARDED) {
        transport.sendMessage(ddm);
    }
// } //else log.info("sI_{} Skipping step: {}", transport.getShardIndex(), chain.getTaskId());
}
Also used : SkipGramChain(org.nd4j.parameterserver.distributed.training.chains.SkipGramChain) DistributedSgDotMessage(org.nd4j.parameterserver.distributed.messages.intercom.DistributedSgDotMessage)

Aggregations

DistributedSgDotMessage (org.nd4j.parameterserver.distributed.messages.intercom.DistributedSgDotMessage)2 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)1 Test (org.junit.Test)1 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1 VoidConfiguration (org.nd4j.parameterserver.distributed.conf.VoidConfiguration)1 DotAggregation (org.nd4j.parameterserver.distributed.messages.aggregations.DotAggregation)1 DistributedAssignMessage (org.nd4j.parameterserver.distributed.messages.intercom.DistributedAssignMessage)1 DistributedInitializationMessage (org.nd4j.parameterserver.distributed.messages.intercom.DistributedInitializationMessage)1 DistributedSolidMessage (org.nd4j.parameterserver.distributed.messages.intercom.DistributedSolidMessage)1 VectorRequestMessage (org.nd4j.parameterserver.distributed.messages.requests.VectorRequestMessage)1 SkipGramChain (org.nd4j.parameterserver.distributed.training.chains.SkipGramChain)1