Search in sources :

Example 1 with SparkComputationGraph

use of org.deeplearning4j.spark.impl.graph.SparkComputationGraph in project deeplearning4j by deeplearning4j.

the class TestKryoWarning method doTestCG.

private static void doTestCG(SparkConf sparkConf) {
    JavaSparkContext sc = new JavaSparkContext(sparkConf);
    try {
        ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in").addLayer("0", new OutputLayer.Builder().nIn(10).nOut(10).build(), "in").setOutputs("0").pretrain(false).backprop(true).build();
        TrainingMaster tm = new ParameterAveragingTrainingMaster.Builder(1).build();
        SparkListenable scg = new SparkComputationGraph(sc, conf, tm);
    } finally {
        sc.stop();
    }
}
Also used : SparkComputationGraph(org.deeplearning4j.spark.impl.graph.SparkComputationGraph) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) ParameterAveragingTrainingMaster(org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster) ParameterAveragingTrainingMaster(org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster) TrainingMaster(org.deeplearning4j.spark.api.TrainingMaster)

Example 2 with SparkComputationGraph

use of org.deeplearning4j.spark.impl.graph.SparkComputationGraph in project deeplearning4j by deeplearning4j.

the class TestCompareParameterAveragingSparkVsSingleMachine method testAverageEveryStepGraph.

@Test
public void testAverageEveryStepGraph() {
    //Idea: averaging every step with SGD (SGD updater + optimizer) is mathematically identical to doing the learning
    // on a single machine for synchronous distributed training
    //BUT: This is *ONLY* the case if all workers get an identical number of examples. This won't be the case if
    // we use RDD.randomSplit (which is what occurs if we use .fit(JavaRDD<DataSet> on a data set that needs splitting),
    // which might give a number of examples that isn't divisible by number of workers (like 39 examples on 4 executors)
    //This is also ONLY the case using SGD updater
    int miniBatchSizePerWorker = 10;
    int nWorkers = 4;
    for (boolean saveUpdater : new boolean[] { true, false }) {
        JavaSparkContext sc = getContext(nWorkers);
        try {
            //Do training locally, for 3 minibatches
            int[] seeds = { 1, 2, 3 };
            //                CudaGridExecutioner executioner = (CudaGridExecutioner) Nd4j.getExecutioner();
            ComputationGraph net = new ComputationGraph(getGraphConf(12345, Updater.SGD));
            net.init();
            INDArray initialParams = net.params().dup();
            for (int i = 0; i < seeds.length; i++) {
                DataSet ds = getOneDataSet(miniBatchSizePerWorker * nWorkers, seeds[i]);
                if (!saveUpdater)
                    net.setUpdater(null);
                net.fit(ds);
            }
            INDArray finalParams = net.params().dup();
            //                executioner.addToWatchdog(finalParams, "finalParams");
            //Do training on Spark with one executor, for 3 separate minibatches
            TrainingMaster tm = getTrainingMaster(1, miniBatchSizePerWorker, saveUpdater);
            SparkComputationGraph sparkNet = new SparkComputationGraph(sc, getGraphConf(12345, Updater.SGD), tm);
            sparkNet.setCollectTrainingStats(true);
            INDArray initialSparkParams = sparkNet.getNetwork().params().dup();
            for (int i = 0; i < seeds.length; i++) {
                List<DataSet> list = getOneDataSetAsIndividalExamples(miniBatchSizePerWorker * nWorkers, seeds[i]);
                JavaRDD<DataSet> rdd = sc.parallelize(list);
                sparkNet.fit(rdd);
            }
            System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
            INDArray finalSparkParams = sparkNet.getNetwork().params().dup();
            //                executioner.addToWatchdog(finalSparkParams, "finalSparkParams");
            float[] fp = finalParams.data().asFloat();
            float[] fps = finalSparkParams.data().asFloat();
            System.out.println("Initial (Local) params:       " + Arrays.toString(initialParams.data().asFloat()));
            System.out.println("Initial (Spark) params:       " + Arrays.toString(initialSparkParams.data().asFloat()));
            System.out.println("Final (Local) params: " + Arrays.toString(fp));
            System.out.println("Final (Spark) params: " + Arrays.toString(fps));
            assertEquals(initialParams, initialSparkParams);
            assertNotEquals(initialParams, finalParams);
            assertArrayEquals(fp, fps, 1e-5f);
            double sparkScore = sparkNet.getScore();
            assertTrue(sparkScore > 0.0);
            assertEquals(net.score(), sparkScore, 1e-3);
        } finally {
            sc.stop();
        }
    }
}
Also used : SparkComputationGraph(org.deeplearning4j.spark.impl.graph.SparkComputationGraph) DataSet(org.nd4j.linalg.dataset.DataSet) TrainingMaster(org.deeplearning4j.spark.api.TrainingMaster) INDArray(org.nd4j.linalg.api.ndarray.INDArray) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) SparkComputationGraph(org.deeplearning4j.spark.impl.graph.SparkComputationGraph) Test(org.junit.Test)

