Search in sources :

Example 1 with InterleavedRouter

use of org.nd4j.parameterserver.distributed.logic.routing.InterleavedRouter in project nd4j by deeplearning4j.

the class VoidParameterServerStressTest method testPerformanceUnicast1.

/**
 * This is one of the MOST IMPORTANT tests
 */
@Test
public void testPerformanceUnicast1() {
    List<String> list = new ArrayList<>();
    for (int t = 0; t < 1; t++) {
        list.add("127.0.0.1:3838" + t);
    }
    VoidConfiguration voidConfiguration = VoidConfiguration.builder().unicastPort(49823).numberOfShards(list.size()).shardAddresses(list).build();
    VoidParameterServer[] shards = new VoidParameterServer[list.size()];
    for (int t = 0; t < shards.length; t++) {
        shards[t] = new VoidParameterServer(NodeRole.SHARD);
        Transport transport = new RoutedTransport();
        transport.setIpAndPort("127.0.0.1", Integer.valueOf("3838" + t));
        shards[t].setShardIndex((short) t);
        shards[t].init(voidConfiguration, transport, new SkipGramTrainer());
        assertEquals(NodeRole.SHARD, shards[t].getNodeRole());
    }
    VoidParameterServer clientNode = new VoidParameterServer(NodeRole.CLIENT);
    RoutedTransport transport = new RoutedTransport();
    ClientRouter router = new InterleavedRouter(0);
    transport.setRouter(router);
    transport.setIpAndPort("127.0.0.1", voidConfiguration.getUnicastPort());
    router.init(voidConfiguration, transport);
    clientNode.init(voidConfiguration, transport, new SkipGramTrainer());
    assertEquals(NodeRole.CLIENT, clientNode.getNodeRole());
    final List<Long> times = new CopyOnWriteArrayList<>();
    // at this point, everything should be started, time for tests
    clientNode.initializeSeqVec(100, NUM_WORDS, 123, 25, true, false);
    log.info("Initialization finished, going to tests...");
    Thread[] threads = new Thread[4];
    for (int t = 0; t < threads.length; t++) {
        final int e = t;
        threads[t] = new Thread(() -> {
            List<Long> results = new ArrayList<>();
            int chunk = NUM_WORDS / threads.length;
            int start = e * chunk;
            int end = (e + 1) * chunk;
            for (int i = 0; i < 200; i++) {
                long time1 = System.nanoTime();
                INDArray array = clientNode.getVector(RandomUtils.nextInt(start, end));
                long time2 = System.nanoTime();
                results.add(time2 - time1);
                if ((i + 1) % 100 == 0)
                    log.info("Thread {} cnt {}", e, i + 1);
            }
            times.addAll(results);
        });
        threads[t].setDaemon(true);
        threads[t].start();
    }
    for (int t = 0; t < threads.length; t++) {
        try {
            threads[t].join();
        } catch (Exception e) {
        }
    }
    List<Long> newTimes = new ArrayList<>(times);
    Collections.sort(newTimes);
    log.info("p50: {} us", newTimes.get(newTimes.size() / 2) / 1000);
    // shutdown everything
    for (VoidParameterServer shard : shards) {
        shard.getTransport().shutdown();
    }
    clientNode.getTransport().shutdown();
}
Also used : VoidConfiguration(org.nd4j.parameterserver.distributed.conf.VoidConfiguration) SkipGramTrainer(org.nd4j.parameterserver.distributed.training.impl.SkipGramTrainer) InterleavedRouter(org.nd4j.parameterserver.distributed.logic.routing.InterleavedRouter) ArrayList(java.util.ArrayList) CopyOnWriteArrayList(java.util.concurrent.CopyOnWriteArrayList) INDArray(org.nd4j.linalg.api.ndarray.INDArray) AtomicLong(java.util.concurrent.atomic.AtomicLong) ClientRouter(org.nd4j.parameterserver.distributed.logic.ClientRouter) ArrayList(java.util.ArrayList) List(java.util.List) CopyOnWriteArrayList(java.util.concurrent.CopyOnWriteArrayList) MulticastTransport(org.nd4j.parameterserver.distributed.transport.MulticastTransport) RoutedTransport(org.nd4j.parameterserver.distributed.transport.RoutedTransport) Transport(org.nd4j.parameterserver.distributed.transport.Transport) RoutedTransport(org.nd4j.parameterserver.distributed.transport.RoutedTransport) CopyOnWriteArrayList(java.util.concurrent.CopyOnWriteArrayList) Test(org.junit.Test)

