Search in sources :

Example 11 with SparkComputationGraph

use of org.deeplearning4j.spark.impl.graph.SparkComputationGraph in project deeplearning4j by deeplearning4j.

the class TestSparkMultiLayerParameterAveraging method testIterationCountsGraph.

@Test
public void testIterationCountsGraph() throws Exception {
    int dataSetObjSize = 5;
    int batchSizePerExecutor = 25;
    List<DataSet> list = new ArrayList<>();
    int minibatchesPerWorkerPerEpoch = 10;
    DataSetIterator iter = new MnistDataSetIterator(dataSetObjSize, batchSizePerExecutor * numExecutors() * minibatchesPerWorkerPerEpoch, false);
    while (iter.hasNext()) {
        list.add(iter.next());
    }
    ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().updater(Updater.RMSPROP).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).graphBuilder().addInputs("in").addLayer("0", new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(28 * 28).nOut(50).activation(Activation.TANH).build(), "in").addLayer("1", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(50).nOut(10).activation(Activation.SOFTMAX).build(), "0").pretrain(false).backprop(true).setOutputs("1").build();
    for (int avgFreq : new int[] { 1, 5, 10 }) {
        System.out.println("--- Avg freq " + avgFreq + " ---");
        SparkComputationGraph sparkNet = new SparkComputationGraph(sc, conf.clone(), new ParameterAveragingTrainingMaster.Builder(numExecutors(), dataSetObjSize).batchSizePerWorker(batchSizePerExecutor).averagingFrequency(avgFreq).repartionData(Repartition.Always).build());
        sparkNet.setListeners(new ScoreIterationListener(1));
        JavaRDD<DataSet> rdd = sc.parallelize(list);
        assertEquals(0, sparkNet.getNetwork().getConfiguration().getIterationCount());
        sparkNet.fit(rdd);
        assertEquals(minibatchesPerWorkerPerEpoch, sparkNet.getNetwork().getConfiguration().getIterationCount());
        sparkNet.fit(rdd);
        assertEquals(2 * minibatchesPerWorkerPerEpoch, sparkNet.getNetwork().getConfiguration().getIterationCount());
        sparkNet.getTrainingMaster().deleteTempFiles(sc);
    }
}
Also used : OutputLayer(org.deeplearning4j.nn.conf.layers.OutputLayer) SparkComputationGraph(org.deeplearning4j.spark.impl.graph.SparkComputationGraph) MnistDataSetIterator(org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator) MultiDataSet(org.nd4j.linalg.dataset.MultiDataSet) DataSet(org.nd4j.linalg.dataset.DataSet) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) LabeledPoint(org.apache.spark.mllib.regression.LabeledPoint) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) MnistDataSetIterator(org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator) BaseSparkTest(org.deeplearning4j.spark.BaseSparkTest) Test(org.junit.Test)

Aggregations

SparkComputationGraph (org.deeplearning4j.spark.impl.graph.SparkComputationGraph)11 Test (org.junit.Test)9 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)8 DataSet (org.nd4j.linalg.dataset.DataSet)7 BaseSparkTest (org.deeplearning4j.spark.BaseSparkTest)6 INDArray (org.nd4j.linalg.api.ndarray.INDArray)6 IrisDataSetIterator (org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator)5 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)5 ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)5 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)5 JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)4 TrainingMaster (org.deeplearning4j.spark.api.TrainingMaster)4 SparkTrainingStats (org.deeplearning4j.spark.api.stats.SparkTrainingStats)4 MultiDataSet (org.nd4j.linalg.dataset.MultiDataSet)4 File (java.io.File)3 LabeledPoint (org.apache.spark.mllib.regression.LabeledPoint)3 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)3 ParameterAveragingTrainingMaster (org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster)3 MnistDataSetIterator (org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator)2 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)2