Search in sources :

Example 1 with SubscriberState

use of org.nd4j.parameterserver.model.SubscriberState in project nd4j by deeplearning4j.

the class ParameterServerClient method isReadyForNext.

/**
 * Returns true if the client is
 * ready for a next array or not
 * @return true if the client is
 * ready for the next array or not,false otherwise
 */
public boolean isReadyForNext() {
    if (objectMapper == null)
        objectMapper = new ObjectMapper();
    try {
        int masterStream = Integer.parseInt(ndarraySendUrl.split(":")[2]);
        SubscriberState subscriberState = objectMapper.readValue(Unirest.get(String.format("http://%s:%d/state/%d", masterStatusHost, masterStatusPort, masterStream)).asJson().getBody().toString(), SubscriberState.class);
        return subscriberState.isReady();
    } catch (Exception e) {
        e.printStackTrace();
    }
    return false;
}
Also used : ObjectMapper(org.nd4j.shade.jackson.databind.ObjectMapper) SubscriberState(org.nd4j.parameterserver.model.SubscriberState)

Example 2 with SubscriberState

use of org.nd4j.parameterserver.model.SubscriberState in project nd4j by deeplearning4j.

the class StatusServer method startServer.

/**
 * Start a server based on the given subscriber.
 * Note that for the port to start the server on, you should
 * set the statusServerPortField on the subscriber
 * either manually or via command line. The
 * server defaults to port 9000.
 *
 * The end points are:
 * /opType: returns the opType information (master/slave)
 * /started: if it's a master node, it returns master:started/stopped and responder:started/stopped
 * /connectioninfo: See the SlaveConnectionInfo and MasterConnectionInfo classes for fields.
 * /ids: the list of ids for all of the subscribers
 * @param statusStorage the subscriber to base
 *                   the status server on
 * @return the started server
 */
public static Server startServer(StatusStorage statusStorage, int statusServerPort) {
    log.info("Starting server on port " + statusServerPort);
    RoutingDsl dsl = new RoutingDsl();
    dsl.GET("/ids/").routeTo(new F.Function0<Result>() {

        @Override
        public Result apply() throws Throwable {
            List<Integer> ids = statusStorage.ids();
            return ok(toJson(ids));
        }
    });
    dsl.GET("/state/:id").routeTo(new F.Function<String, Result>() {

        @Override
        public Result apply(String id) throws Throwable {
            return ok(toJson(statusStorage.getState(Integer.parseInt(id))));
        }
    });
    dsl.GET("/opType/:id").routeTo(new F.Function<String, Result>() {

        @Override
        public Result apply(String id) throws Throwable {
            return ok(toJson(ServerTypeJson.builder().type(statusStorage.getState(Integer.parseInt(id)).serverType())));
        }
    });
    dsl.GET("/started/:id").routeTo(new F.Function<String, Result>() {

        @Override
        public Result apply(String id) throws Throwable {
            return statusStorage.getState(Integer.parseInt(id)).isMaster() ? ok(toJson(MasterStatus.builder().master(statusStorage.getState(Integer.parseInt(id)).getServerState()).responder(statusStorage.getState(Integer.parseInt(id) + 1).getServerState()).responderN(statusStorage.getState(Integer.parseInt(id)).getTotalUpdates()).build())) : ok(toJson(SlaveStatus.builder().slave(statusStorage.getState(Integer.parseInt(id)).serverType()).build()));
        }
    });
    dsl.GET("/connectioninfo/:id").routeTo(new F.Function<String, Result>() {

        @Override
        public Result apply(String id) throws Throwable {
            return ok(toJson(statusStorage.getState(Integer.parseInt(id)).getConnectionInfo()));
        }
    });
    dsl.POST("/updatestatus/:id").routeTo(new F.Function<String, Result>() {

        @Override
        public Result apply(String id) throws Throwable {
            SubscriberState subscriberState = Json.fromJson(request().body().asJson(), SubscriberState.class);
            statusStorage.updateState(subscriberState);
            return ok(toJson(subscriberState));
        }
    });
    Server server = Server.forRouter(dsl.build(), Mode.PROD, statusServerPort);
    return server;
}
Also used : Server(play.server.Server) F(play.libs.F) List(java.util.List) RoutingDsl(play.routing.RoutingDsl) Result(play.mvc.Result) SubscriberState(org.nd4j.parameterserver.model.SubscriberState)

Example 3 with SubscriberState

use of org.nd4j.parameterserver.model.SubscriberState in project nd4j by deeplearning4j.

the class StorageTests method testMapStorage.