Example 2 with InterleavedRouter

use of org.nd4j.parameterserver.distributed.logic.routing.InterleavedRouter in project nd4j by deeplearning4j.

the class RoutedTransport method init.

@Override
public void init(@NonNull VoidConfiguration voidConfiguration, @NonNull Clipboard clipboard, @NonNull NodeRole role, @NonNull String localIp, int localPort, short shardIndex) {
    this.nodeRole = role;
    this.clipboard = clipboard;
    this.voidConfiguration = voidConfiguration;
    this.shardIndex = shardIndex;
    this.messages = new LinkedBlockingQueue<>();
    // shutdown hook
    super.init(voidConfiguration, clipboard, role, localIp, localPort, shardIndex);
    setProperty("aeron.client.liveness.timeout", "30000000000");
    context = new Aeron.Context().publicationConnectionTimeout(30000000000L).driverTimeoutMs(30000).keepAliveInterval(100000000);
    driver = MediaDriver.launchEmbedded();
    context.aeronDirectoryName(driver.aeronDirectoryName());
    aeron = Aeron.connect(context);
    if (router == null)
        router = new InterleavedRouter();
    // we skip IPs assign process if they were defined externally
    if (port == 0) {
        ip = localIp;
        port = localPort;
    }
    unicastChannelUri = "aeron:udp?endpoint=" + ip + ":" + port;
    subscriptionForClients = aeron.addSubscription(unicastChannelUri, voidConfiguration.getStreamId());
    // clean shut down
    Runtime.getRuntime().addShutdownHook(new Thread(() -> {
        CloseHelper.quietClose(aeron);
        CloseHelper.quietClose(driver);
        CloseHelper.quietClose(context);
        CloseHelper.quietClose(subscriptionForClients);
    }));
    messageHandlerForClients = new FragmentAssembler((buffer, offset, length, header) -> jointMessageHandler(buffer, offset, length, header));
    /*
            Now, regardless of current role,
             we set up publication channel to each shard
         */
    String shardChannelUri = null;
    String remoteIp = null;
    int remotePort = 0;
    for (String ip : voidConfiguration.getShardAddresses()) {
        if (ip.contains(":")) {
            shardChannelUri = "aeron:udp?endpoint=" + ip;
            String[] split = ip.split(":");
            remoteIp = split[0];
            remotePort = Integer.valueOf(split[1]);
        } else {
            shardChannelUri = "aeron:udp?endpoint=" + ip + ":" + voidConfiguration.getUnicastPort();
            remoteIp = ip;
            remotePort = voidConfiguration.getUnicastPort();
        }
        Publication publication = aeron.addPublication(shardChannelUri, voidConfiguration.getStreamId());
        RemoteConnection connection = RemoteConnection.builder().ip(remoteIp).port(remotePort).publication(publication).locker(new Object()).build();
        shards.add(connection);
    }
    if (nodeRole == NodeRole.SHARD)
        log.info("Initialized as [{}]; ShardIndex: [{}]; Own endpoint: [{}]", nodeRole, shardIndex, unicastChannelUri);
    else
        log.info("Initialized as [{}]; Own endpoint: [{}]", nodeRole, unicastChannelUri);
    switch(nodeRole) {
        case MASTER:
        case BACKUP:
            {
            }
        case SHARD:
            {
                /*
                    For unicast transport we want to have interconnects between all shards first of all, because we know their IPs in advance.
                    But due to design requirements, clients have the same first step, so it's kinda shared for all states :)
                 */
                /*
                    Next step is connections setup for backup nodes.
                    TODO: to be implemented
                 */
                addClient(ip, port);
            }
            break;
        case CLIENT:
            {
            /*
                    For Clients on unicast transport, we either set up connection to single Shard, or to multiple shards
                    But since this code is shared - we don't do anything here
                 */
            }
            break;
        default:
            throw new ND4JIllegalStateException("Unknown NodeRole being passed: " + nodeRole);
    }
    router.init(voidConfiguration, this);
    this.originatorId = HashUtil.getLongHash(this.getIp() + ":" + this.getPort());
}
Also used : java.util(java.util) HashUtil(org.nd4j.linalg.util.HashUtil) FragmentAssembler(io.aeron.FragmentAssembler) org.nd4j.parameterserver.distributed.messages(org.nd4j.parameterserver.distributed.messages) AtomicBoolean(java.util.concurrent.atomic.AtomicBoolean) IntroductionRequestMessage(org.nd4j.parameterserver.distributed.messages.requests.IntroductionRequestMessage) ClientRouter(org.nd4j.parameterserver.distributed.logic.ClientRouter) InterleavedRouter(org.nd4j.parameterserver.distributed.logic.routing.InterleavedRouter) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) StringUtils(org.nd4j.linalg.io.StringUtils) Publication(io.aeron.Publication) VoidConfiguration(org.nd4j.parameterserver.distributed.conf.VoidConfiguration) RetransmissionHandler(org.nd4j.parameterserver.distributed.logic.RetransmissionHandler) System.setProperty(java.lang.System.setProperty) CloseHelper(org.agrona.CloseHelper) MediaDriver(io.aeron.driver.MediaDriver) Aeron(io.aeron.Aeron) Nd4j(org.nd4j.linalg.factory.Nd4j) NodeRole(org.nd4j.parameterserver.distributed.enums.NodeRole) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) LinkedBlockingQueue(java.util.concurrent.LinkedBlockingQueue) lombok(lombok) LockSupport(java.util.concurrent.locks.LockSupport) Slf4j(lombok.extern.slf4j.Slf4j) Header(io.aeron.logbuffer.Header) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) Clipboard(org.nd4j.parameterserver.distributed.logic.completion.Clipboard) DirectBuffer(org.agrona.DirectBuffer) InterleavedRouter(org.nd4j.parameterserver.distributed.logic.routing.InterleavedRouter) Publication(io.aeron.Publication) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) FragmentAssembler(io.aeron.FragmentAssembler)

