use of org.deeplearning4j.api.storage.impl.RemoteUIStatsStorageRouter in project deeplearning4j by deeplearning4j.
the class ParallelWrapperMain method runMain.
public void runMain(String... args) throws Exception {
JCommander jcmdr = new JCommander(this);
try {
jcmdr.parse(args);
} catch (ParameterException e) {
System.err.println(e.getMessage());
//User provides invalid input -> print the usage info
jcmdr.usage();
try {
Thread.sleep(500);
} catch (Exception e2) {
}
System.exit(1);
}
Model model = ModelGuesser.loadModelGuess(modelPath);
// ParallelWrapper will take care of load balancing between GPUs.
ParallelWrapper wrapper = new ParallelWrapper.Builder(model).prefetchBuffer(prefetchSize).workers(workers).averagingFrequency(averagingFrequency).averageUpdaters(averageUpdaters).reportScoreAfterAveraging(reportScore).useLegacyAveraging(legacyAveraging).build();
if (dataSetIteratorFactoryClazz != null) {
DataSetIteratorProviderFactory dataSetIteratorProviderFactory = (DataSetIteratorProviderFactory) Class.forName(dataSetIteratorFactoryClazz).newInstance();
DataSetIterator dataSetIterator = dataSetIteratorProviderFactory.create();
if (uiUrl != null) {
// it's important that the UI can report results from parallel training
// there's potential for StatsListener to fail if certain properties aren't set in the model
StatsStorageRouter remoteUIRouter = new RemoteUIStatsStorageRouter("http://" + uiUrl);
wrapper.setListeners(remoteUIRouter, new StatsListener(null));
}
wrapper.fit(dataSetIterator);
ModelSerializer.writeModel(model, new File(modelOutputPath), true);
} else if (multiDataSetIteratorFactoryClazz != null) {
MultiDataSetProviderFactory multiDataSetProviderFactory = (MultiDataSetProviderFactory) Class.forName(multiDataSetIteratorFactoryClazz).newInstance();
MultiDataSetIterator iterator = multiDataSetProviderFactory.create();
if (uiUrl != null) {
// it's important that the UI can report results from parallel training
// there's potential for StatsListener to fail if certain properties aren't set in the model
StatsStorageRouter remoteUIRouter = new RemoteUIStatsStorageRouter("http://" + uiUrl);
wrapper.setListeners(remoteUIRouter, new StatsListener(null));
}
wrapper.fit(iterator);
ModelSerializer.writeModel(model, new File(modelOutputPath), true);
} else {
throw new IllegalStateException("Please provide a datasetiteraator or multi datasetiterator class");
}
}
use of org.deeplearning4j.api.storage.impl.RemoteUIStatsStorageRouter in project deeplearning4j by deeplearning4j.
the class TestRemoteReceiver method testRemoteFull.
@Test
@Ignore
public void testRemoteFull() throws Exception {
//Use this in conjunction with startRemoteUI()
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).list().layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(4).build()).layer(1, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(4).nOut(3).build()).pretrain(false).backprop(true).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
StatsStorageRouter ssr = new RemoteUIStatsStorageRouter("http://localhost:9000");
net.setListeners(new StatsListener(ssr), new ScoreIterationListener(1));
DataSetIterator iter = new IrisDataSetIterator(150, 150);
for (int i = 0; i < 500; i++) {
net.fit(iter);
// Thread.sleep(100);
Thread.sleep(100);
}
}
use of org.deeplearning4j.api.storage.impl.RemoteUIStatsStorageRouter in project deeplearning4j by deeplearning4j.
the class TestRemoteReceiver method testRemoteBasic.
@Test
@Ignore
public void testRemoteBasic() throws Exception {
List<Persistable> updates = new ArrayList<>();
List<Persistable> staticInfo = new ArrayList<>();
List<StorageMetaData> metaData = new ArrayList<>();
CollectionStatsStorageRouter collectionRouter = new CollectionStatsStorageRouter(metaData, staticInfo, updates);
UIServer s = UIServer.getInstance();
s.enableRemoteListener(collectionRouter, false);
RemoteUIStatsStorageRouter remoteRouter = new RemoteUIStatsStorageRouter("http://localhost:9000");
SbeStatsReport update1 = new SbeStatsReport();
update1.setDeviceCurrentBytes(new long[] { 1, 2 });
update1.reportIterationCount(10);
update1.reportIDs("sid", "tid", "wid", 123456);
update1.reportPerformance(10, 20, 30, 40, 50);
SbeStatsReport update2 = new SbeStatsReport();
update2.setDeviceCurrentBytes(new long[] { 3, 4 });
update2.reportIterationCount(20);
update2.reportIDs("sid2", "tid2", "wid2", 123456);
update2.reportPerformance(11, 21, 31, 40, 50);
StorageMetaData smd1 = new SbeStorageMetaData(123, "sid", "typeid", "wid", "initTypeClass", "updaterTypeClass");
StorageMetaData smd2 = new SbeStorageMetaData(456, "sid2", "typeid2", "wid2", "initTypeClass2", "updaterTypeClass2");
SbeStatsInitializationReport init1 = new SbeStatsInitializationReport();
init1.reportIDs("sid", "wid", "tid", 3145253452L);
init1.reportHardwareInfo(1, 2, 3, 4, null, null, "2344253");
remoteRouter.putUpdate(update1);
Thread.sleep(100);
remoteRouter.putStorageMetaData(smd1);
Thread.sleep(100);
remoteRouter.putStaticInfo(init1);
Thread.sleep(100);
remoteRouter.putUpdate(update2);
Thread.sleep(100);
remoteRouter.putStorageMetaData(smd2);
Thread.sleep(2000);
assertEquals(2, metaData.size());
assertEquals(2, updates.size());
assertEquals(1, staticInfo.size());
assertEquals(Arrays.asList(update1, update2), updates);
assertEquals(Arrays.asList(smd1, smd2), metaData);
assertEquals(Collections.singletonList(init1), staticInfo);
}
Aggregations