Search in sources :

Example 16 with StatsStorage

use of org.deeplearning4j.api.storage.StatsStorage in project deeplearning4j by deeplearning4j.

the class TestPlayUI method testUICompGraph.

@Test
@Ignore
public void testUICompGraph() throws Exception {
    StatsStorage ss = new InMemoryStatsStorage();
    UIServer uiServer = UIServer.getInstance();
    uiServer.attach(ss);
    ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in").addLayer("L0", new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(4).build(), "in").addLayer("L1", new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(4).nOut(3).build(), "L0").pretrain(false).backprop(true).setOutputs("L1").build();
    ComputationGraph net = new ComputationGraph(conf);
    net.init();
    net.setListeners(new StatsListener(ss), new ScoreIterationListener(1));
    DataSetIterator iter = new IrisDataSetIterator(150, 150);
    for (int i = 0; i < 100; i++) {
        net.fit(iter);
        Thread.sleep(100);
    }
    Thread.sleep(100000);
}
Also used : InMemoryStatsStorage(org.deeplearning4j.ui.storage.InMemoryStatsStorage) StatsStorage(org.deeplearning4j.api.storage.StatsStorage) InMemoryStatsStorage(org.deeplearning4j.ui.storage.InMemoryStatsStorage) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) UIServer(org.deeplearning4j.ui.api.UIServer) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) StatsListener(org.deeplearning4j.ui.stats.StatsListener) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) Ignore(org.junit.Ignore) Test(org.junit.Test)

Example 17 with StatsStorage

use of org.deeplearning4j.api.storage.StatsStorage in project deeplearning4j by deeplearning4j.

the class TestPlayUI method testUI_VAE.

@Test
@Ignore
public void testUI_VAE() throws Exception {
    //Variational autoencoder - for unsupervised layerwise pretraining
    StatsStorage ss = new InMemoryStatsStorage();
    UIServer uiServer = UIServer.getInstance();
    uiServer.attach(ss);
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).learningRate(1e-5).list().layer(0, new VariationalAutoencoder.Builder().nIn(4).nOut(3).encoderLayerSizes(10, 11).decoderLayerSizes(12, 13).weightInit(WeightInit.XAVIER).pzxActivationFunction("identity").reconstructionDistribution(new GaussianReconstructionDistribution()).activation(Activation.LEAKYRELU).updater(Updater.SGD).build()).layer(1, new VariationalAutoencoder.Builder().nIn(3).nOut(3).encoderLayerSizes(7).decoderLayerSizes(8).weightInit(WeightInit.XAVIER).pzxActivationFunction("identity").reconstructionDistribution(new GaussianReconstructionDistribution()).activation(Activation.LEAKYRELU).updater(Updater.SGD).build()).layer(2, new OutputLayer.Builder().nIn(3).nOut(3).build()).pretrain(true).backprop(true).build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();
    net.setListeners(new StatsListener(ss), new ScoreIterationListener(1));
    DataSetIterator iter = new IrisDataSetIterator(150, 150);
    for (int i = 0; i < 50; i++) {
        net.fit(iter);
        Thread.sleep(100);
    }
    Thread.sleep(100000);
}
Also used : OutputLayer(org.deeplearning4j.nn.conf.layers.OutputLayer) InMemoryStatsStorage(org.deeplearning4j.ui.storage.InMemoryStatsStorage) StatsStorage(org.deeplearning4j.api.storage.StatsStorage) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) UIServer(org.deeplearning4j.ui.api.UIServer) VariationalAutoencoder(org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder) StatsListener(org.deeplearning4j.ui.stats.StatsListener) GaussianReconstructionDistribution(org.deeplearning4j.nn.conf.layers.variational.GaussianReconstructionDistribution) InMemoryStatsStorage(org.deeplearning4j.ui.storage.InMemoryStatsStorage) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) Ignore(org.junit.Ignore) Test(org.junit.Test)

Example 18 with StatsStorage

use of org.deeplearning4j.api.storage.StatsStorage in project deeplearning4j by deeplearning4j.