Example 3 with InterleavedRouter

use of org.nd4j.parameterserver.distributed.logic.routing.InterleavedRouter in project nd4j by deeplearning4j.

the class RoutedTransportTest method testMessaging1.

/**
 * This test
 *
 * @throws Exception
 */
@Test
public void testMessaging1() throws Exception {
    List<String> list = new ArrayList<>();
    for (int t = 0; t < 5; t++) {
        list.add("127.0.0.1:3838" + t);
    }
    VoidConfiguration voidConfiguration = // this port will be used only by client
    VoidConfiguration.builder().shardAddresses(list).unicastPort(43120).build();
    // first of all we start shards
    RoutedTransport[] transports = new RoutedTransport[list.size()];
    for (int t = 0; t < transports.length; t++) {
        Clipboard clipboard = new Clipboard();
        transports[t] = new RoutedTransport();
        transports[t].setIpAndPort("127.0.0.1", Integer.valueOf("3838" + t));
        transports[t].init(voidConfiguration, clipboard, NodeRole.SHARD, "127.0.0.1", voidConfiguration.getUnicastPort(), (short) t);
    }
    for (int t = 0; t < transports.length; t++) {
        transports[t].launch(Transport.ThreadingModel.DEDICATED_THREADS);
    }
    // now we start client, for this test we'll have only one client
    Clipboard clipboard = new Clipboard();
    RoutedTransport clientTransport = new RoutedTransport();
    clientTransport.setIpAndPort("127.0.0.1", voidConfiguration.getUnicastPort());
    // setRouter call should be called before init, and we need
    ClientRouter router = new InterleavedRouter(0);
    clientTransport.setRouter(router);
    router.init(voidConfiguration, clientTransport);
    clientTransport.init(voidConfiguration, clipboard, NodeRole.CLIENT, "127.0.0.1", voidConfiguration.getUnicastPort(), (short) -1);
    clientTransport.launch(Transport.ThreadingModel.DEDICATED_THREADS);
    // we send message somewhere
    VoidMessage message = new IntroductionRequestMessage("127.0.0.1", voidConfiguration.getUnicastPort());
    clientTransport.sendMessage(message);
    Thread.sleep(500);
    message = transports[0].messages.poll(1, TimeUnit.SECONDS);
    assertNotEquals(null, message);
    for (int t = 1; t < transports.length; t++) {
        message = transports[t].messages.poll(1, TimeUnit.SECONDS);
        assertEquals(null, message);
    }
    /**
     * This is very important part, shutting down all transports
     */
    for (RoutedTransport transport : transports) {
        transport.shutdown();
    }
    clientTransport.shutdown();
}
Also used : VoidConfiguration(org.nd4j.parameterserver.distributed.conf.VoidConfiguration) VoidMessage(org.nd4j.parameterserver.distributed.messages.VoidMessage) InterleavedRouter(org.nd4j.parameterserver.distributed.logic.routing.InterleavedRouter) ArrayList(java.util.ArrayList) ClientRouter(org.nd4j.parameterserver.distributed.logic.ClientRouter) IntroductionRequestMessage(org.nd4j.parameterserver.distributed.messages.requests.IntroductionRequestMessage) Clipboard(org.nd4j.parameterserver.distributed.logic.completion.Clipboard) Test(org.junit.Test)

