Search in sources :

Example 1 with MergeVertex

use of org.deeplearning4j.nn.conf.graph.MergeVertex in project deeplearning4j by deeplearning4j.

the class TestComputationGraphNetwork method testPreprocessorAddition.

@Test
public void testPreprocessorAddition() {
    //Also check that nIns are set automatically
    //First: check FF -> RNN
    ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in").setInputTypes(InputType.feedForward(5)).addLayer("rnn", new GravesLSTM.Builder().nOut(5).build(), "in").addLayer("out", new RnnOutputLayer.Builder().nOut(5).build(), "rnn").setOutputs("out").build();
    assertEquals(5, ((FeedForwardLayer) ((LayerVertex) conf1.getVertices().get("rnn")).getLayerConf().getLayer()).getNIn());
    assertEquals(5, ((FeedForwardLayer) ((LayerVertex) conf1.getVertices().get("out")).getLayerConf().getLayer()).getNIn());
    LayerVertex lv1 = (LayerVertex) conf1.getVertices().get("rnn");
    assertTrue(lv1.getPreProcessor() instanceof FeedForwardToRnnPreProcessor);
    LayerVertex lv2 = (LayerVertex) conf1.getVertices().get("out");
    assertNull(lv2.getPreProcessor());
    //Check RNN -> FF -> RNN
    ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in").setInputTypes(InputType.recurrent(5)).addLayer("ff", new DenseLayer.Builder().nOut(5).build(), "in").addLayer("out", new RnnOutputLayer.Builder().nOut(5).build(), "ff").setOutputs("out").build();
    assertEquals(5, ((FeedForwardLayer) ((LayerVertex) conf2.getVertices().get("ff")).getLayerConf().getLayer()).getNIn());
    assertEquals(5, ((FeedForwardLayer) ((LayerVertex) conf2.getVertices().get("out")).getLayerConf().getLayer()).getNIn());
    lv1 = (LayerVertex) conf2.getVertices().get("ff");
    assertTrue(lv1.getPreProcessor() instanceof RnnToFeedForwardPreProcessor);
    lv2 = (LayerVertex) conf2.getVertices().get("out");
    assertTrue(lv2.getPreProcessor() instanceof FeedForwardToRnnPreProcessor);
    //CNN -> Dense
    ComputationGraphConfiguration conf3 = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in").setInputTypes(InputType.convolutional(28, 28, 1)).addLayer("cnn", new ConvolutionLayer.Builder().kernelSize(2, 2).padding(0, 0).stride(2, 2).nOut(3).build(), //(28-2+0)/2+1 = 14
    "in").addLayer("pool", new SubsamplingLayer.Builder().kernelSize(2, 2).padding(0, 0).stride(2, 2).build(), //(14-2+0)/2+1=7
    "cnn").addLayer("dense", new DenseLayer.Builder().nOut(10).build(), "pool").addLayer("out", new OutputLayer.Builder().nIn(10).nOut(5).build(), "dense").setOutputs("out").build();
    //Check preprocessors:
    lv1 = (LayerVertex) conf3.getVertices().get("cnn");
    //Shouldn't be adding preprocessor here
    assertNull(lv1.getPreProcessor());
    lv2 = (LayerVertex) conf3.getVertices().get("pool");
    assertNull(lv2.getPreProcessor());
    LayerVertex lv3 = (LayerVertex) conf3.getVertices().get("dense");
    assertTrue(lv3.getPreProcessor() instanceof CnnToFeedForwardPreProcessor);
    CnnToFeedForwardPreProcessor proc = (CnnToFeedForwardPreProcessor) lv3.getPreProcessor();
    assertEquals(3, proc.getNumChannels());
    assertEquals(7, proc.getInputHeight());
    assertEquals(7, proc.getInputWidth());
    LayerVertex lv4 = (LayerVertex) conf3.getVertices().get("out");
    assertNull(lv4.getPreProcessor());
    //Check nIns:
    assertEquals(7 * 7 * 3, ((FeedForwardLayer) lv3.getLayerConf().getLayer()).getNIn());
    //CNN->Dense, RNN->Dense, Dense->RNN
    ComputationGraphConfiguration conf4 = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("inCNN", "inRNN").setInputTypes(InputType.convolutional(28, 28, 1), InputType.recurrent(5)).addLayer("cnn", new ConvolutionLayer.Builder().kernelSize(2, 2).padding(0, 0).stride(2, 2).nOut(3).build(), //(28-2+0)/2+1 = 14
    "inCNN").addLayer("pool", new SubsamplingLayer.Builder().kernelSize(2, 2).padding(0, 0).stride(2, 2).build(), //(14-2+0)/2+1=7
    "cnn").addLayer("dense", new DenseLayer.Builder().nOut(10).build(), "pool").addLayer("dense2", new DenseLayer.Builder().nOut(10).build(), "inRNN").addVertex("merge", new MergeVertex(), "dense", "dense2").addLayer("out", new RnnOutputLayer.Builder().nOut(5).build(), "merge").setOutputs("out").build();
    //Check preprocessors:
    lv1 = (LayerVertex) conf4.getVertices().get("cnn");
    //Expect no preprocessor: cnn data -> cnn layer
    assertNull(lv1.getPreProcessor());
    lv2 = (LayerVertex) conf4.getVertices().get("pool");
    assertNull(lv2.getPreProcessor());
    lv3 = (LayerVertex) conf4.getVertices().get("dense");
    assertTrue(lv3.getPreProcessor() instanceof CnnToFeedForwardPreProcessor);
    proc = (CnnToFeedForwardPreProcessor) lv3.getPreProcessor();
    assertEquals(3, proc.getNumChannels());
    assertEquals(7, proc.getInputHeight());
    assertEquals(7, proc.getInputWidth());
    lv4 = (LayerVertex) conf4.getVertices().get("dense2");
    assertTrue(lv4.getPreProcessor() instanceof RnnToFeedForwardPreProcessor);
    LayerVertex lv5 = (LayerVertex) conf4.getVertices().get("out");
    assertTrue(lv5.getPreProcessor() instanceof FeedForwardToRnnPreProcessor);
    //Check nIns:
    assertEquals(7 * 7 * 3, ((FeedForwardLayer) lv3.getLayerConf().getLayer()).getNIn());
    assertEquals(5, ((FeedForwardLayer) lv4.getLayerConf().getLayer()).getNIn());
    //10+10 out of the merge vertex -> 20 in to output layer vertex
    assertEquals(20, ((FeedForwardLayer) lv5.getLayerConf().getLayer()).getNIn());
    //Input to 2 CNN layers:
    ComputationGraphConfiguration conf5 = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder().addInputs("input").setInputTypes(InputType.convolutional(28, 28, 1)).addLayer("cnn_1", new ConvolutionLayer.Builder(2, 2).stride(2, 2).nIn(1).nOut(3).build(), "input").addLayer("cnn_2", new ConvolutionLayer.Builder(4, 4).stride(2, 2).padding(1, 1).nIn(1).nOut(3).build(), "input").addLayer("max_1", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).build(), "cnn_1", "cnn_2").addLayer("output", new OutputLayer.Builder().nOut(10).build(), //.nIn(7 * 7 * 6)
    "max_1").setOutputs("output").pretrain(false).backprop(true).build();
    lv1 = (LayerVertex) conf5.getVertices().get("cnn_1");
    //Expect no preprocessor: cnn data -> cnn layer
    assertNull(lv1.getPreProcessor());
    lv2 = (LayerVertex) conf5.getVertices().get("cnn_2");
    //Expect no preprocessor: cnn data -> cnn layer
    assertNull(lv2.getPreProcessor());
    assertNull(((LayerVertex) conf5.getVertices().get("max_1")).getPreProcessor());
    lv3 = (LayerVertex) conf5.getVertices().get("output");
    assertTrue(lv3.getPreProcessor() instanceof CnnToFeedForwardPreProcessor);
    CnnToFeedForwardPreProcessor cnnff = (CnnToFeedForwardPreProcessor) lv3.getPreProcessor();
    assertEquals(6, cnnff.getNumChannels());
    assertEquals(7, cnnff.getInputHeight());
    assertEquals(7, cnnff.getInputWidth());
}
Also used : LayerVertex(org.deeplearning4j.nn.conf.graph.LayerVertex) CnnToFeedForwardPreProcessor(org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor) MergeVertex(org.deeplearning4j.nn.conf.graph.MergeVertex) RnnToFeedForwardPreProcessor(org.deeplearning4j.nn.conf.preprocessor.RnnToFeedForwardPreProcessor) FeedForwardToRnnPreProcessor(org.deeplearning4j.nn.conf.preprocessor.FeedForwardToRnnPreProcessor) Test(org.junit.Test)

