use of org.deeplearning4j.api.storage.StorageMetaData in project deeplearning4j by deeplearning4j.
the class RemoteUIStatsStorageRouter method tryPost.
private boolean tryPost(ToPost toPost) throws IOException {
HttpURLConnection connection = getConnection();
String className;
byte[] asBytes;
StorageType type;
if (toPost.getMeta() != null) {
StorageMetaData smd = toPost.getMeta();
className = smd.getClass().getName();
asBytes = smd.encode();
type = StorageType.MetaData;
} else if (toPost.getStaticInfo() != null) {
Persistable p = toPost.getStaticInfo();
className = p.getClass().getName();
asBytes = p.encode();
type = StorageType.StaticInfo;
} else {
Persistable p = toPost.getUpdate();
className = p.getClass().getName();
asBytes = p.encode();
type = StorageType.Update;
}
String base64 = DatatypeConverter.printBase64Binary(asBytes);
Map<String, String> jsonObj = new LinkedHashMap<>();
jsonObj.put("type", type.name());
jsonObj.put("class", className);
jsonObj.put("data", base64);
String str;
try {
str = objectMapper.writeValueAsString(jsonObj);
} catch (Exception e) {
//Should never get an exception from simple Map<String,String>
throw new RuntimeException(e);
}
DataOutputStream dos = new DataOutputStream(connection.getOutputStream());
dos.writeBytes(str);
dos.flush();
dos.close();
try {
int responseCode = connection.getResponseCode();
if (responseCode != 200) {
BufferedReader in = new BufferedReader(new InputStreamReader(connection.getInputStream()));
String inputLine;
StringBuilder response = new StringBuilder();
while ((inputLine = in.readLine()) != null) {
response.append(inputLine);
}
in.close();
log.warn("Error posting to remote UI - received response code {}\tContent: {}", response, response.toString());
return false;
}
} catch (IOException e) {
String msg = e.getMessage();
if (msg.contains("403 for URL")) {
log.warn("Error posting to remote UI at {} (Response code: 403)." + " Remote listener support is not enabled? use UIServer.getInstance().enableRemoteListener()", url, e);
} else {
log.warn("Error posting to remote UI at {}", url, e);
}
return false;
}
return true;
}
use of org.deeplearning4j.api.storage.StorageMetaData in project deeplearning4j by deeplearning4j.
the class ParameterAveragingElementCombineFunction method call.
@Override
public ParameterAveragingAggregationTuple call(ParameterAveragingAggregationTuple v1, ParameterAveragingAggregationTuple v2) throws Exception {
if (v1 == null)
return v2;
else if (v2 == null)
return v1;
//Handle edge case of less data than executors: in this case, one (or both) of v1 and v2 might not have any contents...
if (v1.getParametersSum() == null)
return v2;
else if (v2.getParametersSum() == null)
return v1;
INDArray newParams = v1.getParametersSum().addi(v2.getParametersSum());
INDArray updaterStateSum;
if (v1.getUpdaterStateSum() == null) {
updaterStateSum = v2.getUpdaterStateSum();
} else {
updaterStateSum = v1.getUpdaterStateSum();
if (v2.getUpdaterStateSum() != null)
updaterStateSum.addi(v2.getUpdaterStateSum());
}
double scoreSum = v1.getScoreSum() + v2.getScoreSum();
int aggregationCount = v1.getAggregationsCount() + v2.getAggregationsCount();
SparkTrainingStats stats = v1.getSparkTrainingStats();
if (v2.getSparkTrainingStats() != null) {
if (stats == null)
stats = v2.getSparkTrainingStats();
else
stats.addOtherTrainingStats(v2.getSparkTrainingStats());
}
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
Collection<StorageMetaData> listenerMetaData = v1.getListenerMetaData();
if (listenerMetaData == null)
listenerMetaData = v2.getListenerMetaData();
else {
Collection<StorageMetaData> newMeta = v2.getListenerMetaData();
if (newMeta != null)
listenerMetaData.addAll(newMeta);
}
Collection<Persistable> listenerStaticInfo = v1.getListenerStaticInfo();
if (listenerStaticInfo == null)
listenerStaticInfo = v2.getListenerStaticInfo();
else {
Collection<Persistable> newStatic = v2.getListenerStaticInfo();
if (newStatic != null)
listenerStaticInfo.addAll(newStatic);
}
Collection<Persistable> listenerUpdates = v1.getListenerUpdates();
if (listenerUpdates == null)
listenerUpdates = v2.getListenerUpdates();
else {
Collection<Persistable> listenerUpdates2 = v2.getListenerUpdates();
if (listenerUpdates2 != null)
listenerUpdates.addAll(listenerUpdates2);
}
return new ParameterAveragingAggregationTuple(newParams, updaterStateSum, scoreSum, aggregationCount, stats, listenerMetaData, listenerStaticInfo, listenerUpdates);
}
use of org.deeplearning4j.api.storage.StorageMetaData in project deeplearning4j by deeplearning4j.
the class TestRemoteReceiver method testRemoteBasic.
@Test
@Ignore
public void testRemoteBasic() throws Exception {
List<Persistable> updates = new ArrayList<>();
List<Persistable> staticInfo = new ArrayList<>();
List<StorageMetaData> metaData = new ArrayList<>();
CollectionStatsStorageRouter collectionRouter = new CollectionStatsStorageRouter(metaData, staticInfo, updates);
UIServer s = UIServer.getInstance();
s.enableRemoteListener(collectionRouter, false);
RemoteUIStatsStorageRouter remoteRouter = new RemoteUIStatsStorageRouter("http://localhost:9000");
SbeStatsReport update1 = new SbeStatsReport();
update1.setDeviceCurrentBytes(new long[] { 1, 2 });
update1.reportIterationCount(10);
update1.reportIDs("sid", "tid", "wid", 123456);
update1.reportPerformance(10, 20, 30, 40, 50);
SbeStatsReport update2 = new SbeStatsReport();
update2.setDeviceCurrentBytes(new long[] { 3, 4 });
update2.reportIterationCount(20);
update2.reportIDs("sid2", "tid2", "wid2", 123456);
update2.reportPerformance(11, 21, 31, 40, 50);
StorageMetaData smd1 = new SbeStorageMetaData(123, "sid", "typeid", "wid", "initTypeClass", "updaterTypeClass");
StorageMetaData smd2 = new SbeStorageMetaData(456, "sid2", "typeid2", "wid2", "initTypeClass2", "updaterTypeClass2");
SbeStatsInitializationReport init1 = new SbeStatsInitializationReport();
init1.reportIDs("sid", "wid", "tid", 3145253452L);
init1.reportHardwareInfo(1, 2, 3, 4, null, null, "2344253");
remoteRouter.putUpdate(update1);
Thread.sleep(100);
remoteRouter.putStorageMetaData(smd1);
Thread.sleep(100);
remoteRouter.putStaticInfo(init1);
Thread.sleep(100);
remoteRouter.putUpdate(update2);
Thread.sleep(100);
remoteRouter.putStorageMetaData(smd2);
Thread.sleep(2000);
assertEquals(2, metaData.size());
assertEquals(2, updates.size());
assertEquals(1, staticInfo.size());
assertEquals(Arrays.asList(update1, update2), updates);
assertEquals(Arrays.asList(smd1, smd2), metaData);
assertEquals(Collections.singletonList(init1), staticInfo);
}
use of org.deeplearning4j.api.storage.StorageMetaData in project deeplearning4j by deeplearning4j.
the class TestStorageMetaData method testStorageMetaData.
@Test
public void testStorageMetaData() {
Serializable extraMeta = "ExtraMetaData";
long timeStamp = 123456;
StorageMetaData m = new SbeStorageMetaData(timeStamp, "sessionID", "typeID", "workerID", "org.some.class.InitType", "org.some.class.UpdateType", extraMeta);
byte[] bytes = m.encode();
StorageMetaData m2 = new SbeStorageMetaData();
m2.decode(bytes);
assertEquals(m, m2);
assertArrayEquals(bytes, m2.encode());
//Sanity check: null values
m = new SbeStorageMetaData(0, null, null, null, null, (String) null);
bytes = m.encode();
m2 = new SbeStorageMetaData();
m2.decode(bytes);
//In practice, we don't want these things to ever be null anyway...
assertNullOrZeroLength(m2.getSessionID());
assertNullOrZeroLength(m2.getTypeID());
assertNullOrZeroLength(m2.getWorkerID());
assertNullOrZeroLength(m2.getInitTypeClass());
assertNullOrZeroLength(m2.getUpdateTypeClass());
assertArrayEquals(bytes, m2.encode());
}
use of org.deeplearning4j.api.storage.StorageMetaData in project deeplearning4j by deeplearning4j.
the class ParameterAveragingTrainingWorker method getFinalResult.
@Override
public ParameterAveragingTrainingResult getFinalResult(ComputationGraph network) {
INDArray updaterState = null;
if (saveUpdater) {
ComputationGraphUpdater u = network.getUpdater();
if (u != null)
updaterState = u.getStateViewArray();
}
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
Collection<StorageMetaData> storageMetaData = null;
Collection<Persistable> listenerStaticInfo = null;
Collection<Persistable> listenerUpdates = null;
if (listenerRouterProvider != null) {
StatsStorageRouter r = listenerRouterProvider.getRouter();
if (r instanceof VanillaStatsStorageRouter) {
//TODO this is ugly... need to find a better solution
VanillaStatsStorageRouter ssr = (VanillaStatsStorageRouter) r;
storageMetaData = ssr.getStorageMetaData();
listenerStaticInfo = ssr.getStaticInfo();
listenerUpdates = ssr.getUpdates();
}
}
return new ParameterAveragingTrainingResult(network.params(), updaterState, network.score(), storageMetaData, listenerStaticInfo, listenerUpdates);
}
Aggregations