use of org.nd4j.parameterserver.distributed.transport.MulticastTransport in project nd4j by deeplearning4j.
the class VoidParameterServerStressTest method testPerformanceMulticast1.
@Test
@Ignore
public void testPerformanceMulticast1() throws Exception {
VoidConfiguration voidConfiguration = VoidConfiguration.builder().networkMask("192.168.0.0/16").numberOfShards(1).build();
List<String> addresses = new ArrayList<>();
for (int s = 0; s < 5; s++) {
addresses.add("192.168.1.35:3789" + s);
}
voidConfiguration.setShardAddresses(addresses);
voidConfiguration.setForcedRole(NodeRole.CLIENT);
VoidConfiguration[] voidConfigurations = new VoidConfiguration[5];
VoidParameterServer[] shards = new VoidParameterServer[5];
for (int s = 0; s < shards.length; s++) {
voidConfigurations[s] = VoidConfiguration.builder().unicastPort(Integer.valueOf("3789" + s)).networkMask("192.168.0.0/16").build();
voidConfigurations[s].setShardAddresses(addresses);
MulticastTransport transport = new MulticastTransport();
transport.setIpAndPort("192.168.1.35", Integer.valueOf("3789" + s));
shards[s] = new VoidParameterServer(false);
shards[s].setShardIndex((short) s);
shards[s].init(voidConfigurations[s], transport, new SkipGramTrainer());
assertEquals(NodeRole.SHARD, shards[s].getNodeRole());
}
// this is going to be our Client shard
VoidParameterServer parameterServer = new VoidParameterServer();
parameterServer.init(voidConfiguration);
assertEquals(NodeRole.CLIENT, VoidParameterServer.getInstance().getNodeRole());
log.info("Instantiation finished...");
parameterServer.initializeSeqVec(100, NUM_WORDS, 123, 20, true, false);
log.info("Initialization finished...");
final List<Long> times = new CopyOnWriteArrayList<>();
Thread[] threads = new Thread[8];
for (int t = 0; t < threads.length; t++) {
final int e = t;
threads[t] = new Thread(() -> {
List<Long> results = new ArrayList<>();
int chunk = NUM_WORDS / threads.length;
int start = e * chunk;
int end = (e + 1) * chunk;
for (int i = 0; i < 100000; i++) {
long time1 = System.nanoTime();
INDArray array = parameterServer.getVector(RandomUtils.nextInt(start, end));
long time2 = System.nanoTime();
results.add(time2 - time1);
if ((i + 1) % 1000 == 0)
log.info("Thread {} cnt {}", e, i + 1);
}
times.addAll(results);
});
threads[t].setDaemon(true);
threads[t].start();
}
for (int t = 0; t < threads.length; t++) {
try {
threads[t].join();
} catch (Exception e) {
}
}
List<Long> newTimes = new ArrayList<>(times);
Collections.sort(newTimes);
log.info("p50: {} us", newTimes.get(newTimes.size() / 2) / 1000);
parameterServer.shutdown();
;
for (VoidParameterServer server : shards) {
server.shutdown();
}
}
Aggregations