Example 2 with MergeVertex

use of org.deeplearning4j.nn.conf.graph.MergeVertex in project deeplearning4j by deeplearning4j.

the class ComputationGraphConfigurationTest method testJSONWithGraphNodes.

@Test
public void testJSONWithGraphNodes() {
    ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder().addInputs("input1", "input2").addLayer("cnn1", new ConvolutionLayer.Builder(2, 2).stride(2, 2).nIn(1).nOut(5).build(), "input1").addLayer("cnn2", new ConvolutionLayer.Builder(2, 2).stride(2, 2).nIn(1).nOut(5).build(), "input2").addVertex("merge1", new MergeVertex(), "cnn1", "cnn2").addVertex("subset1", new SubsetVertex(0, 1), "merge1").addLayer("dense1", new DenseLayer.Builder().nIn(20).nOut(5).build(), "subset1").addLayer("dense2", new DenseLayer.Builder().nIn(20).nOut(5).build(), "subset1").addVertex("add", new ElementWiseVertex(ElementWiseVertex.Op.Add), "dense1", "dense2").addLayer("out", new OutputLayer.Builder().nIn(1).nOut(1).build(), "add").setOutputs("out").build();
    String json = conf.toJson();
    System.out.println(json);
    ComputationGraphConfiguration conf2 = ComputationGraphConfiguration.fromJson(json);
    assertEquals(json, conf2.toJson());
    assertEquals(conf, conf2);
}
Also used : DenseLayer(org.deeplearning4j.nn.conf.layers.DenseLayer) ElementWiseVertex(org.deeplearning4j.nn.conf.graph.ElementWiseVertex) MergeVertex(org.deeplearning4j.nn.conf.graph.MergeVertex) ConvolutionLayer(org.deeplearning4j.nn.conf.layers.ConvolutionLayer) SubsetVertex(org.deeplearning4j.nn.conf.graph.SubsetVertex) Test(org.junit.Test)

