Search in sources :

Example 41 with ComputationGraph

use of org.deeplearning4j.nn.graph.ComputationGraph in project deeplearning4j by deeplearning4j.

the class RegressionTest060 method regressionTestCGLSTM1.

@Test
public void regressionTestCGLSTM1() throws Exception {
    File f = new ClassPathResource("regression_testing/060/060_ModelSerializer_Regression_CG_LSTM_1.zip").getTempFileFromArchive();
    ComputationGraph net = ModelSerializer.restoreComputationGraph(f, true);
    ComputationGraphConfiguration conf = net.getConfiguration();
    assertEquals(3, conf.getVertices().size());
    assertTrue(conf.isBackprop());
    assertFalse(conf.isPretrain());
    GravesLSTM l0 = (GravesLSTM) ((LayerVertex) conf.getVertices().get("0")).getLayerConf().getLayer();
    assertEquals("tanh", l0.getActivationFn().toString());
    assertEquals(3, l0.getNIn());
    assertEquals(4, l0.getNOut());
    assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, l0.getGradientNormalization());
    assertEquals(1.5, l0.getGradientNormalizationThreshold(), 1e-5);
    GravesBidirectionalLSTM l1 = (GravesBidirectionalLSTM) ((LayerVertex) conf.getVertices().get("1")).getLayerConf().getLayer();
    assertEquals("softsign", l1.getActivationFn().toString());
    assertEquals(4, l1.getNIn());
    assertEquals(4, l1.getNOut());
    assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, l1.getGradientNormalization());
    assertEquals(1.5, l1.getGradientNormalizationThreshold(), 1e-5);
    RnnOutputLayer l2 = (RnnOutputLayer) ((LayerVertex) conf.getVertices().get("2")).getLayerConf().getLayer();
    assertEquals(4, l2.getNIn());
    assertEquals(5, l2.getNOut());
    assertEquals("softmax", l2.getActivationFn().toString());
    assertTrue(l2.getLossFn() instanceof LossMCXENT);
}
Also used : LayerVertex(org.deeplearning4j.nn.conf.graph.LayerVertex) LossMCXENT(org.nd4j.linalg.lossfunctions.impl.LossMCXENT) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) File(java.io.File) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) Test(org.junit.Test)

Example 42 with ComputationGraph

use of org.deeplearning4j.nn.graph.ComputationGraph in project deeplearning4j by deeplearning4j.

the class RegressionTest071 method regressionTestCGLSTM1.

@Test
public void regressionTestCGLSTM1() throws Exception {
    File f = new ClassPathResource("regression_testing/071/071_ModelSerializer_Regression_CG_LSTM_1.zip").getTempFileFromArchive();
    ComputationGraph net = ModelSerializer.restoreComputationGraph(f, true);
    ComputationGraphConfiguration conf = net.getConfiguration();
    assertEquals(3, conf.getVertices().size());
    assertTrue(conf.isBackprop());
    assertFalse(conf.isPretrain());
    GravesLSTM l0 = (GravesLSTM) ((LayerVertex) conf.getVertices().get("0")).getLayerConf().getLayer();
    assertEquals("tanh", l0.getActivationFn().toString());
    assertEquals(3, l0.getNIn());
    assertEquals(4, l0.getNOut());
    assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, l0.getGradientNormalization());
    assertEquals(1.5, l0.getGradientNormalizationThreshold(), 1e-5);
    GravesBidirectionalLSTM l1 = (GravesBidirectionalLSTM) ((LayerVertex) conf.getVertices().get("1")).getLayerConf().getLayer();
    assertEquals("softsign", l1.getActivationFn().toString());
    assertEquals(4, l1.getNIn());
    assertEquals(4, l1.getNOut());
    assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, l1.getGradientNormalization());
    assertEquals(1.5, l1.getGradientNormalizationThreshold(), 1e-5);
    RnnOutputLayer l2 = (RnnOutputLayer) ((LayerVertex) conf.getVertices().get("2")).getLayerConf().getLayer();
    assertEquals(4, l2.getNIn());
    assertEquals(5, l2.getNOut());
    assertEquals("softmax", l2.getActivationFn().toString());
    assertTrue(l2.getLossFn() instanceof LossMCXENT);
}
Also used : LayerVertex(org.deeplearning4j.nn.conf.graph.LayerVertex) LossMCXENT(org.nd4j.linalg.lossfunctions.impl.LossMCXENT) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) File(java.io.File) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) Test(org.junit.Test)

