Search in sources :

Example 26 with MediaDriver

use of io.aeron.driver.MediaDriver in project deeplearning4j by deeplearning4j.

the class ParameterServerParallelWrapper method init.

private void init(Object iterator) {
    if (numEpochs < 1)
        throw new IllegalStateException("numEpochs must be >= 1");
    //TODO: make this efficient
    if (iterator instanceof DataSetIterator) {
        DataSetIterator dataSetIterator = (DataSetIterator) iterator;
        numUpdatesPerEpoch = numUpdatesPerEpoch(dataSetIterator);
    } else if (iterator instanceof MultiDataSetIterator) {
        MultiDataSetIterator iterator1 = (MultiDataSetIterator) iterator;
        numUpdatesPerEpoch = numUpdatesPerEpoch(iterator1);
    } else
        throw new IllegalArgumentException("Illegal type of object passed in for initialization. Must be of type DataSetIterator or MultiDataSetIterator");
    mediaDriverContext = new MediaDriver.Context();
    mediaDriver = MediaDriver.launchEmbedded(mediaDriverContext);
    parameterServerNode = new ParameterServerNode(mediaDriver, statusServerPort, numWorkers);
    running = new AtomicBoolean(true);
    if (parameterServerArgs == null)
        parameterServerArgs = new String[] { "-m", "true", "-s", "1," + String.valueOf(model.numParams()), "-p", "40323", "-h", "localhost", "-id", "11", "-md", mediaDriver.aeronDirectoryName(), "-sh", "localhost", "-sp", String.valueOf(statusServerPort), "-u", String.valueOf(numUpdatesPerEpoch) };
    if (numWorkers == 0)
        numWorkers = Runtime.getRuntime().availableProcessors();
    linkedBlockingQueue = new LinkedBlockingQueue<>(numWorkers);
    //pass through args for the parameter server subscriber
    parameterServerNode.runMain(parameterServerArgs);
    while (!parameterServerNode.subscriberLaunched()) {
        try {
            Thread.sleep(10000);
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
        }
    }
    try {
        Thread.sleep(10000);
    } catch (InterruptedException e) {
        Thread.currentThread().interrupt();
    }
    log.info("Parameter server started");
    parameterServerClient = new Trainer[numWorkers];
    executorService = Executors.newFixedThreadPool(numWorkers);
    for (int i = 0; i < numWorkers; i++) {
        Model model = null;
        if (this.model instanceof ComputationGraph) {
            ComputationGraph computationGraph = (ComputationGraph) this.model;
            model = computationGraph.clone();
        } else if (this.model instanceof MultiLayerNetwork) {
            MultiLayerNetwork multiLayerNetwork = (MultiLayerNetwork) this.model;
            model = multiLayerNetwork.clone();
        }
        parameterServerClient[i] = new Trainer(ParameterServerClient.builder().aeron(parameterServerNode.getAeron()).ndarrayRetrieveUrl(parameterServerNode.getSubscriber()[i].getResponder().connectionUrl()).ndarraySendUrl(parameterServerNode.getSubscriber()[i].getSubscriber().connectionUrl()).subscriberHost("localhost").masterStatusHost("localhost").masterStatusPort(statusServerPort).subscriberPort(40625 + i).subscriberStream(12 + i).build(), running, linkedBlockingQueue, model);
        final int j = i;
        executorService.submit(() -> parameterServerClient[j].start());
    }
    init = true;
    log.info("Initialized wrapper");
}
Also used : ParameterServerNode(org.nd4j.parameterserver.node.ParameterServerNode) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) AsyncMultiDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator) AtomicBoolean(java.util.concurrent.atomic.AtomicBoolean) MediaDriver(io.aeron.driver.MediaDriver) Model(org.deeplearning4j.nn.api.Model) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) AsyncDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncDataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) AsyncMultiDataSetIterator(org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator)

Example 27 with MediaDriver

use of io.aeron.driver.MediaDriver in project Aeron by real-logic.

the class TermBufferLengthTest method shouldHaveCorrectTermBufferLength.

@Theory
@Test(timeout = 10000)
public void shouldHaveCorrectTermBufferLength(final String channel) throws Exception {
    final MediaDriver.Context ctx = new MediaDriver.Context();
    ctx.publicationTermBufferLength(TEST_TERM_LENGTH * 2);
    ctx.ipcTermBufferLength(TEST_TERM_LENGTH * 2);
    try (MediaDriver ignore = MediaDriver.launch(ctx);
        Aeron aeron = Aeron.connect();
        Publication publication = aeron.addPublication(channel, STREAM_ID)) {
        assertThat(publication.termBufferLength(), is(TEST_TERM_LENGTH));
    } finally {
        ctx.deleteAeronDirectory();
    }
}
Also used : MediaDriver(io.aeron.driver.MediaDriver) Theory(org.junit.experimental.theories.Theory) Test(org.junit.Test)

Example 28 with MediaDriver

use of io.aeron.driver.MediaDriver in project Aeron by real-logic.

the class FragmentedMessageTest method shouldReceivePublishedMessage.