@Test
public void testMapStorage() throws Exception {
    StatusStorage mapDb = new MapDbStatusStorage();
    assertEquals(SubscriberState.empty(), mapDb.getState(-1));
    SubscriberState noEmpty = SubscriberState.builder().isMaster(true).serverState("master").streamId(1).build();
    mapDb.updateState(noEmpty);
    assertEquals(noEmpty, mapDb.getState(1));
    Thread.sleep(10000);
    assertTrue(mapDb.numStates() == 0);
}
Also used : SubscriberState(org.nd4j.parameterserver.model.SubscriberState) Test(org.junit.Test)

Example 4 with SubscriberState

use of org.nd4j.parameterserver.model.SubscriberState in project nd4j by deeplearning4j.

the class StorageTests method testStorage.

@Test
public void testStorage() throws Exception {
    StatusStorage statusStorage = new InMemoryStatusStorage();
    assertEquals(SubscriberState.empty(), statusStorage.getState(-1));
    SubscriberState noEmpty = SubscriberState.builder().isMaster(true).serverState("master").streamId(1).build();
    statusStorage.updateState(noEmpty);
    assertEquals(noEmpty, statusStorage.getState(1));
    Thread.sleep(10000);
    assertTrue(statusStorage.numStates() == 0);
}
Also used : SubscriberState(org.nd4j.parameterserver.model.SubscriberState) Test(org.junit.Test)

Example 5 with SubscriberState

use of org.nd4j.parameterserver.model.SubscriberState in project nd4j by deeplearning4j.

the class ParameterServerSubscriber method run.

/**
 * @param args
 */
