Search in sources :

Example 11 with DataSetLossCalculator

use of org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculator in project deeplearning4j by deeplearning4j.

the class TestParallelEarlyStoppingUI method testParallelStatsListenerCompatibility.

@Test
//To be run manually
@Ignore
public void testParallelStatsListenerCompatibility() throws Exception {
    UIServer uiServer = UIServer.getInstance();
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).updater(Updater.SGD).weightInit(WeightInit.XAVIER).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(3).build()).layer(1, new OutputLayer.Builder().nIn(3).nOut(3).lossFunction(LossFunctions.LossFunction.MCXENT).build()).pretrain(false).backprop(true).build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    // it's important that the UI can report results from parallel training
    // there's potential for StatsListener to fail if certain properties aren't set in the model
    StatsStorage statsStorage = new InMemoryStatsStorage();
    net.setListeners(new StatsListener(statsStorage));
    uiServer.attach(statsStorage);
    DataSetIterator irisIter = new IrisDataSetIterator(50, 500);
    EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
    EarlyStoppingConfiguration<MultiLayerNetwork> esConf = new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>().epochTerminationConditions(new MaxEpochsTerminationCondition(500)).scoreCalculator(new DataSetLossCalculator(irisIter, true)).evaluateEveryNEpochs(2).modelSaver(saver).build();
    IEarlyStoppingTrainer<MultiLayerNetwork> trainer = new EarlyStoppingParallelTrainer<>(esConf, net, irisIter, null, 3, 6, 2);
    EarlyStoppingResult<MultiLayerNetwork> result = trainer.fit();
    System.out.println(result);
    assertEquals(EarlyStoppingResult.TerminationReason.EpochTerminationCondition, result.getTerminationReason());
}
Also used : InMemoryModelSaver(org.deeplearning4j.earlystopping.saver.InMemoryModelSaver) MaxEpochsTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition) InMemoryStatsStorage(org.deeplearning4j.ui.storage.InMemoryStatsStorage) StatsStorage(org.deeplearning4j.api.storage.StatsStorage) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) UIServer(org.deeplearning4j.ui.api.UIServer) StatsListener(org.deeplearning4j.ui.stats.StatsListener) EarlyStoppingConfiguration(org.deeplearning4j.earlystopping.EarlyStoppingConfiguration) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) InMemoryStatsStorage(org.deeplearning4j.ui.storage.InMemoryStatsStorage) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) DataSetLossCalculator(org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculator) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) Ignore(org.junit.Ignore) Test(org.junit.Test)

Aggregations

IrisDataSetIterator (org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator)11 InMemoryModelSaver (org.deeplearning4j.earlystopping.saver.InMemoryModelSaver)11 DataSetLossCalculator (org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculator)11 MaxEpochsTerminationCondition (org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition)11 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)11 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)11 Test (org.junit.Test)11 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)11 ScoreIterationListener (org.deeplearning4j.optimize.listeners.ScoreIterationListener)10 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)9 ListDataSetIterator (org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator)8 MaxTimeIterationTerminationCondition (org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationCondition)8 EarlyStoppingTrainer (org.deeplearning4j.earlystopping.trainer.EarlyStoppingTrainer)8 IEarlyStoppingTrainer (org.deeplearning4j.earlystopping.trainer.IEarlyStoppingTrainer)8 EarlyStoppingConfiguration (org.deeplearning4j.earlystopping.EarlyStoppingConfiguration)3 MaxScoreIterationTerminationCondition (org.deeplearning4j.earlystopping.termination.MaxScoreIterationTerminationCondition)3 MultipleEpochsIterator (org.deeplearning4j.datasets.iterator.MultipleEpochsIterator)2 ScoreImprovementEpochTerminationCondition (org.deeplearning4j.earlystopping.termination.ScoreImprovementEpochTerminationCondition)2 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)2 Random (java.util.Random)1