Search in sources :

Example 96 with ComputationGraphConfiguration

use of org.deeplearning4j.nn.conf.ComputationGraphConfiguration in project deeplearning4j by deeplearning4j.

the class TestSparkMultiLayerParameterAveraging method testIterationCountsGraph.

@Test
public void testIterationCountsGraph() throws Exception {
    int dataSetObjSize = 5;
    int batchSizePerExecutor = 25;
    List<DataSet> list = new ArrayList<>();
    int minibatchesPerWorkerPerEpoch = 10;
    DataSetIterator iter = new MnistDataSetIterator(dataSetObjSize, batchSizePerExecutor * numExecutors() * minibatchesPerWorkerPerEpoch, false);
    while (iter.hasNext()) {
        list.add(iter.next());
    }
    ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().updater(Updater.RMSPROP).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).graphBuilder().addInputs("in").addLayer("0", new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(28 * 28).nOut(50).activation(Activation.TANH).build(), "in").addLayer("1", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(50).nOut(10).activation(Activation.SOFTMAX).build(), "0").pretrain(false).backprop(true).setOutputs("1").build();
    for (int avgFreq : new int[] { 1, 5, 10 }) {
        System.out.println("--- Avg freq " + avgFreq + " ---");
        SparkComputationGraph sparkNet = new SparkComputationGraph(sc, conf.clone(), new ParameterAveragingTrainingMaster.Builder(numExecutors(), dataSetObjSize).batchSizePerWorker(batchSizePerExecutor).averagingFrequency(avgFreq).repartionData(Repartition.Always).build());
        sparkNet.setListeners(new ScoreIterationListener(1));
        JavaRDD<DataSet> rdd = sc.parallelize(list);
        assertEquals(0, sparkNet.getNetwork().getConfiguration().getIterationCount());
        sparkNet.fit(rdd);
        assertEquals(minibatchesPerWorkerPerEpoch, sparkNet.getNetwork().getConfiguration().getIterationCount());
        sparkNet.fit(rdd);
        assertEquals(2 * minibatchesPerWorkerPerEpoch, sparkNet.getNetwork().getConfiguration().getIterationCount());
        sparkNet.getTrainingMaster().deleteTempFiles(sc);
    }
}
Also used : OutputLayer(org.deeplearning4j.nn.conf.layers.OutputLayer) SparkComputationGraph(org.deeplearning4j.spark.impl.graph.SparkComputationGraph) MnistDataSetIterator(org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator) MultiDataSet(org.nd4j.linalg.dataset.MultiDataSet) DataSet(org.nd4j.linalg.dataset.DataSet) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) LabeledPoint(org.apache.spark.mllib.regression.LabeledPoint) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) MnistDataSetIterator(org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator) BaseSparkTest(org.deeplearning4j.spark.BaseSparkTest) Test(org.junit.Test)

Aggregations

ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)96 Test (org.junit.Test)84 ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)63 INDArray (org.nd4j.linalg.api.ndarray.INDArray)51 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)50 NormalDistribution (org.deeplearning4j.nn.conf.distribution.NormalDistribution)28 DataSet (org.nd4j.linalg.dataset.DataSet)22 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)20 IrisDataSetIterator (org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator)17 ScoreIterationListener (org.deeplearning4j.optimize.listeners.ScoreIterationListener)17 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)17 DenseLayer (org.deeplearning4j.nn.conf.layers.DenseLayer)14 Random (java.util.Random)13 ParameterAveragingTrainingMaster (org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster)11 InMemoryModelSaver (org.deeplearning4j.earlystopping.saver.InMemoryModelSaver)10 MaxEpochsTerminationCondition (org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition)10 MultiDataSet (org.nd4j.linalg.dataset.MultiDataSet)10 MaxTimeIterationTerminationCondition (org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationCondition)9 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)9 RnnOutputLayer (org.deeplearning4j.nn.conf.layers.RnnOutputLayer)9