use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.
the class TrainModule method getConfig.
private Triple<MultiLayerConfiguration, ComputationGraphConfiguration, NeuralNetConfiguration> getConfig() {
boolean noData = currentSessionID == null;
StatsStorage ss = (noData ? null : knownSessionIDs.get(currentSessionID));
List<Persistable> allStatic = (noData ? Collections.EMPTY_LIST : ss.getAllStaticInfos(currentSessionID, StatsListener.TYPE_ID));
if (allStatic.size() == 0)
return null;
StatsInitializationReport p = (StatsInitializationReport) allStatic.get(0);
String modelClass = p.getModelClassName();
String config = p.getModelConfigJson();
if (modelClass.endsWith("MultiLayerNetwork")) {
MultiLayerConfiguration conf = MultiLayerConfiguration.fromJson(config);
return new Triple<>(conf, null, null);
} else if (modelClass.endsWith("ComputationGraph")) {
ComputationGraphConfiguration conf = ComputationGraphConfiguration.fromJson(config);
return new Triple<>(null, conf, null);
} else {
try {
NeuralNetConfiguration layer = NeuralNetConfiguration.mapper().readValue(config, NeuralNetConfiguration.class);
return new Triple<>(null, null, layer);
} catch (Exception e) {
e.printStackTrace();
}
}
return null;
}
use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.
the class TrainModule method getLayerActivations.
private Triple<int[], float[], float[]> getLayerActivations(int index, TrainModuleUtils.GraphInfo gi, List<Persistable> updates, List<Integer> iterationCounts) {
if (gi == null) {
return EMPTY_TRIPLE;
}
//Index may be for an input, for example
String type = gi.getVertexTypes().get(index);
if ("input".equalsIgnoreCase(type)) {
return EMPTY_TRIPLE;
}
List<String> origNames = gi.getOriginalVertexName();
if (index < 0 || index >= origNames.size()) {
return EMPTY_TRIPLE;
}
String layerName = origNames.get(index);
int size = (updates == null ? 0 : updates.size());
int[] iterCounts = new int[size];
float[] mean = new float[size];
float[] stdev = new float[size];
int used = 0;
if (updates != null) {
int uCount = -1;
for (Persistable u : updates) {
uCount++;
if (!(u instanceof StatsReport))
continue;
StatsReport sp = (StatsReport) u;
if (iterationCounts == null) {
iterCounts[used] = sp.getIterationCount();
} else {
iterCounts[used] = iterationCounts.get(uCount);
}
Map<String, Double> means = sp.getMean(StatsType.Activations);
Map<String, Double> stdevs = sp.getStdev(StatsType.Activations);
//TODO PROPER VALIDATION ETC, ERROR HANDLING
if (means != null && means.containsKey(layerName)) {
mean[used] = means.get(layerName).floatValue();
stdev[used] = stdevs.get(layerName).floatValue();
if (!Float.isFinite(mean[used])) {
mean[used] = (float) NAN_REPLACEMENT_VALUE;
}
if (!Float.isFinite(stdev[used])) {
stdev[used] = (float) NAN_REPLACEMENT_VALUE;
}
used++;
}
}
}
if (used != iterCounts.length) {
iterCounts = Arrays.copyOf(iterCounts, used);
mean = Arrays.copyOf(mean, used);
stdev = Arrays.copyOf(stdev, used);
}
return new Triple<>(iterCounts, mean, stdev);
}
use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.
the class TrainModule method getModelGraph.
private Result getModelGraph() {
boolean noData = currentSessionID == null;
StatsStorage ss = (noData ? null : knownSessionIDs.get(currentSessionID));
List<Persistable> allStatic = (noData ? Collections.EMPTY_LIST : ss.getAllStaticInfos(currentSessionID, StatsListener.TYPE_ID));
if (allStatic.size() == 0) {
return ok();
}
TrainModuleUtils.GraphInfo gi = getGraphInfo();
if (gi == null)
return ok();
return ok(Json.toJson(gi));
}
use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.
the class TrainModule method sessionInfo.
private Result sessionInfo() {
//Display, for each session: session ID, start time, number of workers, last update
Map<String, Object> dataEachSession = new HashMap<>();
for (Map.Entry<String, StatsStorage> entry : knownSessionIDs.entrySet()) {
Map<String, Object> dataThisSession = new HashMap<>();
String sid = entry.getKey();
StatsStorage ss = entry.getValue();
List<String> workerIDs = ss.listWorkerIDsForSessionAndType(sid, StatsListener.TYPE_ID);
int workerCount = (workerIDs == null ? 0 : workerIDs.size());
List<Persistable> staticInfo = ss.getAllStaticInfos(sid, StatsListener.TYPE_ID);
long initTime = Long.MAX_VALUE;
if (staticInfo != null) {
for (Persistable p : staticInfo) {
initTime = Math.min(p.getTimeStamp(), initTime);
}
}
long lastUpdateTime = Long.MIN_VALUE;
List<Persistable> lastUpdatesAllWorkers = ss.getLatestUpdateAllWorkers(sid, StatsListener.TYPE_ID);
for (Persistable p : lastUpdatesAllWorkers) {
lastUpdateTime = Math.max(lastUpdateTime, p.getTimeStamp());
}
dataThisSession.put("numWorkers", workerCount);
dataThisSession.put("initTime", initTime == Long.MAX_VALUE ? "" : initTime);
dataThisSession.put("lastUpdate", lastUpdateTime == Long.MIN_VALUE ? "" : lastUpdateTime);
// add hashmap of workers
if (workerCount > 0) {
dataThisSession.put("workers", workerIDs);
}
//Model info: type, # layers, # params...
if (staticInfo != null && staticInfo.size() > 0) {
StatsInitializationReport sr = (StatsInitializationReport) staticInfo.get(0);
String modelClassName = sr.getModelClassName();
if (modelClassName.endsWith("MultiLayerNetwork")) {
modelClassName = "MultiLayerNetwork";
} else if (modelClassName.endsWith("ComputationGraph")) {
modelClassName = "ComputationGraph";
}
int numLayers = sr.getModelNumLayers();
long numParams = sr.getModelNumParams();
dataThisSession.put("modelType", modelClassName);
dataThisSession.put("numLayers", numLayers);
dataThisSession.put("numParams", numParams);
} else {
dataThisSession.put("modelType", "");
dataThisSession.put("numLayers", "");
dataThisSession.put("numParams", "");
}
dataEachSession.put(sid, dataThisSession);
}
return ok(Json.toJson(dataEachSession));
}
use of org.deeplearning4j.api.storage.Persistable in project deeplearning4j by deeplearning4j.
the class TrainModule method getLayerInfoTable.
private String[][] getLayerInfoTable(int layerIdx, TrainModuleUtils.GraphInfo gi, I18N i18N, boolean noData, StatsStorage ss, String wid) {
List<String[]> layerInfoRows = new ArrayList<>();
layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerName"), gi.getVertexNames().get(layerIdx) });
layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerType"), "" });
if (!noData) {
Persistable p = ss.getStaticInfo(currentSessionID, StatsListener.TYPE_ID, wid);
if (p != null) {
StatsInitializationReport initReport = (StatsInitializationReport) p;
String configJson = initReport.getModelConfigJson();
String modelClass = initReport.getModelClassName();
//TODO error handling...
String layerType = "";
Layer layer = null;
NeuralNetConfiguration nnc = null;
if (modelClass.endsWith("MultiLayerNetwork")) {
MultiLayerConfiguration conf = MultiLayerConfiguration.fromJson(configJson);
//-1 because of input
int confIdx = layerIdx - 1;
if (confIdx >= 0) {
nnc = conf.getConf(confIdx);
layer = nnc.getLayer();
} else {
//Input layer
layerType = "Input";
}
} else if (modelClass.endsWith("ComputationGraph")) {
ComputationGraphConfiguration conf = ComputationGraphConfiguration.fromJson(configJson);
String vertexName = gi.getVertexNames().get(layerIdx);
Map<String, GraphVertex> vertices = conf.getVertices();
if (vertices.containsKey(vertexName) && vertices.get(vertexName) instanceof LayerVertex) {
LayerVertex lv = (LayerVertex) vertices.get(vertexName);
nnc = lv.getLayerConf();
layer = nnc.getLayer();
} else if (conf.getNetworkInputs().contains(vertexName)) {
layerType = "Input";
} else {
GraphVertex gv = conf.getVertices().get(vertexName);
if (gv != null) {
layerType = gv.getClass().getSimpleName();
}
}
} else if (modelClass.endsWith("VariationalAutoencoder")) {
layerType = gi.getVertexTypes().get(layerIdx);
Map<String, String> map = gi.getVertexInfo().get(layerIdx);
for (Map.Entry<String, String> entry : map.entrySet()) {
layerInfoRows.add(new String[] { entry.getKey(), entry.getValue() });
}
}
if (layer != null) {
layerType = getLayerType(layer);
}
if (layer != null) {
String activationFn = null;
if (layer instanceof FeedForwardLayer) {
FeedForwardLayer ffl = (FeedForwardLayer) layer;
layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerNIn"), String.valueOf(ffl.getNIn()) });
layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerSize"), String.valueOf(ffl.getNOut()) });
activationFn = layer.getActivationFn().toString();
}
int nParams = layer.initializer().numParams(nnc);
layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerNParams"), String.valueOf(nParams) });
if (nParams > 0) {
WeightInit wi = layer.getWeightInit();
String str = wi.toString();
if (wi == WeightInit.DISTRIBUTION) {
str += layer.getDist();
}
layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerWeightInit"), str });
Updater u = layer.getUpdater();
String us = (u == null ? "" : u.toString());
layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerUpdater"), us });
//TODO: Maybe L1/L2, dropout, updater-specific values etc
}
if (layer instanceof ConvolutionLayer || layer instanceof SubsamplingLayer) {
int[] kernel;
int[] stride;
int[] padding;
if (layer instanceof ConvolutionLayer) {
ConvolutionLayer cl = (ConvolutionLayer) layer;
kernel = cl.getKernelSize();
stride = cl.getStride();
padding = cl.getPadding();
} else {
SubsamplingLayer ssl = (SubsamplingLayer) layer;
kernel = ssl.getKernelSize();
stride = ssl.getStride();
padding = ssl.getPadding();
activationFn = null;
layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerSubsamplingPoolingType"), ssl.getPoolingType().toString() });
}
layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerCnnKernel"), Arrays.toString(kernel) });
layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerCnnStride"), Arrays.toString(stride) });
layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerCnnPadding"), Arrays.toString(padding) });
}
if (activationFn != null) {
layerInfoRows.add(new String[] { i18N.getMessage("train.model.layerinfotable.layerActivationFn"), activationFn });
}
}
layerInfoRows.get(1)[1] = layerType;
}
}
return layerInfoRows.toArray(new String[layerInfoRows.size()][0]);
}
Aggregations