Search in sources :

Example 1 with NodeRole

use of org.nd4j.parameterserver.distributed.enums.NodeRole in project nd4j by deeplearning4j.

the class VoidParameterServer method getRole.

/**
 * This method checks for designated role, according to local IP addresses and configuration passed into method
 *
 * @param voidConfiguration
 * @param localIPs
 * @return
 */
protected Pair<NodeRole, String> getRole(@NonNull VoidConfiguration voidConfiguration, @NonNull Collection<String> localIPs) {
    NodeRole result = NodeRole.CLIENT;
    for (String ip : voidConfiguration.getShardAddresses()) {
        String cleansed = ip.replaceAll(":.*", "");
        if (localIPs.contains(cleansed))
            return Pair.create(NodeRole.SHARD, ip);
    }
    if (voidConfiguration.getBackupAddresses() != null)
        for (String ip : voidConfiguration.getBackupAddresses()) {
            String cleansed = ip.replaceAll(":.*", "");
            if (localIPs.contains(cleansed))
                return Pair.create(NodeRole.BACKUP, ip);
        }
    String sparkIp = null;
    if (sparkIp == null && voidConfiguration.getNetworkMask() != null) {
        NetworkOrganizer organizer = new NetworkOrganizer(voidConfiguration.getNetworkMask());
        sparkIp = organizer.getMatchingAddress();
    }
    // last resort here...
    if (sparkIp == null)
        sparkIp = System.getenv("DL4J_VOID_IP");
    log.info("Got [{}] as sparkIp", sparkIp);
    if (sparkIp == null)
        throw new ND4JIllegalStateException("Can't get IP address for UDP communcation");
    // local IP from pair is used for shard only, so we don't care
    return Pair.create(result, sparkIp + ":" + voidConfiguration.getUnicastPort());
}
Also used : NodeRole(org.nd4j.parameterserver.distributed.enums.NodeRole) NetworkOrganizer(org.nd4j.parameterserver.distributed.util.NetworkOrganizer) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException)

Example 2 with NodeRole

use of org.nd4j.parameterserver.distributed.enums.NodeRole in project nd4j by deeplearning4j.

the class FrameTest method testFrame1.

/**
 * Simple test for Frame functionality
 */
@Test
public void testFrame1() {
    final AtomicInteger count = new AtomicInteger(0);
    Frame<TrainingMessage> frame = new Frame<>();
    for (int i = 0; i < 10; i++) {
        frame.stackMessage(new TrainingMessage() {

            @Override
            public byte getCounter() {
                return 2;
            }

            @Override
            public void setTargetId(short id) {
            }

            @Override
            public int getRetransmitCount() {
                return 0;
            }

            @Override
            public void incrementRetransmitCount() {
            }

            @Override
            public long getFrameId() {
                return 0;
            }

            @Override
            public void setFrameId(long frameId) {
            }

            @Override
            public long getOriginatorId() {
                return 0;
            }

            @Override
            public void setOriginatorId(long id) {
            }

            @Override
            public short getTargetId() {
                return 0;
            }

            @Override
            public long getTaskId() {
                return 0;
            }

            @Override
            public int getMessageType() {
                return 0;
            }

            @Override
            public byte[] asBytes() {
                return new byte[0];
            }

            @Override
            public UnsafeBuffer asUnsafeBuffer() {
                return null;
            }

            @Override
            public void attachContext(VoidConfiguration voidConfiguration, TrainingDriver<? extends TrainingMessage> trainer, Clipboard clipboard, Transport transport, Storage storage, NodeRole role, short shardIndex) {
            // no-op intentionally
            }

            @Override
            public void extractContext(BaseVoidMessage message) {
            // no-op intentionally
            }

            @Override
            public void processMessage() {
                count.incrementAndGet();
            }

            @Override
            public boolean isJoinSupported() {
                return false;
            }

            @Override
            public void joinMessage(VoidMessage message) {
            // no-op
            }

            @Override
            public boolean isBlockingMessage() {
                return false;
            }
        });
    }
    assertEquals(10, frame.size());
    frame.processMessage();
    assertEquals(20, count.get());
}
Also used : VoidConfiguration(org.nd4j.parameterserver.distributed.conf.VoidConfiguration) NodeRole(org.nd4j.parameterserver.distributed.enums.NodeRole) Storage(org.nd4j.parameterserver.distributed.logic.Storage) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) Clipboard(org.nd4j.parameterserver.distributed.logic.completion.Clipboard) UnsafeBuffer(org.agrona.concurrent.UnsafeBuffer) Transport(org.nd4j.parameterserver.distributed.transport.Transport) Test(org.junit.Test)

