use of org.nd4j.parameterserver.ParameterServerListener in project nd4j by deeplearning4j.
the class ParameterServerClientTest method testServer.
@Test
public void testServer() throws Exception {
int subscriberPort = 40625 + new java.util.Random().nextInt(100);
ParameterServerClient client = ParameterServerClient.builder().aeron(aeron).ndarrayRetrieveUrl(masterNode.getResponder().connectionUrl()).ndarraySendUrl(slaveNode.getSubscriber().connectionUrl()).subscriberHost("localhost").subscriberPort(subscriberPort).subscriberStream(12).build();
assertEquals(String.format("localhost:%d:12", subscriberPort), client.connectionUrl());
// flow 1:
/**
* Client (40125:12): sends array to listener on slave(40126:10)
* which publishes to master (40123:11)
* which adds the array for parameter averaging.
* In this case totalN should be 1.
*/
client.pushNDArray(Nd4j.ones(parameterLength));
log.info("Pushed ndarray");
Thread.sleep(30000);
ParameterServerListener listener = (ParameterServerListener) masterNode.getCallback();
assertEquals(1, listener.getUpdater().numUpdates());
assertEquals(Nd4j.ones(parameterLength), listener.getUpdater().ndArrayHolder().get());
INDArray arr = client.getArray();
assertEquals(Nd4j.ones(1000), arr);
}
use of org.nd4j.parameterserver.ParameterServerListener in project nd4j by deeplearning4j.
the class ParameterServerNode method runMain.
/**
* Run this node with the given args
* These args are the same ones
* that a {@link ParameterServerSubscriber} takes
* @param args the arguments for the {@link ParameterServerSubscriber}
*/
public void runMain(String[] args) {
server = StatusServer.startServer(new InMemoryStatusStorage(), statusPort);
if (mediaDriver == null)
mediaDriver = MediaDriver.launchEmbedded();
log.info("Started media driver with aeron directory " + mediaDriver.aeronDirectoryName());
// cache a reference to the first listener.
// The reason we do this is to share an updater and listener across *all* subscribers
// This will create a shared pool of subscribers all updating the same "server".
// This will simulate a shared pool but allow an accumulative effect of anything
// like averaging we try.
NDArrayCallback parameterServerListener = null;
ParameterServerListener cast = null;
for (int i = 0; i < numWorkers; i++) {
subscriber[i] = new ParameterServerSubscriber(mediaDriver);
// ensure reuse of aeron wherever possible
if (aeron == null)
aeron = Aeron.connect(getContext(mediaDriver));
subscriber[i].setAeron(aeron);
List<String> multiArgs = new ArrayList<>(Arrays.asList(args));
if (multiArgs.contains("-id")) {
int streamIdIdx = multiArgs.indexOf("-id") + 1;
int streamId = Integer.parseInt(multiArgs.get(streamIdIdx)) + i;
multiArgs.set(streamIdIdx, String.valueOf(streamId));
} else if (multiArgs.contains("--streamId")) {
int streamIdIdx = multiArgs.indexOf("--streamId") + 1;
int streamId = Integer.parseInt(multiArgs.get(streamIdIdx)) + i;
multiArgs.set(streamIdIdx, String.valueOf(streamId));
}
if (i == 0) {
subscriber[i].run(multiArgs.toArray(new String[args.length]));
parameterServerListener = subscriber[i].getCallback();
cast = subscriber[i].getParameterServerListener();
} else {
// note that we set both the callback AND the listener here
subscriber[i].setCallback(parameterServerListener);
subscriber[i].setParameterServerListener(cast);
// now run the callback initialized with this callback instead
// in the run method it will use this reference instead of creating it
// itself
subscriber[i].run(multiArgs.toArray(new String[args.length]));
}
}
}
use of org.nd4j.parameterserver.ParameterServerListener in project nd4j by deeplearning4j.
the class ParameterServerClientPartialTest method testServer.
@Test
public void testServer() throws Exception {
ParameterServerClient client = ParameterServerClient.builder().aeron(aeron).ndarrayRetrieveUrl(masterNode.getResponder().connectionUrl()).ndarraySendUrl(slaveNode.getSubscriber().connectionUrl()).subscriberHost("localhost").subscriberPort(40325).subscriberStream(12).build();
assertEquals("localhost:40325:12", client.connectionUrl());
// flow 1:
/**
* Client (40125:12): sends array to listener on slave(40126:10)
* which publishes to master (40123:11)
* which adds the array for parameter averaging.
* In this case totalN should be 1.
*/
client.pushNDArrayMessage(NDArrayMessage.of(Nd4j.ones(2), new int[] { 0 }, 0));
log.info("Pushed ndarray");
Thread.sleep(30000);
ParameterServerListener listener = (ParameterServerListener) masterNode.getCallback();
assertEquals(1, listener.getUpdater().numUpdates());
INDArray assertion = Nd4j.create(new int[] { 2, 2 });
assertion.getColumn(0).addi(1.0);
assertEquals(assertion, listener.getUpdater().ndArrayHolder().get());
INDArray arr = client.getArray();
assertEquals(assertion, arr);
}
Aggregations