Search in sources :

Example 6 with MergeVertex

use of org.deeplearning4j.nn.conf.graph.MergeVertex in project deeplearning4j by deeplearning4j.

the class TransferLearningComplex method testAddOutput.

@Test
public void testAddOutput() {
    NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().learningRate(0.9).activation(Activation.IDENTITY).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.SGD);
    ComputationGraphConfiguration conf = overallConf.graphBuilder().addInputs("inCentre", "inRight").addLayer("denseCentre0", new DenseLayer.Builder().nIn(2).nOut(2).build(), "inCentre").addLayer("denseRight0", new DenseLayer.Builder().nIn(2).nOut(2).build(), "inRight").addVertex("mergeRight", new MergeVertex(), "denseCentre0", "denseRight0").addLayer("outRight", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(4).nOut(2).build(), "mergeRight").setOutputs("outRight").build();
    ComputationGraph modelToTune = new ComputationGraph(conf);
    modelToTune.init();
    ComputationGraph modelNow = new TransferLearning.GraphBuilder(modelToTune).addLayer("outCentre", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(2).nOut(3).build(), "denseCentre0").setOutputs("outCentre").build();
    assertEquals(2, modelNow.getNumOutputArrays());
    MultiDataSet rand = new MultiDataSet(new INDArray[] { Nd4j.rand(2, 2), Nd4j.rand(2, 2) }, new INDArray[] { Nd4j.rand(2, 2), Nd4j.rand(2, 3) });
    modelNow.fit(rand);
    log.info(modelNow.summary());
}
Also used : DenseLayer(org.deeplearning4j.nn.conf.layers.DenseLayer) MultiDataSet(org.nd4j.linalg.dataset.MultiDataSet) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) MergeVertex(org.deeplearning4j.nn.conf.graph.MergeVertex) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) Test(org.junit.Test)

Example 7 with MergeVertex

use of org.deeplearning4j.nn.conf.graph.MergeVertex in project deeplearning4j by deeplearning4j.

the class TestComputationGraphNetwork method printSummary.

@Test
public void printSummary() {
    NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().learningRate(0.1).activation(Activation.IDENTITY).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.SGD);
    ComputationGraphConfiguration conf = overallConf.graphBuilder().addInputs("inCentre", "inRight").addLayer("denseCentre0", new DenseLayer.Builder().nIn(10).nOut(9).build(), "inCentre").addLayer("denseCentre1", new DenseLayer.Builder().nIn(9).nOut(8).build(), "denseCentre0").addLayer("denseCentre2", new DenseLayer.Builder().nIn(8).nOut(7).build(), "denseCentre1").addLayer("denseCentre3", new DenseLayer.Builder().nIn(7).nOut(7).build(), "denseCentre2").addLayer("outCentre", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(7).nOut(4).build(), "denseCentre3").addVertex("subsetLeft", new SubsetVertex(0, 3), "denseCentre1").addLayer("denseLeft0", new DenseLayer.Builder().nIn(4).nOut(5).build(), "subsetLeft").addLayer("outLeft", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(5).nOut(6).build(), "denseLeft0").addLayer("denseRight", new DenseLayer.Builder().nIn(7).nOut(7).build(), "denseCentre2").addLayer("denseRight0", new DenseLayer.Builder().nIn(2).nOut(3).build(), "inRight").addVertex("mergeRight", new MergeVertex(), "denseRight", "denseRight0").addLayer("denseRight1", new DenseLayer.Builder().nIn(10).nOut(5).build(), "mergeRight").addLayer("outRight", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(5).nOut(5).build(), "denseRight1").setOutputs("outLeft", "outCentre", "outRight").build();
    ComputationGraph modelToTune = new ComputationGraph(conf);
    modelToTune.init();
    System.out.println(modelToTune.summary());
    ComputationGraph modelNow = new TransferLearning.GraphBuilder(modelToTune).setFeatureExtractor("denseCentre2").build();
    System.out.println(modelNow.summary());
}
Also used : MergeVertex(org.deeplearning4j.nn.conf.graph.MergeVertex) SubsetVertex(org.deeplearning4j.nn.conf.graph.SubsetVertex) Test(org.junit.Test)

Example 8 with MergeVertex

use of org.deeplearning4j.nn.conf.graph.MergeVertex in project deeplearning4j by deeplearning4j.

the class TransferLearningHelperTest method tesUnfrozenSubset.

