Search in sources :

Example 1 with InMemoryModelSaver

use of org.deeplearning4j.earlystopping.saver.InMemoryModelSaver in project deeplearning4j by deeplearning4j.

the class TestParallelEarlyStopping method testBadTuning.

@Test
public void testBadTuning() {
    //Test poor tuning (high LR): should terminate on MaxScoreIterationTerminationCondition
    Nd4j.getRandom().setSeed(12345);
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).updater(Updater.SGD).learningRate(//Intentionally huge LR
    1.0).weightInit(WeightInit.XAVIER).list().layer(0, new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()).pretrain(false).backprop(true).build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.setListeners(new ScoreIterationListener(1));
    DataSetIterator irisIter = new IrisDataSetIterator(10, 150);
    EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
    EarlyStoppingConfiguration<MultiLayerNetwork> esConf = new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>().epochTerminationConditions(new MaxEpochsTerminationCondition(5000)).iterationTerminationConditions(new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES), //Initial score is ~2.5
    new MaxScoreIterationTerminationCondition(10)).scoreCalculator(new DataSetLossCalculator(irisIter, true)).modelSaver(saver).build();
    IEarlyStoppingTrainer<MultiLayerNetwork> trainer = new EarlyStoppingParallelTrainer<>(esConf, net, irisIter, null, 2, 2, 1);
    EarlyStoppingResult result = trainer.fit();
    assertTrue(result.getTotalEpochs() < 5);
    assertEquals(EarlyStoppingResult.TerminationReason.IterationTerminationCondition, result.getTerminationReason());
    String expDetails = new MaxScoreIterationTerminationCondition(10).toString();
    assertEquals(expDetails, result.getTerminationDetails());
    assertTrue(result.getBestModelEpoch() <= 0);
    assertNotNull(result.getBestModel());
}
Also used : OutputLayer(org.deeplearning4j.nn.conf.layers.OutputLayer) InMemoryModelSaver(org.deeplearning4j.earlystopping.saver.InMemoryModelSaver) MaxEpochsTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) EarlyStoppingResult(org.deeplearning4j.earlystopping.EarlyStoppingResult) EarlyStoppingConfiguration(org.deeplearning4j.earlystopping.EarlyStoppingConfiguration) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) MaxScoreIterationTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxScoreIterationTerminationCondition) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) DataSetLossCalculator(org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculator) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) MaxTimeIterationTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationCondition) Test(org.junit.Test)

Example 2 with InMemoryModelSaver

use of org.deeplearning4j.earlystopping.saver.InMemoryModelSaver in project deeplearning4j by deeplearning4j.

the class TestParallelEarlyStopping method testEarlyStoppingEveryNEpoch.

// parallel training results vary wildly with expected result
// need to determine if this test is feasible, and how it should
// be properly designed
//    @Test
//    public void testEarlyStoppingIris(){
//        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
//                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1)
//                .updater(Updater.SGD)
//                .weightInit(WeightInit.XAVIER)
//                .list()
//                .layer(0,new OutputLayer.Builder().nIn(4).nOut(3).lossFunction(LossFunctions.LossFunction.MCXENT).build())
//                .pretrain(false).backprop(true)
//                .build();
//        MultiLayerNetwork net = new MultiLayerNetwork(conf);
//        net.setListeners(new ScoreIterationListener(1));
//
//        DataSetIterator irisIter = new IrisDataSetIterator(50,600);
//        EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
//        EarlyStoppingConfiguration<MultiLayerNetwork> esConf = new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
//            .epochTerminationConditions(new MaxEpochsTerminationCondition(5))
//            .evaluateEveryNEpochs(1)
//            .iterationTerminationConditions(new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES))
//            .scoreCalculator(new DataSetLossCalculator(irisIter,true))
//            .modelSaver(saver)
//            .build();
//
//        IEarlyStoppingTrainer<MultiLayerNetwork> trainer = new EarlyStoppingParallelTrainer<>(esConf,net,irisIter,null,2,2,1);
//
//        EarlyStoppingResult<MultiLayerNetwork> result = trainer.fit();
//        System.out.println(result);
//
//        assertEquals(5, result.getTotalEpochs());
//        assertEquals(EarlyStoppingResult.TerminationReason.EpochTerminationCondition,result.getTerminationReason());
//        Map<Integer,Double> scoreVsIter = result.getScoreVsEpoch();
//        assertEquals(5,scoreVsIter.size());
//        String expDetails = esConf.getEpochTerminationConditions().get(0).toString();
//        assertEquals(expDetails, result.getTerminationDetails());
//
//        MultiLayerNetwork out = result.getBestModel();
//        assertNotNull(out);
//
//        //Check that best score actually matches (returned model vs. manually calculated score)
//        MultiLayerNetwork bestNetwork = result.getBestModel();
//        irisIter.reset();
//        double score = bestNetwork.score(irisIter.next());
//        assertEquals(result.getBestModelScore(), score, 1e-4);
//    }
@Test
public void testEarlyStoppingEveryNEpoch() {
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).updater(Updater.SGD).weightInit(WeightInit.XAVIER).list().layer(0, new OutputLayer.Builder().nIn(4).nOut(3).lossFunction(LossFunctions.LossFunction.MCXENT).build()).pretrain(false).backprop(true).build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.setListeners(new ScoreIterationListener(1));
    DataSetIterator irisIter = new IrisDataSetIterator(50, 600);
    EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
    EarlyStoppingConfiguration<MultiLayerNetwork> esConf = new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>().epochTerminationConditions(new MaxEpochsTerminationCondition(5)).scoreCalculator(new DataSetLossCalculator(irisIter, true)).evaluateEveryNEpochs(2).modelSaver(saver).build();
    IEarlyStoppingTrainer<MultiLayerNetwork> trainer = new EarlyStoppingParallelTrainer<>(esConf, net, irisIter, null, 2, 6, 1);
    EarlyStoppingResult<MultiLayerNetwork> result = trainer.fit();
    System.out.println(result);
    assertEquals(5, result.getTotalEpochs());
    assertEquals(EarlyStoppingResult.TerminationReason.EpochTerminationCondition, result.getTerminationReason());
}
Also used : InMemoryModelSaver(org.deeplearning4j.earlystopping.saver.InMemoryModelSaver) MaxEpochsTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) EarlyStoppingConfiguration(org.deeplearning4j.earlystopping.EarlyStoppingConfiguration) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) DataSetLossCalculator(org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculator) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) Test(org.junit.Test)