Example 3 with SparkComputationGraph

use of org.deeplearning4j.spark.impl.graph.SparkComputationGraph in project deeplearning4j by deeplearning4j.

the class TestCompareParameterAveragingSparkVsSingleMachine method testAverageEveryStepGraphCNN.

@Test
public void testAverageEveryStepGraphCNN() {
    //Idea: averaging every step with SGD (SGD updater + optimizer) is mathematically identical to doing the learning
    // on a single machine for synchronous distributed training
    //BUT: This is *ONLY* the case if all workers get an identical number of examples. This won't be the case if
    // we use RDD.randomSplit (which is what occurs if we use .fit(JavaRDD<DataSet> on a data set that needs splitting),
    // which might give a number of examples that isn't divisible by number of workers (like 39 examples on 4 executors)
    //This is also ONLY the case using SGD updater
    int miniBatchSizePerWorker = 10;
    int nWorkers = 4;
    for (boolean saveUpdater : new boolean[] { true, false }) {
        JavaSparkContext sc = getContext(nWorkers);
        try {
            //Do training locally, for 3 minibatches
            int[] seeds = { 1, 2, 3 };
            ComputationGraph net = new ComputationGraph(getGraphConfCNN(12345, Updater.SGD));
            net.init();
            INDArray initialParams = net.params().dup();
            for (int i = 0; i < seeds.length; i++) {
                DataSet ds = getOneDataSetCNN(miniBatchSizePerWorker * nWorkers, seeds[i]);
                if (!saveUpdater)
                    net.setUpdater(null);
                net.fit(ds);
            }
            INDArray finalParams = net.params().dup();
            //Do training on Spark with one executor, for 3 separate minibatches
            TrainingMaster tm = getTrainingMaster(1, miniBatchSizePerWorker, saveUpdater);
            SparkComputationGraph sparkNet = new SparkComputationGraph(sc, getGraphConfCNN(12345, Updater.SGD), tm);
            sparkNet.setCollectTrainingStats(true);
            INDArray initialSparkParams = sparkNet.getNetwork().params().dup();
            for (int i = 0; i < seeds.length; i++) {
                List<DataSet> list = getOneDataSetAsIndividalExamplesCNN(miniBatchSizePerWorker * nWorkers, seeds[i]);
                JavaRDD<DataSet> rdd = sc.parallelize(list);
                sparkNet.fit(rdd);
            }
            System.out.println(sparkNet.getSparkTrainingStats().statsAsString());
            INDArray finalSparkParams = sparkNet.getNetwork().params().dup();
            System.out.println("Initial (Local) params:  " + Arrays.toString(initialParams.data().asFloat()));
            System.out.println("Initial (Spark) params:  " + Arrays.toString(initialSparkParams.data().asFloat()));
            System.out.println("Final (Local) params:    " + Arrays.toString(finalParams.data().asFloat()));
            System.out.println("Final (Spark) params:    " + Arrays.toString(finalSparkParams.data().asFloat()));
            assertArrayEquals(initialParams.data().asFloat(), initialSparkParams.data().asFloat(), 1e-8f);
            assertArrayEquals(finalParams.data().asFloat(), finalSparkParams.data().asFloat(), 1e-6f);
            double sparkScore = sparkNet.getScore();
            assertTrue(sparkScore > 0.0);
            assertEquals(net.score(), sparkScore, 1e-3);
        } finally {
            sc.stop();
        }
    }
}
Also used : SparkComputationGraph(org.deeplearning4j.spark.impl.graph.SparkComputationGraph) INDArray(org.nd4j.linalg.api.ndarray.INDArray) DataSet(org.nd4j.linalg.dataset.DataSet) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) SparkComputationGraph(org.deeplearning4j.spark.impl.graph.SparkComputationGraph) TrainingMaster(org.deeplearning4j.spark.api.TrainingMaster) Test(org.junit.Test)

