use of org.nd4j.parameterserver.distributed.messages.intercom.DistributedSgDotMessage in project nd4j by deeplearning4j.
the class VoidParameterServerTest method testNodeInitialization2.
/**
* This is very important test, it covers basic messages handling over network.
* Here we have 1 client, 1 connected Shard + 2 shards available over multicast UDP
*
* PLEASE NOTE: This test uses manual stepping through messages
*
* @throws Exception
*/
@Test
public void testNodeInitialization2() throws Exception {
final AtomicInteger failCnt = new AtomicInteger(0);
final AtomicInteger passCnt = new AtomicInteger(0);
final AtomicInteger startCnt = new AtomicInteger(0);
INDArray exp = Nd4j.create(new double[] { 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 1.00, 1.00, 1.00, 1.00, 1.00, 1.00, 1.00, 1.00, 1.00, 1.00, 2.00, 2.00, 2.00, 2.00, 2.00, 2.00, 2.00, 2.00, 2.00, 2.00 });
final VoidConfiguration clientConf = VoidConfiguration.builder().unicastPort(34567).multicastPort(45678).numberOfShards(3).shardAddresses(localIPs).multicastNetwork("224.0.1.1").streamId(119).forcedRole(NodeRole.CLIENT).ttl(4).build();
final VoidConfiguration shardConf1 = VoidConfiguration.builder().unicastPort(34567).multicastPort(45678).numberOfShards(3).streamId(119).shardAddresses(localIPs).multicastNetwork("224.0.1.1").ttl(4).build();
final VoidConfiguration shardConf2 = // we'll never get anything on this port
VoidConfiguration.builder().unicastPort(34569).multicastPort(45678).numberOfShards(3).streamId(119).shardAddresses(localIPs).multicastNetwork("224.0.1.1").ttl(4).build();
final VoidConfiguration shardConf3 = // we'll never get anything on this port
VoidConfiguration.builder().unicastPort(34570).multicastPort(45678).numberOfShards(3).streamId(119).shardAddresses(localIPs).multicastNetwork("224.0.1.1").ttl(4).build();
VoidParameterServer clientNode = new VoidParameterServer(true);
clientNode.setShardIndex((short) 0);
clientNode.init(clientConf);
clientNode.getTransport().launch(Transport.ThreadingModel.DEDICATED_THREADS);
assertEquals(NodeRole.CLIENT, clientNode.getNodeRole());
Thread[] threads = new Thread[3];
final VoidConfiguration[] voidConfigurations = new VoidConfiguration[] { shardConf1, shardConf2, shardConf3 };
VoidParameterServer[] shards = new VoidParameterServer[threads.length];
for (int t = 0; t < threads.length; t++) {
final int x = t;
threads[t] = new Thread(() -> {
shards[x] = new VoidParameterServer(true);
shards[x].setShardIndex((short) x);
shards[x].init(voidConfigurations[x]);
shards[x].getTransport().launch(Transport.ThreadingModel.DEDICATED_THREADS);
assertEquals(NodeRole.SHARD, shards[x].getNodeRole());
startCnt.incrementAndGet();
passCnt.incrementAndGet();
});
threads[t].setDaemon(true);
threads[t].start();
}
// we block until all threads are really started before sending commands
while (startCnt.get() < threads.length) Thread.sleep(500);
// give additional time to start handlers
Thread.sleep(1000);
// now we'll send commands from Client, and we'll check how these messages will be handled
DistributedInitializationMessage message = DistributedInitializationMessage.builder().numWords(100).columnsPerShard(10).seed(123).useHs(false).useNeg(true).vectorLength(100).build();
log.info("MessageType: {}", message.getMessageType());
clientNode.getTransport().sendMessage(message);
// now we check message queue within Shards
for (int t = 0; t < threads.length; t++) {
VoidMessage incMessage = shards[t].getTransport().takeMessage();
assertNotEquals("Failed for shard " + t, null, incMessage);
assertEquals("Failed for shard " + t, message.getMessageType(), incMessage.getMessageType());
// we should put message back to corresponding
shards[t].getTransport().putMessage(incMessage);
}
for (int t = 0; t < threads.length; t++) {
VoidMessage incMessage = shards[t].getTransport().takeMessage();
assertNotEquals("Failed for shard " + t, null, incMessage);
shards[t].handleMessage(message);
/**
* Now we're checking how data storage was initialized
*/
assertEquals(null, shards[t].getNegTable());
assertEquals(null, shards[t].getSyn1());
assertNotEquals(null, shards[t].getExpTable());
assertNotEquals(null, shards[t].getSyn0());
assertNotEquals(null, shards[t].getSyn1Neg());
}
// now we'll check passing for negTable, but please note - we're not sending it right now
INDArray negTable = Nd4j.create(100000).assign(12.0f);
DistributedSolidMessage negMessage = new DistributedSolidMessage(WordVectorStorage.NEGATIVE_TABLE, negTable, false);
for (int t = 0; t < threads.length; t++) {
shards[t].handleMessage(negMessage);
assertNotEquals(null, shards[t].getNegTable());
assertEquals(negTable, shards[t].getNegTable());
}
// now we assign each row to something
for (int t = 0; t < threads.length; t++) {
shards[t].handleMessage(new DistributedAssignMessage(WordVectorStorage.SYN_0, 1, (double) t));
assertEquals(Nd4j.create(message.getColumnsPerShard()).assign((double) t), shards[t].getSyn0().getRow(1));
}
// and now we'll request for aggregated vector for row 1
clientNode.getVector(1);
VoidMessage vecm = shards[0].getTransport().takeMessage();
assertEquals(7, vecm.getMessageType());
VectorRequestMessage vrm = (VectorRequestMessage) vecm;
assertEquals(1, vrm.getRowIndex());
shards[0].handleMessage(vecm);
Thread.sleep(100);
// at this moment all 3 shards should already have distributed message
for (int t = 0; t < threads.length; t++) {
VoidMessage dm = shards[t].getTransport().takeMessage();
assertEquals(20, dm.getMessageType());
shards[t].handleMessage(dm);
}
// at this moment we should have messages propagated across all shards
Thread.sleep(100);
for (int t = threads.length - 1; t >= 0; t--) {
VoidMessage msg;
while ((msg = shards[t].getTransport().takeMessage()) != null) {
shards[t].handleMessage(msg);
}
}
// and at this moment, Shard_0 should contain aggregated vector for us
assertEquals(true, shards[0].clipboard.isTracking(0L, 1L));
assertEquals(true, shards[0].clipboard.isReady(0L, 1L));
INDArray jointVector = shards[0].clipboard.nextCandidate().getAccumulatedResult();
log.info("Joint vector: {}", jointVector);
assertEquals(exp, jointVector);
// first, we're setting data to something predefined
for (int t = 0; t < threads.length; t++) {
shards[t].handleMessage(new DistributedAssignMessage(WordVectorStorage.SYN_0, 0, 0.0));
shards[t].handleMessage(new DistributedAssignMessage(WordVectorStorage.SYN_0, 1, 1.0));
shards[t].handleMessage(new DistributedAssignMessage(WordVectorStorage.SYN_0, 2, 2.0));
shards[t].handleMessage(new DistributedAssignMessage(WordVectorStorage.SYN_1_NEGATIVE, 0, 0.0));
shards[t].handleMessage(new DistributedAssignMessage(WordVectorStorage.SYN_1_NEGATIVE, 1, 1.0));
shards[t].handleMessage(new DistributedAssignMessage(WordVectorStorage.SYN_1_NEGATIVE, 2, 2.0));
}
DistributedSgDotMessage ddot = new DistributedSgDotMessage(2L, new int[] { 0, 1, 2 }, new int[] { 0, 1, 2 }, 0, 1, new byte[] { 0, 1 }, true, (short) 0, 0.01f);
for (int t = 0; t < threads.length; t++) {
shards[t].handleMessage(ddot);
}
Thread.sleep(100);
for (int t = threads.length - 1; t >= 0; t--) {
VoidMessage msg;
while ((msg = shards[t].getTransport().takeMessage()) != null) {
shards[t].handleMessage(msg);
}
}
// at this moment ot should be caclulated everywhere
exp = Nd4j.create(new double[] { 0.0, 30.0, 120.0 });
for (int t = 0; t < threads.length; t++) {
assertEquals(true, shards[t].clipboard.isReady(0L, 2L));
DotAggregation dot = (DotAggregation) shards[t].clipboard.unpin(0L, 2L);
INDArray aggregated = dot.getAccumulatedResult();
assertEquals(exp, aggregated);
}
for (int t = 0; t < threads.length; t++) {
threads[t].join();
}
for (int t = 0; t < threads.length; t++) {
shards[t].shutdown();
}
assertEquals(threads.length, passCnt.get());
for (VoidParameterServer server : shards) {
server.shutdown();
}
clientNode.shutdown();
}
use of org.nd4j.parameterserver.distributed.messages.intercom.DistributedSgDotMessage in project nd4j by deeplearning4j.
the class SkipGramTrainer method startTraining.
@Override
public void startTraining(SkipGramRequestMessage message) {
/**
* All we do right HERE - is dot calculation start
*/
/**
* If we're on HS, we know pairs in advance: it's our points.
*/
// log.info("sI_{} adding SkipGramChain originator: {}; frame: {}; task: {}", transport.getShardIndex(), message.getOriginatorId(), message.getFrameId(), message.getTaskId());
SkipGramChain chain = new SkipGramChain(message.getOriginatorId(), message.getTaskId(), message.getFrameId());
chain.addElement(message);
// log.info("Starting chain [{}]", chain.getTaskId());
chains.put(RequestDescriptor.createDescriptor(message.getOriginatorId(), message.getTaskId()), chain);
// we assume this is HS round
// if (message.getPoints() != null && message.getPoints().length > 0) {
// replicate(message.getW2(), message.getPoints().length);
int[] row_syn0 = new int[0];
int[] row_syn1 = message.getPoints();
if (message.getNegSamples() > 0) {
int rows = storage.getArray(WordVectorStorage.SYN_0).rows();
int[] tempArray = new int[message.getNegSamples() + 1];
tempArray[0] = message.getW1();
for (int e = 1; e < message.getNegSamples() + 1; e++) {
while (true) {
int rnd = RandomUtils.nextInt(0, rows);
if (rnd != message.getW1()) {
tempArray[e] = rnd;
break;
}
}
}
row_syn1 = ArrayUtils.addAll(row_syn1, tempArray);
message.setNegatives(tempArray);
}
if (message.getPoints().length != message.getCodes().length)
throw new RuntimeException("Mismatiching points/codes lengths here!");
// FIXME: taskId should be real here, since it'll be used for task chain tracking
// as result, we'll have aggregated dot as single ordered column, which might be used for gradient calculation
DistributedSgDotMessage ddm = new DistributedSgDotMessage(message.getTaskId(), row_syn0, row_syn1, message.getW1(), message.getW2(), message.getCodes(), message.getCodes() != null && message.getCodes().length > 0, message.getNegSamples(), (float) message.getAlpha());
ddm.setTargetId((short) -1);
ddm.setOriginatorId(message.getOriginatorId());
if (voidConfiguration.getExecutionMode() == ExecutionMode.AVERAGING) {
transport.putMessage(ddm);
} else if (voidConfiguration.getExecutionMode() == ExecutionMode.SHARDED) {
transport.sendMessage(ddm);
}
// } //else log.info("sI_{} Skipping step: {}", transport.getShardIndex(), chain.getTaskId());
}
Aggregations