Search in sources :

Example 1 with MulticastTransport

use of org.nd4j.parameterserver.distributed.transport.MulticastTransport in project nd4j by deeplearning4j.

the class VoidParameterServerStressTest method testPerformanceMulticast1.

@Test
@Ignore
public void testPerformanceMulticast1() throws Exception {
    VoidConfiguration voidConfiguration = VoidConfiguration.builder().networkMask("192.168.0.0/16").numberOfShards(1).build();
    List<String> addresses = new ArrayList<>();
    for (int s = 0; s < 5; s++) {
        addresses.add("192.168.1.35:3789" + s);
    }
    voidConfiguration.setShardAddresses(addresses);
    voidConfiguration.setForcedRole(NodeRole.CLIENT);
    VoidConfiguration[] voidConfigurations = new VoidConfiguration[5];
    VoidParameterServer[] shards = new VoidParameterServer[5];
    for (int s = 0; s < shards.length; s++) {
        voidConfigurations[s] = VoidConfiguration.builder().unicastPort(Integer.valueOf("3789" + s)).networkMask("192.168.0.0/16").build();
        voidConfigurations[s].setShardAddresses(addresses);
        MulticastTransport transport = new MulticastTransport();
        transport.setIpAndPort("192.168.1.35", Integer.valueOf("3789" + s));
        shards[s] = new VoidParameterServer(false);
        shards[s].setShardIndex((short) s);
        shards[s].init(voidConfigurations[s], transport, new SkipGramTrainer());
        assertEquals(NodeRole.SHARD, shards[s].getNodeRole());
    }
    // this is going to be our Client shard
    VoidParameterServer parameterServer = new VoidParameterServer();
    parameterServer.init(voidConfiguration);
    assertEquals(NodeRole.CLIENT, VoidParameterServer.getInstance().getNodeRole());
    log.info("Instantiation finished...");
    parameterServer.initializeSeqVec(100, NUM_WORDS, 123, 20, true, false);
    log.info("Initialization finished...");
    final List<Long> times = new CopyOnWriteArrayList<>();
    Thread[] threads = new Thread[8];
    for (int t = 0; t < threads.length; t++) {
        final int e = t;
        threads[t] = new Thread(() -> {
            List<Long> results = new ArrayList<>();
            int chunk = NUM_WORDS / threads.length;
            int start = e * chunk;
            int end = (e + 1) * chunk;
            for (int i = 0; i < 100000; i++) {
                long time1 = System.nanoTime();
                INDArray array = parameterServer.getVector(RandomUtils.nextInt(start, end));
                long time2 = System.nanoTime();
                results.add(time2 - time1);
                if ((i + 1) % 1000 == 0)
                    log.info("Thread {} cnt {}", e, i + 1);
            }
            times.addAll(results);
        });
        threads[t].setDaemon(true);
        threads[t].start();
    }
    for (int t = 0; t < threads.length; t++) {
        try {
            threads[t].join();
        } catch (Exception e) {
        }
    }
    List<Long> newTimes = new ArrayList<>(times);
    Collections.sort(newTimes);
    log.info("p50: {} us", newTimes.get(newTimes.size() / 2) / 1000);
    parameterServer.shutdown();
    ;
    for (VoidParameterServer server : shards) {
        server.shutdown();
    }
}
Also used : VoidConfiguration(org.nd4j.parameterserver.distributed.conf.VoidConfiguration) SkipGramTrainer(org.nd4j.parameterserver.distributed.training.impl.SkipGramTrainer) ArrayList(java.util.ArrayList) CopyOnWriteArrayList(java.util.concurrent.CopyOnWriteArrayList) INDArray(org.nd4j.linalg.api.ndarray.INDArray) MulticastTransport(org.nd4j.parameterserver.distributed.transport.MulticastTransport) AtomicLong(java.util.concurrent.atomic.AtomicLong) ArrayList(java.util.ArrayList) List(java.util.List) CopyOnWriteArrayList(java.util.concurrent.CopyOnWriteArrayList) CopyOnWriteArrayList(java.util.concurrent.CopyOnWriteArrayList) Ignore(org.junit.Ignore) Test(org.junit.Test)

Aggregations

ArrayList (java.util.ArrayList)1 List (java.util.List)1 CopyOnWriteArrayList (java.util.concurrent.CopyOnWriteArrayList)1 AtomicLong (java.util.concurrent.atomic.AtomicLong)1 Ignore (org.junit.Ignore)1 Test (org.junit.Test)1 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1 VoidConfiguration (org.nd4j.parameterserver.distributed.conf.VoidConfiguration)1 SkipGramTrainer (org.nd4j.parameterserver.distributed.training.impl.SkipGramTrainer)1 MulticastTransport (org.nd4j.parameterserver.distributed.transport.MulticastTransport)1