Search in sources :

Example 6 with SparkTrainingStats

use of org.deeplearning4j.spark.api.stats.SparkTrainingStats in project deeplearning4j by deeplearning4j.

the class TestTrainingStatsCollection method testStatsCollection.

@Test
public void testStatsCollection() throws Exception {
    int nWorkers = 4;
    SparkConf sparkConf = new SparkConf();
    sparkConf.setMaster("local[" + nWorkers + "]");
    sparkConf.setAppName("Test");
    JavaSparkContext sc = new JavaSparkContext(sparkConf);
    try {
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).list().layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build()).layer(1, new OutputLayer.Builder().nIn(10).nOut(10).build()).pretrain(false).backprop(true).build();
        int miniBatchSizePerWorker = 10;
        int averagingFrequency = 5;
        int numberOfAveragings = 3;
        int totalExamples = nWorkers * miniBatchSizePerWorker * averagingFrequency * numberOfAveragings;
        Nd4j.getRandom().setSeed(12345);
        List<DataSet> list = new ArrayList<>();
        for (int i = 0; i < totalExamples; i++) {
            INDArray f = Nd4j.rand(1, 10);
            INDArray l = Nd4j.rand(1, 10);
            DataSet ds = new DataSet(f, l);
            list.add(ds);
        }
        JavaRDD<DataSet> rdd = sc.parallelize(list);
        rdd.repartition(4);
        ParameterAveragingTrainingMaster tm = new ParameterAveragingTrainingMaster.Builder(nWorkers, 1).averagingFrequency(averagingFrequency).batchSizePerWorker(miniBatchSizePerWorker).saveUpdater(true).workerPrefetchNumBatches(0).repartionData(Repartition.Always).build();
        SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, conf, tm);
        sparkNet.setCollectTrainingStats(true);
        sparkNet.fit(rdd);
        //Collect the expected keys:
        List<String> expectedStatNames = new ArrayList<>();
        Class<?>[] classes = new Class[] { CommonSparkTrainingStats.class, ParameterAveragingTrainingMasterStats.class, ParameterAveragingTrainingWorkerStats.class };
        String[] fieldNames = new String[] { "columnNames", "columnNames", "columnNames" };
        for (int i = 0; i < classes.length; i++) {
            Field field = classes[i].getDeclaredField(fieldNames[i]);
            field.setAccessible(true);
            Object f = field.get(null);
            Collection<String> c = (Collection<String>) f;
            expectedStatNames.addAll(c);
        }
        System.out.println(expectedStatNames);
        SparkTrainingStats stats = sparkNet.getSparkTrainingStats();
        Set<String> actualKeySet = stats.getKeySet();
        assertEquals(expectedStatNames.size(), actualKeySet.size());
        for (String s : stats.getKeySet()) {
            assertTrue(expectedStatNames.contains(s));
            assertNotNull(stats.getValue(s));
        }
        String statsAsString = stats.statsAsString();
        System.out.println(statsAsString);
        //One line per stat
        assertEquals(actualKeySet.size(), statsAsString.split("\n").length);
        //Go through nested stats
        //First: master stats
        assertTrue(stats instanceof ParameterAveragingTrainingMasterStats);
        ParameterAveragingTrainingMasterStats masterStats = (ParameterAveragingTrainingMasterStats) stats;
        List<EventStats> exportTimeStats = masterStats.getParameterAveragingMasterExportTimesMs();
        assertEquals(1, exportTimeStats.size());
        assertDurationGreaterZero(exportTimeStats);
        assertNonNullFields(exportTimeStats);
        assertExpectedNumberMachineIdsJvmIdsThreadIds(exportTimeStats, 1, 1, 1);
        List<EventStats> countRddTime = masterStats.getParameterAveragingMasterCountRddSizeTimesMs();
        //occurs once per fit
        assertEquals(1, countRddTime.size());
        assertDurationGreaterEqZero(countRddTime);
        assertNonNullFields(countRddTime);
        //should occur only in master once
        assertExpectedNumberMachineIdsJvmIdsThreadIds(countRddTime, 1, 1, 1);
        List<EventStats> broadcastCreateTime = masterStats.getParameterAveragingMasterBroadcastCreateTimesMs();
        assertEquals(numberOfAveragings, broadcastCreateTime.size());
        assertDurationGreaterEqZero(broadcastCreateTime);
        assertNonNullFields(broadcastCreateTime);
        //only 1 thread for master
        assertExpectedNumberMachineIdsJvmIdsThreadIds(broadcastCreateTime, 1, 1, 1);
        List<EventStats> fitTimes = masterStats.getParameterAveragingMasterFitTimesMs();
        //i.e., number of times fit(JavaRDD<DataSet>) was called
        assertEquals(1, fitTimes.size());
        assertDurationGreaterZero(fitTimes);
        assertNonNullFields(fitTimes);
        //only 1 thread for master
        assertExpectedNumberMachineIdsJvmIdsThreadIds(fitTimes, 1, 1, 1);
        List<EventStats> splitTimes = masterStats.getParameterAveragingMasterSplitTimesMs();
        //Splitting of the data set is executed once only (i.e., one fit(JavaRDD<DataSet>) call)
        assertEquals(1, splitTimes.size());
        assertDurationGreaterEqZero(splitTimes);
        assertNonNullFields(splitTimes);
        //only 1 thread for master
        assertExpectedNumberMachineIdsJvmIdsThreadIds(splitTimes, 1, 1, 1);
        List<EventStats> aggregateTimesMs = masterStats.getParamaterAveragingMasterAggregateTimesMs();
        assertEquals(numberOfAveragings, aggregateTimesMs.size());
        assertDurationGreaterEqZero(aggregateTimesMs);
        assertNonNullFields(aggregateTimesMs);
        //only 1 thread for master
        assertExpectedNumberMachineIdsJvmIdsThreadIds(aggregateTimesMs, 1, 1, 1);
        List<EventStats> processParamsTimesMs = masterStats.getParameterAveragingMasterProcessParamsUpdaterTimesMs();
        assertEquals(numberOfAveragings, processParamsTimesMs.size());
        assertDurationGreaterEqZero(processParamsTimesMs);
        assertNonNullFields(processParamsTimesMs);
        //only 1 thread for master
        assertExpectedNumberMachineIdsJvmIdsThreadIds(processParamsTimesMs, 1, 1, 1);
        List<EventStats> repartitionTimesMs = masterStats.getParameterAveragingMasterRepartitionTimesMs();
        assertEquals(numberOfAveragings, repartitionTimesMs.size());
        assertDurationGreaterEqZero(repartitionTimesMs);
        assertNonNullFields(repartitionTimesMs);
        //only 1 thread for master
        assertExpectedNumberMachineIdsJvmIdsThreadIds(repartitionTimesMs, 1, 1, 1);
        //Second: Common spark training stats
        SparkTrainingStats commonStats = masterStats.getNestedTrainingStats();
        assertNotNull(commonStats);
        assertTrue(commonStats instanceof CommonSparkTrainingStats);
        CommonSparkTrainingStats cStats = (CommonSparkTrainingStats) commonStats;
        List<EventStats> workerFlatMapTotalTimeMs = cStats.getWorkerFlatMapTotalTimeMs();
        assertEquals(numberOfAveragings * nWorkers, workerFlatMapTotalTimeMs.size());
        assertDurationGreaterZero(workerFlatMapTotalTimeMs);
        assertNonNullFields(workerFlatMapTotalTimeMs);
        assertExpectedNumberMachineIdsJvmIdsThreadIds(workerFlatMapTotalTimeMs, 1, 1, nWorkers);
        List<EventStats> workerFlatMapGetInitialModelTimeMs = cStats.getWorkerFlatMapGetInitialModelTimeMs();
        assertEquals(numberOfAveragings * nWorkers, workerFlatMapGetInitialModelTimeMs.size());
        assertDurationGreaterEqZero(workerFlatMapGetInitialModelTimeMs);
        assertNonNullFields(workerFlatMapGetInitialModelTimeMs);
        assertExpectedNumberMachineIdsJvmIdsThreadIds(workerFlatMapGetInitialModelTimeMs, 1, 1, nWorkers);
        List<EventStats> workerFlatMapDataSetGetTimesMs = cStats.getWorkerFlatMapDataSetGetTimesMs();
        int numMinibatchesProcessed = workerFlatMapDataSetGetTimesMs.size();
        //1 for every time we get a data set
        int expectedNumMinibatchesProcessed = numberOfAveragings * nWorkers * averagingFrequency;
        //Sometimes random split is just bad - some executors might miss out on getting the expected amount of data
        assertTrue(numMinibatchesProcessed >= expectedNumMinibatchesProcessed - 5);
        List<EventStats> workerFlatMapProcessMiniBatchTimesMs = cStats.getWorkerFlatMapProcessMiniBatchTimesMs();
        assertTrue(workerFlatMapProcessMiniBatchTimesMs.size() >= numberOfAveragings * nWorkers * averagingFrequency - 5);
        assertDurationGreaterEqZero(workerFlatMapProcessMiniBatchTimesMs);
        assertNonNullFields(workerFlatMapDataSetGetTimesMs);
        assertExpectedNumberMachineIdsJvmIdsThreadIds(workerFlatMapDataSetGetTimesMs, 1, 1, nWorkers);
        //Third: ParameterAveragingTrainingWorker stats
        SparkTrainingStats paramAvgStats = cStats.getNestedTrainingStats();
        assertNotNull(paramAvgStats);
        assertTrue(paramAvgStats instanceof ParameterAveragingTrainingWorkerStats);
        ParameterAveragingTrainingWorkerStats pStats = (ParameterAveragingTrainingWorkerStats) paramAvgStats;
        List<EventStats> parameterAveragingWorkerBroadcastGetValueTimeMs = pStats.getParameterAveragingWorkerBroadcastGetValueTimeMs();
        assertEquals(numberOfAveragings * nWorkers, parameterAveragingWorkerBroadcastGetValueTimeMs.size());
        assertDurationGreaterEqZero(parameterAveragingWorkerBroadcastGetValueTimeMs);
        assertNonNullFields(parameterAveragingWorkerBroadcastGetValueTimeMs);
        assertExpectedNumberMachineIdsJvmIdsThreadIds(parameterAveragingWorkerBroadcastGetValueTimeMs, 1, 1, nWorkers);
        List<EventStats> parameterAveragingWorkerInitTimeMs = pStats.getParameterAveragingWorkerInitTimeMs();
        assertEquals(numberOfAveragings * nWorkers, parameterAveragingWorkerInitTimeMs.size());
        assertDurationGreaterEqZero(parameterAveragingWorkerInitTimeMs);
        assertNonNullFields(parameterAveragingWorkerInitTimeMs);
        assertExpectedNumberMachineIdsJvmIdsThreadIds(parameterAveragingWorkerInitTimeMs, 1, 1, nWorkers);
        List<EventStats> parameterAveragingWorkerFitTimesMs = pStats.getParameterAveragingWorkerFitTimesMs();
        assertTrue(parameterAveragingWorkerFitTimesMs.size() >= numberOfAveragings * nWorkers * averagingFrequency - 5);
        assertDurationGreaterEqZero(parameterAveragingWorkerFitTimesMs);
        assertNonNullFields(parameterAveragingWorkerFitTimesMs);
        assertExpectedNumberMachineIdsJvmIdsThreadIds(parameterAveragingWorkerFitTimesMs, 1, 1, nWorkers);
        assertNull(pStats.getNestedTrainingStats());
        //Finally: try exporting stats
        String tempDir = System.getProperty("java.io.tmpdir");
        String outDir = FilenameUtils.concat(tempDir, "dl4j_testTrainingStatsCollection");
        stats.exportStatFiles(outDir, sc.sc());
        String htmlPlotsPath = FilenameUtils.concat(outDir, "AnalysisPlots.html");
        StatsUtils.exportStatsAsHtml(stats, htmlPlotsPath, sc);
        ByteArrayOutputStream baos = new ByteArrayOutputStream();
        StatsUtils.exportStatsAsHTML(stats, baos);
        baos.close();
        byte[] bytes = baos.toByteArray();
        String str = new String(bytes, "UTF-8");
    //            System.out.println(str);
    } finally {
        sc.stop();
    }
}
Also used : OutputLayer(org.deeplearning4j.nn.conf.layers.OutputLayer) ParameterAveragingTrainingMasterStats(org.deeplearning4j.spark.impl.paramavg.stats.ParameterAveragingTrainingMasterStats) DataSet(org.nd4j.linalg.dataset.DataSet) CommonSparkTrainingStats(org.deeplearning4j.spark.api.stats.CommonSparkTrainingStats) SparkTrainingStats(org.deeplearning4j.spark.api.stats.SparkTrainingStats) Field(java.lang.reflect.Field) EventStats(org.deeplearning4j.spark.stats.EventStats) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) SparkDl4jMultiLayer(org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) ParameterAveragingTrainingWorkerStats(org.deeplearning4j.spark.impl.paramavg.stats.ParameterAveragingTrainingWorkerStats) ByteArrayOutputStream(java.io.ByteArrayOutputStream) ParameterAveragingTrainingMaster(org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster) INDArray(org.nd4j.linalg.api.ndarray.INDArray) CommonSparkTrainingStats(org.deeplearning4j.spark.api.stats.CommonSparkTrainingStats) SparkConf(org.apache.spark.SparkConf) Test(org.junit.Test)