Example 43 with ComputationGraph

use of org.deeplearning4j.nn.graph.ComputationGraph in project deeplearning4j by deeplearning4j.

the class ModelSerializerTest method testWriteCGModelInputStream.

@Test
public void testWriteCGModelInputStream() throws Exception {
    ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).learningRate(0.1).graphBuilder().addInputs("in").addLayer("dense", new DenseLayer.Builder().nIn(4).nOut(2).build(), "in").addLayer("out", new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3).build(), "dense").setOutputs("out").pretrain(false).backprop(true).build();
    ComputationGraph cg = new ComputationGraph(config);
    cg.init();
    File tempFile = File.createTempFile("tsfs", "fdfsdf");
    tempFile.deleteOnExit();
    ModelSerializer.writeModel(cg, tempFile, true);
    FileInputStream fis = new FileInputStream(tempFile);
    ComputationGraph network = ModelSerializer.restoreComputationGraph(fis);
    assertEquals(network.getConfiguration().toJson(), cg.getConfiguration().toJson());
    assertEquals(cg.params(), network.params());
    assertEquals(cg.getUpdater(), network.getUpdater());
}
Also used : DenseLayer(org.deeplearning4j.nn.conf.layers.DenseLayer) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) File(java.io.File) FileInputStream(java.io.FileInputStream) Test(org.junit.Test)

Example 44 with ComputationGraph

use of org.deeplearning4j.nn.graph.ComputationGraph in project deeplearning4j by deeplearning4j.

the class BaseStatsListener method iterationDone.

