Search in sources :

Example 1 with RemoteUIStatsStorageRouter

use of org.deeplearning4j.api.storage.impl.RemoteUIStatsStorageRouter in project deeplearning4j by deeplearning4j.

the class ParallelWrapperMain method runMain.

public void runMain(String... args) throws Exception {
    JCommander jcmdr = new JCommander(this);
    try {
        jcmdr.parse(args);
    } catch (ParameterException e) {
        System.err.println(e.getMessage());
        //User provides invalid input -> print the usage info
        jcmdr.usage();
        try {
            Thread.sleep(500);
        } catch (Exception e2) {
        }
        System.exit(1);
    }
    Model model = ModelGuesser.loadModelGuess(modelPath);
    // ParallelWrapper will take care of load balancing between GPUs.
    ParallelWrapper wrapper = new ParallelWrapper.Builder(model).prefetchBuffer(prefetchSize).workers(workers).averagingFrequency(averagingFrequency).averageUpdaters(averageUpdaters).reportScoreAfterAveraging(reportScore).useLegacyAveraging(legacyAveraging).build();
    if (dataSetIteratorFactoryClazz != null) {
        DataSetIteratorProviderFactory dataSetIteratorProviderFactory = (DataSetIteratorProviderFactory) Class.forName(dataSetIteratorFactoryClazz).newInstance();
        DataSetIterator dataSetIterator = dataSetIteratorProviderFactory.create();
        if (uiUrl != null) {
            // it's important that the UI can report results from parallel training
            // there's potential for StatsListener to fail if certain properties aren't set in the model
            StatsStorageRouter remoteUIRouter = new RemoteUIStatsStorageRouter("http://" + uiUrl);
            wrapper.setListeners(remoteUIRouter, new StatsListener(null));
        }
        wrapper.fit(dataSetIterator);
        ModelSerializer.writeModel(model, new File(modelOutputPath), true);
    } else if (multiDataSetIteratorFactoryClazz != null) {
        MultiDataSetProviderFactory multiDataSetProviderFactory = (MultiDataSetProviderFactory) Class.forName(multiDataSetIteratorFactoryClazz).newInstance();
        MultiDataSetIterator iterator = multiDataSetProviderFactory.create();
        if (uiUrl != null) {
            // it's important that the UI can report results from parallel training
            // there's potential for StatsListener to fail if certain properties aren't set in the model
            StatsStorageRouter remoteUIRouter = new RemoteUIStatsStorageRouter("http://" + uiUrl);
            wrapper.setListeners(remoteUIRouter, new StatsListener(null));
        }
        wrapper.fit(iterator);
        ModelSerializer.writeModel(model, new File(modelOutputPath), true);
    } else {
        throw new IllegalStateException("Please provide a datasetiteraator or multi datasetiterator class");
    }
}
Also used : ParallelWrapper(org.deeplearning4j.parallelism.ParallelWrapper) StatsStorageRouter(org.deeplearning4j.api.storage.StatsStorageRouter) RemoteUIStatsStorageRouter(org.deeplearning4j.api.storage.impl.RemoteUIStatsStorageRouter) StatsListener(org.deeplearning4j.ui.stats.StatsListener) ParameterException(com.beust.jcommander.ParameterException) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) JCommander(com.beust.jcommander.JCommander) Model(org.deeplearning4j.nn.api.Model) RemoteUIStatsStorageRouter(org.deeplearning4j.api.storage.impl.RemoteUIStatsStorageRouter) ParameterException(com.beust.jcommander.ParameterException) File(java.io.File) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator)

Example 2 with RemoteUIStatsStorageRouter

use of org.deeplearning4j.api.storage.impl.RemoteUIStatsStorageRouter in project deeplearning4j by deeplearning4j.

the class TestRemoteReceiver method testRemoteFull.

@Test
@Ignore
public void testRemoteFull() throws Exception {
    //Use this in conjunction with startRemoteUI()
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).list().layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(4).build()).layer(1, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(4).nOut(3).build()).pretrain(false).backprop(true).build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();
    StatsStorageRouter ssr = new RemoteUIStatsStorageRouter("http://localhost:9000");
    net.setListeners(new StatsListener(ssr), new ScoreIterationListener(1));
    DataSetIterator iter = new IrisDataSetIterator(150, 150);
    for (int i = 0; i < 500; i++) {
        net.fit(iter);
        //            Thread.sleep(100);
        Thread.sleep(100);
    }
}
Also used : OutputLayer(org.deeplearning4j.nn.conf.layers.OutputLayer) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) StatsStorageRouter(org.deeplearning4j.api.storage.StatsStorageRouter) RemoteUIStatsStorageRouter(org.deeplearning4j.api.storage.impl.RemoteUIStatsStorageRouter) CollectionStatsStorageRouter(org.deeplearning4j.api.storage.impl.CollectionStatsStorageRouter) StatsListener(org.deeplearning4j.ui.stats.StatsListener) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) DenseLayer(org.deeplearning4j.nn.conf.layers.DenseLayer) RemoteUIStatsStorageRouter(org.deeplearning4j.api.storage.impl.RemoteUIStatsStorageRouter) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) Ignore(org.junit.Ignore) Test(org.junit.Test)