@Test
public void tesUnfrozenSubset() {
    NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().learningRate(0.1).seed(124).activation(Activation.IDENTITY).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.SGD);
    /*
                             (inCentre)                        (inRight)
                                |                                |
                            denseCentre0                         |
                                |                                |
                 ,--------  denseCentre1                       denseRight0
                /               |                                |
        subsetLeft(0-3)    denseCentre2 ---- denseRight ----  mergeRight
              |                 |                                |
         denseLeft0        denseCentre3                        denseRight1
              |                 |                                |
          (outLeft)         (outCentre)                        (outRight)
        
         */
    ComputationGraphConfiguration conf = overallConf.graphBuilder().addInputs("inCentre", "inRight").addLayer("denseCentre0", new DenseLayer.Builder().nIn(10).nOut(9).build(), "inCentre").addLayer("denseCentre1", new DenseLayer.Builder().nIn(9).nOut(8).build(), "denseCentre0").addLayer("denseCentre2", new DenseLayer.Builder().nIn(8).nOut(7).build(), "denseCentre1").addLayer("denseCentre3", new DenseLayer.Builder().nIn(7).nOut(7).build(), "denseCentre2").addLayer("outCentre", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(7).nOut(4).build(), "denseCentre3").addVertex("subsetLeft", new SubsetVertex(0, 3), "denseCentre1").addLayer("denseLeft0", new DenseLayer.Builder().nIn(4).nOut(5).build(), "subsetLeft").addLayer("outLeft", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(5).nOut(6).build(), "denseLeft0").addLayer("denseRight", new DenseLayer.Builder().nIn(7).nOut(7).build(), "denseCentre2").addLayer("denseRight0", new DenseLayer.Builder().nIn(2).nOut(3).build(), "inRight").addVertex("mergeRight", new MergeVertex(), "denseRight", "denseRight0").addLayer("denseRight1", new DenseLayer.Builder().nIn(10).nOut(5).build(), "mergeRight").addLayer("outRight", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(5).nOut(5).build(), "denseRight1").setOutputs("outLeft", "outCentre", "outRight").build();
    ComputationGraph modelToTune = new ComputationGraph(conf);
    modelToTune.init();
    TransferLearningHelper helper = new TransferLearningHelper(modelToTune, "denseCentre2");
    ComputationGraph modelSubset = helper.unfrozenGraph();
    ComputationGraphConfiguration expectedConf = //inputs are in sorted order
    overallConf.graphBuilder().addInputs("denseCentre1", "denseCentre2", "inRight").addLayer("denseCentre3", new DenseLayer.Builder().nIn(7).nOut(7).build(), "denseCentre2").addLayer("outCentre", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(7).nOut(4).build(), "denseCentre3").addVertex("subsetLeft", new SubsetVertex(0, 3), "denseCentre1").addLayer("denseLeft0", new DenseLayer.Builder().nIn(4).nOut(5).build(), "subsetLeft").addLayer("outLeft", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(5).nOut(6).build(), "denseLeft0").addLayer("denseRight", new DenseLayer.Builder().nIn(7).nOut(7).build(), "denseCentre2").addLayer("denseRight0", new DenseLayer.Builder().nIn(2).nOut(3).build(), "inRight").addVertex("mergeRight", new MergeVertex(), "denseRight", "denseRight0").addLayer("denseRight1", new DenseLayer.Builder().nIn(10).nOut(5).build(), "mergeRight").addLayer("outRight", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(5).nOut(5).build(), "denseRight1").setOutputs("outLeft", "outCentre", "outRight").build();
    ComputationGraph expectedModel = new ComputationGraph(expectedConf);
    expectedModel.init();
    assertEquals(expectedConf.toJson(), modelSubset.getConfiguration().toJson());
}
Also used : OutputLayer(org.deeplearning4j.nn.conf.layers.OutputLayer) DenseLayer(org.deeplearning4j.nn.conf.layers.DenseLayer) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) MergeVertex(org.deeplearning4j.nn.conf.graph.MergeVertex) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) SubsetVertex(org.deeplearning4j.nn.conf.graph.SubsetVertex) Test(org.junit.Test)

Aggregations

MergeVertex (org.deeplearning4j.nn.conf.graph.MergeVertex)8 Test (org.junit.Test)8 DenseLayer (org.deeplearning4j.nn.conf.layers.DenseLayer)6 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)5 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)5 ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)5 SubsetVertex (org.deeplearning4j.nn.conf.graph.SubsetVertex)4 MultiDataSet (org.nd4j.linalg.dataset.MultiDataSet)4 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)3 INDArray (org.nd4j.linalg.api.ndarray.INDArray)3 ElementWiseVertex (org.deeplearning4j.nn.conf.graph.ElementWiseVertex)1 LayerVertex (org.deeplearning4j.nn.conf.graph.LayerVertex)1 ConvolutionLayer (org.deeplearning4j.nn.conf.layers.ConvolutionLayer)1 CnnToFeedForwardPreProcessor (org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor)1 FeedForwardToRnnPreProcessor (org.deeplearning4j.nn.conf.preprocessor.FeedForwardToRnnPreProcessor)1 RnnToFeedForwardPreProcessor (org.deeplearning4j.nn.conf.preprocessor.RnnToFeedForwardPreProcessor)1 FrozenLayer (org.deeplearning4j.nn.layers.FrozenLayer)1