Example 4 with SparkComputationGraph

use of org.deeplearning4j.spark.impl.graph.SparkComputationGraph in project deeplearning4j by deeplearning4j.

the class TestSparkMultiLayerParameterAveraging method testVaePretrainSimpleCG.

@Test
public void testVaePretrainSimpleCG() {
    //Simple sanity check on pretraining
    int nIn = 8;
    Nd4j.getRandom().setSeed(12345);
    ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(Updater.RMSPROP).weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in").addLayer("0", new VariationalAutoencoder.Builder().nIn(8).nOut(10).encoderLayerSizes(12).decoderLayerSizes(13).reconstructionDistribution(new GaussianReconstructionDistribution("identity")).build(), "in").setOutputs("0").pretrain(true).backprop(false).build();
    //Do training on Spark with one executor, for 3 separate minibatches
    int rddDataSetNumExamples = 10;
    int totalAveragings = 5;
    int averagingFrequency = 3;
    ParameterAveragingTrainingMaster tm = new ParameterAveragingTrainingMaster.Builder(rddDataSetNumExamples).averagingFrequency(averagingFrequency).batchSizePerWorker(rddDataSetNumExamples).saveUpdater(true).workerPrefetchNumBatches(0).build();
    Nd4j.getRandom().setSeed(12345);
    SparkComputationGraph sparkNet = new SparkComputationGraph(sc, conf.clone(), tm);
    List<DataSet> trainData = new ArrayList<>();
    int nDataSets = numExecutors() * totalAveragings * averagingFrequency;
    for (int i = 0; i < nDataSets; i++) {
        trainData.add(new DataSet(Nd4j.rand(rddDataSetNumExamples, nIn), null));
    }
    JavaRDD<DataSet> data = sc.parallelize(trainData);
    sparkNet.fit(data);
}
Also used : SparkComputationGraph(org.deeplearning4j.spark.impl.graph.SparkComputationGraph) MultiDataSet(org.nd4j.linalg.dataset.MultiDataSet) DataSet(org.nd4j.linalg.dataset.DataSet) VariationalAutoencoder(org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder) LabeledPoint(org.apache.spark.mllib.regression.LabeledPoint) GaussianReconstructionDistribution(org.deeplearning4j.nn.conf.layers.variational.GaussianReconstructionDistribution) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) BaseSparkTest(org.deeplearning4j.spark.BaseSparkTest) Test(org.junit.Test)

Example 5 with SparkComputationGraph

use of org.deeplearning4j.spark.impl.graph.SparkComputationGraph in project deeplearning4j by deeplearning4j.

the class ParameterAveragingTrainingMaster method processResults.