Example 3 with RemoteUIStatsStorageRouter

use of org.deeplearning4j.api.storage.impl.RemoteUIStatsStorageRouter in project deeplearning4j by deeplearning4j.

the class TestRemoteReceiver method testRemoteBasic.

@Test
@Ignore
public void testRemoteBasic() throws Exception {
    List<Persistable> updates = new ArrayList<>();
    List<Persistable> staticInfo = new ArrayList<>();
    List<StorageMetaData> metaData = new ArrayList<>();
    CollectionStatsStorageRouter collectionRouter = new CollectionStatsStorageRouter(metaData, staticInfo, updates);
    UIServer s = UIServer.getInstance();
    s.enableRemoteListener(collectionRouter, false);
    RemoteUIStatsStorageRouter remoteRouter = new RemoteUIStatsStorageRouter("http://localhost:9000");
    SbeStatsReport update1 = new SbeStatsReport();
    update1.setDeviceCurrentBytes(new long[] { 1, 2 });
    update1.reportIterationCount(10);
    update1.reportIDs("sid", "tid", "wid", 123456);
    update1.reportPerformance(10, 20, 30, 40, 50);
    SbeStatsReport update2 = new SbeStatsReport();
    update2.setDeviceCurrentBytes(new long[] { 3, 4 });
    update2.reportIterationCount(20);
    update2.reportIDs("sid2", "tid2", "wid2", 123456);
    update2.reportPerformance(11, 21, 31, 40, 50);
    StorageMetaData smd1 = new SbeStorageMetaData(123, "sid", "typeid", "wid", "initTypeClass", "updaterTypeClass");
    StorageMetaData smd2 = new SbeStorageMetaData(456, "sid2", "typeid2", "wid2", "initTypeClass2", "updaterTypeClass2");
    SbeStatsInitializationReport init1 = new SbeStatsInitializationReport();
    init1.reportIDs("sid", "wid", "tid", 3145253452L);
    init1.reportHardwareInfo(1, 2, 3, 4, null, null, "2344253");
    remoteRouter.putUpdate(update1);
    Thread.sleep(100);
    remoteRouter.putStorageMetaData(smd1);
    Thread.sleep(100);
    remoteRouter.putStaticInfo(init1);
    Thread.sleep(100);
    remoteRouter.putUpdate(update2);
    Thread.sleep(100);
    remoteRouter.putStorageMetaData(smd2);
    Thread.sleep(2000);
    assertEquals(2, metaData.size());
    assertEquals(2, updates.size());
    assertEquals(1, staticInfo.size());
    assertEquals(Arrays.asList(update1, update2), updates);
    assertEquals(Arrays.asList(smd1, smd2), metaData);
    assertEquals(Collections.singletonList(init1), staticInfo);
}
Also used : StorageMetaData(org.deeplearning4j.api.storage.StorageMetaData) SbeStorageMetaData(org.deeplearning4j.ui.storage.impl.SbeStorageMetaData) Persistable(org.deeplearning4j.api.storage.Persistable) SbeStatsInitializationReport(org.deeplearning4j.ui.stats.impl.SbeStatsInitializationReport) SbeStatsReport(org.deeplearning4j.ui.stats.impl.SbeStatsReport) UIServer(org.deeplearning4j.ui.api.UIServer) ArrayList(java.util.ArrayList) CollectionStatsStorageRouter(org.deeplearning4j.api.storage.impl.CollectionStatsStorageRouter) RemoteUIStatsStorageRouter(org.deeplearning4j.api.storage.impl.RemoteUIStatsStorageRouter) SbeStorageMetaData(org.deeplearning4j.ui.storage.impl.SbeStorageMetaData) Ignore(org.junit.Ignore) Test(org.junit.Test)

Aggregations

RemoteUIStatsStorageRouter (org.deeplearning4j.api.storage.impl.RemoteUIStatsStorageRouter)3 StatsStorageRouter (org.deeplearning4j.api.storage.StatsStorageRouter)2 CollectionStatsStorageRouter (org.deeplearning4j.api.storage.impl.CollectionStatsStorageRouter)2 StatsListener (org.deeplearning4j.ui.stats.StatsListener)2 Ignore (org.junit.Ignore)2 Test (org.junit.Test)2 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)2 JCommander (com.beust.jcommander.JCommander)1 ParameterException (com.beust.jcommander.ParameterException)1 File (java.io.File)1 ArrayList (java.util.ArrayList)1 Persistable (org.deeplearning4j.api.storage.Persistable)1 StorageMetaData (org.deeplearning4j.api.storage.StorageMetaData)1 IrisDataSetIterator (org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator)1 Model (org.deeplearning4j.nn.api.Model)1 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)1 DenseLayer (org.deeplearning4j.nn.conf.layers.DenseLayer)1 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)1 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)1 ScoreIterationListener (org.deeplearning4j.optimize.listeners.ScoreIterationListener)1