use of org.nd4j.linalg.dataset.api.iterator.DataSetIterator in project deeplearning4j by deeplearning4j.
the class ParallelWrapperMain method runMain.
public void runMain(String... args) throws Exception {
JCommander jcmdr = new JCommander(this);
try {
jcmdr.parse(args);
} catch (ParameterException e) {
System.err.println(e.getMessage());
//User provides invalid input -> print the usage info
jcmdr.usage();
try {
Thread.sleep(500);
} catch (Exception e2) {
}
System.exit(1);
}
Model model = ModelGuesser.loadModelGuess(modelPath);
// ParallelWrapper will take care of load balancing between GPUs.
ParallelWrapper wrapper = new ParallelWrapper.Builder(model).prefetchBuffer(prefetchSize).workers(workers).averagingFrequency(averagingFrequency).averageUpdaters(averageUpdaters).reportScoreAfterAveraging(reportScore).useLegacyAveraging(legacyAveraging).build();
if (dataSetIteratorFactoryClazz != null) {
DataSetIteratorProviderFactory dataSetIteratorProviderFactory = (DataSetIteratorProviderFactory) Class.forName(dataSetIteratorFactoryClazz).newInstance();
DataSetIterator dataSetIterator = dataSetIteratorProviderFactory.create();
if (uiUrl != null) {
// it's important that the UI can report results from parallel training
// there's potential for StatsListener to fail if certain properties aren't set in the model
StatsStorageRouter remoteUIRouter = new RemoteUIStatsStorageRouter("http://" + uiUrl);
wrapper.setListeners(remoteUIRouter, new StatsListener(null));
}
wrapper.fit(dataSetIterator);
ModelSerializer.writeModel(model, new File(modelOutputPath), true);
} else if (multiDataSetIteratorFactoryClazz != null) {
MultiDataSetProviderFactory multiDataSetProviderFactory = (MultiDataSetProviderFactory) Class.forName(multiDataSetIteratorFactoryClazz).newInstance();
MultiDataSetIterator iterator = multiDataSetProviderFactory.create();
if (uiUrl != null) {
// it's important that the UI can report results from parallel training
// there's potential for StatsListener to fail if certain properties aren't set in the model
StatsStorageRouter remoteUIRouter = new RemoteUIStatsStorageRouter("http://" + uiUrl);
wrapper.setListeners(remoteUIRouter, new StatsListener(null));
}
wrapper.fit(iterator);
ModelSerializer.writeModel(model, new File(modelOutputPath), true);
} else {
throw new IllegalStateException("Please provide a datasetiteraator or multi datasetiterator class");
}
}
use of org.nd4j.linalg.dataset.api.iterator.DataSetIterator in project deeplearning4j by deeplearning4j.
the class ParallelWrapperTest method testParallelWrapperRun.
@Test
public void testParallelWrapperRun() throws Exception {
int nChannels = 1;
int outputNum = 10;
// for GPU you usually want to have higher batchSize
int batchSize = 128;
int nEpochs = 10;
int iterations = 1;
int seed = 123;
log.info("Load data....");
DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, 12345);
DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, 12345);
log.info("Build model....");
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed).iterations(iterations).regularization(true).l2(0.0005).learningRate(//.biasLearningRate(0.02)
0.01).weightInit(WeightInit.XAVIER).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.NESTEROVS).momentum(0.9).list().layer(0, new ConvolutionLayer.Builder(5, 5).nIn(nChannels).stride(1, 1).nOut(20).activation(Activation.IDENTITY).build()).layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(2, 2).build()).layer(2, new ConvolutionLayer.Builder(5, 5).stride(1, 1).nOut(50).activation(Activation.IDENTITY).build()).layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(2, 2).build()).layer(4, new DenseLayer.Builder().activation(Activation.RELU).nOut(500).build()).layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(outputNum).activation(Activation.SOFTMAX).build()).backprop(true).pretrain(false);
// The builder needs the dimensions of the image along with the number of channels. these are 28x28 images in one channel
new ConvolutionLayerSetup(builder, 28, 28, 1);
MultiLayerConfiguration conf = builder.build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
// ParallelWrapper will take care of load balancing between GPUs.
ParallelWrapper wrapper = new ParallelWrapper.Builder(model).prefetchBuffer(24).workers(2).averagingFrequency(3).reportScoreAfterAveraging(true).useLegacyAveraging(true).build();
log.info("Train model....");
model.setListeners(new ScoreIterationListener(100));
long timeX = System.currentTimeMillis();
for (int i = 0; i < nEpochs; i++) {
long time1 = System.currentTimeMillis();
// Please note: we're feeding ParallelWrapper with iterator, not model directly
// wrapper.fit(mnistMultiEpochIterator);
wrapper.fit(mnistTrain);
long time2 = System.currentTimeMillis();
log.info("*** Completed epoch {}, time: {} ***", i, (time2 - time1));
}
long timeY = System.currentTimeMillis();
log.info("*** Training complete, time: {} ***", (timeY - timeX));
log.info("Evaluate model....");
Evaluation eval = new Evaluation(outputNum);
while (mnistTest.hasNext()) {
DataSet ds = mnistTest.next();
INDArray output = model.output(ds.getFeatureMatrix(), false);
eval.eval(ds.getLabels(), output);
}
log.info(eval.stats());
mnistTest.reset();
log.info("****************Example finished********************");
wrapper.shutdown();
}
use of org.nd4j.linalg.dataset.api.iterator.DataSetIterator in project deeplearning4j by deeplearning4j.
the class TestParallelEarlyStopping method testBadTuning.
@Test
public void testBadTuning() {
//Test poor tuning (high LR): should terminate on MaxScoreIterationTerminationCondition
Nd4j.getRandom().setSeed(12345);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).updater(Updater.SGD).learningRate(//Intentionally huge LR
1.0).weightInit(WeightInit.XAVIER).list().layer(0, new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()).pretrain(false).backprop(true).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.setListeners(new ScoreIterationListener(1));
DataSetIterator irisIter = new IrisDataSetIterator(10, 150);
EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
EarlyStoppingConfiguration<MultiLayerNetwork> esConf = new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>().epochTerminationConditions(new MaxEpochsTerminationCondition(5000)).iterationTerminationConditions(new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES), //Initial score is ~2.5
new MaxScoreIterationTerminationCondition(10)).scoreCalculator(new DataSetLossCalculator(irisIter, true)).modelSaver(saver).build();
IEarlyStoppingTrainer<MultiLayerNetwork> trainer = new EarlyStoppingParallelTrainer<>(esConf, net, irisIter, null, 2, 2, 1);
EarlyStoppingResult result = trainer.fit();
assertTrue(result.getTotalEpochs() < 5);
assertEquals(EarlyStoppingResult.TerminationReason.IterationTerminationCondition, result.getTerminationReason());
String expDetails = new MaxScoreIterationTerminationCondition(10).toString();
assertEquals(expDetails, result.getTerminationDetails());
assertTrue(result.getBestModelEpoch() <= 0);
assertNotNull(result.getBestModel());
}
use of org.nd4j.linalg.dataset.api.iterator.DataSetIterator in project deeplearning4j by deeplearning4j.
the class TestParallelEarlyStopping method testEarlyStoppingEveryNEpoch.
// parallel training results vary wildly with expected result
// need to determine if this test is feasible, and how it should
// be properly designed
// @Test
// public void testEarlyStoppingIris(){
// MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
// .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1)
// .updater(Updater.SGD)
// .weightInit(WeightInit.XAVIER)
// .list()
// .layer(0,new OutputLayer.Builder().nIn(4).nOut(3).lossFunction(LossFunctions.LossFunction.MCXENT).build())
// .pretrain(false).backprop(true)
// .build();
// MultiLayerNetwork net = new MultiLayerNetwork(conf);
// net.setListeners(new ScoreIterationListener(1));
//
// DataSetIterator irisIter = new IrisDataSetIterator(50,600);
// EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
// EarlyStoppingConfiguration<MultiLayerNetwork> esConf = new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
// .epochTerminationConditions(new MaxEpochsTerminationCondition(5))
// .evaluateEveryNEpochs(1)
// .iterationTerminationConditions(new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES))
// .scoreCalculator(new DataSetLossCalculator(irisIter,true))
// .modelSaver(saver)
// .build();
//
// IEarlyStoppingTrainer<MultiLayerNetwork> trainer = new EarlyStoppingParallelTrainer<>(esConf,net,irisIter,null,2,2,1);
//
// EarlyStoppingResult<MultiLayerNetwork> result = trainer.fit();
// System.out.println(result);
//
// assertEquals(5, result.getTotalEpochs());
// assertEquals(EarlyStoppingResult.TerminationReason.EpochTerminationCondition,result.getTerminationReason());
// Map<Integer,Double> scoreVsIter = result.getScoreVsEpoch();
// assertEquals(5,scoreVsIter.size());
// String expDetails = esConf.getEpochTerminationConditions().get(0).toString();
// assertEquals(expDetails, result.getTerminationDetails());
//
// MultiLayerNetwork out = result.getBestModel();
// assertNotNull(out);
//
// //Check that best score actually matches (returned model vs. manually calculated score)
// MultiLayerNetwork bestNetwork = result.getBestModel();
// irisIter.reset();
// double score = bestNetwork.score(irisIter.next());
// assertEquals(result.getBestModelScore(), score, 1e-4);
// }
@Test
public void testEarlyStoppingEveryNEpoch() {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).updater(Updater.SGD).weightInit(WeightInit.XAVIER).list().layer(0, new OutputLayer.Builder().nIn(4).nOut(3).lossFunction(LossFunctions.LossFunction.MCXENT).build()).pretrain(false).backprop(true).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.setListeners(new ScoreIterationListener(1));
DataSetIterator irisIter = new IrisDataSetIterator(50, 600);
EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
EarlyStoppingConfiguration<MultiLayerNetwork> esConf = new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>().epochTerminationConditions(new MaxEpochsTerminationCondition(5)).scoreCalculator(new DataSetLossCalculator(irisIter, true)).evaluateEveryNEpochs(2).modelSaver(saver).build();
IEarlyStoppingTrainer<MultiLayerNetwork> trainer = new EarlyStoppingParallelTrainer<>(esConf, net, irisIter, null, 2, 6, 1);
EarlyStoppingResult<MultiLayerNetwork> result = trainer.fit();
System.out.println(result);
assertEquals(5, result.getTotalEpochs());
assertEquals(EarlyStoppingResult.TerminationReason.EpochTerminationCondition, result.getTerminationReason());
}
use of org.nd4j.linalg.dataset.api.iterator.DataSetIterator in project deeplearning4j by deeplearning4j.
the class ParallelWrapperMainTest method runParallelWrapperMain.
@Test
public void runParallelWrapperMain() throws Exception {
int nChannels = 1;
int outputNum = 10;
// for GPU you usually want to have higher batchSize
int batchSize = 128;
int nEpochs = 10;
int iterations = 1;
int seed = 123;
int uiPort = new Random().nextInt(1000) + 9000;
System.setProperty("org.deeplearning4j.ui.port", String.valueOf(uiPort));
log.info("Load data....");
DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, 12345);
DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, 12345);
log.info("Build model....");
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed).iterations(iterations).regularization(true).l2(0.0005).learningRate(//.biasLearningRate(0.02)
0.01).weightInit(WeightInit.XAVIER).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.NESTEROVS).momentum(0.9).list().layer(0, new ConvolutionLayer.Builder(5, 5).nIn(nChannels).stride(1, 1).nOut(20).activation(Activation.IDENTITY).build()).layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(2, 2).build()).layer(2, new ConvolutionLayer.Builder(5, 5).stride(1, 1).nOut(50).activation(Activation.IDENTITY).build()).layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(2, 2).build()).layer(4, new DenseLayer.Builder().activation(Activation.RELU).nOut(500).build()).layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(outputNum).activation(Activation.SOFTMAX).build()).backprop(true).pretrain(false);
// The builder needs the dimensions of the image along with the number of channels. these are 28x28 images in one channel
new ConvolutionLayerSetup(builder, 28, 28, 1);
MultiLayerConfiguration conf = builder.build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
File tempModel = new File("tmpmodel.zip");
tempModel.deleteOnExit();
ModelSerializer.writeModel(model, tempModel, false);
File tmp = new File("tmpmodel.bin");
tmp.deleteOnExit();
ParallelWrapperMain parallelWrapperMain = new ParallelWrapperMain();
parallelWrapperMain.runMain(new String[] { "--modelPath", tempModel.getAbsolutePath(), "--dataSetIteratorFactoryClazz", MnistDataSetIteratorProviderFactory.class.getName(), "--modelOutputPath", tmp.getAbsolutePath(), "--uiUrl", "localhost:" + uiPort });
Thread.sleep(30000);
}
Aggregations