use of org.nd4j.parameterserver.distributed.messages.requests.AssignRequestMessage in project nd4j by deeplearning4j.
the class VoidParameterServerTest method testNodeInitialization3.
/**
* PLEASE NOTE: This test uses automatic feeding through messages
*
* @throws Exception
*/
@Test
public void testNodeInitialization3() throws Exception {
final AtomicInteger failCnt = new AtomicInteger(0);
final AtomicInteger passCnt = new AtomicInteger(0);
final AtomicInteger startCnt = new AtomicInteger(0);
Nd4j.create(1);
final VoidConfiguration clientConf = VoidConfiguration.builder().unicastPort(34567).multicastPort(45678).numberOfShards(3).shardAddresses(localIPs).multicastNetwork("224.0.1.1").streamId(119).forcedRole(NodeRole.CLIENT).ttl(4).build();
final VoidConfiguration shardConf1 = VoidConfiguration.builder().unicastPort(34567).multicastPort(45678).numberOfShards(3).streamId(119).shardAddresses(localIPs).multicastNetwork("224.0.1.1").ttl(4).build();
final VoidConfiguration shardConf2 = // we'll never get anything on this port
VoidConfiguration.builder().unicastPort(34569).multicastPort(45678).numberOfShards(3).streamId(119).shardAddresses(localIPs).multicastNetwork("224.0.1.1").ttl(4).build();
final VoidConfiguration shardConf3 = // we'll never get anything on this port
VoidConfiguration.builder().unicastPort(34570).multicastPort(45678).numberOfShards(3).streamId(119).shardAddresses(localIPs).multicastNetwork("224.0.1.1").ttl(4).build();
VoidParameterServer clientNode = new VoidParameterServer();
clientNode.setShardIndex((short) 0);
clientNode.init(clientConf);
clientNode.getTransport().launch(Transport.ThreadingModel.DEDICATED_THREADS);
assertEquals(NodeRole.CLIENT, clientNode.getNodeRole());
Thread[] threads = new Thread[3];
final VoidConfiguration[] voidConfigurations = new VoidConfiguration[] { shardConf1, shardConf2, shardConf3 };
VoidParameterServer[] shards = new VoidParameterServer[threads.length];
final AtomicBoolean runner = new AtomicBoolean(true);
for (int t = 0; t < threads.length; t++) {
final int x = t;
threads[t] = new Thread(() -> {
shards[x] = new VoidParameterServer();
shards[x].setShardIndex((short) x);
shards[x].init(voidConfigurations[x]);
shards[x].getTransport().launch(Transport.ThreadingModel.DEDICATED_THREADS);
assertEquals(NodeRole.SHARD, shards[x].getNodeRole());
startCnt.incrementAndGet();
try {
while (runner.get()) Thread.sleep(100);
} catch (Exception e) {
}
});
threads[t].setDaemon(true);
threads[t].start();
}
// waiting till all shards are initialized
while (startCnt.get() < threads.length) Thread.sleep(20);
InitializationRequestMessage irm = InitializationRequestMessage.builder().numWords(100).columnsPerShard(50).seed(123).useHs(true).useNeg(false).vectorLength(150).build();
// after this point we'll assume all Shards are initialized
// mostly because Init message is blocking
clientNode.getTransport().sendMessage(irm);
log.info("------------------");
AssignRequestMessage arm = new AssignRequestMessage(WordVectorStorage.SYN_0, 192f, 11);
clientNode.getTransport().sendMessage(arm);
Thread.sleep(1000);
// This is blocking method
INDArray vec = clientNode.getVector(WordVectorStorage.SYN_0, 11);
assertEquals(Nd4j.create(150).assign(192f), vec);
// now we go for gradients-like test
// first of all we set exptable to something predictable
INDArray expSyn0 = Nd4j.create(150).assign(0.01f);
INDArray expSyn1_1 = Nd4j.create(150).assign(0.020005);
INDArray expSyn1_2 = Nd4j.create(150).assign(0.019995f);
INDArray expTable = Nd4j.create(10000).assign(0.5f);
AssignRequestMessage expReqMsg = new AssignRequestMessage(WordVectorStorage.EXP_TABLE, expTable);
clientNode.getTransport().sendMessage(expReqMsg);
arm = new AssignRequestMessage(WordVectorStorage.SYN_0, 0.01, -1);
clientNode.getTransport().sendMessage(arm);
arm = new AssignRequestMessage(WordVectorStorage.SYN_1, 0.02, -1);
clientNode.getTransport().sendMessage(arm);
Thread.sleep(500);
// no we'll send single SkipGram request that involves calculation for 0 -> {1,2}, and will check result against pre-calculated values
SkipGramRequestMessage sgrm = new SkipGramRequestMessage(0, 1, new int[] { 1, 2 }, new byte[] { 0, 1 }, (short) 0, 0.001, 119L);
clientNode.getTransport().sendMessage(sgrm);
// TODO: we might want to introduce optional CompletedMessage here
// now we just wait till everything is finished
Thread.sleep(1000);
// This is blocking method
INDArray row_syn0 = clientNode.getVector(WordVectorStorage.SYN_0, 0);
INDArray row_syn1_1 = clientNode.getVector(WordVectorStorage.SYN_1, 1);
INDArray row_syn1_2 = clientNode.getVector(WordVectorStorage.SYN_1, 2);
assertEquals(expSyn0, row_syn0);
assertArrayEquals(expSyn1_1.data().asFloat(), row_syn1_1.data().asFloat(), 1e-6f);
assertArrayEquals(expSyn1_2.data().asFloat(), row_syn1_2.data().asFloat(), 1e-6f);
runner.set(false);
for (int t = 0; t < threads.length; t++) {
threads[t].join();
}
for (VoidParameterServer server : shards) {
server.shutdown();
}
clientNode.shutdown();
}
Aggregations