Search in sources :

Example 1 with NDArrayCallback

use of org.nd4j.aeron.ipc.NDArrayCallback in project nd4j by deeplearning4j.

the class ParameterServerNode method runMain.

/**
 * Run this node with the given args
 * These args are the same ones
 * that a {@link ParameterServerSubscriber} takes
 * @param args the arguments for the {@link ParameterServerSubscriber}
 */
public void runMain(String[] args) {
    server = StatusServer.startServer(new InMemoryStatusStorage(), statusPort);
    if (mediaDriver == null)
        mediaDriver = MediaDriver.launchEmbedded();
    log.info("Started media driver with aeron directory " + mediaDriver.aeronDirectoryName());
    // cache a reference to the first listener.
    // The reason we do this is to share an updater and listener across *all* subscribers
    // This will create a shared pool of subscribers all updating the same "server".
    // This will simulate a shared pool but allow an accumulative effect of anything
    // like averaging we try.
    NDArrayCallback parameterServerListener = null;
    ParameterServerListener cast = null;
    for (int i = 0; i < numWorkers; i++) {
        subscriber[i] = new ParameterServerSubscriber(mediaDriver);
        // ensure reuse of aeron wherever possible
        if (aeron == null)
            aeron = Aeron.connect(getContext(mediaDriver));
        subscriber[i].setAeron(aeron);
        List<String> multiArgs = new ArrayList<>(Arrays.asList(args));
        if (multiArgs.contains("-id")) {
            int streamIdIdx = multiArgs.indexOf("-id") + 1;
            int streamId = Integer.parseInt(multiArgs.get(streamIdIdx)) + i;
            multiArgs.set(streamIdIdx, String.valueOf(streamId));
        } else if (multiArgs.contains("--streamId")) {
            int streamIdIdx = multiArgs.indexOf("--streamId") + 1;
            int streamId = Integer.parseInt(multiArgs.get(streamIdIdx)) + i;
            multiArgs.set(streamIdIdx, String.valueOf(streamId));
        }
        if (i == 0) {
            subscriber[i].run(multiArgs.toArray(new String[args.length]));
            parameterServerListener = subscriber[i].getCallback();
            cast = subscriber[i].getParameterServerListener();
        } else {
            // note that we set both the callback AND the listener here
            subscriber[i].setCallback(parameterServerListener);
            subscriber[i].setParameterServerListener(cast);
            // now run the callback initialized with this callback instead
            // in the run method it will use this reference instead of creating it
            // itself
            subscriber[i].run(multiArgs.toArray(new String[args.length]));
        }
    }
}
Also used : InMemoryStatusStorage(org.nd4j.parameterserver.status.play.InMemoryStatusStorage) NDArrayCallback(org.nd4j.aeron.ipc.NDArrayCallback) ParameterServerSubscriber(org.nd4j.parameterserver.ParameterServerSubscriber) ArrayList(java.util.ArrayList) ParameterServerListener(org.nd4j.parameterserver.ParameterServerListener)

Aggregations

ArrayList (java.util.ArrayList)1 NDArrayCallback (org.nd4j.aeron.ipc.NDArrayCallback)1 ParameterServerListener (org.nd4j.parameterserver.ParameterServerListener)1 ParameterServerSubscriber (org.nd4j.parameterserver.ParameterServerSubscriber)1 InMemoryStatusStorage (org.nd4j.parameterserver.status.play.InMemoryStatusStorage)1