Example 7 with SparkTrainingStats

use of org.deeplearning4j.spark.api.stats.SparkTrainingStats in project deeplearning4j by deeplearning4j.

the class ParameterAveragingTrainingMaster method processResults.

private void processResults(SparkDl4jMultiLayer network, SparkComputationGraph graph, JavaRDD<ParameterAveragingTrainingResult> results, int splitNum, int totalSplits) {
    if (collectTrainingStats)
        stats.logAggregateStartTime();
    ParameterAveragingAggregationTuple tuple = results.aggregate(null, new ParameterAveragingElementAddFunction(), new ParameterAveragingElementCombineFunction());
    INDArray params = tuple.getParametersSum();
    int aggCount = tuple.getAggregationsCount();
    SparkTrainingStats aggregatedStats = tuple.getSparkTrainingStats();
    if (collectTrainingStats)
        stats.logAggregationEndTime();
    if (collectTrainingStats)
        stats.logProcessParamsUpdaterStart();
    if (params != null) {
        params.divi(aggCount);
        INDArray updaterState = tuple.getUpdaterStateSum();
        if (updaterState != null)
            //May be null if all SGD updaters, for example
            updaterState.divi(aggCount);
        if (network != null) {
            MultiLayerNetwork net = network.getNetwork();
            net.setParameters(params);
            if (updaterState != null)
                net.getUpdater().setStateViewArray(null, updaterState, false);
            network.setScore(tuple.getScoreSum() / tuple.getAggregationsCount());
        } else {
            ComputationGraph g = graph.getNetwork();
            g.setParams(params);
            if (updaterState != null)
                g.getUpdater().setStateViewArray(updaterState);
            graph.setScore(tuple.getScoreSum() / tuple.getAggregationsCount());
        }
    } else {
        log.info("Skipping imbalanced split with no data for all executors");
    }
    if (collectTrainingStats) {
        stats.logProcessParamsUpdaterEnd();
        stats.addWorkerStats(aggregatedStats);
    }
    if (statsStorage != null) {
        Collection<StorageMetaData> meta = tuple.getListenerMetaData();
        if (meta != null && meta.size() > 0) {
            statsStorage.putStorageMetaData(meta);
        }
        Collection<Persistable> staticInfo = tuple.getListenerStaticInfo();
        if (staticInfo != null && staticInfo.size() > 0) {
            statsStorage.putStaticInfo(staticInfo);
        }
        Collection<Persistable> updates = tuple.getListenerUpdates();
        if (updates != null && updates.size() > 0) {
            statsStorage.putUpdate(updates);
        }
    }
    if (Nd4j.getExecutioner() instanceof GridExecutioner)
        ((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
    log.info("Completed training of split {} of {}", splitNum, totalSplits);
    if (params != null) {
        //Params may be null for edge case (empty RDD)
        if (network != null) {
            MultiLayerConfiguration conf = network.getNetwork().getLayerWiseConfigurations();
            int numUpdates = network.getNetwork().conf().getNumIterations() * averagingFrequency;
            conf.setIterationCount(conf.getIterationCount() + numUpdates);
        } else {
            ComputationGraphConfiguration conf = graph.getNetwork().getConfiguration();
            int numUpdates = graph.getNetwork().conf().getNumIterations() * averagingFrequency;
            conf.setIterationCount(conf.getIterationCount() + numUpdates);
        }
    }
}
Also used : ParameterAveragingElementCombineFunction(org.deeplearning4j.spark.impl.paramavg.aggregator.ParameterAveragingElementCombineFunction) SparkTrainingStats(org.deeplearning4j.spark.api.stats.SparkTrainingStats) ParameterAveragingAggregationTuple(org.deeplearning4j.spark.impl.paramavg.aggregator.ParameterAveragingAggregationTuple) GridExecutioner(org.nd4j.linalg.api.ops.executioner.GridExecutioner) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) ParameterAveragingElementAddFunction(org.deeplearning4j.spark.impl.paramavg.aggregator.ParameterAveragingElementAddFunction) SparkComputationGraph(org.deeplearning4j.spark.impl.graph.SparkComputationGraph) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork)

