Search in sources :

Example 81 with ComputationGraph

use of org.deeplearning4j.nn.graph.ComputationGraph in project deeplearning4j by deeplearning4j.

the class TransferLearningComplex method testMergeAndFreeze.

@Test
public void testMergeAndFreeze() {
    // in1 -> A -> B -> merge, in2 -> C -> merge -> D -> out
    //Goal here: test a number of things...
    // (a) Ensure that freezing C doesn't impact A and B. Only C should be frozen in this config
    // (b) Test global override (should be selective)
    ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().updater(Updater.ADAM).learningRate(1e-4).activation(Activation.LEAKYRELU).graphBuilder().addInputs("in1", "in2").addLayer("A", new DenseLayer.Builder().nIn(10).nOut(9).build(), "in1").addLayer("B", new DenseLayer.Builder().nIn(9).nOut(8).build(), "A").addLayer("C", new DenseLayer.Builder().nIn(7).nOut(6).build(), "in2").addLayer("D", new DenseLayer.Builder().nIn(8 + 7).nOut(5).build(), "B", "C").addLayer("out", new OutputLayer.Builder().nIn(5).nOut(4).build(), "D").setOutputs("out").build();
    ComputationGraph graph = new ComputationGraph(conf);
    graph.init();
    int[] topologicalOrder = graph.topologicalSortOrder();
    org.deeplearning4j.nn.graph.vertex.GraphVertex[] vertices = graph.getVertices();
    for (int i = 0; i < topologicalOrder.length; i++) {
        org.deeplearning4j.nn.graph.vertex.GraphVertex v = vertices[topologicalOrder[i]];
        log.info(i + "\t" + v.getVertexName());
    }
    ComputationGraph graph2 = new TransferLearning.GraphBuilder(graph).fineTuneConfiguration(new FineTuneConfiguration.Builder().learningRate(2e-2).build()).setFeatureExtractor("C").build();
    boolean cFound = false;
    Layer[] layers = graph2.getLayers();
    for (Layer l : layers) {
        String name = l.conf().getLayer().getLayerName();
        log.info(name + "\t frozen: " + (l instanceof FrozenLayer));
        if ("C".equals(l.conf().getLayer().getLayerName())) {
            //Only C should be frozen in this config
            cFound = true;
            assertTrue(name, l instanceof FrozenLayer);
        } else {
            assertFalse(name, l instanceof FrozenLayer);
        }
        //Also check config:
        assertEquals(Updater.ADAM, l.conf().getLayer().getUpdater());
        assertEquals(2e-2, l.conf().getLayer().getLearningRate(), 1e-5);
        assertEquals(Activation.LEAKYRELU.getActivationFunction(), l.conf().getLayer().getActivationFn());
    }
    assertTrue(cFound);
}
Also used : Layer(org.deeplearning4j.nn.api.Layer) OutputLayer(org.deeplearning4j.nn.conf.layers.OutputLayer) DenseLayer(org.deeplearning4j.nn.conf.layers.DenseLayer) FrozenLayer(org.deeplearning4j.nn.layers.FrozenLayer) FrozenLayer(org.deeplearning4j.nn.layers.FrozenLayer) DenseLayer(org.deeplearning4j.nn.conf.layers.DenseLayer) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) Test(org.junit.Test)

Example 82 with ComputationGraph

use of org.deeplearning4j.nn.graph.ComputationGraph in project deeplearning4j by deeplearning4j.

the class TestListenerSetting method testSettingListenersUnsupervised.

