Search in sources :

Example 1 with Transport

use of org.nd4j.parameterserver.distributed.transport.Transport in project nd4j by deeplearning4j.

the class VoidParameterServerStressTest method testPerformanceUnicast4.

/**
 * This test checks multiple Clients hammering single Shard
 *
 * @throws Exception
 */
@Test
public void testPerformanceUnicast4() throws Exception {
    VoidConfiguration voidConfiguration = VoidConfiguration.builder().unicastPort(49823).numberOfShards(1).shardAddresses(Arrays.asList("127.0.0.1:49823")).build();
    Transport transport = new RoutedTransport();
    transport.setIpAndPort("127.0.0.1", Integer.valueOf("49823"));
    VoidParameterServer parameterServer = new VoidParameterServer(NodeRole.SHARD);
    parameterServer.setShardIndex((short) 0);
    parameterServer.init(voidConfiguration, transport, new SkipGramTrainer());
    parameterServer.initializeSeqVec(100, NUM_WORDS, 123L, 100, true, false);
    VoidParameterServer[] clients = new VoidParameterServer[1];
    for (int c = 0; c < clients.length; c++) {
        clients[c] = new VoidParameterServer(NodeRole.CLIENT);
        Transport clientTransport = new RoutedTransport();
        clientTransport.setIpAndPort("127.0.0.1", Integer.valueOf("4872" + c));
        clients[c].init(voidConfiguration, clientTransport, new SkipGramTrainer());
        assertEquals(NodeRole.CLIENT, clients[c].getNodeRole());
    }
    final List<Long> times = new CopyOnWriteArrayList<>();
    log.info("Starting loop...");
    Thread[] threads = new Thread[clients.length];
    for (int t = 0; t < threads.length; t++) {
        final int c = t;
        threads[t] = new Thread(() -> {
            List<Long> results = new ArrayList<>();
            AtomicLong sequence = new AtomicLong(0);
            for (int i = 0; i < 500; i++) {
                Frame<SkipGramRequestMessage> frame = new Frame<>(sequence.incrementAndGet());
                for (int f = 0; f < 128; f++) {
                    frame.stackMessage(getSGRM());
                }
                long time1 = System.nanoTime();
                clients[c].execDistributed(frame);
                long time2 = System.nanoTime();
                results.add(time2 - time1);
                if ((i + 1) % 50 == 0)
                    log.info("Thread_{} finished {} frames...", c, i);
            }
            times.addAll(results);
        });
        threads[t].setDaemon(true);
        threads[t].start();
    }
    for (Thread thread : threads) thread.join();
    List<Long> newTimes = new ArrayList<>(times);
    Collections.sort(newTimes);
    log.info("p50: {} us", newTimes.get(newTimes.size() / 2) / 1000);
    for (VoidParameterServer client : clients) {
        client.shutdown();
    }
    parameterServer.shutdown();
}
Also used : VoidConfiguration(org.nd4j.parameterserver.distributed.conf.VoidConfiguration) Frame(org.nd4j.parameterserver.distributed.messages.Frame) SkipGramTrainer(org.nd4j.parameterserver.distributed.training.impl.SkipGramTrainer) ArrayList(java.util.ArrayList) CopyOnWriteArrayList(java.util.concurrent.CopyOnWriteArrayList) AtomicLong(java.util.concurrent.atomic.AtomicLong) AtomicLong(java.util.concurrent.atomic.AtomicLong) ArrayList(java.util.ArrayList) List(java.util.List) CopyOnWriteArrayList(java.util.concurrent.CopyOnWriteArrayList) MulticastTransport(org.nd4j.parameterserver.distributed.transport.MulticastTransport) RoutedTransport(org.nd4j.parameterserver.distributed.transport.RoutedTransport) Transport(org.nd4j.parameterserver.distributed.transport.Transport) RoutedTransport(org.nd4j.parameterserver.distributed.transport.RoutedTransport) CopyOnWriteArrayList(java.util.concurrent.CopyOnWriteArrayList) Test(org.junit.Test)