Example 8 with SparkTrainingStats

use of org.deeplearning4j.spark.api.stats.SparkTrainingStats in project deeplearning4j by deeplearning4j.

the class TestPreProcessedData method testPreprocessedDataCompGraphDataSet.

@Test
public void testPreprocessedDataCompGraphDataSet() {
    //Test _loading_ of preprocessed DataSet data
    int dataSetObjSize = 5;
    int batchSizePerExecutor = 10;
    String path = FilenameUtils.concat(System.getProperty("java.io.tmpdir"), "dl4j_testpreprocdata2");
    File f = new File(path);
    if (f.exists())
        f.delete();
    f.mkdir();
    DataSetIterator iter = new IrisDataSetIterator(5, 150);
    int i = 0;
    while (iter.hasNext()) {
        File f2 = new File(FilenameUtils.concat(path, "data" + (i++) + ".bin"));
        iter.next().save(f2);
    }
    ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().updater(Updater.RMSPROP).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).graphBuilder().addInputs("in").addLayer("0", new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(4).nOut(3).activation(Activation.TANH).build(), "in").addLayer("1", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(3).nOut(3).activation(Activation.SOFTMAX).build(), "0").setOutputs("1").pretrain(false).backprop(true).build();
    SparkComputationGraph sparkNet = new SparkComputationGraph(sc, conf, new ParameterAveragingTrainingMaster.Builder(numExecutors(), dataSetObjSize).batchSizePerWorker(batchSizePerExecutor).averagingFrequency(1).repartionData(Repartition.Always).build());
    sparkNet.setCollectTrainingStats(true);
    sparkNet.fit("file:///" + path.replaceAll("\\\\", "/"));
    SparkTrainingStats sts = sparkNet.getSparkTrainingStats();
    //4 'fits' per averaging (4 executors, 1 averaging freq); 10 examples each -> 40 examples per fit. 150/40 = 3 averagings (round down); 3*4 = 12
    int expNumFits = 12;
    //Unfortunately: perfect partitioning isn't guaranteed by SparkUtils.balancedRandomSplit (esp. if original partitions are all size 1
    // which appears to be occurring at least some of the time), but we should get close to what we expect...
    assertTrue(Math.abs(expNumFits - sts.getValue("ParameterAveragingWorkerFitTimesMs").size()) < 3);
    assertEquals(3, sts.getValue("ParameterAveragingMasterMapPartitionsTimesMs").size());
}
Also used : SparkComputationGraph(org.deeplearning4j.spark.impl.graph.SparkComputationGraph) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) ParameterAveragingTrainingMaster(org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster) SparkTrainingStats(org.deeplearning4j.spark.api.stats.SparkTrainingStats) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) File(java.io.File) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) PortableDataStreamDataSetIterator(org.deeplearning4j.spark.iterator.PortableDataStreamDataSetIterator) BaseSparkTest(org.deeplearning4j.spark.BaseSparkTest) Test(org.junit.Test)

