Search in sources :

Example 6 with DataSetToMultiDataSetFn

use of org.deeplearning4j.spark.impl.graph.dataset.DataSetToMultiDataSetFn in project deeplearning4j by deeplearning4j.

the class TestEarlyStoppingSparkCompGraph method testBadTuning.

@Test
public void testBadTuning() {
    //Test poor tuning (high LR): should terminate on MaxScoreIterationTerminationCondition
    Nd4j.getRandom().setSeed(12345);
    ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).updater(Updater.SGD).learningRate(//Intentionally huge LR
    2.0).weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in").addLayer("0", new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.IDENTITY).lossFunction(LossFunctions.LossFunction.MSE).build(), "in").setOutputs("0").pretrain(false).backprop(true).build();
    ComputationGraph net = new ComputationGraph(conf);
    net.setListeners(new ScoreIterationListener(1));
    JavaRDD<DataSet> irisData = getIris();
    EarlyStoppingModelSaver<ComputationGraph> saver = new InMemoryModelSaver<>();
    EarlyStoppingConfiguration<ComputationGraph> esConf = new EarlyStoppingConfiguration.Builder<ComputationGraph>().epochTerminationConditions(new MaxEpochsTerminationCondition(5000)).iterationTerminationConditions(new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES), //Initial score is ~2.5
    new MaxScoreIterationTerminationCondition(7.5)).scoreCalculator(new SparkLossCalculatorComputationGraph(irisData.map(new DataSetToMultiDataSetFn()), true, sc.sc())).modelSaver(saver).build();
    TrainingMaster tm = new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 10, 1, 0);
    IEarlyStoppingTrainer<ComputationGraph> trainer = new SparkEarlyStoppingGraphTrainer(getContext().sc(), tm, esConf, net, irisData.map(new DataSetToMultiDataSetFn()));
    EarlyStoppingResult result = trainer.fit();
    assertTrue(result.getTotalEpochs() < 5);
    assertEquals(EarlyStoppingResult.TerminationReason.IterationTerminationCondition, result.getTerminationReason());
    String expDetails = new MaxScoreIterationTerminationCondition(7.5).toString();
    assertEquals(expDetails, result.getTerminationDetails());
}
Also used : InMemoryModelSaver(org.deeplearning4j.earlystopping.saver.InMemoryModelSaver) MaxEpochsTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition) SparkEarlyStoppingGraphTrainer(org.deeplearning4j.spark.earlystopping.SparkEarlyStoppingGraphTrainer) DataSet(org.nd4j.linalg.dataset.DataSet) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) ParameterAveragingTrainingMaster(org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster) TrainingMaster(org.deeplearning4j.spark.api.TrainingMaster) ParameterAveragingTrainingMaster(org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster) EarlyStoppingResult(org.deeplearning4j.earlystopping.EarlyStoppingResult) EarlyStoppingConfiguration(org.deeplearning4j.earlystopping.EarlyStoppingConfiguration) SparkLossCalculatorComputationGraph(org.deeplearning4j.spark.earlystopping.SparkLossCalculatorComputationGraph) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) DataSetToMultiDataSetFn(org.deeplearning4j.spark.impl.graph.dataset.DataSetToMultiDataSetFn) SparkLossCalculatorComputationGraph(org.deeplearning4j.spark.earlystopping.SparkLossCalculatorComputationGraph) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) MaxScoreIterationTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxScoreIterationTerminationCondition) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) MaxTimeIterationTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationCondition) Test(org.junit.Test)

Aggregations

DataSetToMultiDataSetFn (org.deeplearning4j.spark.impl.graph.dataset.DataSetToMultiDataSetFn)6 EarlyStoppingConfiguration (org.deeplearning4j.earlystopping.EarlyStoppingConfiguration)5 InMemoryModelSaver (org.deeplearning4j.earlystopping.saver.InMemoryModelSaver)5 MaxEpochsTerminationCondition (org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition)5 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)5 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)5 ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)5 ScoreIterationListener (org.deeplearning4j.optimize.listeners.ScoreIterationListener)5 TrainingMaster (org.deeplearning4j.spark.api.TrainingMaster)5 SparkEarlyStoppingGraphTrainer (org.deeplearning4j.spark.earlystopping.SparkEarlyStoppingGraphTrainer)5 SparkLossCalculatorComputationGraph (org.deeplearning4j.spark.earlystopping.SparkLossCalculatorComputationGraph)5 ParameterAveragingTrainingMaster (org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster)5 Test (org.junit.Test)5 DataSet (org.nd4j.linalg.dataset.DataSet)5 MaxTimeIterationTerminationCondition (org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationCondition)4 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)4 EarlyStoppingResult (org.deeplearning4j.earlystopping.EarlyStoppingResult)3 MaxScoreIterationTerminationCondition (org.deeplearning4j.earlystopping.termination.MaxScoreIterationTerminationCondition)3 IrisDataSetIterator (org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator)1 ScoreImprovementEpochTerminationCondition (org.deeplearning4j.earlystopping.termination.ScoreImprovementEpochTerminationCondition)1