Search in sources :

Example 1 with StatsStorageRouter

use of org.deeplearning4j.api.storage.StatsStorageRouter in project deeplearning4j by deeplearning4j.

the class ParallelWrapperMain method runMain.

public void runMain(String... args) throws Exception {
    JCommander jcmdr = new JCommander(this);
    try {
        jcmdr.parse(args);
    } catch (ParameterException e) {
        System.err.println(e.getMessage());
        //User provides invalid input -> print the usage info
        jcmdr.usage();
        try {
            Thread.sleep(500);
        } catch (Exception e2) {
        }
        System.exit(1);
    }
    Model model = ModelGuesser.loadModelGuess(modelPath);
    // ParallelWrapper will take care of load balancing between GPUs.
    ParallelWrapper wrapper = new ParallelWrapper.Builder(model).prefetchBuffer(prefetchSize).workers(workers).averagingFrequency(averagingFrequency).averageUpdaters(averageUpdaters).reportScoreAfterAveraging(reportScore).useLegacyAveraging(legacyAveraging).build();
    if (dataSetIteratorFactoryClazz != null) {
        DataSetIteratorProviderFactory dataSetIteratorProviderFactory = (DataSetIteratorProviderFactory) Class.forName(dataSetIteratorFactoryClazz).newInstance();
        DataSetIterator dataSetIterator = dataSetIteratorProviderFactory.create();
        if (uiUrl != null) {
            // it's important that the UI can report results from parallel training
            // there's potential for StatsListener to fail if certain properties aren't set in the model
            StatsStorageRouter remoteUIRouter = new RemoteUIStatsStorageRouter("http://" + uiUrl);
            wrapper.setListeners(remoteUIRouter, new StatsListener(null));
        }
        wrapper.fit(dataSetIterator);
        ModelSerializer.writeModel(model, new File(modelOutputPath), true);
    } else if (multiDataSetIteratorFactoryClazz != null) {
        MultiDataSetProviderFactory multiDataSetProviderFactory = (MultiDataSetProviderFactory) Class.forName(multiDataSetIteratorFactoryClazz).newInstance();
        MultiDataSetIterator iterator = multiDataSetProviderFactory.create();
        if (uiUrl != null) {
            // it's important that the UI can report results from parallel training
            // there's potential for StatsListener to fail if certain properties aren't set in the model
            StatsStorageRouter remoteUIRouter = new RemoteUIStatsStorageRouter("http://" + uiUrl);
            wrapper.setListeners(remoteUIRouter, new StatsListener(null));
        }
        wrapper.fit(iterator);
        ModelSerializer.writeModel(model, new File(modelOutputPath), true);
    } else {
        throw new IllegalStateException("Please provide a datasetiteraator or multi datasetiterator class");
    }
}
Also used : ParallelWrapper(org.deeplearning4j.parallelism.ParallelWrapper) StatsStorageRouter(org.deeplearning4j.api.storage.StatsStorageRouter) RemoteUIStatsStorageRouter(org.deeplearning4j.api.storage.impl.RemoteUIStatsStorageRouter) StatsListener(org.deeplearning4j.ui.stats.StatsListener) ParameterException(com.beust.jcommander.ParameterException) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) JCommander(com.beust.jcommander.JCommander) Model(org.deeplearning4j.nn.api.Model) RemoteUIStatsStorageRouter(org.deeplearning4j.api.storage.impl.RemoteUIStatsStorageRouter) ParameterException(com.beust.jcommander.ParameterException) File(java.io.File) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator)

Example 2 with StatsStorageRouter

use of org.deeplearning4j.api.storage.StatsStorageRouter in project deeplearning4j by deeplearning4j.

the class TestRemoteReceiver method testRemoteFull.

@Test
@Ignore
public void testRemoteFull() throws Exception {
    //Use this in conjunction with startRemoteUI()
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).list().layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(4).build()).layer(1, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(4).nOut(3).build()).pretrain(false).backprop(true).build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();
    StatsStorageRouter ssr = new RemoteUIStatsStorageRouter("http://localhost:9000");
    net.setListeners(new StatsListener(ssr), new ScoreIterationListener(1));
    DataSetIterator iter = new IrisDataSetIterator(150, 150);
    for (int i = 0; i < 500; i++) {
        net.fit(iter);
        //            Thread.sleep(100);
        Thread.sleep(100);
    }
}
Also used : OutputLayer(org.deeplearning4j.nn.conf.layers.OutputLayer) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) StatsStorageRouter(org.deeplearning4j.api.storage.StatsStorageRouter) RemoteUIStatsStorageRouter(org.deeplearning4j.api.storage.impl.RemoteUIStatsStorageRouter) CollectionStatsStorageRouter(org.deeplearning4j.api.storage.impl.CollectionStatsStorageRouter) StatsListener(org.deeplearning4j.ui.stats.StatsListener) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) DenseLayer(org.deeplearning4j.nn.conf.layers.DenseLayer) RemoteUIStatsStorageRouter(org.deeplearning4j.api.storage.impl.RemoteUIStatsStorageRouter) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) Ignore(org.junit.Ignore) Test(org.junit.Test)

Example 3 with StatsStorageRouter

use of org.deeplearning4j.api.storage.StatsStorageRouter in project deeplearning4j by deeplearning4j.

the class ParameterAveragingTrainingWorker method getFinalResult.

