Search in sources :

Example 1 with AssignRequestMessage

use of org.nd4j.parameterserver.distributed.messages.requests.AssignRequestMessage in project nd4j by deeplearning4j.

the class VoidParameterServerTest method testNodeInitialization3.

/**
 * PLEASE NOTE: This test uses automatic feeding through messages
 *
 * @throws Exception
 */
@Test
public void testNodeInitialization3() throws Exception {
    final AtomicInteger failCnt = new AtomicInteger(0);
    final AtomicInteger passCnt = new AtomicInteger(0);
    final AtomicInteger startCnt = new AtomicInteger(0);
    Nd4j.create(1);
    final VoidConfiguration clientConf = VoidConfiguration.builder().unicastPort(34567).multicastPort(45678).numberOfShards(3).shardAddresses(localIPs).multicastNetwork("224.0.1.1").streamId(119).forcedRole(NodeRole.CLIENT).ttl(4).build();
    final VoidConfiguration shardConf1 = VoidConfiguration.builder().unicastPort(34567).multicastPort(45678).numberOfShards(3).streamId(119).shardAddresses(localIPs).multicastNetwork("224.0.1.1").ttl(4).build();
    final VoidConfiguration shardConf2 = // we'll never get anything on this port
    VoidConfiguration.builder().unicastPort(34569).multicastPort(45678).numberOfShards(3).streamId(119).shardAddresses(localIPs).multicastNetwork("224.0.1.1").ttl(4).build();
    final VoidConfiguration shardConf3 = // we'll never get anything on this port
    VoidConfiguration.builder().unicastPort(34570).multicastPort(45678).numberOfShards(3).streamId(119).shardAddresses(localIPs).multicastNetwork("224.0.1.1").ttl(4).build();
    VoidParameterServer clientNode = new VoidParameterServer();
    clientNode.setShardIndex((short) 0);
    clientNode.init(clientConf);
    clientNode.getTransport().launch(Transport.ThreadingModel.DEDICATED_THREADS);
    assertEquals(NodeRole.CLIENT, clientNode.getNodeRole());
    Thread[] threads = new Thread[3];
    final VoidConfiguration[] voidConfigurations = new VoidConfiguration[] { shardConf1, shardConf2, shardConf3 };
    VoidParameterServer[] shards = new VoidParameterServer[threads.length];
    final AtomicBoolean runner = new AtomicBoolean(true);
    for (int t = 0; t < threads.length; t++) {
        final int x = t;
        threads[t] = new Thread(() -> {
            shards[x] = new VoidParameterServer();
            shards[x].setShardIndex((short) x);
            shards[x].init(voidConfigurations[x]);
            shards[x].getTransport().launch(Transport.ThreadingModel.DEDICATED_THREADS);
            assertEquals(NodeRole.SHARD, shards[x].getNodeRole());
            startCnt.incrementAndGet();
            try {
                while (runner.get()) Thread.sleep(100);
            } catch (Exception e) {
            }
        });
        threads[t].setDaemon(true);
        threads[t].start();
    }
    // waiting till all shards are initialized
    while (startCnt.get() < threads.length) Thread.sleep(20);
    InitializationRequestMessage irm = InitializationRequestMessage.builder().numWords(100).columnsPerShard(50).seed(123).useHs(true).useNeg(false).vectorLength(150).build();
    // after this point we'll assume all Shards are initialized
    // mostly because Init message is blocking
    clientNode.getTransport().sendMessage(irm);
    log.info("------------------");
    AssignRequestMessage arm = new AssignRequestMessage(WordVectorStorage.SYN_0, 192f, 11);
    clientNode.getTransport().sendMessage(arm);
    Thread.sleep(1000);
    // This is blocking method
    INDArray vec = clientNode.getVector(WordVectorStorage.SYN_0, 11);
    assertEquals(Nd4j.create(150).assign(192f), vec);
    // now we go for gradients-like test
    // first of all we set exptable to something predictable
    INDArray expSyn0 = Nd4j.create(150).assign(0.01f);
    INDArray expSyn1_1 = Nd4j.create(150).assign(0.020005);
    INDArray expSyn1_2 = Nd4j.create(150).assign(0.019995f);
    INDArray expTable = Nd4j.create(10000).assign(0.5f);
    AssignRequestMessage expReqMsg = new AssignRequestMessage(WordVectorStorage.EXP_TABLE, expTable);
    clientNode.getTransport().sendMessage(expReqMsg);
    arm = new AssignRequestMessage(WordVectorStorage.SYN_0, 0.01, -1);
    clientNode.getTransport().sendMessage(arm);
    arm = new AssignRequestMessage(WordVectorStorage.SYN_1, 0.02, -1);
    clientNode.getTransport().sendMessage(arm);
    Thread.sleep(500);
    // no we'll send single SkipGram request that involves calculation for 0 -> {1,2}, and will check result against pre-calculated values
    SkipGramRequestMessage sgrm = new SkipGramRequestMessage(0, 1, new int[] { 1, 2 }, new byte[] { 0, 1 }, (short) 0, 0.001, 119L);
    clientNode.getTransport().sendMessage(sgrm);
    // TODO: we might want to introduce optional CompletedMessage here
    // now we just wait till everything is finished
    Thread.sleep(1000);
    // This is blocking method
    INDArray row_syn0 = clientNode.getVector(WordVectorStorage.SYN_0, 0);
    INDArray row_syn1_1 = clientNode.getVector(WordVectorStorage.SYN_1, 1);
    INDArray row_syn1_2 = clientNode.getVector(WordVectorStorage.SYN_1, 2);
    assertEquals(expSyn0, row_syn0);
    assertArrayEquals(expSyn1_1.data().asFloat(), row_syn1_1.data().asFloat(), 1e-6f);
    assertArrayEquals(expSyn1_2.data().asFloat(), row_syn1_2.data().asFloat(), 1e-6f);
    runner.set(false);
    for (int t = 0; t < threads.length; t++) {
        threads[t].join();
    }
    for (VoidParameterServer server : shards) {
        server.shutdown();
    }
    clientNode.shutdown();
}
Also used : InitializationRequestMessage(org.nd4j.parameterserver.distributed.messages.requests.InitializationRequestMessage) VoidConfiguration(org.nd4j.parameterserver.distributed.conf.VoidConfiguration) AtomicBoolean(java.util.concurrent.atomic.AtomicBoolean) INDArray(org.nd4j.linalg.api.ndarray.INDArray) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) AssignRequestMessage(org.nd4j.parameterserver.distributed.messages.requests.AssignRequestMessage) SkipGramRequestMessage(org.nd4j.parameterserver.distributed.messages.requests.SkipGramRequestMessage) Test(org.junit.Test)

Aggregations

AtomicBoolean (java.util.concurrent.atomic.AtomicBoolean)1 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)1 Test (org.junit.Test)1 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1 VoidConfiguration (org.nd4j.parameterserver.distributed.conf.VoidConfiguration)1 AssignRequestMessage (org.nd4j.parameterserver.distributed.messages.requests.AssignRequestMessage)1 InitializationRequestMessage (org.nd4j.parameterserver.distributed.messages.requests.InitializationRequestMessage)1 SkipGramRequestMessage (org.nd4j.parameterserver.distributed.messages.requests.SkipGramRequestMessage)1