Example 3 with NodeRole

use of org.nd4j.parameterserver.distributed.enums.NodeRole in project nd4j by deeplearning4j.

the class MulticastTransport method init.

@Override
public void init(@NonNull VoidConfiguration voidConfiguration, @NonNull Clipboard clipboard, @NonNull NodeRole role, @NonNull String localIp, int localPort, short shardIndex) {
    if (voidConfiguration.getTtl() < 1)
        throw new ND4JIllegalStateException("For MulticastTransport you should have TTL >= 1, it won't work otherwise");
    if (voidConfiguration.getMulticastNetwork() == null || voidConfiguration.getMulticastNetwork().isEmpty())
        throw new ND4JIllegalStateException("For MulticastTransport you should provide IP from multicast network available/allowed in your environment, i.e.: 224.0.1.1");
    // shutdown hook
    super.init(voidConfiguration, clipboard, role, localIp, localPort, shardIndex);
    this.voidConfiguration = voidConfiguration;
    this.nodeRole = role;
    this.clipboard = clipboard;
    context = new Aeron.Context();
    driver = MediaDriver.launchEmbedded();
    context.aeronDirectoryName(driver.aeronDirectoryName());
    aeron = Aeron.connect(context);
    this.shardIndex = shardIndex;
    multicastChannelUri = "aeron:udp?endpoint=" + voidConfiguration.getMulticastNetwork() + ":" + voidConfiguration.getMulticastPort();
    if (voidConfiguration.getMulticastInterface() != null && !voidConfiguration.getMulticastInterface().isEmpty())
        multicastChannelUri = multicastChannelUri + "|interface=" + voidConfiguration.getMulticastInterface();
    multicastChannelUri = multicastChannelUri + "|ttl=" + voidConfiguration.getTtl();
    if (voidConfiguration.getNumberOfShards() < 0)
        voidConfiguration.setNumberOfShards(voidConfiguration.getShardAddresses().size());
    switch(nodeRole) {
        case BACKUP:
        case SHARD:
            /*
                    In case of Shard, unicast address for communication is known in advance
                 */
            if (ip == null) {
                ip = localIp;
                port = voidConfiguration.getUnicastPort();
            }
            unicastChannelUri = "aeron:udp?endpoint=" + ip + ":" + port;
            log.info("Shard unicast URI: {}/{}", unicastChannelUri, voidConfiguration.getStreamId());
            // this channel will be used to receive batches from Clients
            subscriptionForShards = aeron.addSubscription(unicastChannelUri, voidConfiguration.getStreamId());
            // this channel will be used to send completion reports back to Clients
            publicationForClients = aeron.addPublication(multicastChannelUri, voidConfiguration.getStreamId() + 1);
            // this channel will be used for communication with other Shards
            publicationForShards = aeron.addPublication(multicastChannelUri, voidConfiguration.getStreamId() + 2);
            // this channel will be used to receive messages from other Shards
            subscriptionForClients = aeron.addSubscription(multicastChannelUri, voidConfiguration.getStreamId() + 2);
            messageHandlerForShards = new FragmentAssembler((buffer, offset, length, header) -> shardMessageHandler(buffer, offset, length, header));
            messageHandlerForClients = new FragmentAssembler(((buffer, offset, length, header) -> internalMessageHandler(buffer, offset, length, header)));
            break;
        case CLIENT:
            ip = localIp;
            /*
                    In case of Client, unicast will be one of shards, picked up with random
                 */
            // FIXME: we don't want that
            // ArrayUtil.getRandomElement(configuration.getShardAddresses());
            String rts = voidConfiguration.getShardAddresses().get(0);
            String[] split = rts.split(":");
            if (split.length == 1) {
                ip = rts;
                port = voidConfiguration.getUnicastPort();
            } else {
                ip = split[0];
                port = Integer.valueOf(split[1]);
            }
            unicastChannelUri = "aeron:udp?endpoint=" + ip + ":" + port;
            // unicastChannelUri = "aeron:udp?endpoint=" + ip  + ":" + (configuration.getUnicastPort()) ;
            log.info("Client unicast URI: {}/{}", unicastChannelUri, voidConfiguration.getStreamId());
            /*
                 this channel will be used to send batches to Shards, it's 1:1 channel to one of the Shards
                */
            publicationForShards = aeron.addPublication(unicastChannelUri, voidConfiguration.getStreamId());
            // this channel will be used to receive completion reports from Shards
            subscriptionForClients = aeron.addSubscription(multicastChannelUri, voidConfiguration.getStreamId() + 1);
            messageHandlerForClients = new FragmentAssembler((buffer, offset, length, header) -> clientMessageHandler(buffer, offset, length, header));
            break;
        default:
            log.warn("Unknown role passed: {}", nodeRole);
            throw new RuntimeException();
    }
    // if that's local spark run - we don't need this
    if (voidConfiguration.getNumberOfShards() == 1 && nodeRole == NodeRole.SHARD)
        shutdownSilent();
}
Also used : MediaDriver(io.aeron.driver.MediaDriver) Slf4j(lombok.extern.slf4j.Slf4j) Aeron(io.aeron.Aeron) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) NonNull(lombok.NonNull) Clipboard(org.nd4j.parameterserver.distributed.logic.completion.Clipboard) NodeRole(org.nd4j.parameterserver.distributed.enums.NodeRole) FragmentAssembler(io.aeron.FragmentAssembler) VoidConfiguration(org.nd4j.parameterserver.distributed.conf.VoidConfiguration) CloseHelper(org.agrona.CloseHelper) MeaningfulMessage(org.nd4j.parameterserver.distributed.messages.MeaningfulMessage) VoidMessage(org.nd4j.parameterserver.distributed.messages.VoidMessage) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) Aeron(io.aeron.Aeron) FragmentAssembler(io.aeron.FragmentAssembler)

