Search in sources :

Example 21 with StatsStorage

use of org.deeplearning4j.api.storage.StatsStorage in project deeplearning4j by deeplearning4j.

the class TestParallelEarlyStoppingUI method testParallelStatsListenerCompatibility.

@Test
//To be run manually
@Ignore
public void testParallelStatsListenerCompatibility() throws Exception {
    UIServer uiServer = UIServer.getInstance();
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).updater(Updater.SGD).weightInit(WeightInit.XAVIER).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(3).build()).layer(1, new OutputLayer.Builder().nIn(3).nOut(3).lossFunction(LossFunctions.LossFunction.MCXENT).build()).pretrain(false).backprop(true).build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    // it's important that the UI can report results from parallel training
    // there's potential for StatsListener to fail if certain properties aren't set in the model
    StatsStorage statsStorage = new InMemoryStatsStorage();
    net.setListeners(new StatsListener(statsStorage));
    uiServer.attach(statsStorage);
    DataSetIterator irisIter = new IrisDataSetIterator(50, 500);
    EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
    EarlyStoppingConfiguration<MultiLayerNetwork> esConf = new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>().epochTerminationConditions(new MaxEpochsTerminationCondition(500)).scoreCalculator(new DataSetLossCalculator(irisIter, true)).evaluateEveryNEpochs(2).modelSaver(saver).build();
    IEarlyStoppingTrainer<MultiLayerNetwork> trainer = new EarlyStoppingParallelTrainer<>(esConf, net, irisIter, null, 3, 6, 2);
    EarlyStoppingResult<MultiLayerNetwork> result = trainer.fit();
    System.out.println(result);
    assertEquals(EarlyStoppingResult.TerminationReason.EpochTerminationCondition, result.getTerminationReason());
}
Also used : InMemoryModelSaver(org.deeplearning4j.earlystopping.saver.InMemoryModelSaver) MaxEpochsTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition) InMemoryStatsStorage(org.deeplearning4j.ui.storage.InMemoryStatsStorage) StatsStorage(org.deeplearning4j.api.storage.StatsStorage) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) UIServer(org.deeplearning4j.ui.api.UIServer) StatsListener(org.deeplearning4j.ui.stats.StatsListener) EarlyStoppingConfiguration(org.deeplearning4j.earlystopping.EarlyStoppingConfiguration) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) InMemoryStatsStorage(org.deeplearning4j.ui.storage.InMemoryStatsStorage) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) DataSetLossCalculator(org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculator) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) Ignore(org.junit.Ignore) Test(org.junit.Test)

Example 22 with StatsStorage

use of org.deeplearning4j.api.storage.StatsStorage in project deeplearning4j by deeplearning4j.

the class TestListeners method testStatsCollection.

@Test
public void testStatsCollection() {
    JavaSparkContext sc = getContext();
    int nExecutors = numExecutors();
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(123).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(100).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build()).layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(100).nOut(3).activation(Activation.SOFTMAX).weightInit(WeightInit.XAVIER).build()).pretrain(false).backprop(true).build();
    MultiLayerNetwork network = new MultiLayerNetwork(conf);
    network.init();
    TrainingMaster tm = new ParameterAveragingTrainingMaster.Builder(1).batchSizePerWorker(5).averagingFrequency(6).build();
    SparkDl4jMultiLayer net = new SparkDl4jMultiLayer(sc, conf, tm);
    //In-memory
    StatsStorage ss = new MapDBStatsStorage();
    net.setListeners(ss, Collections.singletonList(new StatsListener(null)));
    List<DataSet> list = new IrisDataSetIterator(120, 150).next().asList();
    //120 examples, 4 executors, 30 examples per executor -> 6 updates of size 5 per executor
    JavaRDD<DataSet> rdd = sc.parallelize(list);
    net.fit(rdd);
    List<String> sessions = ss.listSessionIDs();
    System.out.println("Sessions: " + sessions);
    assertEquals(1, sessions.size());
    String sid = sessions.get(0);
    List<String> typeIDs = ss.listTypeIDsForSession(sid);
    List<String> workers = ss.listWorkerIDsForSession(sid);
    System.out.println(sid + "\t" + typeIDs + "\t" + workers);
    List<Persistable> lastUpdates = ss.getLatestUpdateAllWorkers(sid, StatsListener.TYPE_ID);
    System.out.println(lastUpdates);
    System.out.println("Static info:");
    for (String wid : workers) {
        Persistable staticInfo = ss.getStaticInfo(sid, StatsListener.TYPE_ID, wid);
        System.out.println(sid + "\t" + wid);
    }
    assertEquals(1, typeIDs.size());
    assertEquals(numExecutors(), workers.size());
    String firstWorker = workers.get(0);
    String firstWorkerSubstring = workers.get(0).substring(0, firstWorker.length() - 1);
    for (String wid : workers) {
        String widSubstring = wid.substring(0, wid.length() - 1);
        assertEquals(firstWorkerSubstring, widSubstring);
        String counterVal = wid.substring(wid.length() - 1, wid.length());
        int cv = Integer.parseInt(counterVal);
        assertTrue(0 <= cv && cv < numExecutors());
    }
}
Also used : Persistable(org.deeplearning4j.api.storage.Persistable) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSet(org.nd4j.linalg.dataset.DataSet) TrainingMaster(org.deeplearning4j.spark.api.TrainingMaster) ParameterAveragingTrainingMaster(org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) SparkDl4jMultiLayer(org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer) MapDBStatsStorage(org.deeplearning4j.ui.storage.mapdb.MapDBStatsStorage) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) MapDBStatsStorage(org.deeplearning4j.ui.storage.mapdb.MapDBStatsStorage) StatsStorage(org.deeplearning4j.api.storage.StatsStorage) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) ParameterAveragingTrainingMaster(org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster) StatsListener(org.deeplearning4j.ui.stats.StatsListener) BaseSparkTest(org.deeplearning4j.spark.BaseSparkTest) Test(org.junit.Test)

Aggregations

StatsStorage (org.deeplearning4j.api.storage.StatsStorage)22 Persistable (org.deeplearning4j.api.storage.Persistable)14 Test (org.junit.Test)10 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)8 IrisDataSetIterator (org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator)7 InMemoryStatsStorage (org.deeplearning4j.ui.storage.InMemoryStatsStorage)7 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)6 StatsListener (org.deeplearning4j.ui.stats.StatsListener)6 Ignore (org.junit.Ignore)6 UIServer (org.deeplearning4j.ui.api.UIServer)5 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)5 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)4 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)4 ScoreIterationListener (org.deeplearning4j.optimize.listeners.ScoreIterationListener)4 StatsInitializationReport (org.deeplearning4j.ui.stats.api.StatsInitializationReport)4 MapDBStatsStorage (org.deeplearning4j.ui.storage.mapdb.MapDBStatsStorage)4 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)3 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)3 StatsReport (org.deeplearning4j.ui.stats.api.StatsReport)3 File (java.io.File)2