use of org.deeplearning4j.api.storage.StatsStorage in project deeplearning4j by deeplearning4j.
the class TestPlayUI method testUICompGraph.
@Test
@Ignore
public void testUICompGraph() throws Exception {
StatsStorage ss = new InMemoryStatsStorage();
UIServer uiServer = UIServer.getInstance();
uiServer.attach(ss);
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in").addLayer("L0", new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(4).build(), "in").addLayer("L1", new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(4).nOut(3).build(), "L0").pretrain(false).backprop(true).setOutputs("L1").build();
ComputationGraph net = new ComputationGraph(conf);
net.init();
net.setListeners(new StatsListener(ss), new ScoreIterationListener(1));
DataSetIterator iter = new IrisDataSetIterator(150, 150);
for (int i = 0; i < 100; i++) {
net.fit(iter);
Thread.sleep(100);
}
Thread.sleep(100000);
}
use of org.deeplearning4j.api.storage.StatsStorage in project deeplearning4j by deeplearning4j.
the class TestPlayUI method testUI_VAE.
@Test
@Ignore
public void testUI_VAE() throws Exception {
//Variational autoencoder - for unsupervised layerwise pretraining
StatsStorage ss = new InMemoryStatsStorage();
UIServer uiServer = UIServer.getInstance();
uiServer.attach(ss);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).learningRate(1e-5).list().layer(0, new VariationalAutoencoder.Builder().nIn(4).nOut(3).encoderLayerSizes(10, 11).decoderLayerSizes(12, 13).weightInit(WeightInit.XAVIER).pzxActivationFunction("identity").reconstructionDistribution(new GaussianReconstructionDistribution()).activation(Activation.LEAKYRELU).updater(Updater.SGD).build()).layer(1, new VariationalAutoencoder.Builder().nIn(3).nOut(3).encoderLayerSizes(7).decoderLayerSizes(8).weightInit(WeightInit.XAVIER).pzxActivationFunction("identity").reconstructionDistribution(new GaussianReconstructionDistribution()).activation(Activation.LEAKYRELU).updater(Updater.SGD).build()).layer(2, new OutputLayer.Builder().nIn(3).nOut(3).build()).pretrain(true).backprop(true).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
net.setListeners(new StatsListener(ss), new ScoreIterationListener(1));
DataSetIterator iter = new IrisDataSetIterator(150, 150);
for (int i = 0; i < 50; i++) {
net.fit(iter);
Thread.sleep(100);
}
Thread.sleep(100000);
}
use of org.deeplearning4j.api.storage.StatsStorage in project deeplearning4j by deeplearning4j.
the class TestStatsListener method testListenerBasic.
@Test
public void testListenerBasic() {
for (boolean useJ7 : new boolean[] { false, true }) {
DataSet ds = new IrisDataSetIterator(150, 150).next();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().iterations(1).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list().layer(0, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(4).nOut(3).build()).pretrain(false).backprop(true).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
//in-memory
StatsStorage ss = new MapDBStatsStorage();
if (useJ7) {
net.setListeners(new J7StatsListener(ss));
} else {
net.setListeners(new StatsListener(ss));
}
for (int i = 0; i < 3; i++) {
net.fit(ds);
}
List<String> sids = ss.listSessionIDs();
assertEquals(1, sids.size());
String sessionID = ss.listSessionIDs().get(0);
assertEquals(1, ss.listTypeIDsForSession(sessionID).size());
String typeID = ss.listTypeIDsForSession(sessionID).get(0);
assertEquals(1, ss.listWorkerIDsForSession(sessionID).size());
String workerID = ss.listWorkerIDsForSession(sessionID).get(0);
Persistable staticInfo = ss.getStaticInfo(sessionID, typeID, workerID);
assertNotNull(staticInfo);
System.out.println(staticInfo);
List<Persistable> updates = ss.getAllUpdatesAfter(sessionID, typeID, workerID, 0);
assertEquals(3, updates.size());
for (Persistable p : updates) {
System.out.println(p);
}
}
}
use of org.deeplearning4j.api.storage.StatsStorage in project deeplearning4j by deeplearning4j.
the class TestStatsStorage method testStatsStorage.
@Test
public void testStatsStorage() throws IOException {
for (boolean useJ7Storage : new boolean[] { false, true }) {
for (int i = 0; i < 3; i++) {
StatsStorage ss;
switch(i) {
case 0:
File f = Files.createTempFile("TestMapDbStatsStore", ".db").toFile();
//Don't want file to exist...
f.delete();
ss = new MapDBStatsStorage.Builder().file(f).build();
break;
case 1:
File f2 = Files.createTempFile("TestJ7FileStatsStore", ".db").toFile();
//Don't want file to exist...
f2.delete();
ss = new J7FileStatsStorage(f2);
break;
case 2:
ss = new InMemoryStatsStorage();
break;
default:
throw new RuntimeException();
}
CountingListener l = new CountingListener();
ss.registerStatsStorageListener(l);
assertEquals(1, ss.getListeners().size());
assertEquals(0, ss.listSessionIDs().size());
assertNull(ss.getLatestUpdate("sessionID", "typeID", "workerID"));
assertEquals(0, ss.listSessionIDs().size());
ss.putStaticInfo(getInitReport(0, 0, 0, useJ7Storage));
assertEquals(1, l.countNewSession);
assertEquals(1, l.countNewWorkerId);
assertEquals(1, l.countStaticInfo);
assertEquals(0, l.countUpdate);
assertEquals(Collections.singletonList("sid0"), ss.listSessionIDs());
assertTrue(ss.sessionExists("sid0"));
assertFalse(ss.sessionExists("sid1"));
Persistable expected = getInitReport(0, 0, 0, useJ7Storage);
Persistable p = ss.getStaticInfo("sid0", "tid0", "wid0");
assertEquals(expected, p);
List<Persistable> allStatic = ss.getAllStaticInfos("sid0", "tid0");
assertEquals(Collections.singletonList(expected), allStatic);
assertNull(ss.getLatestUpdate("sid0", "tid0", "wid0"));
assertEquals(Collections.singletonList("tid0"), ss.listTypeIDsForSession("sid0"));
assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSession("sid0"));
assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSessionAndType("sid0", "tid0"));
assertEquals(0, ss.getAllUpdatesAfter("sid0", "tid0", "wid0", 0).size());
assertEquals(0, ss.getNumUpdateRecordsFor("sid0"));
assertEquals(0, ss.getNumUpdateRecordsFor("sid0", "tid0", "wid0"));
//Put first update
ss.putUpdate(getReport(0, 0, 0, 12345, useJ7Storage));
assertEquals(1, ss.getNumUpdateRecordsFor("sid0"));
assertEquals(getReport(0, 0, 0, 12345, useJ7Storage), ss.getLatestUpdate("sid0", "tid0", "wid0"));
assertEquals(Collections.singletonList("tid0"), ss.listTypeIDsForSession("sid0"));
assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSession("sid0"));
assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSessionAndType("sid0", "tid0"));
assertEquals(Collections.singletonList(getReport(0, 0, 0, 12345, useJ7Storage)), ss.getAllUpdatesAfter("sid0", "tid0", "wid0", 0));
assertEquals(1, ss.getNumUpdateRecordsFor("sid0"));
assertEquals(1, ss.getNumUpdateRecordsFor("sid0", "tid0", "wid0"));
List<Persistable> list = ss.getLatestUpdateAllWorkers("sid0", "tid0");
assertEquals(1, list.size());
assertEquals(getReport(0, 0, 0, 12345, useJ7Storage), ss.getUpdate("sid0", "tid0", "wid0", 12345));
assertEquals(1, l.countNewSession);
assertEquals(1, l.countNewWorkerId);
assertEquals(1, l.countStaticInfo);
assertEquals(1, l.countUpdate);
//Put second update
ss.putUpdate(getReport(0, 0, 0, 12346, useJ7Storage));
assertEquals(1, ss.getLatestUpdateAllWorkers("sid0", "tid0").size());
assertEquals(Collections.singletonList("tid0"), ss.listTypeIDsForSession("sid0"));
assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSession("sid0"));
assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSessionAndType("sid0", "tid0"));
assertEquals(getReport(0, 0, 0, 12346, useJ7Storage), ss.getLatestUpdate("sid0", "tid0", "wid0"));
assertEquals(getReport(0, 0, 0, 12346, useJ7Storage), ss.getUpdate("sid0", "tid0", "wid0", 12346));
ss.putUpdate(getReport(0, 0, 1, 12345, useJ7Storage));
assertEquals(2, ss.getLatestUpdateAllWorkers("sid0", "tid0").size());
assertEquals(getReport(0, 0, 1, 12345, useJ7Storage), ss.getLatestUpdate("sid0", "tid0", "wid1"));
assertEquals(getReport(0, 0, 1, 12345, useJ7Storage), ss.getUpdate("sid0", "tid0", "wid1", 12345));
assertEquals(1, l.countNewSession);
assertEquals(2, l.countNewWorkerId);
assertEquals(1, l.countStaticInfo);
assertEquals(3, l.countUpdate);
//Put static info and update with different session, type and worker IDs
ss.putStaticInfo(getInitReport(100, 200, 300, useJ7Storage));
assertEquals(2, ss.getLatestUpdateAllWorkers("sid0", "tid0").size());
ss.putUpdate(getReport(100, 200, 300, 12346, useJ7Storage));
assertEquals(Collections.singletonList(getReport(100, 200, 300, 12346, useJ7Storage)), ss.getLatestUpdateAllWorkers("sid100", "tid200"));
assertEquals(Collections.singletonList("tid200"), ss.listTypeIDsForSession("sid100"));
List<String> temp = ss.listWorkerIDsForSession("sid100");
System.out.println("temp: " + temp);
assertEquals(Collections.singletonList("wid300"), ss.listWorkerIDsForSession("sid100"));
assertEquals(Collections.singletonList("wid300"), ss.listWorkerIDsForSessionAndType("sid100", "tid200"));
assertEquals(getReport(100, 200, 300, 12346, useJ7Storage), ss.getLatestUpdate("sid100", "tid200", "wid300"));
assertEquals(getReport(100, 200, 300, 12346, useJ7Storage), ss.getUpdate("sid100", "tid200", "wid300", 12346));
assertEquals(2, l.countNewSession);
assertEquals(3, l.countNewWorkerId);
assertEquals(2, l.countStaticInfo);
assertEquals(4, l.countUpdate);
}
}
}
use of org.deeplearning4j.api.storage.StatsStorage in project deeplearning4j by deeplearning4j.
the class FlowListenerModule method getUpdate.
private Result getUpdate(String sessionID) {
if (!knownSessionIDs.containsKey(sessionID))
return ok("Unknown session ID");
StatsStorage ss = knownSessionIDs.get(sessionID);
List<Persistable> list = ss.getLatestUpdateAllWorkers(sessionID, TYPE_ID);
if (list == null || list.size() == 0)
return ok();
Persistable p = list.get(0);
if (!(p instanceof FlowUpdatePersistable))
return ok();
FlowUpdatePersistable f = (FlowUpdatePersistable) p;
return ok(Json.toJson(f.getModelState()));
}
Aggregations