Search in sources :

Example 1 with MultipleEpochsIterator

use of org.deeplearning4j.datasets.iterator.MultipleEpochsIterator in project deeplearning4j by deeplearning4j.

the class TestEarlyStopping method testEarlyStoppingGetBestModel.

@Test
public void testEarlyStoppingGetBestModel() {
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).updater(Updater.SGD).weightInit(WeightInit.XAVIER).list().layer(0, new OutputLayer.Builder().nIn(4).nOut(3).lossFunction(LossFunctions.LossFunction.MCXENT).build()).pretrain(false).backprop(true).build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.setListeners(new ScoreIterationListener(1));
    DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
    MultipleEpochsIterator mIter = new MultipleEpochsIterator(10, irisIter);
    EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
    EarlyStoppingConfiguration<MultiLayerNetwork> esConf = new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>().epochTerminationConditions(new MaxEpochsTerminationCondition(5)).iterationTerminationConditions(new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES)).scoreCalculator(new DataSetLossCalculator(irisIter, true)).modelSaver(saver).build();
    IEarlyStoppingTrainer<MultiLayerNetwork> trainer = new EarlyStoppingTrainer(esConf, net, mIter);
    EarlyStoppingResult<MultiLayerNetwork> result = trainer.fit();
    System.out.println(result);
    MultiLayerNetwork mln = result.getBestModel();
    assertEquals(net.getnLayers(), mln.getnLayers());
    assertEquals(net.conf().getNumIterations(), mln.conf().getNumIterations());
    assertEquals(net.conf().getOptimizationAlgo(), mln.conf().getOptimizationAlgo());
    assertEquals(net.conf().getLayer().getActivationFn().toString(), mln.conf().getLayer().getActivationFn().toString());
    assertEquals(net.conf().getLayer().getUpdater(), mln.conf().getLayer().getUpdater());
}
Also used : InMemoryModelSaver(org.deeplearning4j.earlystopping.saver.InMemoryModelSaver) MaxEpochsTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) IEarlyStoppingTrainer(org.deeplearning4j.earlystopping.trainer.IEarlyStoppingTrainer) EarlyStoppingTrainer(org.deeplearning4j.earlystopping.trainer.EarlyStoppingTrainer) MultipleEpochsIterator(org.deeplearning4j.datasets.iterator.MultipleEpochsIterator) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) DataSetLossCalculator(org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculator) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) ListDataSetIterator(org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator) MaxTimeIterationTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationCondition) Test(org.junit.Test)

Example 2 with MultipleEpochsIterator

use of org.deeplearning4j.datasets.iterator.MultipleEpochsIterator in project deeplearning4j by deeplearning4j.

the class TestEarlyStopping method testEarlyStoppingIrisMultiEpoch.

@Test
public void testEarlyStoppingIrisMultiEpoch() {
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).updater(Updater.SGD).weightInit(WeightInit.XAVIER).list().layer(0, new OutputLayer.Builder().nIn(4).nOut(3).lossFunction(LossFunctions.LossFunction.MCXENT).build()).pretrain(false).backprop(true).build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.setListeners(new ScoreIterationListener(1));
    DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
    MultipleEpochsIterator mIter = new MultipleEpochsIterator(10, irisIter);
    EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
    EarlyStoppingConfiguration<MultiLayerNetwork> esConf = new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>().epochTerminationConditions(new MaxEpochsTerminationCondition(5)).iterationTerminationConditions(new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES)).scoreCalculator(new DataSetLossCalculator(irisIter, true)).modelSaver(saver).build();
    IEarlyStoppingTrainer<MultiLayerNetwork> trainer = new EarlyStoppingTrainer(esConf, net, mIter);
    EarlyStoppingResult<MultiLayerNetwork> result = trainer.fit();
    System.out.println(result);
    assertEquals(5, result.getTotalEpochs());
    assertEquals(EarlyStoppingResult.TerminationReason.EpochTerminationCondition, result.getTerminationReason());
    Map<Integer, Double> scoreVsIter = result.getScoreVsEpoch();
    assertEquals(5, scoreVsIter.size());
    String expDetails = esConf.getEpochTerminationConditions().get(0).toString();
    assertEquals(expDetails, result.getTerminationDetails());
    MultiLayerNetwork out = result.getBestModel();
    assertNotNull(out);
    //Check that best score actually matches (returned model vs. manually calculated score)
    MultiLayerNetwork bestNetwork = result.getBestModel();
    irisIter.reset();
    double score = bestNetwork.score(irisIter.next());
    assertEquals(result.getBestModelScore(), score, 1e-2);
}
Also used : IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) InMemoryModelSaver(org.deeplearning4j.earlystopping.saver.InMemoryModelSaver) MaxEpochsTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) IEarlyStoppingTrainer(org.deeplearning4j.earlystopping.trainer.IEarlyStoppingTrainer) EarlyStoppingTrainer(org.deeplearning4j.earlystopping.trainer.EarlyStoppingTrainer) MultipleEpochsIterator(org.deeplearning4j.datasets.iterator.MultipleEpochsIterator) DataSetLossCalculator(org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculator) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) ListDataSetIterator(org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator) MaxTimeIterationTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationCondition) Test(org.junit.Test)

Aggregations

MultipleEpochsIterator (org.deeplearning4j.datasets.iterator.MultipleEpochsIterator)2 IrisDataSetIterator (org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator)2 ListDataSetIterator (org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator)2 InMemoryModelSaver (org.deeplearning4j.earlystopping.saver.InMemoryModelSaver)2 DataSetLossCalculator (org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculator)2 MaxEpochsTerminationCondition (org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition)2 MaxTimeIterationTerminationCondition (org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationCondition)2 EarlyStoppingTrainer (org.deeplearning4j.earlystopping.trainer.EarlyStoppingTrainer)2 IEarlyStoppingTrainer (org.deeplearning4j.earlystopping.trainer.IEarlyStoppingTrainer)2 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)2 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)2 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)2 ScoreIterationListener (org.deeplearning4j.optimize.listeners.ScoreIterationListener)2 Test (org.junit.Test)2 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)2