use of org.deeplearning4j.ui.stats.api.StatsReport in project deeplearning4j by deeplearning4j.
the class HistogramModule method processRequest.
private Result processRequest(String sessionId) {
//TODO cache the relevant info and update, rather than querying StatsStorage and building from scratch each time
StatsStorage ss = knownSessionIDs.get(sessionId);
if (ss == null) {
return Results.notFound("Unknown session ID: " + sessionId);
}
List<String> workerIDs = ss.listWorkerIDsForSession(sessionId);
//TODO checks
StatsInitializationReport initReport = (StatsInitializationReport) ss.getStaticInfo(sessionId, StatsListener.TYPE_ID, workerIDs.get(0));
if (initReport == null)
return Results.ok(Json.toJson(Collections.EMPTY_MAP));
String[] paramNames = initReport.getModelParamNames();
//Infer layer names from param names...
Set<String> layerNameSet = new LinkedHashSet<>();
for (String s : paramNames) {
String[] split = s.split("_");
if (!layerNameSet.contains(split[0])) {
layerNameSet.add(split[0]);
}
}
List<String> layerNameList = new ArrayList<>(layerNameSet);
List<Persistable> list = ss.getAllUpdatesAfter(sessionId, StatsListener.TYPE_ID, workerIDs.get(0), 0);
Collections.sort(list, (a, b) -> Long.compare(a.getTimeStamp(), b.getTimeStamp()));
List<Double> scoreList = new ArrayList<>(list.size());
//List.get(i) -> layer i. Maps: parameter for the given layer
List<Map<String, List<Double>>> meanMagHistoryParams = new ArrayList<>();
//List.get(i) -> layer i. Maps: updates for the given layer
List<Map<String, List<Double>>> meanMagHistoryUpdates = new ArrayList<>();
for (int i = 0; i < layerNameList.size(); i++) {
meanMagHistoryParams.add(new HashMap<>());
meanMagHistoryUpdates.add(new HashMap<>());
}
StatsReport last = null;
for (Persistable p : list) {
if (!(p instanceof StatsReport)) {
log.debug("Encountered unexpected type: {}", p);
continue;
}
StatsReport sp = (StatsReport) p;
scoreList.add(sp.getScore());
//Mean magnitudes
if (sp.hasSummaryStats(StatsType.Parameters, SummaryType.MeanMagnitudes)) {
updateMeanMagnitudeMaps(sp.getMeanMagnitudes(StatsType.Parameters), layerNameList, meanMagHistoryParams);
}
if (sp.hasSummaryStats(StatsType.Updates, SummaryType.MeanMagnitudes)) {
updateMeanMagnitudeMaps(sp.getMeanMagnitudes(StatsType.Updates), layerNameList, meanMagHistoryUpdates);
}
last = sp;
}
Map<String, Map> newParams = getHistogram(last.getHistograms(StatsType.Parameters));
Map<String, Map> newGrad = getHistogram(last.getHistograms(StatsType.Updates));
double lastScore = (scoreList.size() == 0 ? 0.0 : scoreList.get(scoreList.size() - 1));
CompactModelAndGradient g = new CompactModelAndGradient();
g.setGradients(newGrad);
g.setParameters(newParams);
g.setScore(lastScore);
g.setScores(scoreList);
// g.setPath(subPath);
g.setUpdateMagnitudes(meanMagHistoryUpdates);
g.setParamMagnitudes(meanMagHistoryParams);
// g.setLayerNames(layerNames);
g.setLastUpdateTime(last.getTimeStamp());
return Results.ok(Json.toJson(g));
}
use of org.deeplearning4j.ui.stats.api.StatsReport in project deeplearning4j by deeplearning4j.
the class TrainModule method getLayerLearningRates.
private Map<String, Object> getLayerLearningRates(int layerIdx, TrainModuleUtils.GraphInfo gi, List<Persistable> updates, List<Integer> iterationCounts, ModelType modelType) {
if (gi == null) {
return Collections.emptyMap();
}
String layerName = gi.getOriginalVertexName().get(layerIdx);
int size = (updates == null ? 0 : updates.size());
int[] iterCounts = new int[size];
Map<String, float[]> byName = new HashMap<>();
int used = 0;
if (updates != null) {
int uCount = -1;
for (Persistable u : updates) {
uCount++;
if (!(u instanceof StatsReport))
continue;
StatsReport sp = (StatsReport) u;
if (iterationCounts == null) {
iterCounts[used] = sp.getIterationCount();
} else {
iterCounts[used] = iterationCounts.get(uCount);
}
//TODO PROPER VALIDATION ETC, ERROR HANDLING
Map<String, Double> lrs = sp.getLearningRates();
String prefix;
if (modelType == ModelType.Layer) {
prefix = layerName;
} else {
prefix = layerName + "_";
}
for (String p : lrs.keySet()) {
if (p.startsWith(prefix)) {
String layerParamName = p.substring(Math.min(p.length(), prefix.length()));
if (!byName.containsKey(layerParamName)) {
byName.put(layerParamName, new float[size]);
}
float[] lrThisParam = byName.get(layerParamName);
lrThisParam[used] = lrs.get(p).floatValue();
}
}
used++;
}
}
List<String> paramNames = new ArrayList<>(byName.keySet());
//Sorted for consistency
Collections.sort(paramNames);
Map<String, Object> ret = new HashMap<>();
ret.put("iterCounts", iterCounts);
ret.put("paramNames", paramNames);
ret.put("lrs", byName);
return ret;
}
use of org.deeplearning4j.ui.stats.api.StatsReport in project deeplearning4j by deeplearning4j.
the class TrainModule method getModelData.
private Result getModelData(String str) {
Long lastUpdateTime = lastUpdateForSession.get(currentSessionID);
if (lastUpdateTime == null)
lastUpdateTime = -1L;
//TODO validation
int layerIdx = Integer.parseInt(str);
I18N i18N = I18NProvider.getInstance();
//Model info for layer
boolean noData = currentSessionID == null;
//First pass (optimize later): query all data...
StatsStorage ss = (noData ? null : knownSessionIDs.get(currentSessionID));
String wid = getWorkerIdForIndex(currentWorkerIdx);
if (wid == null) {
noData = true;
}
Map<String, Object> result = new HashMap<>();
result.put("updateTimestamp", lastUpdateTime);
Triple<MultiLayerConfiguration, ComputationGraphConfiguration, NeuralNetConfiguration> conf = getConfig();
if (conf == null) {
return ok(Json.toJson(result));
}
TrainModuleUtils.GraphInfo gi = getGraphInfo();
if (gi == null) {
return ok(Json.toJson(result));
}
// Get static layer info
String[][] layerInfoTable = getLayerInfoTable(layerIdx, gi, i18N, noData, ss, wid);
result.put("layerInfo", layerInfoTable);
//First: get all data, and subsample it if necessary, to avoid returning too many points...
List<Persistable> updates = (noData ? null : ss.getAllUpdatesAfter(currentSessionID, StatsListener.TYPE_ID, wid, 0));
List<Integer> iterationCounts = null;
boolean needToHandleLegacyIterCounts = false;
if (updates != null && updates.size() > maxChartPoints) {
int subsamplingFrequency = updates.size() / maxChartPoints;
List<Persistable> subsampled = new ArrayList<>();
iterationCounts = new ArrayList<>();
int pCount = -1;
int lastUpdateIdx = updates.size() - 1;
int lastIterCount = -1;
for (Persistable p : updates) {
if (!(p instanceof StatsReport))
continue;
;
StatsReport sr = (StatsReport) p;
pCount++;
int iterCount = sr.getIterationCount();
if (iterCount <= lastIterCount) {
needToHandleLegacyIterCounts = true;
}
lastIterCount = iterCount;
if (pCount > 0 && subsamplingFrequency > 1 && pCount % subsamplingFrequency != 0) {
//Skip this to subsample the data
if (pCount != lastUpdateIdx)
//Always keep the most recent value
continue;
}
subsampled.add(p);
iterationCounts.add(iterCount);
}
updates = subsampled;
} else if (updates != null) {
int offset = 0;
iterationCounts = new ArrayList<>(updates.size());
int lastIterCount = -1;
for (Persistable p : updates) {
if (!(p instanceof StatsReport))
continue;
;
StatsReport sr = (StatsReport) p;
int iterCount = sr.getIterationCount();
if (iterCount <= lastIterCount) {
needToHandleLegacyIterCounts = true;
}
iterationCounts.add(iterCount);
}
}
//Now, it should use the proper iteration counts
if (needToHandleLegacyIterCounts) {
cleanLegacyIterationCounts(iterationCounts);
}
//Get mean magnitudes line chart
ModelType mt;
if (conf.getFirst() != null)
mt = ModelType.MLN;
else if (conf.getSecond() != null)
mt = ModelType.CG;
else
mt = ModelType.Layer;
MeanMagnitudes mm = getLayerMeanMagnitudes(layerIdx, gi, updates, iterationCounts, mt);
Map<String, Object> mmRatioMap = new HashMap<>();
mmRatioMap.put("layerParamNames", mm.getRatios().keySet());
mmRatioMap.put("iterCounts", mm.getIterations());
mmRatioMap.put("ratios", mm.getRatios());
mmRatioMap.put("paramMM", mm.getParamMM());
mmRatioMap.put("updateMM", mm.getUpdateMM());
result.put("meanMag", mmRatioMap);
//Get activations line chart for layer
Triple<int[], float[], float[]> activationsData = getLayerActivations(layerIdx, gi, updates, iterationCounts);
Map<String, Object> activationMap = new HashMap<>();
activationMap.put("iterCount", activationsData.getFirst());
activationMap.put("mean", activationsData.getSecond());
activationMap.put("stdev", activationsData.getThird());
result.put("activations", activationMap);
//Get learning rate vs. time chart for layer
Map<String, Object> lrs = getLayerLearningRates(layerIdx, gi, updates, iterationCounts, mt);
result.put("learningRates", lrs);
//Parameters histogram data
Persistable lastUpdate = (updates != null && updates.size() > 0 ? updates.get(updates.size() - 1) : null);
Map<String, Object> paramHistograms = getHistograms(layerIdx, gi, StatsType.Parameters, lastUpdate);
result.put("paramHist", paramHistograms);
//Updates histogram data
Map<String, Object> updateHistograms = getHistograms(layerIdx, gi, StatsType.Updates, lastUpdate);
result.put("updateHist", updateHistograms);
return ok(Json.toJson(result));
}
use of org.deeplearning4j.ui.stats.api.StatsReport in project deeplearning4j by deeplearning4j.
the class TrainModule method getMemory.
private static Map<String, Object> getMemory(List<Persistable> staticInfoAllWorkers, List<Persistable> updatesLastNMinutes, I18N i18n) {
Map<String, Object> ret = new HashMap<>();
//First: map workers to JVMs
Set<String> jvmIDs = new HashSet<>();
Map<String, String> workersToJvms = new HashMap<>();
Map<String, Integer> workerNumDevices = new HashMap<>();
Map<String, String[]> deviceNames = new HashMap<>();
for (Persistable p : staticInfoAllWorkers) {
//TODO validation/checks
StatsInitializationReport init = (StatsInitializationReport) p;
String jvmuid = init.getSwJvmUID();
workersToJvms.put(p.getWorkerID(), jvmuid);
jvmIDs.add(jvmuid);
int nDevices = init.getHwNumDevices();
workerNumDevices.put(p.getWorkerID(), nDevices);
if (nDevices > 0) {
String[] deviceNamesArr = init.getHwDeviceDescription();
deviceNames.put(p.getWorkerID(), deviceNamesArr);
}
}
List<String> jvmList = new ArrayList<>(jvmIDs);
Collections.sort(jvmList);
//For each unique JVM, collect memory info
//Do this by selecting the first worker
int count = 0;
for (String jvm : jvmList) {
List<String> workersForJvm = new ArrayList<>();
for (String s : workersToJvms.keySet()) {
if (workersToJvms.get(s).equals(jvm)) {
workersForJvm.add(s);
}
}
Collections.sort(workersForJvm);
String wid = workersForJvm.get(0);
int numDevices = workerNumDevices.get(wid);
Map<String, Object> jvmData = new HashMap<>();
List<Long> timestamps = new ArrayList<>();
List<Float> fracJvm = new ArrayList<>();
List<Float> fracOffHeap = new ArrayList<>();
long[] lastBytes = new long[2 + numDevices];
long[] lastMaxBytes = new long[2 + numDevices];
List<List<Float>> fracDeviceMem = null;
if (numDevices > 0) {
fracDeviceMem = new ArrayList<>(numDevices);
for (int i = 0; i < numDevices; i++) {
fracDeviceMem.add(new ArrayList<>());
}
}
for (Persistable p : updatesLastNMinutes) {
//TODO single pass
if (!p.getWorkerID().equals(wid))
continue;
if (!(p instanceof StatsReport))
continue;
StatsReport sp = (StatsReport) p;
timestamps.add(sp.getTimeStamp());
long jvmCurrentBytes = sp.getJvmCurrentBytes();
long jvmMaxBytes = sp.getJvmMaxBytes();
long ohCurrentBytes = sp.getOffHeapCurrentBytes();
long ohMaxBytes = sp.getOffHeapMaxBytes();
double jvmFrac = jvmCurrentBytes / ((double) jvmMaxBytes);
double offheapFrac = ohCurrentBytes / ((double) ohMaxBytes);
if (Double.isNaN(jvmFrac))
jvmFrac = 0.0;
if (Double.isNaN(offheapFrac))
offheapFrac = 0.0;
fracJvm.add((float) jvmFrac);
fracOffHeap.add((float) offheapFrac);
lastBytes[0] = jvmCurrentBytes;
lastBytes[1] = ohCurrentBytes;
lastMaxBytes[0] = jvmMaxBytes;
lastMaxBytes[1] = ohMaxBytes;
if (numDevices > 0) {
long[] devBytes = sp.getDeviceCurrentBytes();
long[] devMaxBytes = sp.getDeviceMaxBytes();
for (int i = 0; i < numDevices; i++) {
double frac = devBytes[i] / ((double) devMaxBytes[i]);
if (Double.isNaN(frac))
frac = 0.0;
fracDeviceMem.get(i).add((float) frac);
lastBytes[2 + i] = devBytes[i];
lastMaxBytes[2 + i] = devMaxBytes[i];
}
}
}
List<List<Float>> fracUtilized = new ArrayList<>();
fracUtilized.add(fracJvm);
fracUtilized.add(fracOffHeap);
String[] seriesNames = new String[2 + numDevices];
seriesNames[0] = i18n.getMessage("train.system.hwTable.jvmCurrent");
seriesNames[1] = i18n.getMessage("train.system.hwTable.offHeapCurrent");
boolean[] isDevice = new boolean[2 + numDevices];
String[] devNames = deviceNames.get(wid);
for (int i = 0; i < numDevices; i++) {
seriesNames[2 + i] = devNames != null && devNames.length > i ? devNames[i] : "";
fracUtilized.add(fracDeviceMem.get(i));
isDevice[2 + i] = true;
}
jvmData.put("times", timestamps);
jvmData.put("isDevice", isDevice);
jvmData.put("seriesNames", seriesNames);
jvmData.put("values", fracUtilized);
jvmData.put("currentBytes", lastBytes);
jvmData.put("maxBytes", lastMaxBytes);
ret.put(String.valueOf(count), jvmData);
count++;
}
return ret;
}
Aggregations