Search in sources :

Example 1 with ParallelWrapper

use of org.deeplearning4j.parallelism.ParallelWrapper in project deeplearning4j by deeplearning4j.

the class ParallelWrapperMain method runMain.

public void runMain(String... args) throws Exception {
    JCommander jcmdr = new JCommander(this);
    try {
        jcmdr.parse(args);
    } catch (ParameterException e) {
        System.err.println(e.getMessage());
        //User provides invalid input -> print the usage info
        jcmdr.usage();
        try {
            Thread.sleep(500);
        } catch (Exception e2) {
        }
        System.exit(1);
    }
    Model model = ModelGuesser.loadModelGuess(modelPath);
    // ParallelWrapper will take care of load balancing between GPUs.
    ParallelWrapper wrapper = new ParallelWrapper.Builder(model).prefetchBuffer(prefetchSize).workers(workers).averagingFrequency(averagingFrequency).averageUpdaters(averageUpdaters).reportScoreAfterAveraging(reportScore).useLegacyAveraging(legacyAveraging).build();
    if (dataSetIteratorFactoryClazz != null) {
        DataSetIteratorProviderFactory dataSetIteratorProviderFactory = (DataSetIteratorProviderFactory) Class.forName(dataSetIteratorFactoryClazz).newInstance();
        DataSetIterator dataSetIterator = dataSetIteratorProviderFactory.create();
        if (uiUrl != null) {
            // it's important that the UI can report results from parallel training
            // there's potential for StatsListener to fail if certain properties aren't set in the model
            StatsStorageRouter remoteUIRouter = new RemoteUIStatsStorageRouter("http://" + uiUrl);
            wrapper.setListeners(remoteUIRouter, new StatsListener(null));
        }
        wrapper.fit(dataSetIterator);
        ModelSerializer.writeModel(model, new File(modelOutputPath), true);
    } else if (multiDataSetIteratorFactoryClazz != null) {
        MultiDataSetProviderFactory multiDataSetProviderFactory = (MultiDataSetProviderFactory) Class.forName(multiDataSetIteratorFactoryClazz).newInstance();
        MultiDataSetIterator iterator = multiDataSetProviderFactory.create();
        if (uiUrl != null) {
            // it's important that the UI can report results from parallel training
            // there's potential for StatsListener to fail if certain properties aren't set in the model
            StatsStorageRouter remoteUIRouter = new RemoteUIStatsStorageRouter("http://" + uiUrl);
            wrapper.setListeners(remoteUIRouter, new StatsListener(null));
        }
        wrapper.fit(iterator);
        ModelSerializer.writeModel(model, new File(modelOutputPath), true);
    } else {
        throw new IllegalStateException("Please provide a datasetiteraator or multi datasetiterator class");
    }
}
Also used : ParallelWrapper(org.deeplearning4j.parallelism.ParallelWrapper) StatsStorageRouter(org.deeplearning4j.api.storage.StatsStorageRouter) RemoteUIStatsStorageRouter(org.deeplearning4j.api.storage.impl.RemoteUIStatsStorageRouter) StatsListener(org.deeplearning4j.ui.stats.StatsListener) ParameterException(com.beust.jcommander.ParameterException) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) JCommander(com.beust.jcommander.JCommander) Model(org.deeplearning4j.nn.api.Model) RemoteUIStatsStorageRouter(org.deeplearning4j.api.storage.impl.RemoteUIStatsStorageRouter) ParameterException(com.beust.jcommander.ParameterException) File(java.io.File) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator)

Aggregations

JCommander (com.beust.jcommander.JCommander)1 ParameterException (com.beust.jcommander.ParameterException)1 File (java.io.File)1 StatsStorageRouter (org.deeplearning4j.api.storage.StatsStorageRouter)1 RemoteUIStatsStorageRouter (org.deeplearning4j.api.storage.impl.RemoteUIStatsStorageRouter)1 Model (org.deeplearning4j.nn.api.Model)1 ParallelWrapper (org.deeplearning4j.parallelism.ParallelWrapper)1 StatsListener (org.deeplearning4j.ui.stats.StatsListener)1 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)1 MultiDataSetIterator (org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator)1