Example 3 with MergeVertex

use of org.deeplearning4j.nn.conf.graph.MergeVertex in project deeplearning4j by deeplearning4j.

the class TransferLearningHelperTest method testFitUnFrozen.

@Test
public void testFitUnFrozen() {
    NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().learningRate(0.9).seed(124).activation(Activation.IDENTITY).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.SGD);
    ComputationGraphConfiguration conf = overallConf.graphBuilder().addInputs("inCentre", "inRight").addLayer("denseCentre0", new DenseLayer.Builder().nIn(10).nOut(9).build(), "inCentre").addLayer("denseCentre1", new DenseLayer.Builder().nIn(9).nOut(8).build(), "denseCentre0").addLayer("denseCentre2", new DenseLayer.Builder().nIn(8).nOut(7).build(), "denseCentre1").addLayer("denseCentre3", new DenseLayer.Builder().nIn(7).nOut(7).build(), "denseCentre2").addLayer("outCentre", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(7).nOut(4).build(), "denseCentre3").addVertex("subsetLeft", new SubsetVertex(0, 3), "denseCentre1").addLayer("denseLeft0", new DenseLayer.Builder().nIn(4).nOut(5).build(), "subsetLeft").addLayer("outLeft", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(5).nOut(6).build(), "denseLeft0").addLayer("denseRight", new DenseLayer.Builder().nIn(7).nOut(7).build(), "denseCentre2").addLayer("denseRight0", new DenseLayer.Builder().nIn(2).nOut(3).build(), "inRight").addVertex("mergeRight", new MergeVertex(), "denseRight", "denseRight0").addLayer("denseRight1", new DenseLayer.Builder().nIn(10).nOut(5).build(), "mergeRight").addLayer("outRight", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(5).nOut(5).build(), "denseRight1").setOutputs("outLeft", "outCentre", "outRight").build();
    ComputationGraph modelToTune = new ComputationGraph(conf);
    modelToTune.init();
    INDArray inRight = Nd4j.rand(10, 2);
    INDArray inCentre = Nd4j.rand(10, 10);
    INDArray outLeft = Nd4j.rand(10, 6);
    INDArray outRight = Nd4j.rand(10, 5);
    INDArray outCentre = Nd4j.rand(10, 4);
    MultiDataSet origData = new MultiDataSet(new INDArray[] { inCentre, inRight }, new INDArray[] { outLeft, outCentre, outRight });
    ComputationGraph modelIdentical = modelToTune.clone();
    modelIdentical.getVertex("denseCentre0").setLayerAsFrozen();
    modelIdentical.getVertex("denseCentre1").setLayerAsFrozen();
    modelIdentical.getVertex("denseCentre2").setLayerAsFrozen();
    TransferLearningHelper helper = new TransferLearningHelper(modelToTune, "denseCentre2");
    MultiDataSet featurizedDataSet = helper.featurize(origData);
    assertEquals(modelIdentical.getLayer("denseRight0").params(), modelToTune.getLayer("denseRight0").params());
    modelIdentical.fit(origData);
    helper.fitFeaturized(featurizedDataSet);
    assertEquals(modelIdentical.getLayer("denseCentre0").params(), modelToTune.getLayer("denseCentre0").params());
    assertEquals(modelIdentical.getLayer("denseCentre1").params(), modelToTune.getLayer("denseCentre1").params());
    assertEquals(modelIdentical.getLayer("denseCentre2").params(), modelToTune.getLayer("denseCentre2").params());
    assertEquals(modelIdentical.getLayer("denseCentre3").params(), modelToTune.getLayer("denseCentre3").params());
    assertEquals(modelIdentical.getLayer("outCentre").params(), modelToTune.getLayer("outCentre").params());
    assertEquals(modelIdentical.getLayer("denseRight").conf().toJson(), modelToTune.getLayer("denseRight").conf().toJson());
    assertEquals(modelIdentical.getLayer("denseRight").params(), modelToTune.getLayer("denseRight").params());
    assertEquals(modelIdentical.getLayer("denseRight0").conf().toJson(), modelToTune.getLayer("denseRight0").conf().toJson());
    //assertEquals(modelIdentical.getLayer("denseRight0").params(),modelToTune.getLayer("denseRight0").params());
    assertEquals(modelIdentical.getLayer("denseRight1").params(), modelToTune.getLayer("denseRight1").params());
    assertEquals(modelIdentical.getLayer("outRight").params(), modelToTune.getLayer("outRight").params());
    assertEquals(modelIdentical.getLayer("denseLeft0").params(), modelToTune.getLayer("denseLeft0").params());
    assertEquals(modelIdentical.getLayer("outLeft").params(), modelToTune.getLayer("outLeft").params());
    log.info(modelIdentical.summary());
    log.info(helper.unfrozenGraph().summary());
}
Also used : OutputLayer(org.deeplearning4j.nn.conf.layers.OutputLayer) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) MergeVertex(org.deeplearning4j.nn.conf.graph.MergeVertex) SubsetVertex(org.deeplearning4j.nn.conf.graph.SubsetVertex) DenseLayer(org.deeplearning4j.nn.conf.layers.DenseLayer) INDArray(org.nd4j.linalg.api.ndarray.INDArray) MultiDataSet(org.nd4j.linalg.dataset.MultiDataSet) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) Test(org.junit.Test)