the class TestStatsListener method testListenerBasic.

@Test
public void testListenerBasic() {
    for (boolean useJ7 : new boolean[] { false, true }) {
        DataSet ds = new IrisDataSetIterator(150, 150).next();
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().iterations(1).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list().layer(0, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(4).nOut(3).build()).pretrain(false).backprop(true).build();
        MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.init();
        //in-memory
        StatsStorage ss = new MapDBStatsStorage();
        if (useJ7) {
            net.setListeners(new J7StatsListener(ss));
        } else {
            net.setListeners(new StatsListener(ss));
        }
        for (int i = 0; i < 3; i++) {
            net.fit(ds);
        }
        List<String> sids = ss.listSessionIDs();
        assertEquals(1, sids.size());
        String sessionID = ss.listSessionIDs().get(0);
        assertEquals(1, ss.listTypeIDsForSession(sessionID).size());
        String typeID = ss.listTypeIDsForSession(sessionID).get(0);
        assertEquals(1, ss.listWorkerIDsForSession(sessionID).size());
        String workerID = ss.listWorkerIDsForSession(sessionID).get(0);
        Persistable staticInfo = ss.getStaticInfo(sessionID, typeID, workerID);
        assertNotNull(staticInfo);
        System.out.println(staticInfo);
        List<Persistable> updates = ss.getAllUpdatesAfter(sessionID, typeID, workerID, 0);
        assertEquals(3, updates.size());
        for (Persistable p : updates) {
            System.out.println(p);
        }
    }
}
Also used : OutputLayer(org.deeplearning4j.nn.conf.layers.OutputLayer) MapDBStatsStorage(org.deeplearning4j.ui.storage.mapdb.MapDBStatsStorage) StatsStorage(org.deeplearning4j.api.storage.StatsStorage) Persistable(org.deeplearning4j.api.storage.Persistable) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSet(org.nd4j.linalg.dataset.DataSet) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) MapDBStatsStorage(org.deeplearning4j.ui.storage.mapdb.MapDBStatsStorage) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) Test(org.junit.Test)

Example 19 with StatsStorage

use of org.deeplearning4j.api.storage.StatsStorage in project deeplearning4j by deeplearning4j.

the class TestStatsStorage method testStatsStorage.