@Override
public void iterationDone(Model model, int iteration) {
    StatsUpdateConfiguration config = updateConfig;
    ModelInfo modelInfo = getModelInfo(model);
    boolean backpropParamsOnly = backpropParamsOnly(model);
    long currentTime = getTime();
    if (modelInfo.iterCount == 0) {
        modelInfo.initTime = currentTime;
        doInit(model);
    }
    if (config.collectPerformanceStats()) {
        updateExamplesMinibatchesCounts(model);
    }
    if (config.reportingFrequency() > 1 && (iteration == 0 || iteration % config.reportingFrequency() != 0)) {
        modelInfo.iterCount = iteration;
        return;
    }
    StatsReport report = getNewStatsReport();
    //TODO support NTP time
    report.reportIDs(getSessionID(model), TYPE_ID, workerID, System.currentTimeMillis());
    //--- Performance and System Stats ---
    if (config.collectPerformanceStats()) {
        //Stats to collect: total runtime, total examples, total minibatches, iterations/second, examples/second
        double examplesPerSecond;
        double minibatchesPerSecond;
        if (modelInfo.iterCount == 0) {
            //Not possible to work out perf/second: first iteration...
            examplesPerSecond = 0.0;
            minibatchesPerSecond = 0.0;
        } else {
            long deltaTimeMS = currentTime - modelInfo.lastReportTime;
            examplesPerSecond = 1000.0 * modelInfo.examplesSinceLastReport / deltaTimeMS;
            minibatchesPerSecond = 1000.0 * modelInfo.minibatchesSinceLastReport / deltaTimeMS;
        }
        long totalRuntimeMS = currentTime - modelInfo.initTime;
        report.reportPerformance(totalRuntimeMS, modelInfo.totalExamples, modelInfo.totalMinibatches, examplesPerSecond, minibatchesPerSecond);
        modelInfo.examplesSinceLastReport = 0;
        modelInfo.minibatchesSinceLastReport = 0;
    }
    if (config.collectMemoryStats()) {
        Runtime runtime = Runtime.getRuntime();
        long jvmTotal = runtime.totalMemory();
        long jvmMax = runtime.maxMemory();
        //Off-heap memory
        long offheapTotal = Pointer.totalBytes();
        long offheapMax = Pointer.maxBytes();
        //GPU
        long[] gpuCurrentBytes = null;
        long[] gpuMaxBytes = null;
        NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
        int nDevices = nativeOps.getAvailableDevices();
        if (nDevices > 0) {
            gpuCurrentBytes = new long[nDevices];
            gpuMaxBytes = new long[nDevices];
            for (int i = 0; i < nDevices; i++) {
                try {
                    Pointer p = getDevicePointer(i);
                    if (p == null) {
                        gpuMaxBytes[i] = 0;
                        gpuCurrentBytes[i] = 0;
                    } else {
                        gpuMaxBytes[i] = nativeOps.getDeviceTotalMemory(p);
                        gpuCurrentBytes[i] = gpuMaxBytes[i] - nativeOps.getDeviceFreeMemory(p);
                    }
                } catch (Exception e) {
                    e.printStackTrace();
                }
            }
        }
        report.reportMemoryUse(jvmTotal, jvmMax, offheapTotal, offheapMax, gpuCurrentBytes, gpuMaxBytes);
    }
    if (config.collectGarbageCollectionStats()) {
        if (modelInfo.lastReportIteration == -1 || gcBeans == null) {
            //Haven't reported GC stats before...
            gcBeans = ManagementFactory.getGarbageCollectorMXBeans();
            gcStatsAtLastReport = new HashMap<>();
            for (GarbageCollectorMXBean bean : gcBeans) {
                long count = bean.getCollectionCount();
                long timeMs = bean.getCollectionTime();
                gcStatsAtLastReport.put(bean.getName(), new Pair<>(count, timeMs));
            }
        } else {
            for (GarbageCollectorMXBean bean : gcBeans) {
                long count = bean.getCollectionCount();
                long timeMs = bean.getCollectionTime();
                Pair<Long, Long> lastStats = gcStatsAtLastReport.get(bean.getName());
                long deltaGCCount = count - lastStats.getFirst();
                long deltaGCTime = timeMs - lastStats.getSecond();
                lastStats.setFirst(count);
                lastStats.setSecond(timeMs);
                report.reportGarbageCollection(bean.getName(), (int) deltaGCCount, (int) deltaGCTime);
            }
        }
    }
    //--- General ---
    //Always report score
    report.reportScore(model.score());
    if (config.collectLearningRates()) {
        Map<String, Double> lrs = new HashMap<>();
        if (model instanceof MultiLayerNetwork) {
            //Need to append "0_", "1_" etc to param names from layers...
            int layerIdx = 0;
            for (Layer l : ((MultiLayerNetwork) model).getLayers()) {
                NeuralNetConfiguration conf = l.conf();
                Map<String, Double> layerLrs = conf.getLearningRateByParam();
                Set<String> backpropParams = l.paramTable(true).keySet();
                for (Map.Entry<String, Double> entry : layerLrs.entrySet()) {
                    if (!backpropParams.contains(entry.getKey()))
                        //Skip pretrain params
                        continue;
                    lrs.put(layerIdx + "_" + entry.getKey(), entry.getValue());
                }
                layerIdx++;
            }
        } else if (model instanceof ComputationGraph) {
            for (Layer l : ((ComputationGraph) model).getLayers()) {
                //Need to append layer name
                NeuralNetConfiguration conf = l.conf();
                Map<String, Double> layerLrs = conf.getLearningRateByParam();
                String layerName = conf.getLayer().getLayerName();
                Set<String> backpropParams = l.paramTable(true).keySet();
                for (Map.Entry<String, Double> entry : layerLrs.entrySet()) {
                    if (!backpropParams.contains(entry.getKey()))
                        //Skip pretrain params
                        continue;
                    lrs.put(layerName + "_" + entry.getKey(), entry.getValue());
                }
            }
        } else if (model instanceof Layer) {
            Layer l = (Layer) model;
            Map<String, Double> map = l.conf().getLearningRateByParam();
            lrs.putAll(map);
        }
        report.reportLearningRates(lrs);
    }
    if (config.collectHistograms(StatsType.Parameters)) {
        Map<String, Histogram> paramHistograms = getHistograms(model.paramTable(backpropParamsOnly), config.numHistogramBins(StatsType.Parameters));
        report.reportHistograms(StatsType.Parameters, paramHistograms);
    }
    if (config.collectHistograms(StatsType.Gradients)) {
        Map<String, Histogram> gradientHistograms = getHistograms(gradientsPreUpdateMap, config.numHistogramBins(StatsType.Gradients));
        report.reportHistograms(StatsType.Gradients, gradientHistograms);
    }
    if (config.collectHistograms(StatsType.Updates)) {
        Map<String, Histogram> updateHistograms = getHistograms(model.gradient().gradientForVariable(), config.numHistogramBins(StatsType.Updates));
        report.reportHistograms(StatsType.Updates, updateHistograms);
    }
    if (config.collectHistograms(StatsType.Activations)) {
        Map<String, Histogram> activationHistograms = getHistograms(activationsMap, config.numHistogramBins(StatsType.Activations));
        report.reportHistograms(StatsType.Activations, activationHistograms);
    }
    if (config.collectMean(StatsType.Parameters)) {
        Map<String, Double> meanParams = calculateSummaryStats(model.paramTable(backpropParamsOnly), StatType.Mean);
        report.reportMean(StatsType.Parameters, meanParams);
    }
    if (config.collectMean(StatsType.Gradients)) {
        Map<String, Double> meanGradients = calculateSummaryStats(gradientsPreUpdateMap, StatType.Mean);
        report.reportMean(StatsType.Gradients, meanGradients);
    }
    if (config.collectMean(StatsType.Updates)) {
        Map<String, Double> meanUpdates = calculateSummaryStats(model.gradient().gradientForVariable(), StatType.Mean);
        report.reportMean(StatsType.Updates, meanUpdates);
    }
    if (config.collectMean(StatsType.Activations)) {
        Map<String, Double> meanActivations = calculateSummaryStats(activationsMap, StatType.Mean);
        report.reportMean(StatsType.Activations, meanActivations);
    }
    if (config.collectStdev(StatsType.Parameters)) {
        Map<String, Double> stdevParams = calculateSummaryStats(model.paramTable(backpropParamsOnly), StatType.Stdev);
        report.reportStdev(StatsType.Parameters, stdevParams);
    }
    if (config.collectStdev(StatsType.Gradients)) {
        Map<String, Double> stdevGradient = calculateSummaryStats(gradientsPreUpdateMap, StatType.Stdev);
        report.reportStdev(StatsType.Gradients, stdevGradient);
    }
    if (config.collectStdev(StatsType.Updates)) {
        Map<String, Double> stdevUpdates = calculateSummaryStats(model.gradient().gradientForVariable(), StatType.Stdev);
        report.reportStdev(StatsType.Updates, stdevUpdates);
    }
    if (config.collectStdev(StatsType.Activations)) {
        Map<String, Double> stdevActivations = calculateSummaryStats(activationsMap, StatType.Stdev);
        report.reportStdev(StatsType.Activations, stdevActivations);
    }
    if (config.collectMeanMagnitudes(StatsType.Parameters)) {
        Map<String, Double> meanMagParams = calculateSummaryStats(model.paramTable(backpropParamsOnly), StatType.MeanMagnitude);
        report.reportMeanMagnitudes(StatsType.Parameters, meanMagParams);
    }
    if (config.collectMeanMagnitudes(StatsType.Gradients)) {
        Map<String, Double> meanMagGradients = calculateSummaryStats(gradientsPreUpdateMap, StatType.MeanMagnitude);
        report.reportMeanMagnitudes(StatsType.Gradients, meanMagGradients);
    }
    if (config.collectMeanMagnitudes(StatsType.Updates)) {
        Map<String, Double> meanMagUpdates = calculateSummaryStats(model.gradient().gradientForVariable(), StatType.MeanMagnitude);
        report.reportMeanMagnitudes(StatsType.Updates, meanMagUpdates);
    }
    if (config.collectMeanMagnitudes(StatsType.Activations)) {
        Map<String, Double> meanMagActivations = calculateSummaryStats(activationsMap, StatType.MeanMagnitude);
        report.reportMeanMagnitudes(StatsType.Activations, meanMagActivations);
    }
    long endTime = getTime();
    //Amount of time required to alculate all histograms, means etc.
    report.reportStatsCollectionDurationMS((int) (endTime - currentTime));
    modelInfo.lastReportTime = currentTime;
    modelInfo.lastReportIteration = iteration;
    report.reportIterationCount(iteration);
    this.router.putUpdate(report);
    modelInfo.iterCount = iteration;
    activationsMap = null;
}
Also used : DefaultStatsUpdateConfiguration(org.deeplearning4j.ui.stats.impl.DefaultStatsUpdateConfiguration) GarbageCollectorMXBean(java.lang.management.GarbageCollectorMXBean) Pointer(org.bytedeco.javacpp.Pointer) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) NativeOps(org.nd4j.nativeblas.NativeOps) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) Layer(org.deeplearning4j.nn.api.Layer)