Example 4 with MergeVertex

use of org.deeplearning4j.nn.conf.graph.MergeVertex in project deeplearning4j by deeplearning4j.

the class TransferLearningComplex method testSimplerMergeBackProp.

@Test
public void testSimplerMergeBackProp() {
    NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().learningRate(0.9).activation(Activation.IDENTITY).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.SGD);
    /*
                inCentre                inRight
                   |                        |
             denseCentre0               denseRight0
                   |                        |
                   |------ mergeRight ------|
                                |
                              outRight
        
        */
    ComputationGraphConfiguration conf = overallConf.graphBuilder().addInputs("inCentre", "inRight").addLayer("denseCentre0", new DenseLayer.Builder().nIn(2).nOut(2).build(), "inCentre").addLayer("denseRight0", new DenseLayer.Builder().nIn(2).nOut(2).build(), "inRight").addVertex("mergeRight", new MergeVertex(), "denseCentre0", "denseRight0").addLayer("outRight", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(4).nOut(2).build(), "mergeRight").setOutputs("outRight").build();
    ComputationGraph modelToTune = new ComputationGraph(conf);
    modelToTune.init();
    MultiDataSet randData = new MultiDataSet(new INDArray[] { Nd4j.rand(2, 2), Nd4j.rand(2, 2) }, new INDArray[] { Nd4j.rand(2, 2) });
    INDArray denseCentre0 = modelToTune.feedForward(randData.getFeatures(), false).get("denseCentre0");
    MultiDataSet otherRandData = new MultiDataSet(new INDArray[] { denseCentre0, randData.getFeatures(1) }, randData.getLabels());
    ComputationGraphConfiguration otherConf = overallConf.graphBuilder().addInputs("denseCentre0", "inRight").addLayer("denseRight0", new DenseLayer.Builder().nIn(2).nOut(2).build(), "inRight").addVertex("mergeRight", new MergeVertex(), "denseCentre0", "denseRight0").addLayer("outRight", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(4).nOut(2).build(), "mergeRight").setOutputs("outRight").build();
    ComputationGraph modelOther = new ComputationGraph(otherConf);
    modelOther.init();
    modelOther.getLayer("denseRight0").setParams(modelToTune.getLayer("denseRight0").params());
    modelOther.getLayer("outRight").setParams(modelToTune.getLayer("outRight").params());
    modelToTune.getVertex("denseCentre0").setLayerAsFrozen();
    ComputationGraph modelNow = new TransferLearning.GraphBuilder(modelToTune).setFeatureExtractor("denseCentre0").build();
    int n = 0;
    while (n < 5) {
        if (n == 0) {
            //confirm activations out of the merge are equivalent
            assertEquals(modelToTune.feedForward(randData.getFeatures(), false).get("mergeRight"), modelOther.feedForward(otherRandData.getFeatures(), false).get("mergeRight"));
            assertEquals(modelNow.feedForward(randData.getFeatures(), false).get("mergeRight"), modelOther.feedForward(otherRandData.getFeatures(), false).get("mergeRight"));
        }
        //confirm activations out of frozen vertex is the same as the input to the other model
        modelOther.fit(otherRandData);
        modelToTune.fit(randData);
        modelNow.fit(randData);
        assertEquals(otherRandData.getFeatures(0), modelNow.feedForward(randData.getFeatures(), false).get("denseCentre0"));
        assertEquals(otherRandData.getFeatures(0), modelToTune.feedForward(randData.getFeatures(), false).get("denseCentre0"));
        assertEquals(modelOther.getLayer("denseRight0").params(), modelNow.getLayer("denseRight0").params());
        assertEquals(modelOther.getLayer("denseRight0").params(), modelToTune.getLayer("denseRight0").params());
        assertEquals(modelOther.getLayer("outRight").params(), modelNow.getLayer("outRight").params());
        assertEquals(modelOther.getLayer("outRight").params(), modelToTune.getLayer("outRight").params());
        n++;
    }
}
Also used : DenseLayer(org.deeplearning4j.nn.conf.layers.DenseLayer) MultiDataSet(org.nd4j.linalg.dataset.MultiDataSet) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) MergeVertex(org.deeplearning4j.nn.conf.graph.MergeVertex) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) Test(org.junit.Test)

