Search in sources :

Example 1 with Clipboard

use of org.nd4j.parameterserver.distributed.logic.completion.Clipboard in project nd4j by deeplearning4j.

the class ClipboardTest method testPin1.

@Test
public void testPin1() throws Exception {
    Clipboard clipboard = new Clipboard();
    Random rng = new Random(12345L);
    for (int i = 0; i < 100; i++) {
        VectorAggregation aggregation = new VectorAggregation(rng.nextLong(), (short) 100, (short) i, Nd4j.create(5));
        clipboard.pin(aggregation);
    }
    assertEquals(false, clipboard.hasCandidates());
    assertEquals(0, clipboard.getNumberOfCompleteStacks());
    assertEquals(100, clipboard.getNumberOfPinnedStacks());
}
Also used : Clipboard(org.nd4j.parameterserver.distributed.logic.completion.Clipboard) VectorAggregation(org.nd4j.parameterserver.distributed.messages.aggregations.VectorAggregation) Test(org.junit.Test)

Example 2 with Clipboard

use of org.nd4j.parameterserver.distributed.logic.completion.Clipboard in project nd4j by deeplearning4j.

the class ClipboardTest method testPin2.

@Test
public void testPin2() throws Exception {
    Clipboard clipboard = new Clipboard();
    Random rng = new Random(12345L);
    Long validId = 123L;
    short shardIdx = 0;
    for (int i = 0; i < 300; i++) {
        VectorAggregation aggregation = new VectorAggregation(rng.nextLong(), (short) 100, (short) 1, Nd4j.create(5));
        // imitating valid
        if (i % 2 == 0 && shardIdx < 100) {
            aggregation.setTaskId(validId);
            aggregation.setShardIndex(shardIdx++);
        }
        clipboard.pin(aggregation);
    }
    VoidAggregation aggregation = clipboard.getStackFromClipboard(0L, validId);
    assertNotEquals(null, aggregation);
    assertEquals(0, aggregation.getMissingChunks());
    assertEquals(true, clipboard.hasCandidates());
    assertEquals(1, clipboard.getNumberOfCompleteStacks());
}
Also used : VoidAggregation(org.nd4j.parameterserver.distributed.messages.VoidAggregation) Clipboard(org.nd4j.parameterserver.distributed.logic.completion.Clipboard) VectorAggregation(org.nd4j.parameterserver.distributed.messages.aggregations.VectorAggregation) Test(org.junit.Test)

Example 3 with Clipboard

use of org.nd4j.parameterserver.distributed.logic.completion.Clipboard in project nd4j by deeplearning4j.

the class ClipboardTest method testPin3.

/**
 * This test checks how clipboard handles singular aggregations
 * @throws Exception
 */
@Test
public void testPin3() throws Exception {
    Clipboard clipboard = new Clipboard();
    Random rng = new Random(12345L);
    Long validId = 123L;
    InitializationAggregation aggregation = new InitializationAggregation(1, 0);
    clipboard.pin(aggregation);
    assertTrue(clipboard.isTracking(0L, aggregation.getTaskId()));
    assertTrue(clipboard.isReady(0L, aggregation.getTaskId()));
}
Also used : InitializationAggregation(org.nd4j.parameterserver.distributed.messages.aggregations.InitializationAggregation) Clipboard(org.nd4j.parameterserver.distributed.logic.completion.Clipboard) Test(org.junit.Test)

Example 4 with Clipboard

use of org.nd4j.parameterserver.distributed.logic.completion.Clipboard in project nd4j by deeplearning4j.

the class FrameTest method testFrame1.

/**
 * Simple test for Frame functionality
 */
