use of org.nd4j.parameterserver.ParameterServerSubscriber in project nd4j by deeplearning4j.
the class ParameterServerClientPartialTest method before.
@BeforeClass
public static void before() throws Exception {
final MediaDriver.Context ctx = new MediaDriver.Context().threadingMode(ThreadingMode.SHARED).dirsDeleteOnStart(true).termBufferSparseFile(false).conductorIdleStrategy(new BusySpinIdleStrategy()).receiverIdleStrategy(new BusySpinIdleStrategy()).senderIdleStrategy(new BusySpinIdleStrategy());
mediaDriver = MediaDriver.launchEmbedded(ctx);
aeron = Aeron.connect(getContext());
masterNode = new ParameterServerSubscriber(mediaDriver);
masterNode.setAeron(aeron);
int masterPort = 40223 + new java.util.Random().nextInt(13000);
int masterStatusPort = masterPort - 2000;
masterNode.run(new String[] { "-m", "true", "-p", String.valueOf(masterPort), "-h", "localhost", "-id", "11", "-md", mediaDriver.aeronDirectoryName(), "-sp", String.valueOf(masterStatusPort), "-s", "2,2", "-u", String.valueOf(1) });
assertTrue(masterNode.isMaster());
assertEquals(masterPort, masterNode.getPort());
assertEquals("localhost", masterNode.getHost());
assertEquals(11, masterNode.getStreamId());
assertEquals(12, masterNode.getResponder().getStreamId());
assertEquals(masterNode.getMasterArray(), Nd4j.create(new int[] { 2, 2 }));
slaveNode = new ParameterServerSubscriber(mediaDriver);
slaveNode.setAeron(aeron);
int slavePort = masterPort + 100;
int slaveStatusPort = slavePort - 2000;
slaveNode.run(new String[] { "-p", String.valueOf(slavePort), "-h", "localhost", "-id", "10", "-pm", masterNode.getSubscriber().connectionUrl(), "-md", mediaDriver.aeronDirectoryName(), "-sp", String.valueOf(slaveStatusPort), "-u", String.valueOf(1) });
assertFalse(slaveNode.isMaster());
assertEquals(slavePort, slaveNode.getPort());
assertEquals("localhost", slaveNode.getHost());
assertEquals(10, slaveNode.getStreamId());
int tries = 10;
while (!masterNode.subscriberLaunched() && !slaveNode.subscriberLaunched() && tries < 10) {
Thread.sleep(10000);
tries++;
}
if (!masterNode.subscriberLaunched() && !slaveNode.subscriberLaunched()) {
throw new IllegalStateException("Failed to start master and slave node");
}
log.info("Using media driver directory " + mediaDriver.aeronDirectoryName());
log.info("Launched media driver");
}
use of org.nd4j.parameterserver.ParameterServerSubscriber in project nd4j by deeplearning4j.
the class ParameterServerNode method runMain.
/**
* Run this node with the given args
* These args are the same ones
* that a {@link ParameterServerSubscriber} takes
* @param args the arguments for the {@link ParameterServerSubscriber}
*/
public void runMain(String[] args) {
server = StatusServer.startServer(new InMemoryStatusStorage(), statusPort);
if (mediaDriver == null)
mediaDriver = MediaDriver.launchEmbedded();
log.info("Started media driver with aeron directory " + mediaDriver.aeronDirectoryName());
// cache a reference to the first listener.
// The reason we do this is to share an updater and listener across *all* subscribers
// This will create a shared pool of subscribers all updating the same "server".
// This will simulate a shared pool but allow an accumulative effect of anything
// like averaging we try.
NDArrayCallback parameterServerListener = null;
ParameterServerListener cast = null;
for (int i = 0; i < numWorkers; i++) {
subscriber[i] = new ParameterServerSubscriber(mediaDriver);
// ensure reuse of aeron wherever possible
if (aeron == null)
aeron = Aeron.connect(getContext(mediaDriver));
subscriber[i].setAeron(aeron);
List<String> multiArgs = new ArrayList<>(Arrays.asList(args));
if (multiArgs.contains("-id")) {
int streamIdIdx = multiArgs.indexOf("-id") + 1;
int streamId = Integer.parseInt(multiArgs.get(streamIdIdx)) + i;
multiArgs.set(streamIdIdx, String.valueOf(streamId));
} else if (multiArgs.contains("--streamId")) {
int streamIdIdx = multiArgs.indexOf("--streamId") + 1;
int streamId = Integer.parseInt(multiArgs.get(streamIdIdx)) + i;
multiArgs.set(streamIdIdx, String.valueOf(streamId));
}
if (i == 0) {
subscriber[i].run(multiArgs.toArray(new String[args.length]));
parameterServerListener = subscriber[i].getCallback();
cast = subscriber[i].getParameterServerListener();
} else {
// note that we set both the callback AND the listener here
subscriber[i].setCallback(parameterServerListener);
subscriber[i].setParameterServerListener(cast);
// now run the callback initialized with this callback instead
// in the run method it will use this reference instead of creating it
// itself
subscriber[i].run(multiArgs.toArray(new String[args.length]));
}
}
}
use of org.nd4j.parameterserver.ParameterServerSubscriber in project nd4j by deeplearning4j.
the class ParameterServerClientTest method before.
@BeforeClass
public static void before() throws Exception {
mediaDriver = MediaDriver.launchEmbedded(AeronUtil.getMediaDriverContext(parameterLength));
System.setProperty("play.server.dir", "/tmp");
aeron = Aeron.connect(getContext());
masterNode = new ParameterServerSubscriber(mediaDriver);
masterNode.setAeron(aeron);
int masterPort = 40323 + new java.util.Random().nextInt(3000);
masterNode.run(new String[] { "-m", "true", "-s", "1," + String.valueOf(parameterLength), "-p", String.valueOf(masterPort), "-h", "localhost", "-id", "11", "-md", mediaDriver.aeronDirectoryName(), "-sp", "33000", "-u", String.valueOf(1) });
assertTrue(masterNode.isMaster());
assertEquals(masterPort, masterNode.getPort());
assertEquals("localhost", masterNode.getHost());
assertEquals(11, masterNode.getStreamId());
assertEquals(12, masterNode.getResponder().getStreamId());
slaveNode = new ParameterServerSubscriber(mediaDriver);
slaveNode.setAeron(aeron);
slaveNode.run(new String[] { "-p", String.valueOf(masterPort + 100), "-h", "localhost", "-id", "10", "-pm", masterNode.getSubscriber().connectionUrl(), "-md", mediaDriver.aeronDirectoryName(), "-sp", "31000", "-u", String.valueOf(1) });
assertFalse(slaveNode.isMaster());
assertEquals(masterPort + 100, slaveNode.getPort());
assertEquals("localhost", slaveNode.getHost());
assertEquals(10, slaveNode.getStreamId());
int tries = 10;
while (!masterNode.subscriberLaunched() && !slaveNode.subscriberLaunched() && tries < 10) {
Thread.sleep(10000);
tries++;
}
if (!masterNode.subscriberLaunched() && !slaveNode.subscriberLaunched()) {
throw new IllegalStateException("Failed to start master and slave node");
}
log.info("Using media driver directory " + mediaDriver.aeronDirectoryName());
log.info("Launched media driver");
}
Aggregations