Example 4 with NodeRole

use of org.nd4j.parameterserver.distributed.enums.NodeRole in project nd4j by deeplearning4j.

the class RoutedTransport method init.

@Override
public void init(@NonNull VoidConfiguration voidConfiguration, @NonNull Clipboard clipboard, @NonNull NodeRole role, @NonNull String localIp, int localPort, short shardIndex) {
    this.nodeRole = role;
    this.clipboard = clipboard;
    this.voidConfiguration = voidConfiguration;
    this.shardIndex = shardIndex;
    this.messages = new LinkedBlockingQueue<>();
    // shutdown hook
    super.init(voidConfiguration, clipboard, role, localIp, localPort, shardIndex);
    setProperty("aeron.client.liveness.timeout", "30000000000");
    context = new Aeron.Context().publicationConnectionTimeout(30000000000L).driverTimeoutMs(30000).keepAliveInterval(100000000);
    driver = MediaDriver.launchEmbedded();
    context.aeronDirectoryName(driver.aeronDirectoryName());
    aeron = Aeron.connect(context);
    if (router == null)
        router = new InterleavedRouter();
    // we skip IPs assign process if they were defined externally
    if (port == 0) {
        ip = localIp;
        port = localPort;
    }
    unicastChannelUri = "aeron:udp?endpoint=" + ip + ":" + port;
    subscriptionForClients = aeron.addSubscription(unicastChannelUri, voidConfiguration.getStreamId());
    // clean shut down
    Runtime.getRuntime().addShutdownHook(new Thread(() -> {
        CloseHelper.quietClose(aeron);
        CloseHelper.quietClose(driver);
        CloseHelper.quietClose(context);
        CloseHelper.quietClose(subscriptionForClients);
    }));
    messageHandlerForClients = new FragmentAssembler((buffer, offset, length, header) -> jointMessageHandler(buffer, offset, length, header));
    /*
            Now, regardless of current role,
             we set up publication channel to each shard
         */
    String shardChannelUri = null;
    String remoteIp = null;
    int remotePort = 0;
    for (String ip : voidConfiguration.getShardAddresses()) {
        if (ip.contains(":")) {
            shardChannelUri = "aeron:udp?endpoint=" + ip;
            String[] split = ip.split(":");
            remoteIp = split[0];
            remotePort = Integer.valueOf(split[1]);
        } else {
            shardChannelUri = "aeron:udp?endpoint=" + ip + ":" + voidConfiguration.getUnicastPort();
            remoteIp = ip;
            remotePort = voidConfiguration.getUnicastPort();
        }
        Publication publication = aeron.addPublication(shardChannelUri, voidConfiguration.getStreamId());
        RemoteConnection connection = RemoteConnection.builder().ip(remoteIp).port(remotePort).publication(publication).locker(new Object()).build();
        shards.add(connection);
    }
    if (nodeRole == NodeRole.SHARD)
        log.info("Initialized as [{}]; ShardIndex: [{}]; Own endpoint: [{}]", nodeRole, shardIndex, unicastChannelUri);
    else
        log.info("Initialized as [{}]; Own endpoint: [{}]", nodeRole, unicastChannelUri);
    switch(nodeRole) {
        case MASTER:
        case BACKUP:
            {
            }
        case SHARD:
            {
                /*
                    For unicast transport we want to have interconnects between all shards first of all, because we know their IPs in advance.
                    But due to design requirements, clients have the same first step, so it's kinda shared for all states :)
                 */
                /*
                    Next step is connections setup for backup nodes.
                    TODO: to be implemented
                 */
                addClient(ip, port);
            }
            break;
        case CLIENT:
            {
            /*
                    For Clients on unicast transport, we either set up connection to single Shard, or to multiple shards
                    But since this code is shared - we don't do anything here
                 */
            }
            break;
        default:
            throw new ND4JIllegalStateException("Unknown NodeRole being passed: " + nodeRole);
    }
    router.init(voidConfiguration, this);
    this.originatorId = HashUtil.getLongHash(this.getIp() + ":" + this.getPort());
}
Also used : java.util(java.util) HashUtil(org.nd4j.linalg.util.HashUtil) FragmentAssembler(io.aeron.FragmentAssembler) org.nd4j.parameterserver.distributed.messages(org.nd4j.parameterserver.distributed.messages) AtomicBoolean(java.util.concurrent.atomic.AtomicBoolean) IntroductionRequestMessage(org.nd4j.parameterserver.distributed.messages.requests.IntroductionRequestMessage) ClientRouter(org.nd4j.parameterserver.distributed.logic.ClientRouter) InterleavedRouter(org.nd4j.parameterserver.distributed.logic.routing.InterleavedRouter) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) StringUtils(org.nd4j.linalg.io.StringUtils) Publication(io.aeron.Publication) VoidConfiguration(org.nd4j.parameterserver.distributed.conf.VoidConfiguration) RetransmissionHandler(org.nd4j.parameterserver.distributed.logic.RetransmissionHandler) System.setProperty(java.lang.System.setProperty) CloseHelper(org.agrona.CloseHelper) MediaDriver(io.aeron.driver.MediaDriver) Aeron(io.aeron.Aeron) Nd4j(org.nd4j.linalg.factory.Nd4j) NodeRole(org.nd4j.parameterserver.distributed.enums.NodeRole) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) LinkedBlockingQueue(java.util.concurrent.LinkedBlockingQueue) lombok(lombok) LockSupport(java.util.concurrent.locks.LockSupport) Slf4j(lombok.extern.slf4j.Slf4j) Header(io.aeron.logbuffer.Header) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) Clipboard(org.nd4j.parameterserver.distributed.logic.completion.Clipboard) DirectBuffer(org.agrona.DirectBuffer) InterleavedRouter(org.nd4j.parameterserver.distributed.logic.routing.InterleavedRouter) Publication(io.aeron.Publication) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) FragmentAssembler(io.aeron.FragmentAssembler)