Example 4 with InterleavedRouter

use of org.nd4j.parameterserver.distributed.logic.routing.InterleavedRouter in project nd4j by deeplearning4j.

the class VoidParameterServerStressTest method testPerformanceUnicast2.

/**
 * This is second super-important test for unicast transport.
 * Here we send non-blocking messages
 */
@Test
@Ignore
public void testPerformanceUnicast2() {
    List<String> list = new ArrayList<>();
    for (int t = 0; t < 5; t++) {
        list.add("127.0.0.1:3838" + t);
    }
    VoidConfiguration voidConfiguration = VoidConfiguration.builder().unicastPort(49823).numberOfShards(list.size()).shardAddresses(list).build();
    VoidParameterServer[] shards = new VoidParameterServer[list.size()];
    for (int t = 0; t < shards.length; t++) {
        shards[t] = new VoidParameterServer(NodeRole.SHARD);
        Transport transport = new RoutedTransport();
        transport.setIpAndPort("127.0.0.1", Integer.valueOf("3838" + t));
        shards[t].setShardIndex((short) t);
        shards[t].init(voidConfiguration, transport, new SkipGramTrainer());
        assertEquals(NodeRole.SHARD, shards[t].getNodeRole());
    }
    VoidParameterServer clientNode = new VoidParameterServer();
    RoutedTransport transport = new RoutedTransport();
    ClientRouter router = new InterleavedRouter(0);
    transport.setRouter(router);
    transport.setIpAndPort("127.0.0.1", voidConfiguration.getUnicastPort());
    router.init(voidConfiguration, transport);
    clientNode.init(voidConfiguration, transport, new SkipGramTrainer());
    assertEquals(NodeRole.CLIENT, clientNode.getNodeRole());
    final List<Long> times = new CopyOnWriteArrayList<>();
    // at this point, everything should be started, time for tests
    clientNode.initializeSeqVec(100, NUM_WORDS, 123, 25, true, false);
    log.info("Initialization finished, going to tests...");
    Thread[] threads = new Thread[4];
    for (int t = 0; t < threads.length; t++) {
        final int e = t;
        threads[t] = new Thread(() -> {
            List<Long> results = new ArrayList<>();
            int chunk = NUM_WORDS / threads.length;
            int start = e * chunk;
            int end = (e + 1) * chunk;
            for (int i = 0; i < 200; i++) {
                Frame<SkipGramRequestMessage> frame = new Frame<>(BasicSequenceProvider.getInstance().getNextValue());
                for (int f = 0; f < 128; f++) {
                    frame.stackMessage(getSGRM());
                }
                long time1 = System.nanoTime();
                clientNode.execDistributed(frame);
                long time2 = System.nanoTime();
                results.add(time2 - time1);
                if ((i + 1) % 100 == 0)
                    log.info("Thread {} cnt {}", e, i + 1);
            }
            times.addAll(results);
        });
        threads[t].setDaemon(true);
        threads[t].start();
    }
    for (int t = 0; t < threads.length; t++) {
        try {
            threads[t].join();
        } catch (Exception e) {
        }
    }
    List<Long> newTimes = new ArrayList<>(times);
    Collections.sort(newTimes);
    log.info("p50: {} us", newTimes.get(newTimes.size() / 2) / 1000);
    // shutdown everything
    for (VoidParameterServer shard : shards) {
        shard.getTransport().shutdown();
    }
    clientNode.getTransport().shutdown();
}
Also used : VoidConfiguration(org.nd4j.parameterserver.distributed.conf.VoidConfiguration) Frame(org.nd4j.parameterserver.distributed.messages.Frame) SkipGramTrainer(org.nd4j.parameterserver.distributed.training.impl.SkipGramTrainer) InterleavedRouter(org.nd4j.parameterserver.distributed.logic.routing.InterleavedRouter) ArrayList(java.util.ArrayList) CopyOnWriteArrayList(java.util.concurrent.CopyOnWriteArrayList) AtomicLong(java.util.concurrent.atomic.AtomicLong) ClientRouter(org.nd4j.parameterserver.distributed.logic.ClientRouter) ArrayList(java.util.ArrayList) List(java.util.List) CopyOnWriteArrayList(java.util.concurrent.CopyOnWriteArrayList) MulticastTransport(org.nd4j.parameterserver.distributed.transport.MulticastTransport) RoutedTransport(org.nd4j.parameterserver.distributed.transport.RoutedTransport) Transport(org.nd4j.parameterserver.distributed.transport.Transport) RoutedTransport(org.nd4j.parameterserver.distributed.transport.RoutedTransport) CopyOnWriteArrayList(java.util.concurrent.CopyOnWriteArrayList) Ignore(org.junit.Ignore) Test(org.junit.Test)

