Search in sources :

Example 1 with DataSetToMultiDataSetFn

use of org.deeplearning4j.spark.impl.graph.dataset.DataSetToMultiDataSetFn in project deeplearning4j by deeplearning4j.

the class ParameterAveragingTrainingMaster method executeTraining.

@Override
public void executeTraining(SparkComputationGraph graph, JavaRDD<DataSet> trainingData) {
    if (numWorkers == null)
        numWorkers = graph.getSparkContext().defaultParallelism();
    JavaRDD<MultiDataSet> mdsTrainingData = trainingData.map(new DataSetToMultiDataSetFn());
    executeTrainingMDS(graph, mdsTrainingData);
}
Also used : MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) DataSetToMultiDataSetFn(org.deeplearning4j.spark.impl.graph.dataset.DataSetToMultiDataSetFn)

Example 2 with DataSetToMultiDataSetFn

use of org.deeplearning4j.spark.impl.graph.dataset.DataSetToMultiDataSetFn in project deeplearning4j by deeplearning4j.

the class TestEarlyStoppingSparkCompGraph method testEarlyStoppingIris.

@Test
public void testEarlyStoppingIris() {
    ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).updater(Updater.SGD).weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in").addLayer("0", new OutputLayer.Builder().nIn(4).nOut(3).lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in").setOutputs("0").pretrain(false).backprop(true).build();
    ComputationGraph net = new ComputationGraph(conf);
    net.setListeners(new ScoreIterationListener(1));
    JavaRDD<DataSet> irisData = getIris();
    EarlyStoppingModelSaver<ComputationGraph> saver = new InMemoryModelSaver<>();
    EarlyStoppingConfiguration<ComputationGraph> esConf = new EarlyStoppingConfiguration.Builder<ComputationGraph>().epochTerminationConditions(new MaxEpochsTerminationCondition(5)).iterationTerminationConditions(new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES)).scoreCalculator(new SparkLossCalculatorComputationGraph(irisData.map(new DataSetToMultiDataSetFn()), true, sc.sc())).modelSaver(saver).build();
    TrainingMaster tm = new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 10, 1, 0);
    IEarlyStoppingTrainer<ComputationGraph> trainer = new SparkEarlyStoppingGraphTrainer(getContext().sc(), tm, esConf, net, irisData.map(new DataSetToMultiDataSetFn()));
    EarlyStoppingResult<ComputationGraph> result = trainer.fit();
    System.out.println(result);
    assertEquals(5, result.getTotalEpochs());
    assertEquals(EarlyStoppingResult.TerminationReason.EpochTerminationCondition, result.getTerminationReason());
    Map<Integer, Double> scoreVsIter = result.getScoreVsEpoch();
    assertEquals(5, scoreVsIter.size());
    String expDetails = esConf.getEpochTerminationConditions().get(0).toString();
    assertEquals(expDetails, result.getTerminationDetails());
    ComputationGraph out = result.getBestModel();
    assertNotNull(out);
    //Check that best score actually matches (returned model vs. manually calculated score)
    ComputationGraph bestNetwork = result.getBestModel();
    double score = bestNetwork.score(new IrisDataSetIterator(150, 150).next());
    double bestModelScore = result.getBestModelScore();
    assertEquals(bestModelScore, score, 1e-3);
}
Also used : OutputLayer(org.deeplearning4j.nn.conf.layers.OutputLayer) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSet(org.nd4j.linalg.dataset.DataSet) TrainingMaster(org.deeplearning4j.spark.api.TrainingMaster) ParameterAveragingTrainingMaster(org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster) EarlyStoppingConfiguration(org.deeplearning4j.earlystopping.EarlyStoppingConfiguration) DataSetToMultiDataSetFn(org.deeplearning4j.spark.impl.graph.dataset.DataSetToMultiDataSetFn) SparkLossCalculatorComputationGraph(org.deeplearning4j.spark.earlystopping.SparkLossCalculatorComputationGraph) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) InMemoryModelSaver(org.deeplearning4j.earlystopping.saver.InMemoryModelSaver) MaxEpochsTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition) SparkEarlyStoppingGraphTrainer(org.deeplearning4j.spark.earlystopping.SparkEarlyStoppingGraphTrainer) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) ParameterAveragingTrainingMaster(org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster) SparkLossCalculatorComputationGraph(org.deeplearning4j.spark.earlystopping.SparkLossCalculatorComputationGraph) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) MaxTimeIterationTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationCondition) Test(org.junit.Test)