@Override
public ParameterAveragingTrainingResult getFinalResult(ComputationGraph network) {
    INDArray updaterState = null;
    if (saveUpdater) {
        ComputationGraphUpdater u = network.getUpdater();
        if (u != null)
            updaterState = u.getStateViewArray();
    }
    if (Nd4j.getExecutioner() instanceof GridExecutioner)
        ((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
    Collection<StorageMetaData> storageMetaData = null;
    Collection<Persistable> listenerStaticInfo = null;
    Collection<Persistable> listenerUpdates = null;
    if (listenerRouterProvider != null) {
        StatsStorageRouter r = listenerRouterProvider.getRouter();
        if (r instanceof VanillaStatsStorageRouter) {
            //TODO this is ugly... need to find a better solution
            VanillaStatsStorageRouter ssr = (VanillaStatsStorageRouter) r;
            storageMetaData = ssr.getStorageMetaData();
            listenerStaticInfo = ssr.getStaticInfo();
            listenerUpdates = ssr.getUpdates();
        }
    }
    return new ParameterAveragingTrainingResult(network.params(), updaterState, network.score(), storageMetaData, listenerStaticInfo, listenerUpdates);
}
Also used : StorageMetaData(org.deeplearning4j.api.storage.StorageMetaData) GridExecutioner(org.nd4j.linalg.api.ops.executioner.GridExecutioner) Persistable(org.deeplearning4j.api.storage.Persistable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VanillaStatsStorageRouter(org.deeplearning4j.spark.impl.listeners.VanillaStatsStorageRouter) ComputationGraphUpdater(org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater) StatsStorageRouter(org.deeplearning4j.api.storage.StatsStorageRouter) VanillaStatsStorageRouter(org.deeplearning4j.spark.impl.listeners.VanillaStatsStorageRouter)

Example 4 with StatsStorageRouter

use of org.deeplearning4j.api.storage.StatsStorageRouter in project deeplearning4j by deeplearning4j.

the class ParameterAveragingTrainingWorker method getFinalResult.

@Override
public ParameterAveragingTrainingResult getFinalResult(MultiLayerNetwork network) {
    INDArray updaterState = null;
    if (saveUpdater) {
        Updater u = network.getUpdater();
        if (u != null)
            updaterState = u.getStateViewArray();
    }
    if (Nd4j.getExecutioner() instanceof GridExecutioner)
        ((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
    Collection<StorageMetaData> storageMetaData = null;
    Collection<Persistable> listenerStaticInfo = null;
    Collection<Persistable> listenerUpdates = null;
    if (listenerRouterProvider != null) {
        StatsStorageRouter r = listenerRouterProvider.getRouter();
        if (r instanceof VanillaStatsStorageRouter) {
            //TODO this is ugly... need to find a better solution
            VanillaStatsStorageRouter ssr = (VanillaStatsStorageRouter) r;
            storageMetaData = ssr.getStorageMetaData();
            listenerStaticInfo = ssr.getStaticInfo();
            listenerUpdates = ssr.getUpdates();
        }
    }
    return new ParameterAveragingTrainingResult(network.params(), updaterState, network.score(), storageMetaData, listenerStaticInfo, listenerUpdates);
}
Also used : StorageMetaData(org.deeplearning4j.api.storage.StorageMetaData) GridExecutioner(org.nd4j.linalg.api.ops.executioner.GridExecutioner) Persistable(org.deeplearning4j.api.storage.Persistable) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VanillaStatsStorageRouter(org.deeplearning4j.spark.impl.listeners.VanillaStatsStorageRouter) ComputationGraphUpdater(org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater) Updater(org.deeplearning4j.nn.api.Updater) MultiLayerUpdater(org.deeplearning4j.nn.updater.MultiLayerUpdater) StatsStorageRouter(org.deeplearning4j.api.storage.StatsStorageRouter) VanillaStatsStorageRouter(org.deeplearning4j.spark.impl.listeners.VanillaStatsStorageRouter)

Aggregations

StatsStorageRouter (org.deeplearning4j.api.storage.StatsStorageRouter)4 Persistable (org.deeplearning4j.api.storage.Persistable)2 StorageMetaData (org.deeplearning4j.api.storage.StorageMetaData)2 RemoteUIStatsStorageRouter (org.deeplearning4j.api.storage.impl.RemoteUIStatsStorageRouter)2 ComputationGraphUpdater (org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater)2 VanillaStatsStorageRouter (org.deeplearning4j.spark.impl.listeners.VanillaStatsStorageRouter)2 StatsListener (org.deeplearning4j.ui.stats.StatsListener)2 INDArray (org.nd4j.linalg.api.ndarray.INDArray)2 GridExecutioner (org.nd4j.linalg.api.ops.executioner.GridExecutioner)2 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)2 JCommander (com.beust.jcommander.JCommander)1 ParameterException (com.beust.jcommander.ParameterException)1 File (java.io.File)1 CollectionStatsStorageRouter (org.deeplearning4j.api.storage.impl.CollectionStatsStorageRouter)1 IrisDataSetIterator (org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator)1 Model (org.deeplearning4j.nn.api.Model)1 Updater (org.deeplearning4j.nn.api.Updater)1 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)1 DenseLayer (org.deeplearning4j.nn.conf.layers.DenseLayer)1 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)1