Example 3 with InMemoryModelSaver

use of org.deeplearning4j.earlystopping.saver.InMemoryModelSaver in project deeplearning4j by deeplearning4j.

the class TestEarlyStopping method testListeners.

@Test
public void testListeners() {
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).updater(Updater.SGD).weightInit(WeightInit.XAVIER).list().layer(0, new OutputLayer.Builder().nIn(4).nOut(3).lossFunction(LossFunctions.LossFunction.MCXENT).build()).pretrain(false).backprop(true).build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.setListeners(new ScoreIterationListener(1));
    DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
    EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
    EarlyStoppingConfiguration<MultiLayerNetwork> esConf = new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>().epochTerminationConditions(new MaxEpochsTerminationCondition(5)).iterationTerminationConditions(new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES)).scoreCalculator(new DataSetLossCalculator(irisIter, true)).modelSaver(saver).build();
    LoggingEarlyStoppingListener listener = new LoggingEarlyStoppingListener();
    IEarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, net, irisIter, listener);
    trainer.fit();
    assertEquals(1, listener.onStartCallCount);
    assertEquals(5, listener.onEpochCallCount);
    assertEquals(1, listener.onCompletionCallCount);
}
Also used : InMemoryModelSaver(org.deeplearning4j.earlystopping.saver.InMemoryModelSaver) MaxEpochsTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) IEarlyStoppingTrainer(org.deeplearning4j.earlystopping.trainer.IEarlyStoppingTrainer) EarlyStoppingTrainer(org.deeplearning4j.earlystopping.trainer.EarlyStoppingTrainer) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) IEarlyStoppingTrainer(org.deeplearning4j.earlystopping.trainer.IEarlyStoppingTrainer) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) DataSetLossCalculator(org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculator) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) ListDataSetIterator(org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator) MaxTimeIterationTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationCondition) Test(org.junit.Test)

Example 4 with InMemoryModelSaver

use of org.deeplearning4j.earlystopping.saver.InMemoryModelSaver in project deeplearning4j by deeplearning4j.

the class TestEarlyStopping method testEarlyStoppingEveryNEpoch.

@Test
public void testEarlyStoppingEveryNEpoch() {
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).updater(Updater.SGD).weightInit(WeightInit.XAVIER).list().layer(0, new OutputLayer.Builder().nIn(4).nOut(3).lossFunction(LossFunctions.LossFunction.MCXENT).build()).pretrain(false).backprop(true).build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.setListeners(new ScoreIterationListener(1));
    DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
    EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
    EarlyStoppingConfiguration<MultiLayerNetwork> esConf = new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>().epochTerminationConditions(new MaxEpochsTerminationCondition(5)).scoreCalculator(new DataSetLossCalculator(irisIter, true)).evaluateEveryNEpochs(2).modelSaver(saver).build();
    IEarlyStoppingTrainer<MultiLayerNetwork> trainer = new EarlyStoppingTrainer(esConf, net, irisIter);
    EarlyStoppingResult<MultiLayerNetwork> result = trainer.fit();
    System.out.println(result);
    assertEquals(5, result.getTotalEpochs());
    assertEquals(EarlyStoppingResult.TerminationReason.EpochTerminationCondition, result.getTerminationReason());
}
Also used : InMemoryModelSaver(org.deeplearning4j.earlystopping.saver.InMemoryModelSaver) MaxEpochsTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) IEarlyStoppingTrainer(org.deeplearning4j.earlystopping.trainer.IEarlyStoppingTrainer) EarlyStoppingTrainer(org.deeplearning4j.earlystopping.trainer.EarlyStoppingTrainer) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) DataSetLossCalculator(org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculator) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) ListDataSetIterator(org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator) Test(org.junit.Test)