Example 3 with DataSetToMultiDataSetFn

use of org.deeplearning4j.spark.impl.graph.dataset.DataSetToMultiDataSetFn in project deeplearning4j by deeplearning4j.

the class TestEarlyStoppingSparkCompGraph method testNoImprovementNEpochsTermination.

@Test
public void testNoImprovementNEpochsTermination() {
    //Idea: terminate training if score (test set loss) does not improve for 5 consecutive epochs
    //Simulate this by setting LR = 0.0
    Nd4j.getRandom().setSeed(12345);
    ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).updater(Updater.SGD).learningRate(0.0).weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in").addLayer("0", new OutputLayer.Builder().nIn(4).nOut(3).lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in").setOutputs("0").pretrain(false).backprop(true).build();
    ComputationGraph net = new ComputationGraph(conf);
    net.setListeners(new ScoreIterationListener(1));
    JavaRDD<DataSet> irisData = getIris();
    EarlyStoppingModelSaver<ComputationGraph> saver = new InMemoryModelSaver<>();
    EarlyStoppingConfiguration<ComputationGraph> esConf = new EarlyStoppingConfiguration.Builder<ComputationGraph>().epochTerminationConditions(new MaxEpochsTerminationCondition(100), new ScoreImprovementEpochTerminationCondition(5)).iterationTerminationConditions(//Initial score is ~2.5
    new MaxScoreIterationTerminationCondition(7.5)).scoreCalculator(new SparkLossCalculatorComputationGraph(irisData.map(new DataSetToMultiDataSetFn()), true, sc.sc())).modelSaver(saver).build();
    TrainingMaster tm = new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 10, 1, 0);
    IEarlyStoppingTrainer<ComputationGraph> trainer = new SparkEarlyStoppingGraphTrainer(getContext().sc(), tm, esConf, net, irisData.map(new DataSetToMultiDataSetFn()));
    EarlyStoppingResult result = trainer.fit();
    //Expect no score change due to 0 LR -> terminate after 6 total epochs
    //Normally expect 6 epochs exactly; get a little more than that here due to rounding + order of operations
    assertTrue(result.getTotalEpochs() < 12);
    assertEquals(EarlyStoppingResult.TerminationReason.EpochTerminationCondition, result.getTerminationReason());
    String expDetails = new ScoreImprovementEpochTerminationCondition(5).toString();
    assertEquals(expDetails, result.getTerminationDetails());
}
Also used : OutputLayer(org.deeplearning4j.nn.conf.layers.OutputLayer) DataSet(org.nd4j.linalg.dataset.DataSet) ScoreImprovementEpochTerminationCondition(org.deeplearning4j.earlystopping.termination.ScoreImprovementEpochTerminationCondition) TrainingMaster(org.deeplearning4j.spark.api.TrainingMaster) ParameterAveragingTrainingMaster(org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster) EarlyStoppingConfiguration(org.deeplearning4j.earlystopping.EarlyStoppingConfiguration) DataSetToMultiDataSetFn(org.deeplearning4j.spark.impl.graph.dataset.DataSetToMultiDataSetFn) SparkLossCalculatorComputationGraph(org.deeplearning4j.spark.earlystopping.SparkLossCalculatorComputationGraph) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) InMemoryModelSaver(org.deeplearning4j.earlystopping.saver.InMemoryModelSaver) MaxEpochsTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition) SparkEarlyStoppingGraphTrainer(org.deeplearning4j.spark.earlystopping.SparkEarlyStoppingGraphTrainer) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) ParameterAveragingTrainingMaster(org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster) EarlyStoppingResult(org.deeplearning4j.earlystopping.EarlyStoppingResult) SparkLossCalculatorComputationGraph(org.deeplearning4j.spark.earlystopping.SparkLossCalculatorComputationGraph) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) MaxScoreIterationTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxScoreIterationTerminationCondition) Test(org.junit.Test)