@Test
public void testFrame1() {
    final AtomicInteger count = new AtomicInteger(0);
    Frame<TrainingMessage> frame = new Frame<>();
    for (int i = 0; i < 10; i++) {
        frame.stackMessage(new TrainingMessage() {

            @Override
            public byte getCounter() {
                return 2;
            }

            @Override
            public void setTargetId(short id) {
            }

            @Override
            public int getRetransmitCount() {
                return 0;
            }

            @Override
            public void incrementRetransmitCount() {
            }

            @Override
            public long getFrameId() {
                return 0;
            }

            @Override
            public void setFrameId(long frameId) {
            }

            @Override
            public long getOriginatorId() {
                return 0;
            }

            @Override
            public void setOriginatorId(long id) {
            }

            @Override
            public short getTargetId() {
                return 0;
            }

            @Override
            public long getTaskId() {
                return 0;
            }

            @Override
            public int getMessageType() {
                return 0;
            }

            @Override
            public byte[] asBytes() {
                return new byte[0];
            }

            @Override
            public UnsafeBuffer asUnsafeBuffer() {
                return null;
            }

            @Override
            public void attachContext(VoidConfiguration voidConfiguration, TrainingDriver<? extends TrainingMessage> trainer, Clipboard clipboard, Transport transport, Storage storage, NodeRole role, short shardIndex) {
            // no-op intentionally
            }

            @Override
            public void extractContext(BaseVoidMessage message) {
            // no-op intentionally
            }

            @Override
            public void processMessage() {
                count.incrementAndGet();
            }

            @Override
            public boolean isJoinSupported() {
                return false;
            }

            @Override
            public void joinMessage(VoidMessage message) {
            // no-op
            }

            @Override
            public boolean isBlockingMessage() {
                return false;
            }
        });
    }
    assertEquals(10, frame.size());
    frame.processMessage();
    assertEquals(20, count.get());
}
Also used : VoidConfiguration(org.nd4j.parameterserver.distributed.conf.VoidConfiguration) NodeRole(org.nd4j.parameterserver.distributed.enums.NodeRole) Storage(org.nd4j.parameterserver.distributed.logic.Storage) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) Clipboard(org.nd4j.parameterserver.distributed.logic.completion.Clipboard) UnsafeBuffer(org.agrona.concurrent.UnsafeBuffer) Transport(org.nd4j.parameterserver.distributed.transport.Transport) Test(org.junit.Test)

Example 5 with Clipboard

use of org.nd4j.parameterserver.distributed.logic.completion.Clipboard in project nd4j by deeplearning4j.

the class MulticastTransport method init.