private void processResults(SparkDl4jMultiLayer network, SparkComputationGraph graph, JavaRDD<ParameterAveragingTrainingResult> results, int splitNum, int totalSplits) {
    if (collectTrainingStats)
        stats.logAggregateStartTime();
    ParameterAveragingAggregationTuple tuple = results.aggregate(null, new ParameterAveragingElementAddFunction(), new ParameterAveragingElementCombineFunction());
    INDArray params = tuple.getParametersSum();
    int aggCount = tuple.getAggregationsCount();
    SparkTrainingStats aggregatedStats = tuple.getSparkTrainingStats();
    if (collectTrainingStats)
        stats.logAggregationEndTime();
    if (collectTrainingStats)
        stats.logProcessParamsUpdaterStart();
    if (params != null) {
        params.divi(aggCount);
        INDArray updaterState = tuple.getUpdaterStateSum();
        if (updaterState != null)
            //May be null if all SGD updaters, for example
            updaterState.divi(aggCount);
        if (network != null) {
            MultiLayerNetwork net = network.getNetwork();
            net.setParameters(params);
            if (updaterState != null)
                net.getUpdater().setStateViewArray(null, updaterState, false);
            network.setScore(tuple.getScoreSum() / tuple.getAggregationsCount());
        } else {
            ComputationGraph g = graph.getNetwork();
            g.setParams(params);
            if (updaterState != null)
                g.getUpdater().setStateViewArray(updaterState);
            graph.setScore(tuple.getScoreSum() / tuple.getAggregationsCount());
        }
    } else {
        log.info("Skipping imbalanced split with no data for all executors");
    }
    if (collectTrainingStats) {
        stats.logProcessParamsUpdaterEnd();
        stats.addWorkerStats(aggregatedStats);
    }
    if (statsStorage != null) {
        Collection<StorageMetaData> meta = tuple.getListenerMetaData();
        if (meta != null && meta.size() > 0) {
            statsStorage.putStorageMetaData(meta);
        }
        Collection<Persistable> staticInfo = tuple.getListenerStaticInfo();
        if (staticInfo != null && staticInfo.size() > 0) {
            statsStorage.putStaticInfo(staticInfo);
        }
        Collection<Persistable> updates = tuple.getListenerUpdates();
        if (updates != null && updates.size() > 0) {
            statsStorage.putUpdate(updates);
        }
    }
    if (Nd4j.getExecutioner() instanceof GridExecutioner)
        ((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
    log.info("Completed training of split {} of {}", splitNum, totalSplits);
    if (params != null) {
        //Params may be null for edge case (empty RDD)
        if (network != null) {
            MultiLayerConfiguration conf = network.getNetwork().getLayerWiseConfigurations();
            int numUpdates = network.getNetwork().conf().getNumIterations() * averagingFrequency;
            conf.setIterationCount(conf.getIterationCount() + numUpdates);
        } else {
            ComputationGraphConfiguration conf = graph.getNetwork().getConfiguration();
            int numUpdates = graph.getNetwork().conf().getNumIterations() * averagingFrequency;
            conf.setIterationCount(conf.getIterationCount() + numUpdates);
        }
    }
}
Also used : ParameterAveragingElementCombineFunction(org.deeplearning4j.spark.impl.paramavg.aggregator.ParameterAveragingElementCombineFunction) SparkTrainingStats(org.deeplearning4j.spark.api.stats.SparkTrainingStats) ParameterAveragingAggregationTuple(org.deeplearning4j.spark.impl.paramavg.aggregator.ParameterAveragingAggregationTuple) GridExecutioner(org.nd4j.linalg.api.ops.executioner.GridExecutioner) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) ParameterAveragingElementAddFunction(org.deeplearning4j.spark.impl.paramavg.aggregator.ParameterAveragingElementAddFunction) SparkComputationGraph(org.deeplearning4j.spark.impl.graph.SparkComputationGraph) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork)

Aggregations

SparkComputationGraph (org.deeplearning4j.spark.impl.graph.SparkComputationGraph)11 Test (org.junit.Test)9 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)8 DataSet (org.nd4j.linalg.dataset.DataSet)7 BaseSparkTest (org.deeplearning4j.spark.BaseSparkTest)6 INDArray (org.nd4j.linalg.api.ndarray.INDArray)6 IrisDataSetIterator (org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator)5 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)5 ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)5 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)5 JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)4 TrainingMaster (org.deeplearning4j.spark.api.TrainingMaster)4 SparkTrainingStats (org.deeplearning4j.spark.api.stats.SparkTrainingStats)4 MultiDataSet (org.nd4j.linalg.dataset.MultiDataSet)4 File (java.io.File)3 LabeledPoint (org.apache.spark.mllib.regression.LabeledPoint)3 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)3 ParameterAveragingTrainingMaster (org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster)3 MnistDataSetIterator (org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator)2 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)2