use of org.nd4j.parameterserver.updater.SynchronousParameterUpdater in project nd4j by deeplearning4j.
the class ParameterServerSubscriber method run.
/**
* @param args
*/
public void run(String[] args) {
JCommander jcmdr = new JCommander(this);
try {
jcmdr.parse(args);
} catch (ParameterException e) {
e.printStackTrace();
// User provides invalid input -> print the usage info
jcmdr.usage();
try {
Thread.sleep(500);
} catch (Exception e2) {
}
System.exit(1);
}
// ensure that the update opType is configured from the command line args
updateType = UpdateType.valueOf(updateTypeString.toUpperCase());
if (publishMasterUrl == null && !master)
throw new IllegalStateException("Please specify a master url or set master to true");
// for a remote one
if (mediaDriver == null && mediaDriverDirectoryName == null) {
// length of array * sizeof(float)
int ipcLength = ArrayUtil.prod(Ints.toArray(shape)) * 4;
// must be a power of 2
ipcLength *= 2;
// padding for NDArrayMessage
ipcLength += 64;
// Length in bytes for the SO_RCVBUF, 0 means use OS default. This needs to be larger than Receiver Window.
System.setProperty("aeron.socket.so_rcvbuf", String.valueOf(ipcLength));
final MediaDriver.Context mediaDriverCtx = new MediaDriver.Context().threadingMode(ThreadingMode.DEDICATED).dirsDeleteOnStart(deleteDirectoryOnStart).termBufferSparseFile(false).ipcTermBufferLength(ipcLength).publicationTermBufferLength(ipcLength).maxTermBufferLength(ipcLength).conductorIdleStrategy(new BusySpinIdleStrategy()).receiverIdleStrategy(new BusySpinIdleStrategy()).senderIdleStrategy(new BusySpinIdleStrategy());
mediaDriver = MediaDriver.launchEmbedded(mediaDriverCtx);
// set the variable since we are using a media driver directly
mediaDriverDirectoryName = mediaDriver.aeronDirectoryName();
log.info("Using media driver directory " + mediaDriver.aeronDirectoryName());
}
if (aeron == null)
this.aeron = Aeron.connect(getContext());
if (master) {
if (this.callback == null) {
ParameterServerUpdater updater = null;
// instantiate with shape instead of just length
switch(updateType) {
case HOGWILD:
break;
case SYNC:
updater = new SynchronousParameterUpdater(new InMemoryUpdateStorage(), new InMemoryNDArrayHolder(Ints.toArray(shape)), updatesPerEpoch);
break;
case SOFTSYNC:
updater = new SoftSyncParameterUpdater();
break;
case TIME_DELAYED:
break;
case CUSTOM:
try {
updater = (ParameterServerUpdater) Class.forName(System.getProperty(CUSTOM_UPDATE_TYPE)).newInstance();
} catch (Exception e) {
throw new RuntimeException(e);
}
break;
default:
throw new IllegalStateException("Illegal opType of updater");
}
callback = new ParameterServerListener(Ints.toArray(shape), updater);
parameterServerListener = (ParameterServerListener) callback;
}
// start an extra daemon for responding to get queries
ParameterServerListener cast = (ParameterServerListener) callback;
responder = AeronNDArrayResponder.startSubscriber(aeron, host, port + 1, cast.getUpdater().ndArrayHolder(), streamId + 1);
log.info("Started responder on master node " + responder.connectionUrl());
} else {
String[] publishMasterUrlArr = publishMasterUrl.split(":");
if (publishMasterUrlArr == null || publishMasterUrlArr.length < 2)
throw new IllegalStateException("Please specify publish master url as host:port");
callback = new PublishingListener(String.format("aeron:udp?endpoint=%s:%s", publishMasterUrlArr[0], publishMasterUrlArr[1]), Integer.parseInt(publishMasterUrlArr[2]), getContext());
}
log.info("Starting subscriber on " + host + ":" + port + " and stream " + streamId);
AtomicBoolean running = new AtomicBoolean(true);
// start a node
subscriber = AeronNDArraySubscriber.startSubscriber(aeron, host, port, callback, streamId, running);
while (!subscriber.launched()) {
LockSupport.parkNanos(100000);
}
// Only schedule this if a remote server is available.
if (CheckSocket.remotePortTaken(statusServerHost, statusServerPort, 10000)) {
scheduledExecutorService = Executors.newScheduledThreadPool(1);
final AtomicInteger failCount = new AtomicInteger(0);
scheduledExecutorService.scheduleAtFixedRate(() -> {
try {
//
if (failCount.get() >= 3)
return;
SubscriberState subscriberState = asState();
JSONObject jsonObject = new JSONObject(objectMapper.writeValueAsString(subscriberState));
String url = String.format("http://%s:%d/updatestatus/%d", statusServerHost, statusServerPort, streamId);
HttpResponse<String> entity = Unirest.post(url).header("Content-Type", "application/json").body(jsonObject).asString();
} catch (Exception e) {
failCount.incrementAndGet();
if (failCount.get() >= 3) {
log.warn("Failed to send update, shutting down likely?", e);
}
}
}, 0, heartbeatMs, TimeUnit.MILLISECONDS);
} else {
log.info("No status server found. Will not send heartbeats. Specified host was " + statusServerHost + " and port was " + statusServerPort);
}
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
close();
}));
// set the server for the status of the master and slave nodes
}
Aggregations