use of org.nd4j.parameterserver.distributed.messages.intercom.DistributedSolidMessage in project nd4j by deeplearning4j.
the class VoidParameterServerTest method testNodeInitialization2.
/**
* This is very important test, it covers basic messages handling over network.
* Here we have 1 client, 1 connected Shard + 2 shards available over multicast UDP
*
* PLEASE NOTE: This test uses manual stepping through messages
*
* @throws Exception
*/
@Test
public void testNodeInitialization2() throws Exception {
final AtomicInteger failCnt = new AtomicInteger(0);
final AtomicInteger passCnt = new AtomicInteger(0);
final AtomicInteger startCnt = new AtomicInteger(0);
INDArray exp = Nd4j.create(new double[] { 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 1.00, 1.00, 1.00, 1.00, 1.00, 1.00, 1.00, 1.00, 1.00, 1.00, 2.00, 2.00, 2.00, 2.00, 2.00, 2.00, 2.00, 2.00, 2.00, 2.00 });
final VoidConfiguration clientConf = VoidConfiguration.builder().unicastPort(34567).multicastPort(45678).numberOfShards(3).shardAddresses(localIPs).multicastNetwork("224.0.1.1").streamId(119).forcedRole(NodeRole.CLIENT).ttl(4).build();
final VoidConfiguration shardConf1 = VoidConfiguration.builder().unicastPort(34567).multicastPort(45678).numberOfShards(3).streamId(119).shardAddresses(localIPs).multicastNetwork("224.0.1.1").ttl(4).build();
final VoidConfiguration shardConf2 = // we'll never get anything on this port
VoidConfiguration.builder().unicastPort(34569).multicastPort(45678).numberOfShards(3).streamId(119).shardAddresses(localIPs).multicastNetwork("224.0.1.1").ttl(4).build();
final VoidConfiguration shardConf3 = // we'll never get anything on this port
VoidConfiguration.builder().unicastPort(34570).multicastPort(45678).numberOfShards(3).streamId(119).shardAddresses(localIPs).multicastNetwork("224.0.1.1").ttl(4).build();
VoidParameterServer clientNode = new VoidParameterServer(true);
clientNode.setShardIndex((short) 0);
clientNode.init(clientConf);
clientNode.getTransport().launch(Transport.ThreadingModel.DEDICATED_THREADS);
assertEquals(NodeRole.CLIENT, clientNode.getNodeRole());
Thread[] threads = new Thread[3];
final VoidConfiguration[] voidConfigurations = new VoidConfiguration[] { shardConf1, shardConf2, shardConf3 };
VoidParameterServer[] shards = new VoidParameterServer[threads.length];
for (int t = 0; t < threads.length; t++) {
final int x = t;
threads[t] = new Thread(() -> {
shards[x] = new VoidParameterServer(true);
shards[x].setShardIndex((short) x);
shards[x].init(voidConfigurations[x]);
shards[x].getTransport().launch(Transport.ThreadingModel.DEDICATED_THREADS);
assertEquals(NodeRole.SHARD, shards[x].getNodeRole());
startCnt.incrementAndGet();
passCnt.incrementAndGet();
});
threads[t].setDaemon(true);
threads[t].start();
}
// we block until all threads are really started before sending commands
while (startCnt.get() < threads.length) Thread.sleep(500);
// give additional time to start handlers
Thread.sleep(1000);
// now we'll send commands from Client, and we'll check how these messages will be handled
DistributedInitializationMessage message = DistributedInitializationMessage.builder().numWords(100).columnsPerShard(10).seed(123).useHs(false).useNeg(true).vectorLength(100).build();
log.info("MessageType: {}", message.getMessageType());
clientNode.getTransport().sendMessage(message);
// now we check message queue within Shards
for (int t = 0; t < threads.length; t++) {
VoidMessage incMessage = shards[t].getTransport().takeMessage();
assertNotEquals("Failed for shard " + t, null, incMessage);
assertEquals("Failed for shard " + t, message.getMessageType(), incMessage.getMessageType());
// we should put message back to corresponding
shards[t].getTransport().putMessage(incMessage);
}
for (int t = 0; t < threads.length; t++) {
VoidMessage incMessage = shards[t].getTransport().takeMessage();
assertNotEquals("Failed for shard " + t, null, incMessage);
shards[t].handleMessage(message);
/**
* Now we're checking how data storage was initialized
*/
assertEquals(null, shards[t].getNegTable());
assertEquals(null, shards[t].getSyn1());
assertNotEquals(null, shards[t].getExpTable());
assertNotEquals(null, shards[t].getSyn0());
assertNotEquals(null, shards[t].getSyn1Neg());
}
// now we'll check passing for negTable, but please note - we're not sending it right now
INDArray negTable = Nd4j.create(100000).assign(12.0f);
DistributedSolidMessage negMessage = new DistributedSolidMessage(WordVectorStorage.NEGATIVE_TABLE, negTable, false);
for (int t = 0; t < threads.length; t++) {
shards[t].handleMessage(negMessage);
assertNotEquals(null, shards[t].getNegTable());
assertEquals(negTable, shards[t].getNegTable());
}
// now we assign each row to something
for (int t = 0; t < threads.length; t++) {
shards[t].handleMessage(new DistributedAssignMessage(WordVectorStorage.SYN_0, 1, (double) t));
assertEquals(Nd4j.create(message.getColumnsPerShard()).assign((double) t), shards[t].getSyn0().getRow(1));
}
// and now we'll request for aggregated vector for row 1
clientNode.getVector(1);
VoidMessage vecm = shards[0].getTransport().takeMessage();
assertEquals(7, vecm.getMessageType());
VectorRequestMessage vrm = (VectorRequestMessage) vecm;
assertEquals(1, vrm.getRowIndex());
shards[0].handleMessage(vecm);
Thread.sleep(100);
// at this moment all 3 shards should already have distributed message
for (int t = 0; t < threads.length; t++) {
VoidMessage dm = shards[t].getTransport().takeMessage();
assertEquals(20, dm.getMessageType());
shards[t].handleMessage(dm);
}
// at this moment we should have messages propagated across all shards
Thread.sleep(100);
for (int t = threads.length - 1; t >= 0; t--) {
VoidMessage msg;
while ((msg = shards[t].getTransport().takeMessage()) != null) {
shards[t].handleMessage(msg);
}
}
// and at this moment, Shard_0 should contain aggregated vector for us
assertEquals(true, shards[0].clipboard.isTracking(0L, 1L));
assertEquals(true, shards[0].clipboard.isReady(0L, 1L));
INDArray jointVector = shards[0].clipboard.nextCandidate().getAccumulatedResult();
log.info("Joint vector: {}", jointVector);
assertEquals(exp, jointVector);
// first, we're setting data to something predefined
for (int t = 0; t < threads.length; t++) {
shards[t].handleMessage(new DistributedAssignMessage(WordVectorStorage.SYN_0, 0, 0.0));
shards[t].handleMessage(new DistributedAssignMessage(WordVectorStorage.SYN_0, 1, 1.0));
shards[t].handleMessage(new DistributedAssignMessage(WordVectorStorage.SYN_0, 2, 2.0));
shards[t].handleMessage(new DistributedAssignMessage(WordVectorStorage.SYN_1_NEGATIVE, 0, 0.0));
shards[t].handleMessage(new DistributedAssignMessage(WordVectorStorage.SYN_1_NEGATIVE, 1, 1.0));
shards[t].handleMessage(new DistributedAssignMessage(WordVectorStorage.SYN_1_NEGATIVE, 2, 2.0));
}
DistributedSgDotMessage ddot = new DistributedSgDotMessage(2L, new int[] { 0, 1, 2 }, new int[] { 0, 1, 2 }, 0, 1, new byte[] { 0, 1 }, true, (short) 0, 0.01f);
for (int t = 0; t < threads.length; t++) {
shards[t].handleMessage(ddot);
}
Thread.sleep(100);
for (int t = threads.length - 1; t >= 0; t--) {
VoidMessage msg;
while ((msg = shards[t].getTransport().takeMessage()) != null) {
shards[t].handleMessage(msg);
}
}
// at this moment ot should be caclulated everywhere
exp = Nd4j.create(new double[] { 0.0, 30.0, 120.0 });
for (int t = 0; t < threads.length; t++) {
assertEquals(true, shards[t].clipboard.isReady(0L, 2L));
DotAggregation dot = (DotAggregation) shards[t].clipboard.unpin(0L, 2L);
INDArray aggregated = dot.getAccumulatedResult();
assertEquals(exp, aggregated);
}
for (int t = 0; t < threads.length; t++) {
threads[t].join();
}
for (int t = 0; t < threads.length; t++) {
shards[t].shutdown();
}
assertEquals(threads.length, passCnt.get());
for (VoidParameterServer server : shards) {
server.shutdown();
}
clientNode.shutdown();
}
Aggregations