use of org.deeplearning4j.api.storage.StatsStorage in project deeplearning4j by deeplearning4j.
the class TestStatsStorage method testFileStatsStore.
@Test
public void testFileStatsStore() throws IOException {
for (boolean useJ7Storage : new boolean[] { false, true }) {
for (int i = 0; i < 2; i++) {
File f;
if (i == 0) {
f = Files.createTempFile("TestMapDbStatsStore", ".db").toFile();
} else {
f = Files.createTempFile("TestSqliteStatsStore", ".db").toFile();
}
//Don't want file to exist...
f.delete();
StatsStorage ss;
if (i == 0) {
ss = new MapDBStatsStorage.Builder().file(f).build();
} else {
ss = new J7FileStatsStorage(f);
}
CountingListener l = new CountingListener();
ss.registerStatsStorageListener(l);
assertEquals(1, ss.getListeners().size());
assertEquals(0, ss.listSessionIDs().size());
assertNull(ss.getLatestUpdate("sessionID", "typeID", "workerID"));
assertEquals(0, ss.listSessionIDs().size());
ss.putStaticInfo(getInitReport(0, 0, 0, useJ7Storage));
assertEquals(1, l.countNewSession);
assertEquals(1, l.countNewWorkerId);
assertEquals(1, l.countStaticInfo);
assertEquals(0, l.countUpdate);
assertEquals(Collections.singletonList("sid0"), ss.listSessionIDs());
assertTrue(ss.sessionExists("sid0"));
assertFalse(ss.sessionExists("sid1"));
Persistable expected = getInitReport(0, 0, 0, useJ7Storage);
Persistable p = ss.getStaticInfo("sid0", "tid0", "wid0");
assertEquals(expected, p);
List<Persistable> allStatic = ss.getAllStaticInfos("sid0", "tid0");
assertEquals(Collections.singletonList(expected), allStatic);
assertNull(ss.getLatestUpdate("sid0", "tid0", "wid0"));
assertEquals(Collections.singletonList("tid0"), ss.listTypeIDsForSession("sid0"));
assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSession("sid0"));
assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSessionAndType("sid0", "tid0"));
assertEquals(0, ss.getAllUpdatesAfter("sid0", "tid0", "wid0", 0).size());
assertEquals(0, ss.getNumUpdateRecordsFor("sid0"));
assertEquals(0, ss.getNumUpdateRecordsFor("sid0", "tid0", "wid0"));
//Put first update
ss.putUpdate(getReport(0, 0, 0, 12345, useJ7Storage));
assertEquals(1, ss.getNumUpdateRecordsFor("sid0"));
assertEquals(getReport(0, 0, 0, 12345, useJ7Storage), ss.getLatestUpdate("sid0", "tid0", "wid0"));
assertEquals(Collections.singletonList("tid0"), ss.listTypeIDsForSession("sid0"));
assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSession("sid0"));
assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSessionAndType("sid0", "tid0"));
assertEquals(Collections.singletonList(getReport(0, 0, 0, 12345, useJ7Storage)), ss.getAllUpdatesAfter("sid0", "tid0", "wid0", 0));
assertEquals(1, ss.getNumUpdateRecordsFor("sid0"));
assertEquals(1, ss.getNumUpdateRecordsFor("sid0", "tid0", "wid0"));
List<Persistable> list = ss.getLatestUpdateAllWorkers("sid0", "tid0");
assertEquals(1, list.size());
assertEquals(getReport(0, 0, 0, 12345, useJ7Storage), ss.getUpdate("sid0", "tid0", "wid0", 12345));
assertEquals(1, l.countNewSession);
assertEquals(1, l.countNewWorkerId);
assertEquals(1, l.countStaticInfo);
assertEquals(1, l.countUpdate);
//Put second update
ss.putUpdate(getReport(0, 0, 0, 12346, useJ7Storage));
assertEquals(1, ss.getLatestUpdateAllWorkers("sid0", "tid0").size());
assertEquals(Collections.singletonList("tid0"), ss.listTypeIDsForSession("sid0"));
assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSession("sid0"));
assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSessionAndType("sid0", "tid0"));
assertEquals(getReport(0, 0, 0, 12346, useJ7Storage), ss.getLatestUpdate("sid0", "tid0", "wid0"));
assertEquals(getReport(0, 0, 0, 12346, useJ7Storage), ss.getUpdate("sid0", "tid0", "wid0", 12346));
ss.putUpdate(getReport(0, 0, 1, 12345, useJ7Storage));
assertEquals(2, ss.getLatestUpdateAllWorkers("sid0", "tid0").size());
assertEquals(getReport(0, 0, 1, 12345, useJ7Storage), ss.getLatestUpdate("sid0", "tid0", "wid1"));
assertEquals(getReport(0, 0, 1, 12345, useJ7Storage), ss.getUpdate("sid0", "tid0", "wid1", 12345));
assertEquals(1, l.countNewSession);
assertEquals(2, l.countNewWorkerId);
assertEquals(1, l.countStaticInfo);
assertEquals(3, l.countUpdate);
//Put static info and update with different session, type and worker IDs
ss.putStaticInfo(getInitReport(100, 200, 300, useJ7Storage));
assertEquals(2, ss.getLatestUpdateAllWorkers("sid0", "tid0").size());
ss.putUpdate(getReport(100, 200, 300, 12346, useJ7Storage));
assertEquals(Collections.singletonList(getReport(100, 200, 300, 12346, useJ7Storage)), ss.getLatestUpdateAllWorkers("sid100", "tid200"));
assertEquals(Collections.singletonList("tid200"), ss.listTypeIDsForSession("sid100"));
List<String> temp = ss.listWorkerIDsForSession("sid100");
System.out.println("temp: " + temp);
assertEquals(Collections.singletonList("wid300"), ss.listWorkerIDsForSession("sid100"));
assertEquals(Collections.singletonList("wid300"), ss.listWorkerIDsForSessionAndType("sid100", "tid200"));
assertEquals(getReport(100, 200, 300, 12346, useJ7Storage), ss.getLatestUpdate("sid100", "tid200", "wid300"));
assertEquals(getReport(100, 200, 300, 12346, useJ7Storage), ss.getUpdate("sid100", "tid200", "wid300", 12346));
assertEquals(2, l.countNewSession);
assertEquals(3, l.countNewWorkerId);
assertEquals(2, l.countStaticInfo);
assertEquals(4, l.countUpdate);
//Close and re-open
ss.close();
assertTrue(ss.isClosed());
if (i == 0) {
ss = new MapDBStatsStorage.Builder().file(f).build();
} else {
ss = new J7FileStatsStorage(f);
}
assertEquals(getReport(0, 0, 0, 12345, useJ7Storage), ss.getUpdate("sid0", "tid0", "wid0", 12345));
assertEquals(getReport(0, 0, 0, 12346, useJ7Storage), ss.getLatestUpdate("sid0", "tid0", "wid0"));
assertEquals(getReport(0, 0, 0, 12346, useJ7Storage), ss.getUpdate("sid0", "tid0", "wid0", 12346));
assertEquals(getReport(0, 0, 1, 12345, useJ7Storage), ss.getLatestUpdate("sid0", "tid0", "wid1"));
assertEquals(getReport(0, 0, 1, 12345, useJ7Storage), ss.getUpdate("sid0", "tid0", "wid1", 12345));
assertEquals(2, ss.getLatestUpdateAllWorkers("sid0", "tid0").size());
assertEquals(1, ss.getLatestUpdateAllWorkers("sid100", "tid200").size());
assertEquals(Collections.singletonList("tid200"), ss.listTypeIDsForSession("sid100"));
assertEquals(Collections.singletonList("wid300"), ss.listWorkerIDsForSession("sid100"));
assertEquals(Collections.singletonList("wid300"), ss.listWorkerIDsForSessionAndType("sid100", "tid200"));
assertEquals(getReport(100, 200, 300, 12346, useJ7Storage), ss.getLatestUpdate("sid100", "tid200", "wid300"));
assertEquals(getReport(100, 200, 300, 12346, useJ7Storage), ss.getUpdate("sid100", "tid200", "wid300", 12346));
}
}
}
use of org.deeplearning4j.api.storage.StatsStorage in project deeplearning4j by deeplearning4j.
the class TrainModule method getConfig.
private Triple<MultiLayerConfiguration, ComputationGraphConfiguration, NeuralNetConfiguration> getConfig() {
boolean noData = currentSessionID == null;
StatsStorage ss = (noData ? null : knownSessionIDs.get(currentSessionID));
List<Persistable> allStatic = (noData ? Collections.EMPTY_LIST : ss.getAllStaticInfos(currentSessionID, StatsListener.TYPE_ID));
if (allStatic.size() == 0)
return null;
StatsInitializationReport p = (StatsInitializationReport) allStatic.get(0);
String modelClass = p.getModelClassName();
String config = p.getModelConfigJson();
if (modelClass.endsWith("MultiLayerNetwork")) {
MultiLayerConfiguration conf = MultiLayerConfiguration.fromJson(config);
return new Triple<>(conf, null, null);
} else if (modelClass.endsWith("ComputationGraph")) {
ComputationGraphConfiguration conf = ComputationGraphConfiguration.fromJson(config);
return new Triple<>(null, conf, null);
} else {
try {
NeuralNetConfiguration layer = NeuralNetConfiguration.mapper().readValue(config, NeuralNetConfiguration.class);
return new Triple<>(null, null, layer);
} catch (Exception e) {
e.printStackTrace();
}
}
return null;
}
use of org.deeplearning4j.api.storage.StatsStorage in project deeplearning4j by deeplearning4j.
the class TrainModule method getModelGraph.
private Result getModelGraph() {
boolean noData = currentSessionID == null;
StatsStorage ss = (noData ? null : knownSessionIDs.get(currentSessionID));
List<Persistable> allStatic = (noData ? Collections.EMPTY_LIST : ss.getAllStaticInfos(currentSessionID, StatsListener.TYPE_ID));
if (allStatic.size() == 0) {
return ok();
}
TrainModuleUtils.GraphInfo gi = getGraphInfo();
if (gi == null)
return ok();
return ok(Json.toJson(gi));
}
use of org.deeplearning4j.api.storage.StatsStorage in project deeplearning4j by deeplearning4j.
the class TrainModule method sessionInfo.
private Result sessionInfo() {
//Display, for each session: session ID, start time, number of workers, last update
Map<String, Object> dataEachSession = new HashMap<>();
for (Map.Entry<String, StatsStorage> entry : knownSessionIDs.entrySet()) {
Map<String, Object> dataThisSession = new HashMap<>();
String sid = entry.getKey();
StatsStorage ss = entry.getValue();
List<String> workerIDs = ss.listWorkerIDsForSessionAndType(sid, StatsListener.TYPE_ID);
int workerCount = (workerIDs == null ? 0 : workerIDs.size());
List<Persistable> staticInfo = ss.getAllStaticInfos(sid, StatsListener.TYPE_ID);
long initTime = Long.MAX_VALUE;
if (staticInfo != null) {
for (Persistable p : staticInfo) {
initTime = Math.min(p.getTimeStamp(), initTime);
}
}
long lastUpdateTime = Long.MIN_VALUE;
List<Persistable> lastUpdatesAllWorkers = ss.getLatestUpdateAllWorkers(sid, StatsListener.TYPE_ID);
for (Persistable p : lastUpdatesAllWorkers) {
lastUpdateTime = Math.max(lastUpdateTime, p.getTimeStamp());
}
dataThisSession.put("numWorkers", workerCount);
dataThisSession.put("initTime", initTime == Long.MAX_VALUE ? "" : initTime);
dataThisSession.put("lastUpdate", lastUpdateTime == Long.MIN_VALUE ? "" : lastUpdateTime);
// add hashmap of workers
if (workerCount > 0) {
dataThisSession.put("workers", workerIDs);
}
//Model info: type, # layers, # params...
if (staticInfo != null && staticInfo.size() > 0) {
StatsInitializationReport sr = (StatsInitializationReport) staticInfo.get(0);
String modelClassName = sr.getModelClassName();
if (modelClassName.endsWith("MultiLayerNetwork")) {
modelClassName = "MultiLayerNetwork";
} else if (modelClassName.endsWith("ComputationGraph")) {
modelClassName = "ComputationGraph";
}
int numLayers = sr.getModelNumLayers();
long numParams = sr.getModelNumParams();
dataThisSession.put("modelType", modelClassName);
dataThisSession.put("numLayers", numLayers);
dataThisSession.put("numParams", numParams);
} else {
dataThisSession.put("modelType", "");
dataThisSession.put("numLayers", "");
dataThisSession.put("numParams", "");
}
dataEachSession.put(sid, dataThisSession);
}
return ok(Json.toJson(dataEachSession));
}
use of org.deeplearning4j.api.storage.StatsStorage in project deeplearning4j by deeplearning4j.
the class TrainModule method getOverviewData.
private Result getOverviewData() {
Long lastUpdate = lastUpdateForSession.get(currentSessionID);
if (lastUpdate == null)
lastUpdate = -1L;
I18N i18N = I18NProvider.getInstance();
boolean noData = currentSessionID == null;
//First pass (optimize later): query all data...
StatsStorage ss = (noData ? null : knownSessionIDs.get(currentSessionID));
String wid = getWorkerIdForIndex(currentWorkerIdx);
if (wid == null) {
noData = true;
}
List<Integer> scoresIterCount = new ArrayList<>();
List<Double> scores = new ArrayList<>();
Map<String, Object> result = new HashMap<>();
result.put("updateTimestamp", lastUpdate);
result.put("scores", scores);
result.put("scoresIter", scoresIterCount);
//Get scores info
List<Persistable> updates = (noData ? null : ss.getAllUpdatesAfter(currentSessionID, StatsListener.TYPE_ID, wid, 0));
if (updates == null || updates.size() == 0) {
noData = true;
}
//Collect update ratios for weights
//Collect standard deviations: activations, gradients, updates
//Mean magnitude (updates) / mean magnitude (parameters)
Map<String, List<Double>> updateRatios = new HashMap<>();
result.put("updateRatios", updateRatios);
Map<String, List<Double>> stdevActivations = new HashMap<>();
Map<String, List<Double>> stdevGradients = new HashMap<>();
Map<String, List<Double>> stdevUpdates = new HashMap<>();
result.put("stdevActivations", stdevActivations);
result.put("stdevGradients", stdevGradients);
result.put("stdevUpdates", stdevUpdates);
if (!noData) {
Persistable u = updates.get(0);
if (u instanceof StatsReport) {
StatsReport sp = (StatsReport) u;
Map<String, Double> map = sp.getMeanMagnitudes(StatsType.Parameters);
if (map != null) {
for (String s : map.keySet()) {
if (!s.toLowerCase().endsWith("w"))
//TODO: more robust "weights only" approach...
continue;
updateRatios.put(s, new ArrayList<>());
}
}
Map<String, Double> stdGrad = sp.getStdev(StatsType.Gradients);
if (stdGrad != null) {
for (String s : stdGrad.keySet()) {
if (!s.toLowerCase().endsWith("w"))
//TODO: more robust "weights only" approach...
continue;
stdevGradients.put(s, new ArrayList<>());
}
}
Map<String, Double> stdUpdate = sp.getStdev(StatsType.Updates);
if (stdUpdate != null) {
for (String s : stdUpdate.keySet()) {
if (!s.toLowerCase().endsWith("w"))
//TODO: more robust "weights only" approach...
continue;
stdevUpdates.put(s, new ArrayList<>());
}
}
Map<String, Double> stdAct = sp.getStdev(StatsType.Activations);
if (stdAct != null) {
for (String s : stdAct.keySet()) {
stdevActivations.put(s, new ArrayList<>());
}
}
}
}
StatsReport last = null;
int lastIterCount = -1;
//Legacy issue - Spark training - iteration counts are used to be reset... which means: could go 0,1,2,0,1,2, etc...
//Or, it could equally go 4,8,4,8,... or 5,5,5,5 - depending on the collection and averaging frequencies
//Now, it should use the proper iteration counts
boolean needToHandleLegacyIterCounts = false;
if (!noData) {
double lastScore;
int totalUpdates = updates.size();
int subsamplingFrequency = 1;
if (totalUpdates > maxChartPoints) {
subsamplingFrequency = totalUpdates / maxChartPoints;
}
int pCount = -1;
int lastUpdateIdx = updates.size() - 1;
for (Persistable u : updates) {
pCount++;
if (!(u instanceof StatsReport))
continue;
last = (StatsReport) u;
int iterCount = last.getIterationCount();
if (iterCount <= lastIterCount) {
needToHandleLegacyIterCounts = true;
}
lastIterCount = iterCount;
if (pCount > 0 && subsamplingFrequency > 1 && pCount % subsamplingFrequency != 0) {
//Skip this - subsample the data
if (pCount != lastUpdateIdx)
//Always keep the most recent value
continue;
}
scoresIterCount.add(iterCount);
lastScore = last.getScore();
if (Double.isFinite(lastScore)) {
scores.add(lastScore);
} else {
scores.add(NAN_REPLACEMENT_VALUE);
}
//Update ratios: mean magnitudes(updates) / mean magnitudes (parameters)
Map<String, Double> updateMM = last.getMeanMagnitudes(StatsType.Updates);
Map<String, Double> paramMM = last.getMeanMagnitudes(StatsType.Parameters);
if (updateMM != null && paramMM != null && updateMM.size() > 0 && paramMM.size() > 0) {
for (String s : updateRatios.keySet()) {
List<Double> ratioHistory = updateRatios.get(s);
double currUpdate = updateMM.getOrDefault(s, 0.0);
double currParam = paramMM.getOrDefault(s, 0.0);
double ratio = currUpdate / currParam;
if (Double.isFinite(ratio)) {
ratioHistory.add(ratio);
} else {
ratioHistory.add(NAN_REPLACEMENT_VALUE);
}
}
}
//Standard deviations: gradients, updates, activations
Map<String, Double> stdGrad = last.getStdev(StatsType.Gradients);
Map<String, Double> stdUpd = last.getStdev(StatsType.Updates);
Map<String, Double> stdAct = last.getStdev(StatsType.Activations);
if (stdGrad != null) {
for (String s : stdevGradients.keySet()) {
double d = stdGrad.getOrDefault(s, 0.0);
stdevGradients.get(s).add(fixNaN(d));
}
}
if (stdUpd != null) {
for (String s : stdevUpdates.keySet()) {
double d = stdUpd.getOrDefault(s, 0.0);
stdevUpdates.get(s).add(fixNaN(d));
}
}
if (stdAct != null) {
for (String s : stdevActivations.keySet()) {
double d = stdAct.getOrDefault(s, 0.0);
stdevActivations.get(s).add(fixNaN(d));
}
}
}
}
if (needToHandleLegacyIterCounts) {
cleanLegacyIterationCounts(scoresIterCount);
}
//----- Performance Info -----
String[][] perfInfo = new String[][] { { i18N.getMessage("train.overview.perftable.startTime"), "" }, { i18N.getMessage("train.overview.perftable.totalRuntime"), "" }, { i18N.getMessage("train.overview.perftable.lastUpdate"), "" }, { i18N.getMessage("train.overview.perftable.totalParamUpdates"), "" }, { i18N.getMessage("train.overview.perftable.updatesPerSec"), "" }, { i18N.getMessage("train.overview.perftable.examplesPerSec"), "" } };
if (last != null) {
perfInfo[2][1] = String.valueOf(dateFormat.format(new Date(last.getTimeStamp())));
perfInfo[3][1] = String.valueOf(last.getTotalMinibatches());
perfInfo[4][1] = String.valueOf(df2.format(last.getMinibatchesPerSecond()));
perfInfo[5][1] = String.valueOf(df2.format(last.getExamplesPerSecond()));
}
result.put("perf", perfInfo);
// ----- Model Info -----
String[][] modelInfo = new String[][] { { i18N.getMessage("train.overview.modeltable.modeltype"), "" }, { i18N.getMessage("train.overview.modeltable.nLayers"), "" }, { i18N.getMessage("train.overview.modeltable.nParams"), "" } };
if (!noData) {
Persistable p = ss.getStaticInfo(currentSessionID, StatsListener.TYPE_ID, wid);
if (p != null) {
StatsInitializationReport initReport = (StatsInitializationReport) p;
int nLayers = initReport.getModelNumLayers();
long numParams = initReport.getModelNumParams();
String className = initReport.getModelClassName();
String modelType;
if (className.endsWith("MultiLayerNetwork")) {
modelType = "MultiLayerNetwork";
} else if (className.endsWith("ComputationGraph")) {
modelType = "ComputationGraph";
} else {
modelType = className;
if (modelType.lastIndexOf('.') > 0) {
modelType = modelType.substring(modelType.lastIndexOf('.') + 1);
}
}
modelInfo[0][1] = modelType;
modelInfo[1][1] = String.valueOf(nLayers);
modelInfo[2][1] = String.valueOf(numParams);
}
}
result.put("model", modelInfo);
return Results.ok(Json.toJson(result));
}
Aggregations