@Test
public void testStatsStorage() throws IOException {
    for (boolean useJ7Storage : new boolean[] { false, true }) {
        for (int i = 0; i < 3; i++) {
            StatsStorage ss;
            switch(i) {
                case 0:
                    File f = Files.createTempFile("TestMapDbStatsStore", ".db").toFile();
                    //Don't want file to exist...
                    f.delete();
                    ss = new MapDBStatsStorage.Builder().file(f).build();
                    break;
                case 1:
                    File f2 = Files.createTempFile("TestJ7FileStatsStore", ".db").toFile();
                    //Don't want file to exist...
                    f2.delete();
                    ss = new J7FileStatsStorage(f2);
                    break;
                case 2:
                    ss = new InMemoryStatsStorage();
                    break;
                default:
                    throw new RuntimeException();
            }
            CountingListener l = new CountingListener();
            ss.registerStatsStorageListener(l);
            assertEquals(1, ss.getListeners().size());
            assertEquals(0, ss.listSessionIDs().size());
            assertNull(ss.getLatestUpdate("sessionID", "typeID", "workerID"));
            assertEquals(0, ss.listSessionIDs().size());
            ss.putStaticInfo(getInitReport(0, 0, 0, useJ7Storage));
            assertEquals(1, l.countNewSession);
            assertEquals(1, l.countNewWorkerId);
            assertEquals(1, l.countStaticInfo);
            assertEquals(0, l.countUpdate);
            assertEquals(Collections.singletonList("sid0"), ss.listSessionIDs());
            assertTrue(ss.sessionExists("sid0"));
            assertFalse(ss.sessionExists("sid1"));
            Persistable expected = getInitReport(0, 0, 0, useJ7Storage);
            Persistable p = ss.getStaticInfo("sid0", "tid0", "wid0");
            assertEquals(expected, p);
            List<Persistable> allStatic = ss.getAllStaticInfos("sid0", "tid0");
            assertEquals(Collections.singletonList(expected), allStatic);
            assertNull(ss.getLatestUpdate("sid0", "tid0", "wid0"));
            assertEquals(Collections.singletonList("tid0"), ss.listTypeIDsForSession("sid0"));
            assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSession("sid0"));
            assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSessionAndType("sid0", "tid0"));
            assertEquals(0, ss.getAllUpdatesAfter("sid0", "tid0", "wid0", 0).size());
            assertEquals(0, ss.getNumUpdateRecordsFor("sid0"));
            assertEquals(0, ss.getNumUpdateRecordsFor("sid0", "tid0", "wid0"));
            //Put first update
            ss.putUpdate(getReport(0, 0, 0, 12345, useJ7Storage));
            assertEquals(1, ss.getNumUpdateRecordsFor("sid0"));
            assertEquals(getReport(0, 0, 0, 12345, useJ7Storage), ss.getLatestUpdate("sid0", "tid0", "wid0"));
            assertEquals(Collections.singletonList("tid0"), ss.listTypeIDsForSession("sid0"));
            assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSession("sid0"));
            assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSessionAndType("sid0", "tid0"));
            assertEquals(Collections.singletonList(getReport(0, 0, 0, 12345, useJ7Storage)), ss.getAllUpdatesAfter("sid0", "tid0", "wid0", 0));
            assertEquals(1, ss.getNumUpdateRecordsFor("sid0"));
            assertEquals(1, ss.getNumUpdateRecordsFor("sid0", "tid0", "wid0"));
            List<Persistable> list = ss.getLatestUpdateAllWorkers("sid0", "tid0");
            assertEquals(1, list.size());
            assertEquals(getReport(0, 0, 0, 12345, useJ7Storage), ss.getUpdate("sid0", "tid0", "wid0", 12345));
            assertEquals(1, l.countNewSession);
            assertEquals(1, l.countNewWorkerId);
            assertEquals(1, l.countStaticInfo);
            assertEquals(1, l.countUpdate);
            //Put second update
            ss.putUpdate(getReport(0, 0, 0, 12346, useJ7Storage));
            assertEquals(1, ss.getLatestUpdateAllWorkers("sid0", "tid0").size());
            assertEquals(Collections.singletonList("tid0"), ss.listTypeIDsForSession("sid0"));
            assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSession("sid0"));
            assertEquals(Collections.singletonList("wid0"), ss.listWorkerIDsForSessionAndType("sid0", "tid0"));
            assertEquals(getReport(0, 0, 0, 12346, useJ7Storage), ss.getLatestUpdate("sid0", "tid0", "wid0"));
            assertEquals(getReport(0, 0, 0, 12346, useJ7Storage), ss.getUpdate("sid0", "tid0", "wid0", 12346));
            ss.putUpdate(getReport(0, 0, 1, 12345, useJ7Storage));
            assertEquals(2, ss.getLatestUpdateAllWorkers("sid0", "tid0").size());
            assertEquals(getReport(0, 0, 1, 12345, useJ7Storage), ss.getLatestUpdate("sid0", "tid0", "wid1"));
            assertEquals(getReport(0, 0, 1, 12345, useJ7Storage), ss.getUpdate("sid0", "tid0", "wid1", 12345));
            assertEquals(1, l.countNewSession);
            assertEquals(2, l.countNewWorkerId);
            assertEquals(1, l.countStaticInfo);
            assertEquals(3, l.countUpdate);
            //Put static info and update with different session, type and worker IDs
            ss.putStaticInfo(getInitReport(100, 200, 300, useJ7Storage));
            assertEquals(2, ss.getLatestUpdateAllWorkers("sid0", "tid0").size());
            ss.putUpdate(getReport(100, 200, 300, 12346, useJ7Storage));
            assertEquals(Collections.singletonList(getReport(100, 200, 300, 12346, useJ7Storage)), ss.getLatestUpdateAllWorkers("sid100", "tid200"));
            assertEquals(Collections.singletonList("tid200"), ss.listTypeIDsForSession("sid100"));
            List<String> temp = ss.listWorkerIDsForSession("sid100");
            System.out.println("temp: " + temp);
            assertEquals(Collections.singletonList("wid300"), ss.listWorkerIDsForSession("sid100"));
            assertEquals(Collections.singletonList("wid300"), ss.listWorkerIDsForSessionAndType("sid100", "tid200"));
            assertEquals(getReport(100, 200, 300, 12346, useJ7Storage), ss.getLatestUpdate("sid100", "tid200", "wid300"));
            assertEquals(getReport(100, 200, 300, 12346, useJ7Storage), ss.getUpdate("sid100", "tid200", "wid300", 12346));
            assertEquals(2, l.countNewSession);
            assertEquals(3, l.countNewWorkerId);
            assertEquals(2, l.countStaticInfo);
            assertEquals(4, l.countUpdate);
        }
    }
}
Also used : MapDBStatsStorage(org.deeplearning4j.ui.storage.mapdb.MapDBStatsStorage) StatsStorage(org.deeplearning4j.api.storage.StatsStorage) J7FileStatsStorage(org.deeplearning4j.ui.storage.sqlite.J7FileStatsStorage) Persistable(org.deeplearning4j.api.storage.Persistable) J7FileStatsStorage(org.deeplearning4j.ui.storage.sqlite.J7FileStatsStorage) MapDBStatsStorage(org.deeplearning4j.ui.storage.mapdb.MapDBStatsStorage) File(java.io.File) Test(org.junit.Test)

