Search in sources :

Example 1 with Sin

use of org.nd4j.linalg.api.ops.impl.transforms.Sin in project deeplearning4j by deeplearning4j.

the class TestEarlyStopping method testMinImprovementNEpochsTermination.

@Test
public void testMinImprovementNEpochsTermination() {
    //Idea: terminate training if score (test set loss) does not improve more than minImprovement for 5 consecutive epochs
    //Simulate this by setting LR = 0.0
    Random rng = new Random(123);
    Nd4j.getRandom().setSeed(12345);
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(123).iterations(10).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).learningRate(0.0).updater(Updater.NESTEROVS).momentum(0.9).list().layer(0, new DenseLayer.Builder().nIn(1).nOut(20).weightInit(WeightInit.XAVIER).activation(Activation.TANH).build()).layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MSE).weightInit(WeightInit.XAVIER).activation(Activation.IDENTITY).weightInit(WeightInit.XAVIER).nIn(20).nOut(1).build()).pretrain(false).backprop(true).build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.setListeners(new ScoreIterationListener(1));
    int nSamples = 100;
    //Generate the training data
    INDArray x = Nd4j.linspace(-10, 10, nSamples).reshape(nSamples, 1);
    INDArray y = Nd4j.getExecutioner().execAndReturn(new Sin(x.dup()));
    DataSet allData = new DataSet(x, y);
    List<DataSet> list = allData.asList();
    Collections.shuffle(list, rng);
    DataSetIterator training = new ListDataSetIterator(list, nSamples);
    double minImprovement = 0.0009;
    EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
    EarlyStoppingConfiguration<MultiLayerNetwork> esConf = new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>().epochTerminationConditions(new MaxEpochsTerminationCondition(1000), //Go on for max 5 epochs without any improvements that are greater than minImprovement
    new ScoreImprovementEpochTerminationCondition(5, minImprovement)).iterationTerminationConditions(new MaxTimeIterationTerminationCondition(3, TimeUnit.MINUTES)).scoreCalculator(new DataSetLossCalculator(training, true)).modelSaver(saver).build();
    IEarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, net, training);
    EarlyStoppingResult result = trainer.fit();
    assertEquals(6, result.getTotalEpochs());
    assertEquals(EarlyStoppingResult.TerminationReason.EpochTerminationCondition, result.getTerminationReason());
    String expDetails = new ScoreImprovementEpochTerminationCondition(5, minImprovement).toString();
    assertEquals(expDetails, result.getTerminationDetails());
}
Also used : ListDataSetIterator(org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator) DataSet(org.nd4j.linalg.dataset.DataSet) ScoreImprovementEpochTerminationCondition(org.deeplearning4j.earlystopping.termination.ScoreImprovementEpochTerminationCondition) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) Random(java.util.Random) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) InMemoryModelSaver(org.deeplearning4j.earlystopping.saver.InMemoryModelSaver) MaxEpochsTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition) IEarlyStoppingTrainer(org.deeplearning4j.earlystopping.trainer.IEarlyStoppingTrainer) EarlyStoppingTrainer(org.deeplearning4j.earlystopping.trainer.EarlyStoppingTrainer) INDArray(org.nd4j.linalg.api.ndarray.INDArray) IEarlyStoppingTrainer(org.deeplearning4j.earlystopping.trainer.IEarlyStoppingTrainer) Sin(org.nd4j.linalg.api.ops.impl.transforms.Sin) DataSetLossCalculator(org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculator) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) ListDataSetIterator(org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator) MaxTimeIterationTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationCondition) Test(org.junit.Test)

Aggregations

Random (java.util.Random)1 IrisDataSetIterator (org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator)1 ListDataSetIterator (org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator)1 InMemoryModelSaver (org.deeplearning4j.earlystopping.saver.InMemoryModelSaver)1 DataSetLossCalculator (org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculator)1 MaxEpochsTerminationCondition (org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition)1 MaxTimeIterationTerminationCondition (org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationCondition)1 ScoreImprovementEpochTerminationCondition (org.deeplearning4j.earlystopping.termination.ScoreImprovementEpochTerminationCondition)1 EarlyStoppingTrainer (org.deeplearning4j.earlystopping.trainer.EarlyStoppingTrainer)1 IEarlyStoppingTrainer (org.deeplearning4j.earlystopping.trainer.IEarlyStoppingTrainer)1 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)1 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)1 ScoreIterationListener (org.deeplearning4j.optimize.listeners.ScoreIterationListener)1 Test (org.junit.Test)1 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1 Sin (org.nd4j.linalg.api.ops.impl.transforms.Sin)1 DataSet (org.nd4j.linalg.dataset.DataSet)1 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)1