@Test
public void testSettingListenersUnsupervised() {
    //Pretrain layers should get copies of the listeners, in addition to the
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(0, new RBM.Builder().nIn(10).nOut(10).build()).layer(1, new AutoEncoder.Builder().nIn(10).nOut(10).build()).layer(2, new VariationalAutoencoder.Builder().nIn(10).nOut(10).build()).build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();
    net.setListeners(new ScoreIterationListener(), new TestRoutingListener());
    for (Layer l : net.getLayers()) {
        Collection<IterationListener> layerListeners = l.getListeners();
        assertEquals(l.getClass().toString(), 2, layerListeners.size());
        IterationListener[] lArr = layerListeners.toArray(new IterationListener[2]);
        assertTrue(lArr[0] instanceof ScoreIterationListener);
        assertTrue(lArr[1] instanceof TestRoutingListener);
    }
    Collection<IterationListener> netListeners = net.getListeners();
    assertEquals(2, netListeners.size());
    IterationListener[] lArr = netListeners.toArray(new IterationListener[2]);
    assertTrue(lArr[0] instanceof ScoreIterationListener);
    assertTrue(lArr[1] instanceof TestRoutingListener);
    ComputationGraphConfiguration gConf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in").addLayer("0", new RBM.Builder().nIn(10).nOut(10).build(), "in").addLayer("1", new AutoEncoder.Builder().nIn(10).nOut(10).build(), "0").addLayer("2", new VariationalAutoencoder.Builder().nIn(10).nOut(10).build(), "1").setOutputs("2").build();
    ComputationGraph cg = new ComputationGraph(gConf);
    cg.init();
    cg.setListeners(new ScoreIterationListener(), new TestRoutingListener());
    for (Layer l : cg.getLayers()) {
        Collection<IterationListener> layerListeners = l.getListeners();
        assertEquals(2, layerListeners.size());
        lArr = layerListeners.toArray(new IterationListener[2]);
        assertTrue(lArr[0] instanceof ScoreIterationListener);
        assertTrue(lArr[1] instanceof TestRoutingListener);
    }
    netListeners = cg.getListeners();
    assertEquals(2, netListeners.size());
    lArr = netListeners.toArray(new IterationListener[2]);
    assertTrue(lArr[0] instanceof ScoreIterationListener);
    assertTrue(lArr[1] instanceof TestRoutingListener);
}
Also used : VariationalAutoencoder(org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) Layer(org.deeplearning4j.nn.api.Layer) MultiLayerConfiguration(org.deeplearning4j.nn.conf.MultiLayerConfiguration) RoutingIterationListener(org.deeplearning4j.api.storage.listener.RoutingIterationListener) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) IterationListener(org.deeplearning4j.optimize.api.IterationListener) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) RBM(org.deeplearning4j.nn.conf.layers.RBM) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) AutoEncoder(org.deeplearning4j.nn.conf.layers.AutoEncoder) Test(org.junit.Test)

Example 83 with ComputationGraph

use of org.deeplearning4j.nn.graph.ComputationGraph in project deeplearning4j by deeplearning4j.

the class TransferLearningHelperTest method tesUnfrozenSubset.

@Test
public void tesUnfrozenSubset() {
    NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().learningRate(0.1).seed(124).activation(Activation.IDENTITY).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.SGD);
    /*
                             (inCentre)                        (inRight)
                                |                                |
                            denseCentre0                         |
                                |                                |
                 ,--------  denseCentre1                       denseRight0
                /               |                                |
        subsetLeft(0-3)    denseCentre2 ---- denseRight ----  mergeRight
              |                 |                                |
         denseLeft0        denseCentre3                        denseRight1
              |                 |                                |
          (outLeft)         (outCentre)                        (outRight)
        
         */
    ComputationGraphConfiguration conf = overallConf.graphBuilder().addInputs("inCentre", "inRight").addLayer("denseCentre0", new DenseLayer.Builder().nIn(10).nOut(9).build(), "inCentre").addLayer("denseCentre1", new DenseLayer.Builder().nIn(9).nOut(8).build(), "denseCentre0").addLayer("denseCentre2", new DenseLayer.Builder().nIn(8).nOut(7).build(), "denseCentre1").addLayer("denseCentre3", new DenseLayer.Builder().nIn(7).nOut(7).build(), "denseCentre2").addLayer("outCentre", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(7).nOut(4).build(), "denseCentre3").addVertex("subsetLeft", new SubsetVertex(0, 3), "denseCentre1").addLayer("denseLeft0", new DenseLayer.Builder().nIn(4).nOut(5).build(), "subsetLeft").addLayer("outLeft", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(5).nOut(6).build(), "denseLeft0").addLayer("denseRight", new DenseLayer.Builder().nIn(7).nOut(7).build(), "denseCentre2").addLayer("denseRight0", new DenseLayer.Builder().nIn(2).nOut(3).build(), "inRight").addVertex("mergeRight", new MergeVertex(), "denseRight", "denseRight0").addLayer("denseRight1", new DenseLayer.Builder().nIn(10).nOut(5).build(), "mergeRight").addLayer("outRight", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(5).nOut(5).build(), "denseRight1").setOutputs("outLeft", "outCentre", "outRight").build();
    ComputationGraph modelToTune = new ComputationGraph(conf);
    modelToTune.init();
    TransferLearningHelper helper = new TransferLearningHelper(modelToTune, "denseCentre2");
    ComputationGraph modelSubset = helper.unfrozenGraph();
    ComputationGraphConfiguration expectedConf = //inputs are in sorted order
    overallConf.graphBuilder().addInputs("denseCentre1", "denseCentre2", "inRight").addLayer("denseCentre3", new DenseLayer.Builder().nIn(7).nOut(7).build(), "denseCentre2").addLayer("outCentre", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(7).nOut(4).build(), "denseCentre3").addVertex("subsetLeft", new SubsetVertex(0, 3), "denseCentre1").addLayer("denseLeft0", new DenseLayer.Builder().nIn(4).nOut(5).build(), "subsetLeft").addLayer("outLeft", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(5).nOut(6).build(), "denseLeft0").addLayer("denseRight", new DenseLayer.Builder().nIn(7).nOut(7).build(), "denseCentre2").addLayer("denseRight0", new DenseLayer.Builder().nIn(2).nOut(3).build(), "inRight").addVertex("mergeRight", new MergeVertex(), "denseRight", "denseRight0").addLayer("denseRight1", new DenseLayer.Builder().nIn(10).nOut(5).build(), "mergeRight").addLayer("outRight", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(5).nOut(5).build(), "denseRight1").setOutputs("outLeft", "outCentre", "outRight").build();
    ComputationGraph expectedModel = new ComputationGraph(expectedConf);
    expectedModel.init();
    assertEquals(expectedConf.toJson(), modelSubset.getConfiguration().toJson());
}
Also used : OutputLayer(org.deeplearning4j.nn.conf.layers.OutputLayer) DenseLayer(org.deeplearning4j.nn.conf.layers.DenseLayer) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) MergeVertex(org.deeplearning4j.nn.conf.graph.MergeVertex) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) SubsetVertex(org.deeplearning4j.nn.conf.graph.SubsetVertex) Test(org.junit.Test)