Example 4 with DataSetToMultiDataSetFn

use of org.deeplearning4j.spark.impl.graph.dataset.DataSetToMultiDataSetFn in project deeplearning4j by deeplearning4j.

the class TestEarlyStoppingSparkCompGraph method testListeners.

@Test
public void testListeners() {
    ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).updater(Updater.SGD).weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in").addLayer("0", new OutputLayer.Builder().nIn(4).nOut(3).lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in").setOutputs("0").pretrain(false).backprop(true).build();
    ComputationGraph net = new ComputationGraph(conf);
    net.setListeners(new ScoreIterationListener(1));
    JavaRDD<DataSet> irisData = getIris();
    EarlyStoppingModelSaver<ComputationGraph> saver = new InMemoryModelSaver<>();
    EarlyStoppingConfiguration<ComputationGraph> esConf = new EarlyStoppingConfiguration.Builder<ComputationGraph>().epochTerminationConditions(new MaxEpochsTerminationCondition(5)).iterationTerminationConditions(new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES)).scoreCalculator(new SparkLossCalculatorComputationGraph(irisData.map(new DataSetToMultiDataSetFn()), true, sc.sc())).modelSaver(saver).build();
    LoggingEarlyStoppingListener listener = new LoggingEarlyStoppingListener();
    TrainingMaster tm = new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 10, 1, 0);
    IEarlyStoppingTrainer<ComputationGraph> trainer = new SparkEarlyStoppingGraphTrainer(getContext().sc(), tm, esConf, net, irisData.map(new DataSetToMultiDataSetFn()));
    trainer.setListener(listener);
    trainer.fit();
    assertEquals(1, listener.onStartCallCount);
    assertEquals(5, listener.onEpochCallCount);
    assertEquals(1, listener.onCompletionCallCount);
}
Also used : OutputLayer(org.deeplearning4j.nn.conf.layers.OutputLayer) InMemoryModelSaver(org.deeplearning4j.earlystopping.saver.InMemoryModelSaver) MaxEpochsTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition) SparkEarlyStoppingGraphTrainer(org.deeplearning4j.spark.earlystopping.SparkEarlyStoppingGraphTrainer) DataSet(org.nd4j.linalg.dataset.DataSet) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) ParameterAveragingTrainingMaster(org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster) TrainingMaster(org.deeplearning4j.spark.api.TrainingMaster) ParameterAveragingTrainingMaster(org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster) EarlyStoppingConfiguration(org.deeplearning4j.earlystopping.EarlyStoppingConfiguration) SparkLossCalculatorComputationGraph(org.deeplearning4j.spark.earlystopping.SparkLossCalculatorComputationGraph) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) DataSetToMultiDataSetFn(org.deeplearning4j.spark.impl.graph.dataset.DataSetToMultiDataSetFn) SparkLossCalculatorComputationGraph(org.deeplearning4j.spark.earlystopping.SparkLossCalculatorComputationGraph) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) MaxTimeIterationTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationCondition) Test(org.junit.Test)

Example 5 with DataSetToMultiDataSetFn

use of org.deeplearning4j.spark.impl.graph.dataset.DataSetToMultiDataSetFn in project deeplearning4j by deeplearning4j.

the class TestEarlyStoppingSparkCompGraph method testTimeTermination.