Example 5 with MergeVertex

use of org.deeplearning4j.nn.conf.graph.MergeVertex in project deeplearning4j by deeplearning4j.

the class TransferLearningComplex method testLessSimpleMergeBackProp.

@Test
public void testLessSimpleMergeBackProp() {
    NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().learningRate(0.9).activation(Activation.IDENTITY).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.SGD);
    /*
                inCentre                inRight
                   |                        |
             denseCentre0               denseRight0
                   |                        |
                   |------ mergeRight ------|
                   |            |
                 outCentre     outRight
        
        */
    ComputationGraphConfiguration conf = overallConf.graphBuilder().addInputs("inCentre", "inRight").addLayer("denseCentre0", new DenseLayer.Builder().nIn(2).nOut(2).build(), "inCentre").addLayer("outCentre", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(2).nOut(2).build(), "denseCentre0").addLayer("denseRight0", new DenseLayer.Builder().nIn(3).nOut(2).build(), "inRight").addVertex("mergeRight", new MergeVertex(), "denseCentre0", "denseRight0").addLayer("outRight", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(4).nOut(2).build(), "mergeRight").setOutputs("outRight").setOutputs("outCentre").build();
    ComputationGraph modelToTune = new ComputationGraph(conf);
    modelToTune.init();
    modelToTune.getVertex("denseCentre0").setLayerAsFrozen();
    MultiDataSet randData = new MultiDataSet(new INDArray[] { Nd4j.rand(2, 2), Nd4j.rand(2, 3) }, new INDArray[] { Nd4j.rand(2, 2), Nd4j.rand(2, 2) });
    INDArray denseCentre0 = modelToTune.feedForward(randData.getFeatures(), false).get("denseCentre0");
    MultiDataSet otherRandData = new MultiDataSet(new INDArray[] { denseCentre0, randData.getFeatures(1) }, randData.getLabels());
    ComputationGraph modelNow = new TransferLearning.GraphBuilder(modelToTune).setFeatureExtractor("denseCentre0").build();
    assertTrue(modelNow.getLayer("denseCentre0") instanceof FrozenLayer);
    int n = 0;
    while (n < 5) {
        if (n == 0) {
            //confirm activations out of the merge are equivalent
            assertEquals(modelToTune.feedForward(randData.getFeatures(), false).get("mergeRight"), modelNow.feedForward(otherRandData.getFeatures(), false).get("mergeRight"));
        }
        //confirm activations out of frozen vertex is the same as the input to the other model
        modelToTune.fit(randData);
        modelNow.fit(randData);
        assertEquals(otherRandData.getFeatures(0), modelNow.feedForward(randData.getFeatures(), false).get("denseCentre0"));
        assertEquals(otherRandData.getFeatures(0), modelToTune.feedForward(randData.getFeatures(), false).get("denseCentre0"));
        assertEquals(modelToTune.getLayer("denseRight0").params(), modelNow.getLayer("denseRight0").params());
        assertEquals(modelToTune.getLayer("outRight").params(), modelNow.getLayer("outRight").params());
        assertEquals(modelToTune.getLayer("outCentre").params(), modelNow.getLayer("outCentre").params());
        n++;
    }
}
Also used : OutputLayer(org.deeplearning4j.nn.conf.layers.OutputLayer) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) MergeVertex(org.deeplearning4j.nn.conf.graph.MergeVertex) FrozenLayer(org.deeplearning4j.nn.layers.FrozenLayer) DenseLayer(org.deeplearning4j.nn.conf.layers.DenseLayer) MultiDataSet(org.nd4j.linalg.dataset.MultiDataSet) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) Test(org.junit.Test)

Aggregations

MergeVertex (org.deeplearning4j.nn.conf.graph.MergeVertex)8 Test (org.junit.Test)8 DenseLayer (org.deeplearning4j.nn.conf.layers.DenseLayer)6 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)5 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)5 ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)5 SubsetVertex (org.deeplearning4j.nn.conf.graph.SubsetVertex)4 MultiDataSet (org.nd4j.linalg.dataset.MultiDataSet)4 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)3 INDArray (org.nd4j.linalg.api.ndarray.INDArray)3 ElementWiseVertex (org.deeplearning4j.nn.conf.graph.ElementWiseVertex)1 LayerVertex (org.deeplearning4j.nn.conf.graph.LayerVertex)1 ConvolutionLayer (org.deeplearning4j.nn.conf.layers.ConvolutionLayer)1 CnnToFeedForwardPreProcessor (org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor)1 FeedForwardToRnnPreProcessor (org.deeplearning4j.nn.conf.preprocessor.FeedForwardToRnnPreProcessor)1 RnnToFeedForwardPreProcessor (org.deeplearning4j.nn.conf.preprocessor.RnnToFeedForwardPreProcessor)1 FrozenLayer (org.deeplearning4j.nn.layers.FrozenLayer)1