Search in sources :

Example 1 with SoftSyncParameterUpdater

use of org.nd4j.parameterserver.updater.SoftSyncParameterUpdater in project nd4j by deeplearning4j.

the class ParameterServerSubscriber method run.

/**
 * @param args
 */
public void run(String[] args) {
    JCommander jcmdr = new JCommander(this);
    try {
        jcmdr.parse(args);
    } catch (ParameterException e) {
        e.printStackTrace();
        // User provides invalid input -> print the usage info
        jcmdr.usage();
        try {
            Thread.sleep(500);
        } catch (Exception e2) {
        }
        System.exit(1);
    }
    // ensure that the update opType is configured from the command line args
    updateType = UpdateType.valueOf(updateTypeString.toUpperCase());
    if (publishMasterUrl == null && !master)
        throw new IllegalStateException("Please specify a master url or set master to true");
    // for a remote one
    if (mediaDriver == null && mediaDriverDirectoryName == null) {
        // length of array * sizeof(float)
        int ipcLength = ArrayUtil.prod(Ints.toArray(shape)) * 4;
        // must be a power of 2
        ipcLength *= 2;
        // padding for NDArrayMessage
        ipcLength += 64;
        // Length in bytes for the SO_RCVBUF, 0 means use OS default. This needs to be larger than Receiver Window.
        System.setProperty("aeron.socket.so_rcvbuf", String.valueOf(ipcLength));
        final MediaDriver.Context mediaDriverCtx = new MediaDriver.Context().threadingMode(ThreadingMode.DEDICATED).dirsDeleteOnStart(deleteDirectoryOnStart).termBufferSparseFile(false).ipcTermBufferLength(ipcLength).publicationTermBufferLength(ipcLength).maxTermBufferLength(ipcLength).conductorIdleStrategy(new BusySpinIdleStrategy()).receiverIdleStrategy(new BusySpinIdleStrategy()).senderIdleStrategy(new BusySpinIdleStrategy());
        mediaDriver = MediaDriver.launchEmbedded(mediaDriverCtx);
        // set the variable since we are using a media driver directly
        mediaDriverDirectoryName = mediaDriver.aeronDirectoryName();
        log.info("Using media driver directory " + mediaDriver.aeronDirectoryName());
    }
    if (aeron == null)
        this.aeron = Aeron.connect(getContext());
    if (master) {
        if (this.callback == null) {
            ParameterServerUpdater updater = null;
            // instantiate with shape instead of just length
            switch(updateType) {
                case HOGWILD:
                    break;
                case SYNC:
                    updater = new SynchronousParameterUpdater(new InMemoryUpdateStorage(), new InMemoryNDArrayHolder(Ints.toArray(shape)), updatesPerEpoch);
                    break;
                case SOFTSYNC:
                    updater = new SoftSyncParameterUpdater();
                    break;
                case TIME_DELAYED:
                    break;
                case CUSTOM:
                    try {
                        updater = (ParameterServerUpdater) Class.forName(System.getProperty(CUSTOM_UPDATE_TYPE)).newInstance();
                    } catch (Exception e) {
                        throw new RuntimeException(e);
                    }
                    break;
                default:
                    throw new IllegalStateException("Illegal opType of updater");
            }
            callback = new ParameterServerListener(Ints.toArray(shape), updater);
            parameterServerListener = (ParameterServerListener) callback;
        }
        // start an extra daemon for responding to get queries
        ParameterServerListener cast = (ParameterServerListener) callback;
        responder = AeronNDArrayResponder.startSubscriber(aeron, host, port + 1, cast.getUpdater().ndArrayHolder(), streamId + 1);
        log.info("Started responder on master node " + responder.connectionUrl());
    } else {
        String[] publishMasterUrlArr = publishMasterUrl.split(":");
        if (publishMasterUrlArr == null || publishMasterUrlArr.length < 2)
            throw new IllegalStateException("Please specify publish master url as host:port");
        callback = new PublishingListener(String.format("aeron:udp?endpoint=%s:%s", publishMasterUrlArr[0], publishMasterUrlArr[1]), Integer.parseInt(publishMasterUrlArr[2]), getContext());
    }
    log.info("Starting subscriber on " + host + ":" + port + " and stream " + streamId);
    AtomicBoolean running = new AtomicBoolean(true);
    // start a node
    subscriber = AeronNDArraySubscriber.startSubscriber(aeron, host, port, callback, streamId, running);
    while (!subscriber.launched()) {
        LockSupport.parkNanos(100000);
    }
    // Only schedule this if a remote server is available.
    if (CheckSocket.remotePortTaken(statusServerHost, statusServerPort, 10000)) {
        scheduledExecutorService = Executors.newScheduledThreadPool(1);
        final AtomicInteger failCount = new AtomicInteger(0);
        scheduledExecutorService.scheduleAtFixedRate(() -> {
            try {
                // 
                if (failCount.get() >= 3)
                    return;
                SubscriberState subscriberState = asState();
                JSONObject jsonObject = new JSONObject(objectMapper.writeValueAsString(subscriberState));
                String url = String.format("http://%s:%d/updatestatus/%d", statusServerHost, statusServerPort, streamId);
                HttpResponse<String> entity = Unirest.post(url).header("Content-Type", "application/json").body(jsonObject).asString();
            } catch (Exception e) {
                failCount.incrementAndGet();
                if (failCount.get() >= 3) {
                    log.warn("Failed to send update, shutting down likely?", e);
                }
            }
        }, 0, heartbeatMs, TimeUnit.MILLISECONDS);
    } else {
        log.info("No status server found. Will not send heartbeats. Specified host was " + statusServerHost + " and port was " + statusServerPort);
    }
    Runtime.getRuntime().addShutdownHook(new Thread(() -> {
        close();
    }));
// set the server for the status of the master and slave nodes
}
Also used : SoftSyncParameterUpdater(org.nd4j.parameterserver.updater.SoftSyncParameterUpdater) ParameterServerUpdater(org.nd4j.parameterserver.updater.ParameterServerUpdater) ParameterException(com.beust.jcommander.ParameterException) SubscriberState(org.nd4j.parameterserver.model.SubscriberState) AtomicBoolean(java.util.concurrent.atomic.AtomicBoolean) MediaDriver(io.aeron.driver.MediaDriver) JSONObject(org.json.JSONObject) InMemoryNDArrayHolder(org.nd4j.aeron.ndarrayholder.InMemoryNDArrayHolder) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) JCommander(com.beust.jcommander.JCommander) BusySpinIdleStrategy(org.agrona.concurrent.BusySpinIdleStrategy) ParameterException(com.beust.jcommander.ParameterException) SynchronousParameterUpdater(org.nd4j.parameterserver.updater.SynchronousParameterUpdater) InMemoryUpdateStorage(org.nd4j.parameterserver.updater.storage.InMemoryUpdateStorage)

Aggregations

JCommander (com.beust.jcommander.JCommander)1 ParameterException (com.beust.jcommander.ParameterException)1 MediaDriver (io.aeron.driver.MediaDriver)1 AtomicBoolean (java.util.concurrent.atomic.AtomicBoolean)1 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)1 BusySpinIdleStrategy (org.agrona.concurrent.BusySpinIdleStrategy)1 JSONObject (org.json.JSONObject)1 InMemoryNDArrayHolder (org.nd4j.aeron.ndarrayholder.InMemoryNDArrayHolder)1 SubscriberState (org.nd4j.parameterserver.model.SubscriberState)1 ParameterServerUpdater (org.nd4j.parameterserver.updater.ParameterServerUpdater)1 SoftSyncParameterUpdater (org.nd4j.parameterserver.updater.SoftSyncParameterUpdater)1 SynchronousParameterUpdater (org.nd4j.parameterserver.updater.SynchronousParameterUpdater)1 InMemoryUpdateStorage (org.nd4j.parameterserver.updater.storage.InMemoryUpdateStorage)1