Example 5 with NodeRole

use of org.nd4j.parameterserver.distributed.enums.NodeRole in project nd4j by deeplearning4j.

the class RoutedTransport method sendMessageToAllClients.

@Override
public void sendMessageToAllClients(VoidMessage message, Long... exclusions) {
    if (nodeRole != NodeRole.SHARD)
        throw new ND4JIllegalStateException("Only SHARD allowed to send messages to all Clients");
    final DirectBuffer buffer = message.asUnsafeBuffer();
    // no need to search for matches above number of then exclusions
    final AtomicInteger cnt = new AtomicInteger(0);
    // final StringBuilder builder = new StringBuilder("Got message from: [").append(message.getOriginatorId()).append("]; Resend: {");
    clients.values().parallelStream().filter(rc -> {
        // do not send message back to yourself :)
        if (rc.getLongHash() == this.originatorId || rc.getLongHash() == 0) {
            // builder.append(", SKIP: ").append(rc.getLongHash());
            return false;
        }
        // we skip exclusions here
        if (exclusions != null && cnt.get() < exclusions.length) {
            for (Long exclude : exclusions) if (exclude.longValue() == rc.getLongHash()) {
                cnt.incrementAndGet();
                // builder.append(", SKIP: ").append(rc.getLongHash());
                return false;
            }
        }
        // builder.append(", PASS: ").append(rc.getLongHash());
        return true;
    }).forEach((rc) -> {
        // log.info("Sending message to {}", rc.getLongHash());
        RetransmissionHandler.TransmissionStatus res;
        long retr = 0;
        boolean delivered = false;
        while (!delivered) {
            // still stupid. maybe use real reentrant lock here?
            synchronized (rc.locker) {
                res = RetransmissionHandler.getTransmissionStatus(rc.getPublication().offer(buffer));
            }
            switch(res) {
                case NOT_CONNECTED:
                    {
                        if (!rc.getActivated().get()) {
                            retr++;
                            if (retr > 20)
                                throw new ND4JIllegalStateException("Can't connect to Shard: [" + rc.getPublication().channel() + "]");
                            try {
                                // Thread.sleep(voidConfiguration.getRetransmitTimeout());
                                LockSupport.parkNanos(voidConfiguration.getRetransmitTimeout() * 1000000);
                            } catch (Exception e) {
                                throw new RuntimeException(e);
                            }
                        } else {
                            throw new ND4JIllegalStateException("Shards reassignment is to be implemented yet");
                        }
                    }
                    break;
                case ADMIN_ACTION:
                case BACKPRESSURE:
                    {
                        try {
                            // Thread.sleep(voidConfiguration.getRetransmitTimeout());
                            LockSupport.parkNanos(voidConfiguration.getRetransmitTimeout() * 1000000);
                        } catch (Exception e) {
                            throw new RuntimeException(e);
                        }
                    }
                    break;
                case MESSAGE_SENT:
                    delivered = true;
                    rc.getActivated().set(true);
                    break;
            }
        }
    });
// s   log.info("RESULT: {}", builder.toString());
}
Also used : DirectBuffer(org.agrona.DirectBuffer) java.util(java.util) HashUtil(org.nd4j.linalg.util.HashUtil) FragmentAssembler(io.aeron.FragmentAssembler) org.nd4j.parameterserver.distributed.messages(org.nd4j.parameterserver.distributed.messages) AtomicBoolean(java.util.concurrent.atomic.AtomicBoolean) IntroductionRequestMessage(org.nd4j.parameterserver.distributed.messages.requests.IntroductionRequestMessage) ClientRouter(org.nd4j.parameterserver.distributed.logic.ClientRouter) InterleavedRouter(org.nd4j.parameterserver.distributed.logic.routing.InterleavedRouter) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) StringUtils(org.nd4j.linalg.io.StringUtils) Publication(io.aeron.Publication) VoidConfiguration(org.nd4j.parameterserver.distributed.conf.VoidConfiguration) RetransmissionHandler(org.nd4j.parameterserver.distributed.logic.RetransmissionHandler) System.setProperty(java.lang.System.setProperty) CloseHelper(org.agrona.CloseHelper) MediaDriver(io.aeron.driver.MediaDriver) Aeron(io.aeron.Aeron) Nd4j(org.nd4j.linalg.factory.Nd4j) NodeRole(org.nd4j.parameterserver.distributed.enums.NodeRole) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) LinkedBlockingQueue(java.util.concurrent.LinkedBlockingQueue) lombok(lombok) LockSupport(java.util.concurrent.locks.LockSupport) Slf4j(lombok.extern.slf4j.Slf4j) Header(io.aeron.logbuffer.Header) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) Clipboard(org.nd4j.parameterserver.distributed.logic.completion.Clipboard) DirectBuffer(org.agrona.DirectBuffer) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) RetransmissionHandler(org.nd4j.parameterserver.distributed.logic.RetransmissionHandler) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException)

Aggregations

NodeRole (org.nd4j.parameterserver.distributed.enums.NodeRole)5 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)4 VoidConfiguration (org.nd4j.parameterserver.distributed.conf.VoidConfiguration)4 Clipboard (org.nd4j.parameterserver.distributed.logic.completion.Clipboard)4 Aeron (io.aeron.Aeron)3 FragmentAssembler (io.aeron.FragmentAssembler)3 MediaDriver (io.aeron.driver.MediaDriver)3 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)3 Slf4j (lombok.extern.slf4j.Slf4j)3 CloseHelper (org.agrona.CloseHelper)3 Publication (io.aeron.Publication)2 Header (io.aeron.logbuffer.Header)2 System.setProperty (java.lang.System.setProperty)2 java.util (java.util)2 ConcurrentHashMap (java.util.concurrent.ConcurrentHashMap)2 LinkedBlockingQueue (java.util.concurrent.LinkedBlockingQueue)2 AtomicBoolean (java.util.concurrent.atomic.AtomicBoolean)2 LockSupport (java.util.concurrent.locks.LockSupport)2 lombok (lombok)2 DirectBuffer (org.agrona.DirectBuffer)2