@Theory
@Test(timeout = 10000)
public void shouldReceivePublishedMessage(final String channel, final ThreadingMode threadingMode) throws Exception {
    final MediaDriver.Context ctx = new MediaDriver.Context();
    ctx.threadingMode(threadingMode);
    final FragmentAssembler adapter = new FragmentAssembler(mockFragmentHandler);
    try (MediaDriver ignore = MediaDriver.launch(ctx);
        Aeron aeron = Aeron.connect();
        Publication publication = aeron.addPublication(channel, STREAM_ID);
        Subscription subscription = aeron.addSubscription(channel, STREAM_ID)) {
        final UnsafeBuffer srcBuffer = new UnsafeBuffer(new byte[ctx.mtuLength() * 4]);
        final int offset = 0;
        final int length = srcBuffer.capacity() / 4;
        for (int i = 0; i < 4; i++) {
            srcBuffer.setMemory(i * length, length, (byte) (65 + i));
        }
        while (publication.offer(srcBuffer, offset, srcBuffer.capacity()) < 0L) {
            Thread.yield();
        }
        final int expectedFragmentsBecauseOfHeader = 5;
        int numFragments = 0;
        do {
            numFragments += subscription.poll(adapter, FRAGMENT_COUNT_LIMIT);
        } while (numFragments < expectedFragmentsBecauseOfHeader);
        final ArgumentCaptor<DirectBuffer> bufferArg = ArgumentCaptor.forClass(DirectBuffer.class);
        final ArgumentCaptor<Header> headerArg = ArgumentCaptor.forClass(Header.class);
        verify(mockFragmentHandler, times(1)).onFragment(bufferArg.capture(), eq(offset), eq(srcBuffer.capacity()), headerArg.capture());
        final DirectBuffer capturedBuffer = bufferArg.getValue();
        for (int i = 0; i < srcBuffer.capacity(); i++) {
            assertThat("same at i=" + i, capturedBuffer.getByte(i), is(srcBuffer.getByte(i)));
        }
        assertThat(headerArg.getValue().flags(), is(END_FRAG_FLAG));
    } finally {
        ctx.deleteAeronDirectory();
    }
}
Also used : DataPoint(org.junit.experimental.theories.DataPoint) DirectBuffer(org.agrona.DirectBuffer) MediaDriver(io.aeron.driver.MediaDriver) Header(io.aeron.logbuffer.Header) UnsafeBuffer(org.agrona.concurrent.UnsafeBuffer) Theory(org.junit.experimental.theories.Theory) Test(org.junit.Test)

Example 29 with MediaDriver

use of io.aeron.driver.MediaDriver in project nd4j by deeplearning4j.

the class LowLatencyMediaDriver method main.

@SuppressWarnings("checkstyle:UncommentedMain")
public static void main(final String... args) {
    MediaDriver.loadPropertiesFiles(args);
    setProperty(DISABLE_BOUNDS_CHECKS_PROP_NAME, "true");
    setProperty("aeron.mtu.length", "16384");
    setProperty("aeron.socket.so_sndbuf", "2097152");
    setProperty("aeron.socket.so_rcvbuf", "2097152");
    setProperty("aeron.rcv.initial.window.length", "2097152");
    final MediaDriver.Context ctx = new MediaDriver.Context().threadingMode(ThreadingMode.DEDICATED).dirsDeleteOnStart(true).termBufferSparseFile(false).conductorIdleStrategy(new BusySpinIdleStrategy()).receiverIdleStrategy(new BusySpinIdleStrategy()).senderIdleStrategy(new BusySpinIdleStrategy());
    try (MediaDriver ignored = MediaDriver.launch(ctx)) {
        new SigIntBarrier().await();
    }
}
Also used : SigIntBarrier(org.agrona.concurrent.SigIntBarrier) MediaDriver(io.aeron.driver.MediaDriver) BusySpinIdleStrategy(org.agrona.concurrent.BusySpinIdleStrategy)

Example 30 with MediaDriver

use of io.aeron.driver.MediaDriver in project nd4j by deeplearning4j.

the class AeronNDArrayResponseTest method before.

@Before
public void before() {
    final MediaDriver.Context ctx = new MediaDriver.Context().threadingMode(ThreadingMode.SHARED).dirsDeleteOnStart(true).termBufferSparseFile(false).conductorIdleStrategy(new BusySpinIdleStrategy()).receiverIdleStrategy(new BusySpinIdleStrategy()).senderIdleStrategy(new BusySpinIdleStrategy());
    mediaDriver = MediaDriver.launchEmbedded(ctx);
    System.out.println("Using media driver directory " + mediaDriver.aeronDirectoryName());
    System.out.println("Launched media driver");
}
Also used : MediaDriver(io.aeron.driver.MediaDriver) BusySpinIdleStrategy(org.agrona.concurrent.BusySpinIdleStrategy) Before(org.junit.Before)

Aggregations

MediaDriver (io.aeron.driver.MediaDriver)59 Aeron (io.aeron.Aeron)22 AtomicBoolean (java.util.concurrent.atomic.AtomicBoolean)22 Subscription (io.aeron.Subscription)13 FragmentHandler (io.aeron.logbuffer.FragmentHandler)11 ContinueBarrier (org.agrona.console.ContinueBarrier)11 Test (org.junit.jupiter.api.Test)10 Publication (io.aeron.Publication)9 IdleStrategy (org.agrona.concurrent.IdleStrategy)9 UnsafeBuffer (org.agrona.concurrent.UnsafeBuffer)9 InterruptAfter (io.aeron.test.InterruptAfter)8 ExecutorService (java.util.concurrent.ExecutorService)8 BusySpinIdleStrategy (org.agrona.concurrent.BusySpinIdleStrategy)8 TestMediaDriver (io.aeron.test.driver.TestMediaDriver)7 Archive (io.aeron.archive.Archive)6 ParameterizedTest (org.junit.jupiter.params.ParameterizedTest)6 Tests (io.aeron.test.Tests)5 DirectBuffer (org.agrona.DirectBuffer)5 CommonContext (io.aeron.CommonContext)4 File (java.io.File)4