Example 5 with InMemoryModelSaver

use of org.deeplearning4j.earlystopping.saver.InMemoryModelSaver in project deeplearning4j by deeplearning4j.

the class TestEarlyStopping method testNoImprovementNEpochsTermination.

@Test
public void testNoImprovementNEpochsTermination() {
    //Idea: terminate training if score (test set loss) does not improve for 5 consecutive epochs
    //Simulate this by setting LR = 0.0
    Nd4j.getRandom().setSeed(12345);
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).updater(Updater.SGD).learningRate(0.0).weightInit(WeightInit.XAVIER).list().layer(0, new OutputLayer.Builder().nIn(4).nOut(3).lossFunction(LossFunctions.LossFunction.MCXENT).build()).pretrain(false).backprop(true).build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.setListeners(new ScoreIterationListener(1));
    DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
    EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
    EarlyStoppingConfiguration<MultiLayerNetwork> esConf = new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>().epochTerminationConditions(new MaxEpochsTerminationCondition(100), new ScoreImprovementEpochTerminationCondition(5)).iterationTerminationConditions(new MaxTimeIterationTerminationCondition(3, TimeUnit.SECONDS), //Initial score is ~2.5
    new MaxScoreIterationTerminationCondition(7.5)).scoreCalculator(new DataSetLossCalculator(irisIter, true)).modelSaver(saver).build();
    IEarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, net, irisIter);
    EarlyStoppingResult result = trainer.fit();
    //Expect no score change due to 0 LR -> terminate after 6 total epochs
    assertEquals(6, result.getTotalEpochs());
    assertEquals(0, result.getBestModelEpoch());
    assertEquals(EarlyStoppingResult.TerminationReason.EpochTerminationCondition, result.getTerminationReason());
    String expDetails = new ScoreImprovementEpochTerminationCondition(5).toString();
    assertEquals(expDetails, result.getTerminationDetails());
}
Also used : InMemoryModelSaver(org.deeplearning4j.earlystopping.saver.InMemoryModelSaver) MaxEpochsTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) ScoreImprovementEpochTerminationCondition(org.deeplearning4j.earlystopping.termination.ScoreImprovementEpochTerminationCondition) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) IEarlyStoppingTrainer(org.deeplearning4j.earlystopping.trainer.IEarlyStoppingTrainer) EarlyStoppingTrainer(org.deeplearning4j.earlystopping.trainer.EarlyStoppingTrainer) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) IEarlyStoppingTrainer(org.deeplearning4j.earlystopping.trainer.IEarlyStoppingTrainer) MaxScoreIterationTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxScoreIterationTerminationCondition) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) DataSetLossCalculator(org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculator) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) ListDataSetIterator(org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator) MaxTimeIterationTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationCondition) Test(org.junit.Test)

Aggregations

InMemoryModelSaver (org.deeplearning4j.earlystopping.saver.InMemoryModelSaver)27 MaxEpochsTerminationCondition (org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition)27 Test (org.junit.Test)27 ScoreIterationListener (org.deeplearning4j.optimize.listeners.ScoreIterationListener)26 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)25 MaxTimeIterationTerminationCondition (org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationCondition)22 IrisDataSetIterator (org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator)19 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)17 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)17 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)17 EarlyStoppingConfiguration (org.deeplearning4j.earlystopping.EarlyStoppingConfiguration)13 MaxScoreIterationTerminationCondition (org.deeplearning4j.earlystopping.termination.MaxScoreIterationTerminationCondition)13 IEarlyStoppingTrainer (org.deeplearning4j.earlystopping.trainer.IEarlyStoppingTrainer)13 DataSetLossCalculator (org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculator)11 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)11 DataSet (org.nd4j.linalg.dataset.DataSet)11 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)10 ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)10 ListDataSetIterator (org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator)9 EarlyStoppingTrainer (org.deeplearning4j.earlystopping.trainer.EarlyStoppingTrainer)9