Example 9 with SparkTrainingStats

use of org.deeplearning4j.spark.api.stats.SparkTrainingStats in project deeplearning4j by deeplearning4j.

the class TestPreProcessedData method testPreprocessedDataCompGraphMultiDataSet.

@Test
public void testPreprocessedDataCompGraphMultiDataSet() throws IOException {
    //Test _loading_ of preprocessed MultiDataSet data
    int dataSetObjSize = 5;
    int batchSizePerExecutor = 10;
    String path = FilenameUtils.concat(System.getProperty("java.io.tmpdir"), "dl4j_testpreprocdata3");
    File f = new File(path);
    if (f.exists())
        f.delete();
    f.mkdir();
    DataSetIterator iter = new IrisDataSetIterator(5, 150);
    int i = 0;
    while (iter.hasNext()) {
        File f2 = new File(FilenameUtils.concat(path, "data" + (i++) + ".bin"));
        DataSet ds = iter.next();
        MultiDataSet mds = new MultiDataSet(ds.getFeatures(), ds.getLabels());
        mds.save(f2);
    }
    ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().updater(Updater.RMSPROP).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).graphBuilder().addInputs("in").addLayer("0", new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(4).nOut(3).activation(Activation.TANH).build(), "in").addLayer("1", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(3).nOut(3).activation(Activation.SOFTMAX).build(), "0").setOutputs("1").pretrain(false).backprop(true).build();
    SparkComputationGraph sparkNet = new SparkComputationGraph(sc, conf, new ParameterAveragingTrainingMaster.Builder(numExecutors(), dataSetObjSize).batchSizePerWorker(batchSizePerExecutor).averagingFrequency(1).repartionData(Repartition.Always).build());
    sparkNet.setCollectTrainingStats(true);
    sparkNet.fitMultiDataSet("file:///" + path.replaceAll("\\\\", "/"));
    SparkTrainingStats sts = sparkNet.getSparkTrainingStats();
    //4 'fits' per averaging (4 executors, 1 averaging freq); 10 examples each -> 40 examples per fit. 150/40 = 3 averagings (round down); 3*4 = 12
    int expNumFits = 12;
    //Unfortunately: perfect partitioning isn't guaranteed by SparkUtils.balancedRandomSplit (esp. if original partitions are all size 1
    // which appears to be occurring at least some of the time), but we should get close to what we expect...
    assertTrue(Math.abs(expNumFits - sts.getValue("ParameterAveragingWorkerFitTimesMs").size()) < 3);
    assertEquals(3, sts.getValue("ParameterAveragingMasterMapPartitionsTimesMs").size());
}
Also used : SparkComputationGraph(org.deeplearning4j.spark.impl.graph.SparkComputationGraph) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) MultiDataSet(org.nd4j.linalg.dataset.MultiDataSet) DataSet(org.nd4j.linalg.dataset.DataSet) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) ParameterAveragingTrainingMaster(org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster) SparkTrainingStats(org.deeplearning4j.spark.api.stats.SparkTrainingStats) MultiDataSet(org.nd4j.linalg.dataset.MultiDataSet) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) File(java.io.File) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) PortableDataStreamDataSetIterator(org.deeplearning4j.spark.iterator.PortableDataStreamDataSetIterator) BaseSparkTest(org.deeplearning4j.spark.BaseSparkTest) Test(org.junit.Test)