Example 2 with Transport

use of org.nd4j.parameterserver.distributed.transport.Transport in project nd4j by deeplearning4j.

the class VoidParameterServerStressTest method testPerformanceUnicast3.

/**
 * This test checks for single Shard scenario, when Shard is also a Client
 *
 * @throws Exception
 */
@Test
public void testPerformanceUnicast3() throws Exception {
    VoidConfiguration voidConfiguration = VoidConfiguration.builder().unicastPort(49823).numberOfShards(1).shardAddresses(Arrays.asList("127.0.0.1:49823")).build();
    Transport transport = new RoutedTransport();
    transport.setIpAndPort("127.0.0.1", Integer.valueOf("49823"));
    VoidParameterServer parameterServer = new VoidParameterServer(NodeRole.SHARD);
    parameterServer.setShardIndex((short) 0);
    parameterServer.init(voidConfiguration, transport, new CbowTrainer());
    parameterServer.initializeSeqVec(100, NUM_WORDS, 123L, 100, true, false);
    final List<Long> times = new ArrayList<>();
    log.info("Starting loop...");
    for (int i = 0; i < 200; i++) {
        Frame<CbowRequestMessage> frame = new Frame<>(BasicSequenceProvider.getInstance().getNextValue());
        for (int f = 0; f < 128; f++) {
            frame.stackMessage(getCRM());
        }
        long time1 = System.nanoTime();
        parameterServer.execDistributed(frame);
        long time2 = System.nanoTime();
        times.add(time2 - time1);
        if (i % 50 == 0)
            log.info("{} frames passed...", i);
    }
    Collections.sort(times);
    log.info("p50: {} us", times.get(times.size() / 2) / 1000);
    parameterServer.shutdown();
}
Also used : VoidConfiguration(org.nd4j.parameterserver.distributed.conf.VoidConfiguration) CbowTrainer(org.nd4j.parameterserver.distributed.training.impl.CbowTrainer) Frame(org.nd4j.parameterserver.distributed.messages.Frame) ArrayList(java.util.ArrayList) CopyOnWriteArrayList(java.util.concurrent.CopyOnWriteArrayList) AtomicLong(java.util.concurrent.atomic.AtomicLong) MulticastTransport(org.nd4j.parameterserver.distributed.transport.MulticastTransport) RoutedTransport(org.nd4j.parameterserver.distributed.transport.RoutedTransport) Transport(org.nd4j.parameterserver.distributed.transport.Transport) RoutedTransport(org.nd4j.parameterserver.distributed.transport.RoutedTransport) CbowRequestMessage(org.nd4j.parameterserver.distributed.messages.requests.CbowRequestMessage) Test(org.junit.Test)

Example 3 with Transport

use of org.nd4j.parameterserver.distributed.transport.Transport in project nd4j by deeplearning4j.

the class VoidParameterServerStressTest method testPerformanceUnicast1.

/**
 * This is one of the MOST IMPORTANT tests
 */