Example 45 with ComputationGraph

use of org.deeplearning4j.nn.graph.ComputationGraph in project deeplearning4j by deeplearning4j.

the class TestSparkComputationGraph method testSeedRepeatability.

@Test
public void testSeedRepeatability() throws Exception {
    ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(Updater.RMSPROP).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in").addLayer("0", new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(4).nOut(4).activation(Activation.TANH).build(), "in").addLayer("1", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(4).nOut(3).activation(Activation.SOFTMAX).build(), "0").setOutputs("1").pretrain(false).backprop(true).build();
    Nd4j.getRandom().setSeed(12345);
    ComputationGraph n1 = new ComputationGraph(conf);
    n1.init();
    Nd4j.getRandom().setSeed(12345);
    ComputationGraph n2 = new ComputationGraph(conf);
    n2.init();
    Nd4j.getRandom().setSeed(12345);
    ComputationGraph n3 = new ComputationGraph(conf);
    n3.init();
    SparkComputationGraph sparkNet1 = new SparkComputationGraph(sc, n1, new ParameterAveragingTrainingMaster.Builder(1).workerPrefetchNumBatches(5).batchSizePerWorker(5).averagingFrequency(1).repartionData(Repartition.Always).rngSeed(12345).build());
    //Training master IDs are only unique if they are created at least 1 ms apart...
    Thread.sleep(100);
    SparkComputationGraph sparkNet2 = new SparkComputationGraph(sc, n2, new ParameterAveragingTrainingMaster.Builder(1).workerPrefetchNumBatches(5).batchSizePerWorker(5).averagingFrequency(1).repartionData(Repartition.Always).rngSeed(12345).build());
    Thread.sleep(100);
    SparkComputationGraph sparkNet3 = new SparkComputationGraph(sc, n3, new ParameterAveragingTrainingMaster.Builder(1).workerPrefetchNumBatches(5).batchSizePerWorker(5).averagingFrequency(1).repartionData(Repartition.Always).rngSeed(98765).build());
    List<DataSet> data = new ArrayList<>();
    DataSetIterator iter = new IrisDataSetIterator(1, 150);
    while (iter.hasNext()) data.add(iter.next());
    JavaRDD<DataSet> rdd = sc.parallelize(data);
    sparkNet1.fit(rdd);
    sparkNet2.fit(rdd);
    sparkNet3.fit(rdd);
    INDArray p1 = sparkNet1.getNetwork().params();
    INDArray p2 = sparkNet2.getNetwork().params();
    INDArray p3 = sparkNet3.getNetwork().params();
    sparkNet1.getTrainingMaster().deleteTempFiles(sc);
    sparkNet2.getTrainingMaster().deleteTempFiles(sc);
    sparkNet3.getTrainingMaster().deleteTempFiles(sc);
    assertEquals(p1, p2);
    assertNotEquals(p1, p3);
}
Also used : OutputLayer(org.deeplearning4j.nn.conf.layers.OutputLayer) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSet(org.nd4j.linalg.dataset.DataSet) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) ParameterAveragingTrainingMaster(org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) RecordReaderMultiDataSetIterator(org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) BaseSparkTest(org.deeplearning4j.spark.BaseSparkTest) Test(org.junit.Test)

Aggregations

ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)109 Test (org.junit.Test)73 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)63 INDArray (org.nd4j.linalg.api.ndarray.INDArray)62 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)36 DataSet (org.nd4j.linalg.dataset.DataSet)25 NormalDistribution (org.deeplearning4j.nn.conf.distribution.NormalDistribution)22 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)21 DenseLayer (org.deeplearning4j.nn.conf.layers.DenseLayer)19 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)19 ScoreIterationListener (org.deeplearning4j.optimize.listeners.ScoreIterationListener)17 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)17 IrisDataSetIterator (org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator)14 Layer (org.deeplearning4j.nn.api.Layer)14 Random (java.util.Random)11 InMemoryModelSaver (org.deeplearning4j.earlystopping.saver.InMemoryModelSaver)10 MaxEpochsTerminationCondition (org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition)10 TrainingMaster (org.deeplearning4j.spark.api.TrainingMaster)10 MaxTimeIterationTerminationCondition (org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationCondition)9 GridExecutioner (org.nd4j.linalg.api.ops.executioner.GridExecutioner)9