Search in sources :

Example 1 with ParameterServerListener

use of org.nd4j.parameterserver.ParameterServerListener in project nd4j by deeplearning4j.

the class ParameterServerClientTest method testServer.

@Test
public void testServer() throws Exception {
    int subscriberPort = 40625 + new java.util.Random().nextInt(100);
    ParameterServerClient client = ParameterServerClient.builder().aeron(aeron).ndarrayRetrieveUrl(masterNode.getResponder().connectionUrl()).ndarraySendUrl(slaveNode.getSubscriber().connectionUrl()).subscriberHost("localhost").subscriberPort(subscriberPort).subscriberStream(12).build();
    assertEquals(String.format("localhost:%d:12", subscriberPort), client.connectionUrl());
    // flow 1:
    /**
     * Client (40125:12): sends array to listener on slave(40126:10)
     * which publishes to master (40123:11)
     * which adds the array for parameter averaging.
     * In this case totalN should be 1.
     */
    client.pushNDArray(Nd4j.ones(parameterLength));
    log.info("Pushed ndarray");
    Thread.sleep(30000);
    ParameterServerListener listener = (ParameterServerListener) masterNode.getCallback();
    assertEquals(1, listener.getUpdater().numUpdates());
    assertEquals(Nd4j.ones(parameterLength), listener.getUpdater().ndArrayHolder().get());
    INDArray arr = client.getArray();
    assertEquals(Nd4j.ones(1000), arr);
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) ParameterServerListener(org.nd4j.parameterserver.ParameterServerListener) Test(org.junit.Test)

Example 2 with ParameterServerListener

use of org.nd4j.parameterserver.ParameterServerListener in project nd4j by deeplearning4j.

the class ParameterServerNode method runMain.

/**
 * Run this node with the given args
 * These args are the same ones
 * that a {@link ParameterServerSubscriber} takes
 * @param args the arguments for the {@link ParameterServerSubscriber}
 */
public void runMain(String[] args) {
    server = StatusServer.startServer(new InMemoryStatusStorage(), statusPort);
    if (mediaDriver == null)
        mediaDriver = MediaDriver.launchEmbedded();
    log.info("Started media driver with aeron directory " + mediaDriver.aeronDirectoryName());
    // cache a reference to the first listener.
    // The reason we do this is to share an updater and listener across *all* subscribers
    // This will create a shared pool of subscribers all updating the same "server".
    // This will simulate a shared pool but allow an accumulative effect of anything
    // like averaging we try.
    NDArrayCallback parameterServerListener = null;
    ParameterServerListener cast = null;
    for (int i = 0; i < numWorkers; i++) {
        subscriber[i] = new ParameterServerSubscriber(mediaDriver);
        // ensure reuse of aeron wherever possible
        if (aeron == null)
            aeron = Aeron.connect(getContext(mediaDriver));
        subscriber[i].setAeron(aeron);
        List<String> multiArgs = new ArrayList<>(Arrays.asList(args));
        if (multiArgs.contains("-id")) {
            int streamIdIdx = multiArgs.indexOf("-id") + 1;
            int streamId = Integer.parseInt(multiArgs.get(streamIdIdx)) + i;
            multiArgs.set(streamIdIdx, String.valueOf(streamId));
        } else if (multiArgs.contains("--streamId")) {
            int streamIdIdx = multiArgs.indexOf("--streamId") + 1;
            int streamId = Integer.parseInt(multiArgs.get(streamIdIdx)) + i;
            multiArgs.set(streamIdIdx, String.valueOf(streamId));
        }
        if (i == 0) {
            subscriber[i].run(multiArgs.toArray(new String[args.length]));
            parameterServerListener = subscriber[i].getCallback();
            cast = subscriber[i].getParameterServerListener();
        } else {
            // note that we set both the callback AND the listener here
            subscriber[i].setCallback(parameterServerListener);
            subscriber[i].setParameterServerListener(cast);
            // now run the callback initialized with this callback instead
            // in the run method it will use this reference instead of creating it
            // itself
            subscriber[i].run(multiArgs.toArray(new String[args.length]));
        }
    }
}
Also used : InMemoryStatusStorage(org.nd4j.parameterserver.status.play.InMemoryStatusStorage) NDArrayCallback(org.nd4j.aeron.ipc.NDArrayCallback) ParameterServerSubscriber(org.nd4j.parameterserver.ParameterServerSubscriber) ArrayList(java.util.ArrayList) ParameterServerListener(org.nd4j.parameterserver.ParameterServerListener)

Example 3 with ParameterServerListener

use of org.nd4j.parameterserver.ParameterServerListener in project nd4j by deeplearning4j.

the class ParameterServerClientPartialTest method testServer.

@Test
public void testServer() throws Exception {
    ParameterServerClient client = ParameterServerClient.builder().aeron(aeron).ndarrayRetrieveUrl(masterNode.getResponder().connectionUrl()).ndarraySendUrl(slaveNode.getSubscriber().connectionUrl()).subscriberHost("localhost").subscriberPort(40325).subscriberStream(12).build();
    assertEquals("localhost:40325:12", client.connectionUrl());
    // flow 1:
    /**
     * Client (40125:12): sends array to listener on slave(40126:10)
     * which publishes to master (40123:11)
     * which adds the array for parameter averaging.
     * In this case totalN should be 1.
     */
    client.pushNDArrayMessage(NDArrayMessage.of(Nd4j.ones(2), new int[] { 0 }, 0));
    log.info("Pushed ndarray");
    Thread.sleep(30000);
    ParameterServerListener listener = (ParameterServerListener) masterNode.getCallback();
    assertEquals(1, listener.getUpdater().numUpdates());
    INDArray assertion = Nd4j.create(new int[] { 2, 2 });
    assertion.getColumn(0).addi(1.0);
    assertEquals(assertion, listener.getUpdater().ndArrayHolder().get());
    INDArray arr = client.getArray();
    assertEquals(assertion, arr);
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) ParameterServerListener(org.nd4j.parameterserver.ParameterServerListener) Test(org.junit.Test)

Aggregations

ParameterServerListener (org.nd4j.parameterserver.ParameterServerListener)3 Test (org.junit.Test)2 INDArray (org.nd4j.linalg.api.ndarray.INDArray)2 ArrayList (java.util.ArrayList)1 NDArrayCallback (org.nd4j.aeron.ipc.NDArrayCallback)1 ParameterServerSubscriber (org.nd4j.parameterserver.ParameterServerSubscriber)1 InMemoryStatusStorage (org.nd4j.parameterserver.status.play.InMemoryStatusStorage)1