@Test
public void testPerformanceUnicast1() {
    List<String> list = new ArrayList<>();
    for (int t = 0; t < 1; t++) {
        list.add("127.0.0.1:3838" + t);
    }
    VoidConfiguration voidConfiguration = VoidConfiguration.builder().unicastPort(49823).numberOfShards(list.size()).shardAddresses(list).build();
    VoidParameterServer[] shards = new VoidParameterServer[list.size()];
    for (int t = 0; t < shards.length; t++) {
        shards[t] = new VoidParameterServer(NodeRole.SHARD);
        Transport transport = new RoutedTransport();
        transport.setIpAndPort("127.0.0.1", Integer.valueOf("3838" + t));
        shards[t].setShardIndex((short) t);
        shards[t].init(voidConfiguration, transport, new SkipGramTrainer());
        assertEquals(NodeRole.SHARD, shards[t].getNodeRole());
    }
    VoidParameterServer clientNode = new VoidParameterServer(NodeRole.CLIENT);
    RoutedTransport transport = new RoutedTransport();
    ClientRouter router = new InterleavedRouter(0);
    transport.setRouter(router);
    transport.setIpAndPort("127.0.0.1", voidConfiguration.getUnicastPort());
    router.init(voidConfiguration, transport);
    clientNode.init(voidConfiguration, transport, new SkipGramTrainer());
    assertEquals(NodeRole.CLIENT, clientNode.getNodeRole());
    final List<Long> times = new CopyOnWriteArrayList<>();
    // at this point, everything should be started, time for tests
    clientNode.initializeSeqVec(100, NUM_WORDS, 123, 25, true, false);
    log.info("Initialization finished, going to tests...");
    Thread[] threads = new Thread[4];
    for (int t = 0; t < threads.length; t++) {
        final int e = t;
        threads[t] = new Thread(() -> {
            List<Long> results = new ArrayList<>();
            int chunk = NUM_WORDS / threads.length;
            int start = e * chunk;
            int end = (e + 1) * chunk;
            for (int i = 0; i < 200; i++) {
                long time1 = System.nanoTime();
                INDArray array = clientNode.getVector(RandomUtils.nextInt(start, end));
                long time2 = System.nanoTime();
                results.add(time2 - time1);
                if ((i + 1) % 100 == 0)
                    log.info("Thread {} cnt {}", e, i + 1);
            }
            times.addAll(results);
        });
        threads[t].setDaemon(true);
        threads[t].start();
    }
    for (int t = 0; t < threads.length; t++) {
        try {
            threads[t].join();
        } catch (Exception e) {
        }
    }
    List<Long> newTimes = new ArrayList<>(times);
    Collections.sort(newTimes);
    log.info("p50: {} us", newTimes.get(newTimes.size() / 2) / 1000);
    // shutdown everything
    for (VoidParameterServer shard : shards) {
        shard.getTransport().shutdown();
    }
    clientNode.getTransport().shutdown();
}
Also used : VoidConfiguration(org.nd4j.parameterserver.distributed.conf.VoidConfiguration) SkipGramTrainer(org.nd4j.parameterserver.distributed.training.impl.SkipGramTrainer) InterleavedRouter(org.nd4j.parameterserver.distributed.logic.routing.InterleavedRouter) ArrayList(java.util.ArrayList) CopyOnWriteArrayList(java.util.concurrent.CopyOnWriteArrayList) INDArray(org.nd4j.linalg.api.ndarray.INDArray) AtomicLong(java.util.concurrent.atomic.AtomicLong) ClientRouter(org.nd4j.parameterserver.distributed.logic.ClientRouter) ArrayList(java.util.ArrayList) List(java.util.List) CopyOnWriteArrayList(java.util.concurrent.CopyOnWriteArrayList) MulticastTransport(org.nd4j.parameterserver.distributed.transport.MulticastTransport) RoutedTransport(org.nd4j.parameterserver.distributed.transport.RoutedTransport) Transport(org.nd4j.parameterserver.distributed.transport.Transport) RoutedTransport(org.nd4j.parameterserver.distributed.transport.RoutedTransport) CopyOnWriteArrayList(java.util.concurrent.CopyOnWriteArrayList) Test(org.junit.Test)

Example 4 with Transport

use of org.nd4j.parameterserver.distributed.transport.Transport in project nd4j by deeplearning4j.

the class FrameTest method testFrame1.

/**
 * Simple test for Frame functionality
 */
