use of io.aeron.driver.MediaDriver in project deeplearning4j by deeplearning4j.
the class ParameterServerParallelWrapper method init.
private void init(Object iterator) {
if (numEpochs < 1)
throw new IllegalStateException("numEpochs must be >= 1");
//TODO: make this efficient
if (iterator instanceof DataSetIterator) {
DataSetIterator dataSetIterator = (DataSetIterator) iterator;
numUpdatesPerEpoch = numUpdatesPerEpoch(dataSetIterator);
} else if (iterator instanceof MultiDataSetIterator) {
MultiDataSetIterator iterator1 = (MultiDataSetIterator) iterator;
numUpdatesPerEpoch = numUpdatesPerEpoch(iterator1);
} else
throw new IllegalArgumentException("Illegal type of object passed in for initialization. Must be of type DataSetIterator or MultiDataSetIterator");
mediaDriverContext = new MediaDriver.Context();
mediaDriver = MediaDriver.launchEmbedded(mediaDriverContext);
parameterServerNode = new ParameterServerNode(mediaDriver, statusServerPort, numWorkers);
running = new AtomicBoolean(true);
if (parameterServerArgs == null)
parameterServerArgs = new String[] { "-m", "true", "-s", "1," + String.valueOf(model.numParams()), "-p", "40323", "-h", "localhost", "-id", "11", "-md", mediaDriver.aeronDirectoryName(), "-sh", "localhost", "-sp", String.valueOf(statusServerPort), "-u", String.valueOf(numUpdatesPerEpoch) };
if (numWorkers == 0)
numWorkers = Runtime.getRuntime().availableProcessors();
linkedBlockingQueue = new LinkedBlockingQueue<>(numWorkers);
//pass through args for the parameter server subscriber
parameterServerNode.runMain(parameterServerArgs);
while (!parameterServerNode.subscriberLaunched()) {
try {
Thread.sleep(10000);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}
try {
Thread.sleep(10000);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
log.info("Parameter server started");
parameterServerClient = new Trainer[numWorkers];
executorService = Executors.newFixedThreadPool(numWorkers);
for (int i = 0; i < numWorkers; i++) {
Model model = null;
if (this.model instanceof ComputationGraph) {
ComputationGraph computationGraph = (ComputationGraph) this.model;
model = computationGraph.clone();
} else if (this.model instanceof MultiLayerNetwork) {
MultiLayerNetwork multiLayerNetwork = (MultiLayerNetwork) this.model;
model = multiLayerNetwork.clone();
}
parameterServerClient[i] = new Trainer(ParameterServerClient.builder().aeron(parameterServerNode.getAeron()).ndarrayRetrieveUrl(parameterServerNode.getSubscriber()[i].getResponder().connectionUrl()).ndarraySendUrl(parameterServerNode.getSubscriber()[i].getSubscriber().connectionUrl()).subscriberHost("localhost").masterStatusHost("localhost").masterStatusPort(statusServerPort).subscriberPort(40625 + i).subscriberStream(12 + i).build(), running, linkedBlockingQueue, model);
final int j = i;
executorService.submit(() -> parameterServerClient[j].start());
}
init = true;
log.info("Initialized wrapper");
}
use of io.aeron.driver.MediaDriver in project Aeron by real-logic.
the class TermBufferLengthTest method shouldHaveCorrectTermBufferLength.
@Theory
@Test(timeout = 10000)
public void shouldHaveCorrectTermBufferLength(final String channel) throws Exception {
final MediaDriver.Context ctx = new MediaDriver.Context();
ctx.publicationTermBufferLength(TEST_TERM_LENGTH * 2);
ctx.ipcTermBufferLength(TEST_TERM_LENGTH * 2);
try (MediaDriver ignore = MediaDriver.launch(ctx);
Aeron aeron = Aeron.connect();
Publication publication = aeron.addPublication(channel, STREAM_ID)) {
assertThat(publication.termBufferLength(), is(TEST_TERM_LENGTH));
} finally {
ctx.deleteAeronDirectory();
}
}
use of io.aeron.driver.MediaDriver in project Aeron by real-logic.
the class FragmentedMessageTest method shouldReceivePublishedMessage.
@Theory
@Test(timeout = 10000)
public void shouldReceivePublishedMessage(final String channel, final ThreadingMode threadingMode) throws Exception {
final MediaDriver.Context ctx = new MediaDriver.Context();
ctx.threadingMode(threadingMode);
final FragmentAssembler adapter = new FragmentAssembler(mockFragmentHandler);
try (MediaDriver ignore = MediaDriver.launch(ctx);
Aeron aeron = Aeron.connect();
Publication publication = aeron.addPublication(channel, STREAM_ID);
Subscription subscription = aeron.addSubscription(channel, STREAM_ID)) {
final UnsafeBuffer srcBuffer = new UnsafeBuffer(new byte[ctx.mtuLength() * 4]);
final int offset = 0;
final int length = srcBuffer.capacity() / 4;
for (int i = 0; i < 4; i++) {
srcBuffer.setMemory(i * length, length, (byte) (65 + i));
}
while (publication.offer(srcBuffer, offset, srcBuffer.capacity()) < 0L) {
Thread.yield();
}
final int expectedFragmentsBecauseOfHeader = 5;
int numFragments = 0;
do {
numFragments += subscription.poll(adapter, FRAGMENT_COUNT_LIMIT);
} while (numFragments < expectedFragmentsBecauseOfHeader);
final ArgumentCaptor<DirectBuffer> bufferArg = ArgumentCaptor.forClass(DirectBuffer.class);
final ArgumentCaptor<Header> headerArg = ArgumentCaptor.forClass(Header.class);
verify(mockFragmentHandler, times(1)).onFragment(bufferArg.capture(), eq(offset), eq(srcBuffer.capacity()), headerArg.capture());
final DirectBuffer capturedBuffer = bufferArg.getValue();
for (int i = 0; i < srcBuffer.capacity(); i++) {
assertThat("same at i=" + i, capturedBuffer.getByte(i), is(srcBuffer.getByte(i)));
}
assertThat(headerArg.getValue().flags(), is(END_FRAG_FLAG));
} finally {
ctx.deleteAeronDirectory();
}
}
use of io.aeron.driver.MediaDriver in project nd4j by deeplearning4j.
the class LowLatencyMediaDriver method main.
@SuppressWarnings("checkstyle:UncommentedMain")
public static void main(final String... args) {
MediaDriver.loadPropertiesFiles(args);
setProperty(DISABLE_BOUNDS_CHECKS_PROP_NAME, "true");
setProperty("aeron.mtu.length", "16384");
setProperty("aeron.socket.so_sndbuf", "2097152");
setProperty("aeron.socket.so_rcvbuf", "2097152");
setProperty("aeron.rcv.initial.window.length", "2097152");
final MediaDriver.Context ctx = new MediaDriver.Context().threadingMode(ThreadingMode.DEDICATED).dirsDeleteOnStart(true).termBufferSparseFile(false).conductorIdleStrategy(new BusySpinIdleStrategy()).receiverIdleStrategy(new BusySpinIdleStrategy()).senderIdleStrategy(new BusySpinIdleStrategy());
try (MediaDriver ignored = MediaDriver.launch(ctx)) {
new SigIntBarrier().await();
}
}
use of io.aeron.driver.MediaDriver in project nd4j by deeplearning4j.
the class AeronNDArrayResponseTest method before.
@Before
public void before() {
final MediaDriver.Context ctx = new MediaDriver.Context().threadingMode(ThreadingMode.SHARED).dirsDeleteOnStart(true).termBufferSparseFile(false).conductorIdleStrategy(new BusySpinIdleStrategy()).receiverIdleStrategy(new BusySpinIdleStrategy()).senderIdleStrategy(new BusySpinIdleStrategy());
mediaDriver = MediaDriver.launchEmbedded(ctx);
System.out.println("Using media driver directory " + mediaDriver.aeronDirectoryName());
System.out.println("Launched media driver");
}
Aggregations