@Test
public void testTimeTermination() {
    //test termination after max time
    Nd4j.getRandom().setSeed(12345);
    ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).updater(Updater.SGD).learningRate(1e-6).weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in").addLayer("0", new OutputLayer.Builder().nIn(4).nOut(3).lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in").setOutputs("0").pretrain(false).backprop(true).build();
    ComputationGraph net = new ComputationGraph(conf);
    net.setListeners(new ScoreIterationListener(1));
    JavaRDD<DataSet> irisData = getIris();
    EarlyStoppingModelSaver<ComputationGraph> saver = new InMemoryModelSaver<>();
    EarlyStoppingConfiguration<ComputationGraph> esConf = new EarlyStoppingConfiguration.Builder<ComputationGraph>().epochTerminationConditions(new MaxEpochsTerminationCondition(10000)).iterationTerminationConditions(new MaxTimeIterationTerminationCondition(3, TimeUnit.SECONDS), //Initial score is ~2.5
    new MaxScoreIterationTerminationCondition(7.5)).scoreCalculator(new SparkLossCalculatorComputationGraph(irisData.map(new DataSetToMultiDataSetFn()), true, sc.sc())).modelSaver(saver).build();
    TrainingMaster tm = new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 10, 1, 0);
    IEarlyStoppingTrainer<ComputationGraph> trainer = new SparkEarlyStoppingGraphTrainer(getContext().sc(), tm, esConf, net, irisData.map(new DataSetToMultiDataSetFn()));
    long startTime = System.currentTimeMillis();
    EarlyStoppingResult result = trainer.fit();
    long endTime = System.currentTimeMillis();
    int durationSeconds = (int) (endTime - startTime) / 1000;
    assertTrue(durationSeconds >= 3);
    assertTrue(durationSeconds <= 9);
    assertEquals(EarlyStoppingResult.TerminationReason.IterationTerminationCondition, result.getTerminationReason());
    String expDetails = new MaxTimeIterationTerminationCondition(3, TimeUnit.SECONDS).toString();
    assertEquals(expDetails, result.getTerminationDetails());
}
Also used : OutputLayer(org.deeplearning4j.nn.conf.layers.OutputLayer) DataSet(org.nd4j.linalg.dataset.DataSet) TrainingMaster(org.deeplearning4j.spark.api.TrainingMaster) ParameterAveragingTrainingMaster(org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster) EarlyStoppingConfiguration(org.deeplearning4j.earlystopping.EarlyStoppingConfiguration) DataSetToMultiDataSetFn(org.deeplearning4j.spark.impl.graph.dataset.DataSetToMultiDataSetFn) SparkLossCalculatorComputationGraph(org.deeplearning4j.spark.earlystopping.SparkLossCalculatorComputationGraph) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) InMemoryModelSaver(org.deeplearning4j.earlystopping.saver.InMemoryModelSaver) MaxEpochsTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition) SparkEarlyStoppingGraphTrainer(org.deeplearning4j.spark.earlystopping.SparkEarlyStoppingGraphTrainer) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) ParameterAveragingTrainingMaster(org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster) EarlyStoppingResult(org.deeplearning4j.earlystopping.EarlyStoppingResult) SparkLossCalculatorComputationGraph(org.deeplearning4j.spark.earlystopping.SparkLossCalculatorComputationGraph) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) MaxScoreIterationTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxScoreIterationTerminationCondition) MaxTimeIterationTerminationCondition(org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationCondition) Test(org.junit.Test)

Aggregations

DataSetToMultiDataSetFn (org.deeplearning4j.spark.impl.graph.dataset.DataSetToMultiDataSetFn)6 EarlyStoppingConfiguration (org.deeplearning4j.earlystopping.EarlyStoppingConfiguration)5 InMemoryModelSaver (org.deeplearning4j.earlystopping.saver.InMemoryModelSaver)5 MaxEpochsTerminationCondition (org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition)5 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)5 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)5 ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)5 ScoreIterationListener (org.deeplearning4j.optimize.listeners.ScoreIterationListener)5 TrainingMaster (org.deeplearning4j.spark.api.TrainingMaster)5 SparkEarlyStoppingGraphTrainer (org.deeplearning4j.spark.earlystopping.SparkEarlyStoppingGraphTrainer)5 SparkLossCalculatorComputationGraph (org.deeplearning4j.spark.earlystopping.SparkLossCalculatorComputationGraph)5 ParameterAveragingTrainingMaster (org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster)5 Test (org.junit.Test)5 DataSet (org.nd4j.linalg.dataset.DataSet)5 MaxTimeIterationTerminationCondition (org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationCondition)4 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)4 EarlyStoppingResult (org.deeplearning4j.earlystopping.EarlyStoppingResult)3 MaxScoreIterationTerminationCondition (org.deeplearning4j.earlystopping.termination.MaxScoreIterationTerminationCondition)3 IrisDataSetIterator (org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator)1 ScoreImprovementEpochTerminationCondition (org.deeplearning4j.earlystopping.termination.ScoreImprovementEpochTerminationCondition)1