Aggregations

VoidConfiguration (org.nd4j.parameterserver.distributed.conf.VoidConfiguration)4 ClientRouter (org.nd4j.parameterserver.distributed.logic.ClientRouter)4 InterleavedRouter (org.nd4j.parameterserver.distributed.logic.routing.InterleavedRouter)4 ArrayList (java.util.ArrayList)3 Test (org.junit.Test)3 List (java.util.List)2 CopyOnWriteArrayList (java.util.concurrent.CopyOnWriteArrayList)2 AtomicLong (java.util.concurrent.atomic.AtomicLong)2 Clipboard (org.nd4j.parameterserver.distributed.logic.completion.Clipboard)2 IntroductionRequestMessage (org.nd4j.parameterserver.distributed.messages.requests.IntroductionRequestMessage)2 SkipGramTrainer (org.nd4j.parameterserver.distributed.training.impl.SkipGramTrainer)2 MulticastTransport (org.nd4j.parameterserver.distributed.transport.MulticastTransport)2 RoutedTransport (org.nd4j.parameterserver.distributed.transport.RoutedTransport)2 Transport (org.nd4j.parameterserver.distributed.transport.Transport)2 Aeron (io.aeron.Aeron)1 FragmentAssembler (io.aeron.FragmentAssembler)1 Publication (io.aeron.Publication)1 MediaDriver (io.aeron.driver.MediaDriver)1 Header (io.aeron.logbuffer.Header)1 System.setProperty (java.lang.System.setProperty)1