@Test
public void testFrame1() {
    final AtomicInteger count = new AtomicInteger(0);
    Frame<TrainingMessage> frame = new Frame<>();
    for (int i = 0; i < 10; i++) {
        frame.stackMessage(new TrainingMessage() {

            @Override
            public byte getCounter() {
                return 2;
            }

            @Override
            public void setTargetId(short id) {
            }

            @Override
            public int getRetransmitCount() {
                return 0;
            }

            @Override
            public void incrementRetransmitCount() {
            }

            @Override
            public long getFrameId() {
                return 0;
            }

            @Override
            public void setFrameId(long frameId) {
            }

            @Override
            public long getOriginatorId() {
                return 0;
            }

            @Override
            public void setOriginatorId(long id) {
            }

            @Override
            public short getTargetId() {
                return 0;
            }

            @Override
            public long getTaskId() {
                return 0;
            }

            @Override
            public int getMessageType() {
                return 0;
            }

            @Override
            public byte[] asBytes() {
                return new byte[0];
            }

            @Override
            public UnsafeBuffer asUnsafeBuffer() {
                return null;
            }

            @Override
            public void attachContext(VoidConfiguration voidConfiguration, TrainingDriver<? extends TrainingMessage> trainer, Clipboard clipboard, Transport transport, Storage storage, NodeRole role, short shardIndex) {
            // no-op intentionally
            }

            @Override
            public void extractContext(BaseVoidMessage message) {
            // no-op intentionally
            }

            @Override
            public void processMessage() {
                count.incrementAndGet();
            }

            @Override
            public boolean isJoinSupported() {
                return false;
            }

            @Override
            public void joinMessage(VoidMessage message) {
            // no-op
            }

            @Override
            public boolean isBlockingMessage() {
                return false;
            }
        });
    }
    assertEquals(10, frame.size());
    frame.processMessage();
    assertEquals(20, count.get());
}
Also used : VoidConfiguration(org.nd4j.parameterserver.distributed.conf.VoidConfiguration) NodeRole(org.nd4j.parameterserver.distributed.enums.NodeRole) Storage(org.nd4j.parameterserver.distributed.logic.Storage) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) Clipboard(org.nd4j.parameterserver.distributed.logic.completion.Clipboard) UnsafeBuffer(org.agrona.concurrent.UnsafeBuffer) Transport(org.nd4j.parameterserver.distributed.transport.Transport) Test(org.junit.Test)

Example 5 with Transport

use of org.nd4j.parameterserver.distributed.transport.Transport in project nd4j by deeplearning4j.

the class VoidParameterServerStressTest method testPerformanceUnicast2.

/**
 * This is second super-important test for unicast transport.
 * Here we send non-blocking messages
 */
