Search in sources :

Example 1 with DistributedSolidMessage

use of org.nd4j.parameterserver.distributed.messages.intercom.DistributedSolidMessage in project nd4j by deeplearning4j.

the class VoidParameterServerTest method testNodeInitialization2.

/**
 * This is very important test, it covers basic messages handling over network.
 * Here we have 1 client, 1 connected Shard + 2 shards available over multicast UDP
 *
 * PLEASE NOTE: This test uses manual stepping through messages
 *
 * @throws Exception
 */
@Test
public void testNodeInitialization2() throws Exception {
    final AtomicInteger failCnt = new AtomicInteger(0);
    final AtomicInteger passCnt = new AtomicInteger(0);
    final AtomicInteger startCnt = new AtomicInteger(0);
    INDArray exp = Nd4j.create(new double[] { 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 1.00, 1.00, 1.00, 1.00, 1.00, 1.00, 1.00, 1.00, 1.00, 1.00, 2.00, 2.00, 2.00, 2.00, 2.00, 2.00, 2.00, 2.00, 2.00, 2.00 });
    final VoidConfiguration clientConf = VoidConfiguration.builder().unicastPort(34567).multicastPort(45678).numberOfShards(3).shardAddresses(localIPs).multicastNetwork("224.0.1.1").streamId(119).forcedRole(NodeRole.CLIENT).ttl(4).build();
    final VoidConfiguration shardConf1 = VoidConfiguration.builder().unicastPort(34567).multicastPort(45678).numberOfShards(3).streamId(119).shardAddresses(localIPs).multicastNetwork("224.0.1.1").ttl(4).build();
    final VoidConfiguration shardConf2 = // we'll never get anything on this port
    VoidConfiguration.builder().unicastPort(34569).multicastPort(45678).numberOfShards(3).streamId(119).shardAddresses(localIPs).multicastNetwork("224.0.1.1").ttl(4).build();
    final VoidConfiguration shardConf3 = // we'll never get anything on this port
    VoidConfiguration.builder().unicastPort(34570).multicastPort(45678).numberOfShards(3).streamId(119).shardAddresses(localIPs).multicastNetwork("224.0.1.1").ttl(4).build();
    VoidParameterServer clientNode = new VoidParameterServer(true);
    clientNode.setShardIndex((short) 0);
    clientNode.init(clientConf);
    clientNode.getTransport().launch(Transport.ThreadingModel.DEDICATED_THREADS);
    assertEquals(NodeRole.CLIENT, clientNode.getNodeRole());
    Thread[] threads = new Thread[3];
    final VoidConfiguration[] voidConfigurations = new VoidConfiguration[] { shardConf1, shardConf2, shardConf3 };
    VoidParameterServer[] shards = new VoidParameterServer[threads.length];
    for (int t = 0; t < threads.length; t++) {
        final int x = t;
        threads[t] = new Thread(() -> {
            shards[x] = new VoidParameterServer(true);
            shards[x].setShardIndex((short) x);
            shards[x].init(voidConfigurations[x]);
            shards[x].getTransport().launch(Transport.ThreadingModel.DEDICATED_THREADS);
            assertEquals(NodeRole.SHARD, shards[x].getNodeRole());
            startCnt.incrementAndGet();
            passCnt.incrementAndGet();
        });
        threads[t].setDaemon(true);
        threads[t].start();
    }
    // we block until all threads are really started before sending commands
    while (startCnt.get() < threads.length) Thread.sleep(500);
    // give additional time to start handlers
    Thread.sleep(1000);
    // now we'll send commands from Client, and we'll check how these messages will be handled
    DistributedInitializationMessage message = DistributedInitializationMessage.builder().numWords(100).columnsPerShard(10).seed(123).useHs(false).useNeg(true).vectorLength(100).build();
    log.info("MessageType: {}", message.getMessageType());
    clientNode.getTransport().sendMessage(message);
    // now we check message queue within Shards
    for (int t = 0; t < threads.length; t++) {
        VoidMessage incMessage = shards[t].getTransport().takeMessage();
        assertNotEquals("Failed for shard " + t, null, incMessage);
        assertEquals("Failed for shard " + t, message.getMessageType(), incMessage.getMessageType());
        // we should put message back to corresponding
        shards[t].getTransport().putMessage(incMessage);
    }
    for (int t = 0; t < threads.length; t++) {
        VoidMessage incMessage = shards[t].getTransport().takeMessage();
        assertNotEquals("Failed for shard " + t, null, incMessage);
        shards[t].handleMessage(message);
        /**
         * Now we're checking how data storage was initialized
         */
        assertEquals(null, shards[t].getNegTable());
        assertEquals(null, shards[t].getSyn1());
        assertNotEquals(null, shards[t].getExpTable());
        assertNotEquals(null, shards[t].getSyn0());
        assertNotEquals(null, shards[t].getSyn1Neg());
    }
    // now we'll check passing for negTable, but please note - we're not sending it right now
    INDArray negTable = Nd4j.create(100000).assign(12.0f);
    DistributedSolidMessage negMessage = new DistributedSolidMessage(WordVectorStorage.NEGATIVE_TABLE, negTable, false);
    for (int t = 0; t < threads.length; t++) {
        shards[t].handleMessage(negMessage);
        assertNotEquals(null, shards[t].getNegTable());
        assertEquals(negTable, shards[t].getNegTable());
    }
    // now we assign each row to something
    for (int t = 0; t < threads.length; t++) {
        shards[t].handleMessage(new DistributedAssignMessage(WordVectorStorage.SYN_0, 1, (double) t));
        assertEquals(Nd4j.create(message.getColumnsPerShard()).assign((double) t), shards[t].getSyn0().getRow(1));
    }
    // and now we'll request for aggregated vector for row 1
    clientNode.getVector(1);
    VoidMessage vecm = shards[0].getTransport().takeMessage();
    assertEquals(7, vecm.getMessageType());
    VectorRequestMessage vrm = (VectorRequestMessage) vecm;
    assertEquals(1, vrm.getRowIndex());
    shards[0].handleMessage(vecm);
    Thread.sleep(100);
    // at this moment all 3 shards should already have distributed message
    for (int t = 0; t < threads.length; t++) {
        VoidMessage dm = shards[t].getTransport().takeMessage();
        assertEquals(20, dm.getMessageType());
        shards[t].handleMessage(dm);
    }
    // at this moment we should have messages propagated across all shards
    Thread.sleep(100);
    for (int t = threads.length - 1; t >= 0; t--) {
        VoidMessage msg;
        while ((msg = shards[t].getTransport().takeMessage()) != null) {
            shards[t].handleMessage(msg);
        }
    }
    // and at this moment, Shard_0 should contain aggregated vector for us
    assertEquals(true, shards[0].clipboard.isTracking(0L, 1L));
    assertEquals(true, shards[0].clipboard.isReady(0L, 1L));
    INDArray jointVector = shards[0].clipboard.nextCandidate().getAccumulatedResult();
    log.info("Joint vector: {}", jointVector);
    assertEquals(exp, jointVector);
    // first, we're setting data to something predefined
    for (int t = 0; t < threads.length; t++) {
        shards[t].handleMessage(new DistributedAssignMessage(WordVectorStorage.SYN_0, 0, 0.0));
        shards[t].handleMessage(new DistributedAssignMessage(WordVectorStorage.SYN_0, 1, 1.0));
        shards[t].handleMessage(new DistributedAssignMessage(WordVectorStorage.SYN_0, 2, 2.0));
        shards[t].handleMessage(new DistributedAssignMessage(WordVectorStorage.SYN_1_NEGATIVE, 0, 0.0));
        shards[t].handleMessage(new DistributedAssignMessage(WordVectorStorage.SYN_1_NEGATIVE, 1, 1.0));
        shards[t].handleMessage(new DistributedAssignMessage(WordVectorStorage.SYN_1_NEGATIVE, 2, 2.0));
    }
    DistributedSgDotMessage ddot = new DistributedSgDotMessage(2L, new int[] { 0, 1, 2 }, new int[] { 0, 1, 2 }, 0, 1, new byte[] { 0, 1 }, true, (short) 0, 0.01f);
    for (int t = 0; t < threads.length; t++) {
        shards[t].handleMessage(ddot);
    }
    Thread.sleep(100);
    for (int t = threads.length - 1; t >= 0; t--) {
        VoidMessage msg;
        while ((msg = shards[t].getTransport().takeMessage()) != null) {
            shards[t].handleMessage(msg);
        }
    }
    // at this moment ot should be caclulated everywhere
    exp = Nd4j.create(new double[] { 0.0, 30.0, 120.0 });
    for (int t = 0; t < threads.length; t++) {
        assertEquals(true, shards[t].clipboard.isReady(0L, 2L));
        DotAggregation dot = (DotAggregation) shards[t].clipboard.unpin(0L, 2L);
        INDArray aggregated = dot.getAccumulatedResult();
        assertEquals(exp, aggregated);
    }
    for (int t = 0; t < threads.length; t++) {
        threads[t].join();
    }
    for (int t = 0; t < threads.length; t++) {
        shards[t].shutdown();
    }
    assertEquals(threads.length, passCnt.get());
    for (VoidParameterServer server : shards) {
        server.shutdown();
    }
    clientNode.shutdown();
}
Also used : VoidConfiguration(org.nd4j.parameterserver.distributed.conf.VoidConfiguration) DistributedSolidMessage(org.nd4j.parameterserver.distributed.messages.intercom.DistributedSolidMessage) VectorRequestMessage(org.nd4j.parameterserver.distributed.messages.requests.VectorRequestMessage) DistributedAssignMessage(org.nd4j.parameterserver.distributed.messages.intercom.DistributedAssignMessage) DotAggregation(org.nd4j.parameterserver.distributed.messages.aggregations.DotAggregation) INDArray(org.nd4j.linalg.api.ndarray.INDArray) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) DistributedInitializationMessage(org.nd4j.parameterserver.distributed.messages.intercom.DistributedInitializationMessage) DistributedSgDotMessage(org.nd4j.parameterserver.distributed.messages.intercom.DistributedSgDotMessage) Test(org.junit.Test)

Aggregations

AtomicInteger (java.util.concurrent.atomic.AtomicInteger)1 Test (org.junit.Test)1 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1 VoidConfiguration (org.nd4j.parameterserver.distributed.conf.VoidConfiguration)1 DotAggregation (org.nd4j.parameterserver.distributed.messages.aggregations.DotAggregation)1 DistributedAssignMessage (org.nd4j.parameterserver.distributed.messages.intercom.DistributedAssignMessage)1 DistributedInitializationMessage (org.nd4j.parameterserver.distributed.messages.intercom.DistributedInitializationMessage)1 DistributedSgDotMessage (org.nd4j.parameterserver.distributed.messages.intercom.DistributedSgDotMessage)1 DistributedSolidMessage (org.nd4j.parameterserver.distributed.messages.intercom.DistributedSolidMessage)1 VectorRequestMessage (org.nd4j.parameterserver.distributed.messages.requests.VectorRequestMessage)1