Example 20 with StatsStorage

use of org.deeplearning4j.api.storage.StatsStorage in project deeplearning4j by deeplearning4j.

the class FlowListenerModule method getUpdate.

private Result getUpdate(String sessionID) {
    if (!knownSessionIDs.containsKey(sessionID))
        return ok("Unknown session ID");
    StatsStorage ss = knownSessionIDs.get(sessionID);
    List<Persistable> list = ss.getLatestUpdateAllWorkers(sessionID, TYPE_ID);
    if (list == null || list.size() == 0)
        return ok();
    Persistable p = list.get(0);
    if (!(p instanceof FlowUpdatePersistable))
        return ok();
    FlowUpdatePersistable f = (FlowUpdatePersistable) p;
    return ok(Json.toJson(f.getModelState()));
}
Also used : StatsStorage(org.deeplearning4j.api.storage.StatsStorage) FlowUpdatePersistable(org.deeplearning4j.ui.flow.data.FlowUpdatePersistable) FlowStaticPersistable(org.deeplearning4j.ui.flow.data.FlowStaticPersistable) Persistable(org.deeplearning4j.api.storage.Persistable) FlowUpdatePersistable(org.deeplearning4j.ui.flow.data.FlowUpdatePersistable)

Aggregations

StatsStorage (org.deeplearning4j.api.storage.StatsStorage)22 Persistable (org.deeplearning4j.api.storage.Persistable)14 Test (org.junit.Test)10 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)8 IrisDataSetIterator (org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator)7 InMemoryStatsStorage (org.deeplearning4j.ui.storage.InMemoryStatsStorage)7 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)6 StatsListener (org.deeplearning4j.ui.stats.StatsListener)6 Ignore (org.junit.Ignore)6 UIServer (org.deeplearning4j.ui.api.UIServer)5 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)5 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)4 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)4 ScoreIterationListener (org.deeplearning4j.optimize.listeners.ScoreIterationListener)4 StatsInitializationReport (org.deeplearning4j.ui.stats.api.StatsInitializationReport)4 MapDBStatsStorage (org.deeplearning4j.ui.storage.mapdb.MapDBStatsStorage)4 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)3 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)3 StatsReport (org.deeplearning4j.ui.stats.api.StatsReport)3 File (java.io.File)2