use of org.nd4j.parameterserver.distributed.messages.VoidMessage in project nd4j by deeplearning4j.
the class InterleavedRouterTest method assignTarget1.
/**
* Testing default assignment for everything, but training requests
*
* @throws Exception
*/
@Test
public void assignTarget1() throws Exception {
InterleavedRouter router = new InterleavedRouter();
router.init(configuration, transport);
for (int i = 0; i < 100; i++) {
VoidMessage message = new InitializationRequestMessage(100, 10, 123, false, false, 10);
int target = router.assignTarget(message);
assertTrue(target >= 0 && target <= 3);
assertEquals(originator, message.getOriginatorId());
}
}
use of org.nd4j.parameterserver.distributed.messages.VoidMessage in project nd4j by deeplearning4j.
the class InterleavedRouterTest method assignTarget3.
/**
* Testing default assignment for everything, but training requests.
* Difference here is pre-defined default index, for everything but TrainingMessages
*
* @throws Exception
*/
@Test
public void assignTarget3() throws Exception {
InterleavedRouter router = new InterleavedRouter(2);
router.init(configuration, transport);
for (int i = 0; i < 3; i++) {
VoidMessage message = new InitializationRequestMessage(100, 10, 123, false, false, 10);
int target = router.assignTarget(message);
assertEquals(2, target);
assertEquals(originator, message.getOriginatorId());
}
}
use of org.nd4j.parameterserver.distributed.messages.VoidMessage in project nd4j by deeplearning4j.
the class BaseTransport method internalMessageHandler.
/**
* This message handler is responsible for receiving coordination messages on Shard side
*
* @param buffer
* @param offset
* @param length
* @param header
*/
protected void internalMessageHandler(DirectBuffer buffer, int offset, int length, Header header) {
/**
* All incoming internal messages are either op commands, or aggregation messages that are tied to commands
*/
byte[] data = new byte[length];
buffer.getBytes(offset, data);
VoidMessage message = VoidMessage.fromBytes(data);
messages.add(message);
// log.info("internalMessageHandler message request incoming: {}", message.getClass().getSimpleName());
}
use of org.nd4j.parameterserver.distributed.messages.VoidMessage in project nd4j by deeplearning4j.
the class BaseTransport method shardMessageHandler.
/**
* This message handler is responsible for receiving messages on Shard side
*
* @param buffer
* @param offset
* @param length
* @param header
*/
protected void shardMessageHandler(DirectBuffer buffer, int offset, int length, Header header) {
/**
* All incoming messages here are supposed to be unicast messages.
*/
// TODO: implement fragmentation handler here PROBABLY. Or forbid messages > MTU?
// log.info("shardMessageHandler message request incoming...");
byte[] data = new byte[length];
buffer.getBytes(offset, data);
VoidMessage message = VoidMessage.fromBytes(data);
if (message.getMessageType() == 7) {
// if that's vector request message - it's special case, we don't send it to other shards yet
// log.info("Shortcut for vector request");
messages.add(message);
} else {
// and send it away to other Shards
publicationForShards.offer(buffer, offset, length);
}
}
use of org.nd4j.parameterserver.distributed.messages.VoidMessage in project nd4j by deeplearning4j.
the class RoutedTransportTest method testMessaging1.
/**
* This test
*
* @throws Exception
*/
@Test
public void testMessaging1() throws Exception {
List<String> list = new ArrayList<>();
for (int t = 0; t < 5; t++) {
list.add("127.0.0.1:3838" + t);
}
VoidConfiguration voidConfiguration = // this port will be used only by client
VoidConfiguration.builder().shardAddresses(list).unicastPort(43120).build();
// first of all we start shards
RoutedTransport[] transports = new RoutedTransport[list.size()];
for (int t = 0; t < transports.length; t++) {
Clipboard clipboard = new Clipboard();
transports[t] = new RoutedTransport();
transports[t].setIpAndPort("127.0.0.1", Integer.valueOf("3838" + t));
transports[t].init(voidConfiguration, clipboard, NodeRole.SHARD, "127.0.0.1", voidConfiguration.getUnicastPort(), (short) t);
}
for (int t = 0; t < transports.length; t++) {
transports[t].launch(Transport.ThreadingModel.DEDICATED_THREADS);
}
// now we start client, for this test we'll have only one client
Clipboard clipboard = new Clipboard();
RoutedTransport clientTransport = new RoutedTransport();
clientTransport.setIpAndPort("127.0.0.1", voidConfiguration.getUnicastPort());
// setRouter call should be called before init, and we need
ClientRouter router = new InterleavedRouter(0);
clientTransport.setRouter(router);
router.init(voidConfiguration, clientTransport);
clientTransport.init(voidConfiguration, clipboard, NodeRole.CLIENT, "127.0.0.1", voidConfiguration.getUnicastPort(), (short) -1);
clientTransport.launch(Transport.ThreadingModel.DEDICATED_THREADS);
// we send message somewhere
VoidMessage message = new IntroductionRequestMessage("127.0.0.1", voidConfiguration.getUnicastPort());
clientTransport.sendMessage(message);
Thread.sleep(500);
message = transports[0].messages.poll(1, TimeUnit.SECONDS);
assertNotEquals(null, message);
for (int t = 1; t < transports.length; t++) {
message = transports[t].messages.poll(1, TimeUnit.SECONDS);
assertEquals(null, message);
}
/**
* This is very important part, shutting down all transports
*/
for (RoutedTransport transport : transports) {
transport.shutdown();
}
clientTransport.shutdown();
}
Aggregations