use of org.nd4j.parameterserver.model.SubscriberState in project nd4j by deeplearning4j.
the class ParameterServerClient method isReadyForNext.
/**
* Returns true if the client is
* ready for a next array or not
* @return true if the client is
* ready for the next array or not,false otherwise
*/
public boolean isReadyForNext() {
if (objectMapper == null)
objectMapper = new ObjectMapper();
try {
int masterStream = Integer.parseInt(ndarraySendUrl.split(":")[2]);
SubscriberState subscriberState = objectMapper.readValue(Unirest.get(String.format("http://%s:%d/state/%d", masterStatusHost, masterStatusPort, masterStream)).asJson().getBody().toString(), SubscriberState.class);
return subscriberState.isReady();
} catch (Exception e) {
e.printStackTrace();
}
return false;
}
use of org.nd4j.parameterserver.model.SubscriberState in project nd4j by deeplearning4j.
the class StatusServer method startServer.
/**
* Start a server based on the given subscriber.
* Note that for the port to start the server on, you should
* set the statusServerPortField on the subscriber
* either manually or via command line. The
* server defaults to port 9000.
*
* The end points are:
* /opType: returns the opType information (master/slave)
* /started: if it's a master node, it returns master:started/stopped and responder:started/stopped
* /connectioninfo: See the SlaveConnectionInfo and MasterConnectionInfo classes for fields.
* /ids: the list of ids for all of the subscribers
* @param statusStorage the subscriber to base
* the status server on
* @return the started server
*/
public static Server startServer(StatusStorage statusStorage, int statusServerPort) {
log.info("Starting server on port " + statusServerPort);
RoutingDsl dsl = new RoutingDsl();
dsl.GET("/ids/").routeTo(new F.Function0<Result>() {
@Override
public Result apply() throws Throwable {
List<Integer> ids = statusStorage.ids();
return ok(toJson(ids));
}
});
dsl.GET("/state/:id").routeTo(new F.Function<String, Result>() {
@Override
public Result apply(String id) throws Throwable {
return ok(toJson(statusStorage.getState(Integer.parseInt(id))));
}
});
dsl.GET("/opType/:id").routeTo(new F.Function<String, Result>() {
@Override
public Result apply(String id) throws Throwable {
return ok(toJson(ServerTypeJson.builder().type(statusStorage.getState(Integer.parseInt(id)).serverType())));
}
});
dsl.GET("/started/:id").routeTo(new F.Function<String, Result>() {
@Override
public Result apply(String id) throws Throwable {
return statusStorage.getState(Integer.parseInt(id)).isMaster() ? ok(toJson(MasterStatus.builder().master(statusStorage.getState(Integer.parseInt(id)).getServerState()).responder(statusStorage.getState(Integer.parseInt(id) + 1).getServerState()).responderN(statusStorage.getState(Integer.parseInt(id)).getTotalUpdates()).build())) : ok(toJson(SlaveStatus.builder().slave(statusStorage.getState(Integer.parseInt(id)).serverType()).build()));
}
});
dsl.GET("/connectioninfo/:id").routeTo(new F.Function<String, Result>() {
@Override
public Result apply(String id) throws Throwable {
return ok(toJson(statusStorage.getState(Integer.parseInt(id)).getConnectionInfo()));
}
});
dsl.POST("/updatestatus/:id").routeTo(new F.Function<String, Result>() {
@Override
public Result apply(String id) throws Throwable {
SubscriberState subscriberState = Json.fromJson(request().body().asJson(), SubscriberState.class);
statusStorage.updateState(subscriberState);
return ok(toJson(subscriberState));
}
});
Server server = Server.forRouter(dsl.build(), Mode.PROD, statusServerPort);
return server;
}
use of org.nd4j.parameterserver.model.SubscriberState in project nd4j by deeplearning4j.
the class StorageTests method testMapStorage.
@Test
public void testMapStorage() throws Exception {
StatusStorage mapDb = new MapDbStatusStorage();
assertEquals(SubscriberState.empty(), mapDb.getState(-1));
SubscriberState noEmpty = SubscriberState.builder().isMaster(true).serverState("master").streamId(1).build();
mapDb.updateState(noEmpty);
assertEquals(noEmpty, mapDb.getState(1));
Thread.sleep(10000);
assertTrue(mapDb.numStates() == 0);
}
use of org.nd4j.parameterserver.model.SubscriberState in project nd4j by deeplearning4j.
the class StorageTests method testStorage.
@Test
public void testStorage() throws Exception {
StatusStorage statusStorage = new InMemoryStatusStorage();
assertEquals(SubscriberState.empty(), statusStorage.getState(-1));
SubscriberState noEmpty = SubscriberState.builder().isMaster(true).serverState("master").streamId(1).build();
statusStorage.updateState(noEmpty);
assertEquals(noEmpty, statusStorage.getState(1));
Thread.sleep(10000);
assertTrue(statusStorage.numStates() == 0);
}
use of org.nd4j.parameterserver.model.SubscriberState in project nd4j by deeplearning4j.
the class ParameterServerSubscriber method run.
/**
* @param args
*/
public void run(String[] args) {
JCommander jcmdr = new JCommander(this);
try {
jcmdr.parse(args);
} catch (ParameterException e) {
e.printStackTrace();
// User provides invalid input -> print the usage info
jcmdr.usage();
try {
Thread.sleep(500);
} catch (Exception e2) {
}
System.exit(1);
}
// ensure that the update opType is configured from the command line args
updateType = UpdateType.valueOf(updateTypeString.toUpperCase());
if (publishMasterUrl == null && !master)
throw new IllegalStateException("Please specify a master url or set master to true");
// for a remote one
if (mediaDriver == null && mediaDriverDirectoryName == null) {
// length of array * sizeof(float)
int ipcLength = ArrayUtil.prod(Ints.toArray(shape)) * 4;
// must be a power of 2
ipcLength *= 2;
// padding for NDArrayMessage
ipcLength += 64;
// Length in bytes for the SO_RCVBUF, 0 means use OS default. This needs to be larger than Receiver Window.
System.setProperty("aeron.socket.so_rcvbuf", String.valueOf(ipcLength));
final MediaDriver.Context mediaDriverCtx = new MediaDriver.Context().threadingMode(ThreadingMode.DEDICATED).dirsDeleteOnStart(deleteDirectoryOnStart).termBufferSparseFile(false).ipcTermBufferLength(ipcLength).publicationTermBufferLength(ipcLength).maxTermBufferLength(ipcLength).conductorIdleStrategy(new BusySpinIdleStrategy()).receiverIdleStrategy(new BusySpinIdleStrategy()).senderIdleStrategy(new BusySpinIdleStrategy());
mediaDriver = MediaDriver.launchEmbedded(mediaDriverCtx);
// set the variable since we are using a media driver directly
mediaDriverDirectoryName = mediaDriver.aeronDirectoryName();
log.info("Using media driver directory " + mediaDriver.aeronDirectoryName());
}
if (aeron == null)
this.aeron = Aeron.connect(getContext());
if (master) {
if (this.callback == null) {
ParameterServerUpdater updater = null;
// instantiate with shape instead of just length
switch(updateType) {
case HOGWILD:
break;
case SYNC:
updater = new SynchronousParameterUpdater(new InMemoryUpdateStorage(), new InMemoryNDArrayHolder(Ints.toArray(shape)), updatesPerEpoch);
break;
case SOFTSYNC:
updater = new SoftSyncParameterUpdater();
break;
case TIME_DELAYED:
break;
case CUSTOM:
try {
updater = (ParameterServerUpdater) Class.forName(System.getProperty(CUSTOM_UPDATE_TYPE)).newInstance();
} catch (Exception e) {
throw new RuntimeException(e);
}
break;
default:
throw new IllegalStateException("Illegal opType of updater");
}
callback = new ParameterServerListener(Ints.toArray(shape), updater);
parameterServerListener = (ParameterServerListener) callback;
}
// start an extra daemon for responding to get queries
ParameterServerListener cast = (ParameterServerListener) callback;
responder = AeronNDArrayResponder.startSubscriber(aeron, host, port + 1, cast.getUpdater().ndArrayHolder(), streamId + 1);
log.info("Started responder on master node " + responder.connectionUrl());
} else {
String[] publishMasterUrlArr = publishMasterUrl.split(":");
if (publishMasterUrlArr == null || publishMasterUrlArr.length < 2)
throw new IllegalStateException("Please specify publish master url as host:port");
callback = new PublishingListener(String.format("aeron:udp?endpoint=%s:%s", publishMasterUrlArr[0], publishMasterUrlArr[1]), Integer.parseInt(publishMasterUrlArr[2]), getContext());
}
log.info("Starting subscriber on " + host + ":" + port + " and stream " + streamId);
AtomicBoolean running = new AtomicBoolean(true);
// start a node
subscriber = AeronNDArraySubscriber.startSubscriber(aeron, host, port, callback, streamId, running);
while (!subscriber.launched()) {
LockSupport.parkNanos(100000);
}
// Only schedule this if a remote server is available.
if (CheckSocket.remotePortTaken(statusServerHost, statusServerPort, 10000)) {
scheduledExecutorService = Executors.newScheduledThreadPool(1);
final AtomicInteger failCount = new AtomicInteger(0);
scheduledExecutorService.scheduleAtFixedRate(() -> {
try {
//
if (failCount.get() >= 3)
return;
SubscriberState subscriberState = asState();
JSONObject jsonObject = new JSONObject(objectMapper.writeValueAsString(subscriberState));
String url = String.format("http://%s:%d/updatestatus/%d", statusServerHost, statusServerPort, streamId);
HttpResponse<String> entity = Unirest.post(url).header("Content-Type", "application/json").body(jsonObject).asString();
} catch (Exception e) {
failCount.incrementAndGet();
if (failCount.get() >= 3) {
log.warn("Failed to send update, shutting down likely?", e);
}
}
}, 0, heartbeatMs, TimeUnit.MILLISECONDS);
} else {
log.info("No status server found. Will not send heartbeats. Specified host was " + statusServerHost + " and port was " + statusServerPort);
}
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
close();
}));
// set the server for the status of the master and slave nodes
}
Aggregations