Example 84 with ComputationGraph

use of org.deeplearning4j.nn.graph.ComputationGraph in project deeplearning4j by deeplearning4j.

the class ModelSerializerTest method testWriteCGModel.

@Test
public void testWriteCGModel() throws Exception {
    ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).learningRate(0.1).graphBuilder().addInputs("in").addLayer("dense", new DenseLayer.Builder().nIn(4).nOut(2).build(), "in").addLayer("out", new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3).build(), "dense").setOutputs("out").pretrain(false).backprop(true).build();
    ComputationGraph cg = new ComputationGraph(config);
    cg.init();
    File tempFile = File.createTempFile("tsfs", "fdfsdf");
    tempFile.deleteOnExit();
    ModelSerializer.writeModel(cg, tempFile, true);
    ComputationGraph network = ModelSerializer.restoreComputationGraph(tempFile);
    assertEquals(network.getConfiguration().toJson(), cg.getConfiguration().toJson());
    assertEquals(cg.params(), network.params());
    assertEquals(cg.getUpdater(), network.getUpdater());
}
Also used : DenseLayer(org.deeplearning4j.nn.conf.layers.DenseLayer) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) File(java.io.File) Test(org.junit.Test)

Example 85 with ComputationGraph

use of org.deeplearning4j.nn.graph.ComputationGraph in project deeplearning4j by deeplearning4j.

the class KerasModelConfigurationTest method importKerasMlpModelMultilossConfigTest.

@Test
public void importKerasMlpModelMultilossConfigTest() throws Exception {
    ClassPathResource configResource = new ClassPathResource("modelimport/keras/configs/mlp_fapi_multiloss_config.json", KerasModelConfigurationTest.class.getClassLoader());
    ComputationGraphConfiguration config = new KerasModel.ModelBuilder().modelJsonInputStream(configResource.getInputStream()).enforceTrainingConfig(true).buildModel().getComputationGraphConfiguration();
    ComputationGraph model = new ComputationGraph(config);
    model.init();
}
Also used : ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) Test(org.junit.Test)

Aggregations

ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)109 Test (org.junit.Test)73 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)63 INDArray (org.nd4j.linalg.api.ndarray.INDArray)62 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)36 DataSet (org.nd4j.linalg.dataset.DataSet)25 NormalDistribution (org.deeplearning4j.nn.conf.distribution.NormalDistribution)22 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)21 DenseLayer (org.deeplearning4j.nn.conf.layers.DenseLayer)19 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)19 ScoreIterationListener (org.deeplearning4j.optimize.listeners.ScoreIterationListener)17 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)17 IrisDataSetIterator (org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator)14 Layer (org.deeplearning4j.nn.api.Layer)14 Random (java.util.Random)11 InMemoryModelSaver (org.deeplearning4j.earlystopping.saver.InMemoryModelSaver)10 MaxEpochsTerminationCondition (org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition)10 TrainingMaster (org.deeplearning4j.spark.api.TrainingMaster)10 MaxTimeIterationTerminationCondition (org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationCondition)9 GridExecutioner (org.nd4j.linalg.api.ops.executioner.GridExecutioner)9