use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.
the class ConvolutionalIterationListener method iterationDone.
/**
* Event listener for each iteration
*
* @param model the model iterating
* @param iteration the iteration number
*/
@Override
public void iterationDone(Model model, int iteration) {
if (iteration % freq == 0) {
List<INDArray> tensors = new ArrayList<>();
int cnt = 0;
Random rnd = new Random();
BufferedImage sourceImage = null;
if (model instanceof MultiLayerNetwork) {
MultiLayerNetwork l = (MultiLayerNetwork) model;
for (Layer layer : l.getLayers()) {
if (layer.type() == Layer.Type.CONVOLUTIONAL) {
INDArray output = layer.activate();
int sampleDim = rnd.nextInt(output.shape()[0] - 1) + 1;
if (cnt == 0) {
INDArray inputs = ((ConvolutionLayer) layer).input();
try {
sourceImage = restoreRGBImage(inputs.tensorAlongDimension(sampleDim, new int[] { 3, 2, 1 }));
} catch (Exception e) {
throw new RuntimeException(e);
}
}
INDArray tad = output.tensorAlongDimension(sampleDim, 3, 2, 1);
tensors.add(tad);
cnt++;
}
}
} else if (model instanceof ComputationGraph) {
ComputationGraph l = (ComputationGraph) model;
for (Layer layer : l.getLayers()) {
if (layer.type() == Layer.Type.CONVOLUTIONAL) {
INDArray output = layer.activate();
int sampleDim = rnd.nextInt(output.shape()[0] - 1) + 1;
if (cnt == 0) {
INDArray inputs = ((ConvolutionLayer) layer).input();
try {
sourceImage = restoreRGBImage(inputs.tensorAlongDimension(sampleDim, new int[] { 3, 2, 1 }));
} catch (Exception e) {
throw new RuntimeException(e);
}
}
INDArray tad = output.tensorAlongDimension(sampleDim, 3, 2, 1);
tensors.add(tad);
cnt++;
}
}
}
BufferedImage render = rasterizeConvoLayers(tensors, sourceImage);
Persistable p = new ConvolutionListenerPersistable(sessionID, workerID, System.currentTimeMillis(), render);
ssr.putStaticInfo(p);
minibatchNum++;
}
}
use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.
the class TrainModule method getLayerLearningRates.
private Map<String, Object> getLayerLearningRates(int layerIdx, TrainModuleUtils.GraphInfo gi, List<Persistable> updates, List<Integer> iterationCounts, ModelType modelType) {
if (gi == null) {
return Collections.emptyMap();
}
String layerName = gi.getOriginalVertexName().get(layerIdx);
int size = (updates == null ? 0 : updates.size());
int[] iterCounts = new int[size];
Map<String, float[]> byName = new HashMap<>();
int used = 0;
if (updates != null) {
int uCount = -1;
for (Persistable u : updates) {
uCount++;
if (!(u instanceof StatsReport))
continue;
StatsReport sp = (StatsReport) u;
if (iterationCounts == null) {
iterCounts[used] = sp.getIterationCount();
} else {
iterCounts[used] = iterationCounts.get(uCount);
}
//TODO PROPER VALIDATION ETC, ERROR HANDLING
Map<String, Double> lrs = sp.getLearningRates();
String prefix;
if (modelType == ModelType.Layer) {
prefix = layerName;
} else {
prefix = layerName + "_";
}
for (String p : lrs.keySet()) {
if (p.startsWith(prefix)) {
String layerParamName = p.substring(Math.min(p.length(), prefix.length()));
if (!byName.containsKey(layerParamName)) {
byName.put(layerParamName, new float[size]);
}
float[] lrThisParam = byName.get(layerParamName);
lrThisParam[used] = lrs.get(p).floatValue();
}
}
used++;
}
}
List<String> paramNames = new ArrayList<>(byName.keySet());
//Sorted for consistency
Collections.sort(paramNames);
Map<String, Object> ret = new HashMap<>();
ret.put("iterCounts", iterCounts);
ret.put("paramNames", paramNames);
ret.put("lrs", byName);
return ret;
}
use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.
the class TrainModule method getDefaultSession.
private void getDefaultSession() {
if (currentSessionID != null)
return;
long mostRecentTime = Long.MIN_VALUE;
String sessionID = null;
for (Map.Entry<String, StatsStorage> entry : knownSessionIDs.entrySet()) {
List<Persistable> staticInfos = entry.getValue().getAllStaticInfos(entry.getKey(), StatsListener.TYPE_ID);
if (staticInfos == null || staticInfos.size() == 0)
continue;
Persistable p = staticInfos.get(0);
long thisTime = p.getTimeStamp();
if (thisTime > mostRecentTime) {
mostRecentTime = thisTime;
sessionID = entry.getKey();
}
}
if (sessionID != null) {
currentSessionID = sessionID;
}
}
use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.
the class TrainModule method getSystemData.
public Result getSystemData() {
Long lastUpdate = lastUpdateForSession.get(currentSessionID);
if (lastUpdate == null)
lastUpdate = -1L;
I18N i18n = I18NProvider.getInstance();
//First: get the MOST RECENT update...
//Then get all updates from most recent - 5 minutes -> TODO make this configurable...
boolean noData = currentSessionID == null;
StatsStorage ss = (noData ? null : knownSessionIDs.get(currentSessionID));
List<Persistable> allStatic = (noData ? Collections.EMPTY_LIST : ss.getAllStaticInfos(currentSessionID, StatsListener.TYPE_ID));
List<Persistable> latestUpdates = (noData ? Collections.EMPTY_LIST : ss.getLatestUpdateAllWorkers(currentSessionID, StatsListener.TYPE_ID));
long lastUpdateTime = -1;
if (latestUpdates == null || latestUpdates.size() == 0) {
noData = true;
} else {
for (Persistable p : latestUpdates) {
lastUpdateTime = Math.max(lastUpdateTime, p.getTimeStamp());
}
}
//TODO Make configurable
long fromTime = lastUpdateTime - 5 * 60 * 1000;
List<Persistable> lastNMinutes = (noData ? null : ss.getAllUpdatesAfter(currentSessionID, StatsListener.TYPE_ID, fromTime));
Map<String, Object> mem = getMemory(allStatic, lastNMinutes, i18n);
Pair<Map<String, Object>, Map<String, Object>> hwSwInfo = getHardwareSoftwareInfo(allStatic, i18n);
Map<String, Object> ret = new HashMap<>();
ret.put("updateTimestamp", lastUpdate);
ret.put("memory", mem);
ret.put("hardware", hwSwInfo.getFirst());
ret.put("software", hwSwInfo.getSecond());
return ok(Json.toJson(ret));
}
use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.
the class TrainModule method getModelData.
private Result getModelData(String str) {
Long lastUpdateTime = lastUpdateForSession.get(currentSessionID);
if (lastUpdateTime == null)
lastUpdateTime = -1L;
//TODO validation
int layerIdx = Integer.parseInt(str);
I18N i18N = I18NProvider.getInstance();
//Model info for layer
boolean noData = currentSessionID == null;
//First pass (optimize later): query all data...
StatsStorage ss = (noData ? null : knownSessionIDs.get(currentSessionID));
String wid = getWorkerIdForIndex(currentWorkerIdx);
if (wid == null) {
noData = true;
}
Map<String, Object> result = new HashMap<>();
result.put("updateTimestamp", lastUpdateTime);
Triple<MultiLayerConfiguration, ComputationGraphConfiguration, NeuralNetConfiguration> conf = getConfig();
if (conf == null) {
return ok(Json.toJson(result));
}
TrainModuleUtils.GraphInfo gi = getGraphInfo();
if (gi == null) {
return ok(Json.toJson(result));
}
// Get static layer info
String[][] layerInfoTable = getLayerInfoTable(layerIdx, gi, i18N, noData, ss, wid);
result.put("layerInfo", layerInfoTable);
//First: get all data, and subsample it if necessary, to avoid returning too many points...
List<Persistable> updates = (noData ? null : ss.getAllUpdatesAfter(currentSessionID, StatsListener.TYPE_ID, wid, 0));
List<Integer> iterationCounts = null;
boolean needToHandleLegacyIterCounts = false;
if (updates != null && updates.size() > maxChartPoints) {
int subsamplingFrequency = updates.size() / maxChartPoints;
List<Persistable> subsampled = new ArrayList<>();
iterationCounts = new ArrayList<>();
int pCount = -1;
int lastUpdateIdx = updates.size() - 1;
int lastIterCount = -1;
for (Persistable p : updates) {
if (!(p instanceof StatsReport))
continue;
;
StatsReport sr = (StatsReport) p;
pCount++;
int iterCount = sr.getIterationCount();
if (iterCount <= lastIterCount) {
needToHandleLegacyIterCounts = true;
}
lastIterCount = iterCount;
if (pCount > 0 && subsamplingFrequency > 1 && pCount % subsamplingFrequency != 0) {
//Skip this to subsample the data
if (pCount != lastUpdateIdx)
//Always keep the most recent value
continue;
}
subsampled.add(p);
iterationCounts.add(iterCount);
}
updates = subsampled;
} else if (updates != null) {
int offset = 0;
iterationCounts = new ArrayList<>(updates.size());
int lastIterCount = -1;
for (Persistable p : updates) {
if (!(p instanceof StatsReport))
continue;
;
StatsReport sr = (StatsReport) p;
int iterCount = sr.getIterationCount();
if (iterCount <= lastIterCount) {
needToHandleLegacyIterCounts = true;
}
iterationCounts.add(iterCount);
}
}
//Now, it should use the proper iteration counts
if (needToHandleLegacyIterCounts) {
cleanLegacyIterationCounts(iterationCounts);
}
//Get mean magnitudes line chart
ModelType mt;
if (conf.getFirst() != null)
mt = ModelType.MLN;
else if (conf.getSecond() != null)
mt = ModelType.CG;
else
mt = ModelType.Layer;
MeanMagnitudes mm = getLayerMeanMagnitudes(layerIdx, gi, updates, iterationCounts, mt);
Map<String, Object> mmRatioMap = new HashMap<>();
mmRatioMap.put("layerParamNames", mm.getRatios().keySet());
mmRatioMap.put("iterCounts", mm.getIterations());
mmRatioMap.put("ratios", mm.getRatios());
mmRatioMap.put("paramMM", mm.getParamMM());
mmRatioMap.put("updateMM", mm.getUpdateMM());
result.put("meanMag", mmRatioMap);
//Get activations line chart for layer
Triple<int[], float[], float[]> activationsData = getLayerActivations(layerIdx, gi, updates, iterationCounts);
Map<String, Object> activationMap = new HashMap<>();
activationMap.put("iterCount", activationsData.getFirst());
activationMap.put("mean", activationsData.getSecond());
activationMap.put("stdev", activationsData.getThird());
result.put("activations", activationMap);
//Get learning rate vs. time chart for layer
Map<String, Object> lrs = getLayerLearningRates(layerIdx, gi, updates, iterationCounts, mt);
result.put("learningRates", lrs);
//Parameters histogram data
Persistable lastUpdate = (updates != null && updates.size() > 0 ? updates.get(updates.size() - 1) : null);
Map<String, Object> paramHistograms = getHistograms(layerIdx, gi, StatsType.Parameters, lastUpdate);
result.put("paramHist", paramHistograms);
//Updates histogram data
Map<String, Object> updateHistograms = getHistograms(layerIdx, gi, StatsType.Updates, lastUpdate);
result.put("updateHist", updateHistograms);
return ok(Json.toJson(result));
}
Aggregations