use of org.deeplearning4j.spark.impl.graph.SparkComputationGraph in project deeplearning4j by deeplearning4j.
the class TestSparkMultiLayerParameterAveraging method testIterationCountsGraph.
@Test
public void testIterationCountsGraph() throws Exception {
int dataSetObjSize = 5;
int batchSizePerExecutor = 25;
List<DataSet> list = new ArrayList<>();
int minibatchesPerWorkerPerEpoch = 10;
DataSetIterator iter = new MnistDataSetIterator(dataSetObjSize, batchSizePerExecutor * numExecutors() * minibatchesPerWorkerPerEpoch, false);
while (iter.hasNext()) {
list.add(iter.next());
}
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().updater(Updater.RMSPROP).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).graphBuilder().addInputs("in").addLayer("0", new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(28 * 28).nOut(50).activation(Activation.TANH).build(), "in").addLayer("1", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(50).nOut(10).activation(Activation.SOFTMAX).build(), "0").pretrain(false).backprop(true).setOutputs("1").build();
for (int avgFreq : new int[] { 1, 5, 10 }) {
System.out.println("--- Avg freq " + avgFreq + " ---");
SparkComputationGraph sparkNet = new SparkComputationGraph(sc, conf.clone(), new ParameterAveragingTrainingMaster.Builder(numExecutors(), dataSetObjSize).batchSizePerWorker(batchSizePerExecutor).averagingFrequency(avgFreq).repartionData(Repartition.Always).build());
sparkNet.setListeners(new ScoreIterationListener(1));
JavaRDD<DataSet> rdd = sc.parallelize(list);
assertEquals(0, sparkNet.getNetwork().getConfiguration().getIterationCount());
sparkNet.fit(rdd);
assertEquals(minibatchesPerWorkerPerEpoch, sparkNet.getNetwork().getConfiguration().getIterationCount());
sparkNet.fit(rdd);
assertEquals(2 * minibatchesPerWorkerPerEpoch, sparkNet.getNetwork().getConfiguration().getIterationCount());
sparkNet.getTrainingMaster().deleteTempFiles(sc);
}
}
Aggregations