Example 10 with SparkTrainingStats

use of org.deeplearning4j.spark.api.stats.SparkTrainingStats in project deeplearning4j by deeplearning4j.

the class TestPreProcessedData method testPreprocessedData.

@Test
public void testPreprocessedData() {
    //Test _loading_ of preprocessed data
    int dataSetObjSize = 5;
    int batchSizePerExecutor = 10;
    String path = FilenameUtils.concat(System.getProperty("java.io.tmpdir"), "dl4j_testpreprocdata");
    File f = new File(path);
    if (f.exists())
        f.delete();
    f.mkdir();
    DataSetIterator iter = new IrisDataSetIterator(5, 150);
    int i = 0;
    while (iter.hasNext()) {
        File f2 = new File(FilenameUtils.concat(path, "data" + (i++) + ".bin"));
        iter.next().save(f2);
    }
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(Updater.RMSPROP).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).list().layer(0, new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(4).nOut(3).activation(Activation.TANH).build()).layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(3).nOut(3).activation(Activation.SOFTMAX).build()).pretrain(false).backprop(true).build();
    SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, conf, new ParameterAveragingTrainingMaster.Builder(numExecutors(), dataSetObjSize).batchSizePerWorker(batchSizePerExecutor).averagingFrequency(1).repartionData(Repartition.Always).build());
    sparkNet.setCollectTrainingStats(true);
    sparkNet.fit("file:///" + path.replaceAll("\\\\", "/"));
    SparkTrainingStats sts = sparkNet.getSparkTrainingStats();
    //4 'fits' per averaging (4 executors, 1 averaging freq); 10 examples each -> 40 examples per fit. 150/40 = 3 averagings (round down); 3*4 = 12
    int expNumFits = 12;
    //Unfortunately: perfect partitioning isn't guaranteed by SparkUtils.balancedRandomSplit (esp. if original partitions are all size 1
    // which appears to be occurring at least some of the time), but we should get close to what we expect...
    assertTrue(Math.abs(expNumFits - sts.getValue("ParameterAveragingWorkerFitTimesMs").size()) < 3);
    assertEquals(3, sts.getValue("ParameterAveragingMasterMapPartitionsTimesMs").size());
}
Also used : IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) ParameterAveragingTrainingMaster(org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster) SparkTrainingStats(org.deeplearning4j.spark.api.stats.SparkTrainingStats) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) SparkDl4jMultiLayer(org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer) File(java.io.File) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) PortableDataStreamDataSetIterator(org.deeplearning4j.spark.iterator.PortableDataStreamDataSetIterator) BaseSparkTest(org.deeplearning4j.spark.BaseSparkTest) Test(org.junit.Test)

Aggregations

SparkTrainingStats (org.deeplearning4j.spark.api.stats.SparkTrainingStats)17 Test (org.junit.Test)8 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)8 IrisDataSetIterator (org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator)7 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)7 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)7 BaseSparkTest (org.deeplearning4j.spark.BaseSparkTest)7 INDArray (org.nd4j.linalg.api.ndarray.INDArray)7 DataSet (org.nd4j.linalg.dataset.DataSet)7 File (java.io.File)6 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)6 SparkDl4jMultiLayer (org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer)5 GridExecutioner (org.nd4j.linalg.api.ops.executioner.GridExecutioner)5 MultiDataSet (org.nd4j.linalg.dataset.MultiDataSet)5 LabeledPoint (org.apache.spark.mllib.regression.LabeledPoint)4 Pair (org.deeplearning4j.berkeley.Pair)4 MnistDataSetIterator (org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator)4 SparkComputationGraph (org.deeplearning4j.spark.impl.graph.SparkComputationGraph)4 ParameterAveragingTrainingMaster (org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster)4 Path (java.nio.file.Path)3