public void run(String[] args) {
    JCommander jcmdr = new JCommander(this);
    try {
        jcmdr.parse(args);
    } catch (ParameterException e) {
        e.printStackTrace();
        // User provides invalid input -> print the usage info
        jcmdr.usage();
        try {
            Thread.sleep(500);
        } catch (Exception e2) {
        }
        System.exit(1);
    }
    // ensure that the update opType is configured from the command line args
    updateType = UpdateType.valueOf(updateTypeString.toUpperCase());
    if (publishMasterUrl == null && !master)
        throw new IllegalStateException("Please specify a master url or set master to true");
    // for a remote one
    if (mediaDriver == null && mediaDriverDirectoryName == null) {
        // length of array * sizeof(float)
        int ipcLength = ArrayUtil.prod(Ints.toArray(shape)) * 4;
        // must be a power of 2
        ipcLength *= 2;
        // padding for NDArrayMessage
        ipcLength += 64;
        // Length in bytes for the SO_RCVBUF, 0 means use OS default. This needs to be larger than Receiver Window.
        System.setProperty("aeron.socket.so_rcvbuf", String.valueOf(ipcLength));
        final MediaDriver.Context mediaDriverCtx = new MediaDriver.Context().threadingMode(ThreadingMode.DEDICATED).dirsDeleteOnStart(deleteDirectoryOnStart).termBufferSparseFile(false).ipcTermBufferLength(ipcLength).publicationTermBufferLength(ipcLength).maxTermBufferLength(ipcLength).conductorIdleStrategy(new BusySpinIdleStrategy()).receiverIdleStrategy(new BusySpinIdleStrategy()).senderIdleStrategy(new BusySpinIdleStrategy());
        mediaDriver = MediaDriver.launchEmbedded(mediaDriverCtx);
        // set the variable since we are using a media driver directly
        mediaDriverDirectoryName = mediaDriver.aeronDirectoryName();
        log.info("Using media driver directory " + mediaDriver.aeronDirectoryName());
    }
    if (aeron == null)
        this.aeron = Aeron.connect(getContext());
    if (master) {
        if (this.callback == null) {
            ParameterServerUpdater updater = null;
            // instantiate with shape instead of just length
            switch(updateType) {
                case HOGWILD:
                    break;
                case SYNC:
                    updater = new SynchronousParameterUpdater(new InMemoryUpdateStorage(), new InMemoryNDArrayHolder(Ints.toArray(shape)), updatesPerEpoch);
                    break;
                case SOFTSYNC:
                    updater = new SoftSyncParameterUpdater();
                    break;
                case TIME_DELAYED:
                    break;
                case CUSTOM:
                    try {
                        updater = (ParameterServerUpdater) Class.forName(System.getProperty(CUSTOM_UPDATE_TYPE)).newInstance();
                    } catch (Exception e) {
                        throw new RuntimeException(e);
                    }
                    break;
                default:
                    throw new IllegalStateException("Illegal opType of updater");
            }
            callback = new ParameterServerListener(Ints.toArray(shape), updater);
            parameterServerListener = (ParameterServerListener) callback;
        }
        // start an extra daemon for responding to get queries
        ParameterServerListener cast = (ParameterServerListener) callback;
        responder = AeronNDArrayResponder.startSubscriber(aeron, host, port + 1, cast.getUpdater().ndArrayHolder(), streamId + 1);
        log.info("Started responder on master node " + responder.connectionUrl());
    } else {
        String[] publishMasterUrlArr = publishMasterUrl.split(":");
        if (publishMasterUrlArr == null || publishMasterUrlArr.length < 2)
            throw new IllegalStateException("Please specify publish master url as host:port");
        callback = new PublishingListener(String.format("aeron:udp?endpoint=%s:%s", publishMasterUrlArr[0], publishMasterUrlArr[1]), Integer.parseInt(publishMasterUrlArr[2]), getContext());
    }
    log.info("Starting subscriber on " + host + ":" + port + " and stream " + streamId);
    AtomicBoolean running = new AtomicBoolean(true);
    // start a node
    subscriber = AeronNDArraySubscriber.startSubscriber(aeron, host, port, callback, streamId, running);
    while (!subscriber.launched()) {
        LockSupport.parkNanos(100000);
    }
    // Only schedule this if a remote server is available.
    if (CheckSocket.remotePortTaken(statusServerHost, statusServerPort, 10000)) {
        scheduledExecutorService = Executors.newScheduledThreadPool(1);
        final AtomicInteger failCount = new AtomicInteger(0);
        scheduledExecutorService.scheduleAtFixedRate(() -> {
            try {
                // 
                if (failCount.get() >= 3)
                    return;
                SubscriberState subscriberState = asState();
                JSONObject jsonObject = new JSONObject(objectMapper.writeValueAsString(subscriberState));
                String url = String.format("http://%s:%d/updatestatus/%d", statusServerHost, statusServerPort, streamId);
                HttpResponse<String> entity = Unirest.post(url).header("Content-Type", "application/json").body(jsonObject).asString();
            } catch (Exception e) {
                failCount.incrementAndGet();
                if (failCount.get() >= 3) {
                    log.warn("Failed to send update, shutting down likely?", e);
                }
            }
        }, 0, heartbeatMs, TimeUnit.MILLISECONDS);
    } else {
        log.info("No status server found. Will not send heartbeats. Specified host was " + statusServerHost + " and port was " + statusServerPort);
    }
    Runtime.getRuntime().addShutdownHook(new Thread(() -> {
        close();
    }));
// set the server for the status of the master and slave nodes
}
Also used : SoftSyncParameterUpdater(org.nd4j.parameterserver.updater.SoftSyncParameterUpdater) ParameterServerUpdater(org.nd4j.parameterserver.updater.ParameterServerUpdater) ParameterException(com.beust.jcommander.ParameterException) SubscriberState(org.nd4j.parameterserver.model.SubscriberState) AtomicBoolean(java.util.concurrent.atomic.AtomicBoolean) MediaDriver(io.aeron.driver.MediaDriver) JSONObject(org.json.JSONObject) InMemoryNDArrayHolder(org.nd4j.aeron.ndarrayholder.InMemoryNDArrayHolder) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) JCommander(com.beust.jcommander.JCommander) BusySpinIdleStrategy(org.agrona.concurrent.BusySpinIdleStrategy) ParameterException(com.beust.jcommander.ParameterException) SynchronousParameterUpdater(org.nd4j.parameterserver.updater.SynchronousParameterUpdater) InMemoryUpdateStorage(org.nd4j.parameterserver.updater.storage.InMemoryUpdateStorage)

Aggregations

SubscriberState (org.nd4j.parameterserver.model.SubscriberState)5 Test (org.junit.Test)2 JCommander (com.beust.jcommander.JCommander)1 ParameterException (com.beust.jcommander.ParameterException)1 MediaDriver (io.aeron.driver.MediaDriver)1 List (java.util.List)1 AtomicBoolean (java.util.concurrent.atomic.AtomicBoolean)1 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)1 BusySpinIdleStrategy (org.agrona.concurrent.BusySpinIdleStrategy)1 JSONObject (org.json.JSONObject)1 InMemoryNDArrayHolder (org.nd4j.aeron.ndarrayholder.InMemoryNDArrayHolder)1 ParameterServerUpdater (org.nd4j.parameterserver.updater.ParameterServerUpdater)1 SoftSyncParameterUpdater (org.nd4j.parameterserver.updater.SoftSyncParameterUpdater)1 SynchronousParameterUpdater (org.nd4j.parameterserver.updater.SynchronousParameterUpdater)1 InMemoryUpdateStorage (org.nd4j.parameterserver.updater.storage.InMemoryUpdateStorage)1 ObjectMapper (org.nd4j.shade.jackson.databind.ObjectMapper)1 F (play.libs.F)1 Result (play.mvc.Result)1 RoutingDsl (play.routing.RoutingDsl)1 Server (play.server.Server)1