@Override
public void init(@NonNull VoidConfiguration voidConfiguration, @NonNull Clipboard clipboard, @NonNull NodeRole role, @NonNull String localIp, int localPort, short shardIndex) {
    if (voidConfiguration.getTtl() < 1)
        throw new ND4JIllegalStateException("For MulticastTransport you should have TTL >= 1, it won't work otherwise");
    if (voidConfiguration.getMulticastNetwork() == null || voidConfiguration.getMulticastNetwork().isEmpty())
        throw new ND4JIllegalStateException("For MulticastTransport you should provide IP from multicast network available/allowed in your environment, i.e.: 224.0.1.1");
    // shutdown hook
    super.init(voidConfiguration, clipboard, role, localIp, localPort, shardIndex);
    this.voidConfiguration = voidConfiguration;
    this.nodeRole = role;
    this.clipboard = clipboard;
    context = new Aeron.Context();
    driver = MediaDriver.launchEmbedded();
    context.aeronDirectoryName(driver.aeronDirectoryName());
    aeron = Aeron.connect(context);
    this.shardIndex = shardIndex;
    multicastChannelUri = "aeron:udp?endpoint=" + voidConfiguration.getMulticastNetwork() + ":" + voidConfiguration.getMulticastPort();
    if (voidConfiguration.getMulticastInterface() != null && !voidConfiguration.getMulticastInterface().isEmpty())
        multicastChannelUri = multicastChannelUri + "|interface=" + voidConfiguration.getMulticastInterface();
    multicastChannelUri = multicastChannelUri + "|ttl=" + voidConfiguration.getTtl();
    if (voidConfiguration.getNumberOfShards() < 0)
        voidConfiguration.setNumberOfShards(voidConfiguration.getShardAddresses().size());
    switch(nodeRole) {
        case BACKUP:
        case SHARD:
            /*
                    In case of Shard, unicast address for communication is known in advance
                 */
            if (ip == null) {
                ip = localIp;
                port = voidConfiguration.getUnicastPort();
            }
            unicastChannelUri = "aeron:udp?endpoint=" + ip + ":" + port;
            log.info("Shard unicast URI: {}/{}", unicastChannelUri, voidConfiguration.getStreamId());
            // this channel will be used to receive batches from Clients
            subscriptionForShards = aeron.addSubscription(unicastChannelUri, voidConfiguration.getStreamId());
            // this channel will be used to send completion reports back to Clients
            publicationForClients = aeron.addPublication(multicastChannelUri, voidConfiguration.getStreamId() + 1);
            // this channel will be used for communication with other Shards
            publicationForShards = aeron.addPublication(multicastChannelUri, voidConfiguration.getStreamId() + 2);
            // this channel will be used to receive messages from other Shards
            subscriptionForClients = aeron.addSubscription(multicastChannelUri, voidConfiguration.getStreamId() + 2);
            messageHandlerForShards = new FragmentAssembler((buffer, offset, length, header) -> shardMessageHandler(buffer, offset, length, header));
            messageHandlerForClients = new FragmentAssembler(((buffer, offset, length, header) -> internalMessageHandler(buffer, offset, length, header)));
            break;
        case CLIENT:
            ip = localIp;
            /*
                    In case of Client, unicast will be one of shards, picked up with random
                 */
            // FIXME: we don't want that
            // ArrayUtil.getRandomElement(configuration.getShardAddresses());
            String rts = voidConfiguration.getShardAddresses().get(0);
            String[] split = rts.split(":");
            if (split.length == 1) {
                ip = rts;
                port = voidConfiguration.getUnicastPort();
            } else {
                ip = split[0];
                port = Integer.valueOf(split[1]);
            }
            unicastChannelUri = "aeron:udp?endpoint=" + ip + ":" + port;
            // unicastChannelUri = "aeron:udp?endpoint=" + ip  + ":" + (configuration.getUnicastPort()) ;
            log.info("Client unicast URI: {}/{}", unicastChannelUri, voidConfiguration.getStreamId());
            /*
                 this channel will be used to send batches to Shards, it's 1:1 channel to one of the Shards
                */
            publicationForShards = aeron.addPublication(unicastChannelUri, voidConfiguration.getStreamId());
            // this channel will be used to receive completion reports from Shards
            subscriptionForClients = aeron.addSubscription(multicastChannelUri, voidConfiguration.getStreamId() + 1);
            messageHandlerForClients = new FragmentAssembler((buffer, offset, length, header) -> clientMessageHandler(buffer, offset, length, header));
            break;
        default:
            log.warn("Unknown role passed: {}", nodeRole);
            throw new RuntimeException();
    }
    // if that's local spark run - we don't need this
    if (voidConfiguration.getNumberOfShards() == 1 && nodeRole == NodeRole.SHARD)
        shutdownSilent();
}
Also used : MediaDriver(io.aeron.driver.MediaDriver) Slf4j(lombok.extern.slf4j.Slf4j) Aeron(io.aeron.Aeron) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) NonNull(lombok.NonNull) Clipboard(org.nd4j.parameterserver.distributed.logic.completion.Clipboard) NodeRole(org.nd4j.parameterserver.distributed.enums.NodeRole) FragmentAssembler(io.aeron.FragmentAssembler) VoidConfiguration(org.nd4j.parameterserver.distributed.conf.VoidConfiguration) CloseHelper(org.agrona.CloseHelper) MeaningfulMessage(org.nd4j.parameterserver.distributed.messages.MeaningfulMessage) VoidMessage(org.nd4j.parameterserver.distributed.messages.VoidMessage) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) Aeron(io.aeron.Aeron) FragmentAssembler(io.aeron.FragmentAssembler)

Aggregations

Clipboard (org.nd4j.parameterserver.distributed.logic.completion.Clipboard)7 Test (org.junit.Test)5 VoidConfiguration (org.nd4j.parameterserver.distributed.conf.VoidConfiguration)4 NodeRole (org.nd4j.parameterserver.distributed.enums.NodeRole)3 Aeron (io.aeron.Aeron)2 FragmentAssembler (io.aeron.FragmentAssembler)2 MediaDriver (io.aeron.driver.MediaDriver)2 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)2 Slf4j (lombok.extern.slf4j.Slf4j)2 CloseHelper (org.agrona.CloseHelper)2 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)2 ClientRouter (org.nd4j.parameterserver.distributed.logic.ClientRouter)2 InterleavedRouter (org.nd4j.parameterserver.distributed.logic.routing.InterleavedRouter)2 VoidMessage (org.nd4j.parameterserver.distributed.messages.VoidMessage)2 VectorAggregation (org.nd4j.parameterserver.distributed.messages.aggregations.VectorAggregation)2 IntroductionRequestMessage (org.nd4j.parameterserver.distributed.messages.requests.IntroductionRequestMessage)2 Publication (io.aeron.Publication)1 Header (io.aeron.logbuffer.Header)1 System.setProperty (java.lang.System.setProperty)1 java.util (java.util)1