use of org.deeplearning4j.api.storage.StatsStorage in project deeplearning4j by deeplearning4j.
the class PlayUIServer method detach.
@Override
public synchronized void detach(StatsStorage statsStorage) {
if (statsStorage == null)
throw new IllegalArgumentException("StatsStorage cannot be null");
if (!statsStorageInstances.contains(statsStorage))
//No op
return;
boolean found = false;
for (Pair<StatsStorage, StatsStorageListener> p : listeners) {
if (p.getFirst() == statsStorage) {
//Same object, not equality
statsStorage.deregisterStatsStorageListener(p.getSecond());
listeners.remove(p);
found = true;
}
}
for (UIModule uiModule : uiModules) {
uiModule.onDetach(statsStorage);
}
if (found) {
log.info("StatsStorage instance detached from UI: {}", statsStorage);
}
}
use of org.deeplearning4j.api.storage.StatsStorage in project deeplearning4j by deeplearning4j.
the class TestPlayUI method testUIMultipleSessions.
@Test
@Ignore
public void testUIMultipleSessions() throws Exception {
for (int session = 0; session < 3; session++) {
StatsStorage ss = new InMemoryStatsStorage();
UIServer uiServer = UIServer.getInstance();
uiServer.attach(ss);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).list().layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(4).build()).layer(1, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(4).nOut(3).build()).pretrain(false).backprop(true).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
net.setListeners(new StatsListener(ss), new ScoreIterationListener(1));
DataSetIterator iter = new IrisDataSetIterator(150, 150);
for (int i = 0; i < 20; i++) {
net.fit(iter);
Thread.sleep(100);
}
}
Thread.sleep(1000000);
}
use of org.deeplearning4j.api.storage.StatsStorage in project deeplearning4j by deeplearning4j.
the class TestPlayUI method testUI_RBM.
@Test
@Ignore
public void testUI_RBM() throws Exception {
//RBM - for unsupervised layerwise pretraining
StatsStorage ss = new InMemoryStatsStorage();
UIServer uiServer = UIServer.getInstance();
uiServer.attach(ss);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).learningRate(1e-5).list().layer(0, new RBM.Builder().nIn(4).nOut(3).build()).layer(1, new RBM.Builder().nIn(3).nOut(3).build()).layer(2, new OutputLayer.Builder().nIn(3).nOut(3).build()).pretrain(true).backprop(true).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
net.setListeners(new StatsListener(ss), new ScoreIterationListener(1));
DataSetIterator iter = new IrisDataSetIterator(150, 150);
for (int i = 0; i < 50; i++) {
net.fit(iter);
Thread.sleep(100);
}
Thread.sleep(100000);
}
use of org.deeplearning4j.api.storage.StatsStorage in project deeplearning4j by deeplearning4j.
the class FlowListenerModule method getStaticInfo.
private Result getStaticInfo(String sessionID) {
if (!knownSessionIDs.containsKey(sessionID))
return ok("Unknown session ID");
StatsStorage ss = knownSessionIDs.get(sessionID);
List<Persistable> list = ss.getAllStaticInfos(sessionID, TYPE_ID);
if (list == null || list.size() == 0)
return ok();
Persistable p = list.get(0);
if (!(p instanceof FlowStaticPersistable))
return ok();
FlowStaticPersistable f = (FlowStaticPersistable) p;
return ok(Json.toJson(f.getModelInfo()));
}
use of org.deeplearning4j.api.storage.StatsStorage in project deeplearning4j by deeplearning4j.
the class HistogramModule method processRequest.
private Result processRequest(String sessionId) {
//TODO cache the relevant info and update, rather than querying StatsStorage and building from scratch each time
StatsStorage ss = knownSessionIDs.get(sessionId);
if (ss == null) {
return Results.notFound("Unknown session ID: " + sessionId);
}
List<String> workerIDs = ss.listWorkerIDsForSession(sessionId);
//TODO checks
StatsInitializationReport initReport = (StatsInitializationReport) ss.getStaticInfo(sessionId, StatsListener.TYPE_ID, workerIDs.get(0));
if (initReport == null)
return Results.ok(Json.toJson(Collections.EMPTY_MAP));
String[] paramNames = initReport.getModelParamNames();
//Infer layer names from param names...
Set<String> layerNameSet = new LinkedHashSet<>();
for (String s : paramNames) {
String[] split = s.split("_");
if (!layerNameSet.contains(split[0])) {
layerNameSet.add(split[0]);
}
}
List<String> layerNameList = new ArrayList<>(layerNameSet);
List<Persistable> list = ss.getAllUpdatesAfter(sessionId, StatsListener.TYPE_ID, workerIDs.get(0), 0);
Collections.sort(list, (a, b) -> Long.compare(a.getTimeStamp(), b.getTimeStamp()));
List<Double> scoreList = new ArrayList<>(list.size());
//List.get(i) -> layer i. Maps: parameter for the given layer
List<Map<String, List<Double>>> meanMagHistoryParams = new ArrayList<>();
//List.get(i) -> layer i. Maps: updates for the given layer
List<Map<String, List<Double>>> meanMagHistoryUpdates = new ArrayList<>();
for (int i = 0; i < layerNameList.size(); i++) {
meanMagHistoryParams.add(new HashMap<>());
meanMagHistoryUpdates.add(new HashMap<>());
}
StatsReport last = null;
for (Persistable p : list) {
if (!(p instanceof StatsReport)) {
log.debug("Encountered unexpected type: {}", p);
continue;
}
StatsReport sp = (StatsReport) p;
scoreList.add(sp.getScore());
//Mean magnitudes
if (sp.hasSummaryStats(StatsType.Parameters, SummaryType.MeanMagnitudes)) {
updateMeanMagnitudeMaps(sp.getMeanMagnitudes(StatsType.Parameters), layerNameList, meanMagHistoryParams);
}
if (sp.hasSummaryStats(StatsType.Updates, SummaryType.MeanMagnitudes)) {
updateMeanMagnitudeMaps(sp.getMeanMagnitudes(StatsType.Updates), layerNameList, meanMagHistoryUpdates);
}
last = sp;
}
Map<String, Map> newParams = getHistogram(last.getHistograms(StatsType.Parameters));
Map<String, Map> newGrad = getHistogram(last.getHistograms(StatsType.Updates));
double lastScore = (scoreList.size() == 0 ? 0.0 : scoreList.get(scoreList.size() - 1));
CompactModelAndGradient g = new CompactModelAndGradient();
g.setGradients(newGrad);
g.setParameters(newParams);
g.setScore(lastScore);
g.setScores(scoreList);
// g.setPath(subPath);
g.setUpdateMagnitudes(meanMagHistoryUpdates);
g.setParamMagnitudes(meanMagHistoryParams);
// g.setLayerNames(layerNames);
g.setLastUpdateTime(last.getTimeStamp());
return Results.ok(Json.toJson(g));
}
Aggregations