@Test
@Ignore
public void testPerformanceUnicast2() {
    List<String> list = new ArrayList<>();
    for (int t = 0; t < 5; t++) {
        list.add("127.0.0.1:3838" + t);
    }
    VoidConfiguration voidConfiguration = VoidConfiguration.builder().unicastPort(49823).numberOfShards(list.size()).shardAddresses(list).build();
    VoidParameterServer[] shards = new VoidParameterServer[list.size()];
    for (int t = 0; t < shards.length; t++) {
        shards[t] = new VoidParameterServer(NodeRole.SHARD);
        Transport transport = new RoutedTransport();
        transport.setIpAndPort("127.0.0.1", Integer.valueOf("3838" + t));
        shards[t].setShardIndex((short) t);
        shards[t].init(voidConfiguration, transport, new SkipGramTrainer());
        assertEquals(NodeRole.SHARD, shards[t].getNodeRole());
    }
    VoidParameterServer clientNode = new VoidParameterServer();
    RoutedTransport transport = new RoutedTransport();
    ClientRouter router = new InterleavedRouter(0);
    transport.setRouter(router);
    transport.setIpAndPort("127.0.0.1", voidConfiguration.getUnicastPort());
    router.init(voidConfiguration, transport);
    clientNode.init(voidConfiguration, transport, new SkipGramTrainer());
    assertEquals(NodeRole.CLIENT, clientNode.getNodeRole());
    final List<Long> times = new CopyOnWriteArrayList<>();
    // at this point, everything should be started, time for tests
    clientNode.initializeSeqVec(100, NUM_WORDS, 123, 25, true, false);
    log.info("Initialization finished, going to tests...");
    Thread[] threads = new Thread[4];
    for (int t = 0; t < threads.length; t++) {
        final int e = t;
        threads[t] = new Thread(() -> {
            List<Long> results = new ArrayList<>();
            int chunk = NUM_WORDS / threads.length;
            int start = e * chunk;
            int end = (e + 1) * chunk;
            for (int i = 0; i < 200; i++) {
                Frame<SkipGramRequestMessage> frame = new Frame<>(BasicSequenceProvider.getInstance().getNextValue());
                for (int f = 0; f < 128; f++) {
                    frame.stackMessage(getSGRM());
                }
                long time1 = System.nanoTime();
                clientNode.execDistributed(frame);
                long time2 = System.nanoTime();
                results.add(time2 - time1);
                if ((i + 1) % 100 == 0)
                    log.info("Thread {} cnt {}", e, i + 1);
            }
            times.addAll(results);
        });
        threads[t].setDaemon(true);
        threads[t].start();
    }
    for (int t = 0; t < threads.length; t++) {
        try {
            threads[t].join();
        } catch (Exception e) {
        }
    }
    List<Long> newTimes = new ArrayList<>(times);
    Collections.sort(newTimes);
    log.info("p50: {} us", newTimes.get(newTimes.size() / 2) / 1000);
    // shutdown everything
    for (VoidParameterServer shard : shards) {
        shard.getTransport().shutdown();
    }
    clientNode.getTransport().shutdown();
}
Also used : VoidConfiguration(org.nd4j.parameterserver.distributed.conf.VoidConfiguration) Frame(org.nd4j.parameterserver.distributed.messages.Frame) SkipGramTrainer(org.nd4j.parameterserver.distributed.training.impl.SkipGramTrainer) InterleavedRouter(org.nd4j.parameterserver.distributed.logic.routing.InterleavedRouter) ArrayList(java.util.ArrayList) CopyOnWriteArrayList(java.util.concurrent.CopyOnWriteArrayList) AtomicLong(java.util.concurrent.atomic.AtomicLong) ClientRouter(org.nd4j.parameterserver.distributed.logic.ClientRouter) ArrayList(java.util.ArrayList) List(java.util.List) CopyOnWriteArrayList(java.util.concurrent.CopyOnWriteArrayList) MulticastTransport(org.nd4j.parameterserver.distributed.transport.MulticastTransport) RoutedTransport(org.nd4j.parameterserver.distributed.transport.RoutedTransport) Transport(org.nd4j.parameterserver.distributed.transport.Transport) RoutedTransport(org.nd4j.parameterserver.distributed.transport.RoutedTransport) CopyOnWriteArrayList(java.util.concurrent.CopyOnWriteArrayList) Ignore(org.junit.Ignore) Test(org.junit.Test)

Aggregations

Test (org.junit.Test)5 VoidConfiguration (org.nd4j.parameterserver.distributed.conf.VoidConfiguration)5 Transport (org.nd4j.parameterserver.distributed.transport.Transport)5 ArrayList (java.util.ArrayList)4 CopyOnWriteArrayList (java.util.concurrent.CopyOnWriteArrayList)4 AtomicLong (java.util.concurrent.atomic.AtomicLong)4 MulticastTransport (org.nd4j.parameterserver.distributed.transport.MulticastTransport)4 RoutedTransport (org.nd4j.parameterserver.distributed.transport.RoutedTransport)4 List (java.util.List)3 Frame (org.nd4j.parameterserver.distributed.messages.Frame)3 SkipGramTrainer (org.nd4j.parameterserver.distributed.training.impl.SkipGramTrainer)3 ClientRouter (org.nd4j.parameterserver.distributed.logic.ClientRouter)2 InterleavedRouter (org.nd4j.parameterserver.distributed.logic.routing.InterleavedRouter)2 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)1 UnsafeBuffer (org.agrona.concurrent.UnsafeBuffer)1 Ignore (org.junit.Ignore)1 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1 NodeRole (org.nd4j.parameterserver.distributed.enums.NodeRole)1 Storage (org.nd4j.parameterserver.distributed.logic.Storage)1 Clipboard (org.nd4j.parameterserver.distributed.logic.completion.Clipboard)1