Search in sources :

Example 1 with SubsetVertex

use of org.deeplearning4j.nn.conf.graph.SubsetVertex in project deeplearning4j by deeplearning4j.

the class ComputationGraphConfigurationTest method testJSONWithGraphNodes.

@Test
public void testJSONWithGraphNodes() {
    ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder().addInputs("input1", "input2").addLayer("cnn1", new ConvolutionLayer.Builder(2, 2).stride(2, 2).nIn(1).nOut(5).build(), "input1").addLayer("cnn2", new ConvolutionLayer.Builder(2, 2).stride(2, 2).nIn(1).nOut(5).build(), "input2").addVertex("merge1", new MergeVertex(), "cnn1", "cnn2").addVertex("subset1", new SubsetVertex(0, 1), "merge1").addLayer("dense1", new DenseLayer.Builder().nIn(20).nOut(5).build(), "subset1").addLayer("dense2", new DenseLayer.Builder().nIn(20).nOut(5).build(), "subset1").addVertex("add", new ElementWiseVertex(ElementWiseVertex.Op.Add), "dense1", "dense2").addLayer("out", new OutputLayer.Builder().nIn(1).nOut(1).build(), "add").setOutputs("out").build();
    String json = conf.toJson();
    System.out.println(json);
    ComputationGraphConfiguration conf2 = ComputationGraphConfiguration.fromJson(json);
    assertEquals(json, conf2.toJson());
    assertEquals(conf, conf2);
}
Also used : DenseLayer(org.deeplearning4j.nn.conf.layers.DenseLayer) ElementWiseVertex(org.deeplearning4j.nn.conf.graph.ElementWiseVertex) MergeVertex(org.deeplearning4j.nn.conf.graph.MergeVertex) ConvolutionLayer(org.deeplearning4j.nn.conf.layers.ConvolutionLayer) SubsetVertex(org.deeplearning4j.nn.conf.graph.SubsetVertex) Test(org.junit.Test)

Example 2 with SubsetVertex

use of org.deeplearning4j.nn.conf.graph.SubsetVertex in project deeplearning4j by deeplearning4j.

the class TransferLearningHelperTest method testFitUnFrozen.

@Test
public void testFitUnFrozen() {
    NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().learningRate(0.9).seed(124).activation(Activation.IDENTITY).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.SGD);
    ComputationGraphConfiguration conf = overallConf.graphBuilder().addInputs("inCentre", "inRight").addLayer("denseCentre0", new DenseLayer.Builder().nIn(10).nOut(9).build(), "inCentre").addLayer("denseCentre1", new DenseLayer.Builder().nIn(9).nOut(8).build(), "denseCentre0").addLayer("denseCentre2", new DenseLayer.Builder().nIn(8).nOut(7).build(), "denseCentre1").addLayer("denseCentre3", new DenseLayer.Builder().nIn(7).nOut(7).build(), "denseCentre2").addLayer("outCentre", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(7).nOut(4).build(), "denseCentre3").addVertex("subsetLeft", new SubsetVertex(0, 3), "denseCentre1").addLayer("denseLeft0", new DenseLayer.Builder().nIn(4).nOut(5).build(), "subsetLeft").addLayer("outLeft", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(5).nOut(6).build(), "denseLeft0").addLayer("denseRight", new DenseLayer.Builder().nIn(7).nOut(7).build(), "denseCentre2").addLayer("denseRight0", new DenseLayer.Builder().nIn(2).nOut(3).build(), "inRight").addVertex("mergeRight", new MergeVertex(), "denseRight", "denseRight0").addLayer("denseRight1", new DenseLayer.Builder().nIn(10).nOut(5).build(), "mergeRight").addLayer("outRight", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(5).nOut(5).build(), "denseRight1").setOutputs("outLeft", "outCentre", "outRight").build();
    ComputationGraph modelToTune = new ComputationGraph(conf);
    modelToTune.init();
    INDArray inRight = Nd4j.rand(10, 2);
    INDArray inCentre = Nd4j.rand(10, 10);
    INDArray outLeft = Nd4j.rand(10, 6);
    INDArray outRight = Nd4j.rand(10, 5);
    INDArray outCentre = Nd4j.rand(10, 4);
    MultiDataSet origData = new MultiDataSet(new INDArray[] { inCentre, inRight }, new INDArray[] { outLeft, outCentre, outRight });
    ComputationGraph modelIdentical = modelToTune.clone();
    modelIdentical.getVertex("denseCentre0").setLayerAsFrozen();
    modelIdentical.getVertex("denseCentre1").setLayerAsFrozen();
    modelIdentical.getVertex("denseCentre2").setLayerAsFrozen();
    TransferLearningHelper helper = new TransferLearningHelper(modelToTune, "denseCentre2");
    MultiDataSet featurizedDataSet = helper.featurize(origData);
    assertEquals(modelIdentical.getLayer("denseRight0").params(), modelToTune.getLayer("denseRight0").params());
    modelIdentical.fit(origData);
    helper.fitFeaturized(featurizedDataSet);
    assertEquals(modelIdentical.getLayer("denseCentre0").params(), modelToTune.getLayer("denseCentre0").params());
    assertEquals(modelIdentical.getLayer("denseCentre1").params(), modelToTune.getLayer("denseCentre1").params());
    assertEquals(modelIdentical.getLayer("denseCentre2").params(), modelToTune.getLayer("denseCentre2").params());
    assertEquals(modelIdentical.getLayer("denseCentre3").params(), modelToTune.getLayer("denseCentre3").params());
    assertEquals(modelIdentical.getLayer("outCentre").params(), modelToTune.getLayer("outCentre").params());
    assertEquals(modelIdentical.getLayer("denseRight").conf().toJson(), modelToTune.getLayer("denseRight").conf().toJson());
    assertEquals(modelIdentical.getLayer("denseRight").params(), modelToTune.getLayer("denseRight").params());
    assertEquals(modelIdentical.getLayer("denseRight0").conf().toJson(), modelToTune.getLayer("denseRight0").conf().toJson());
    //assertEquals(modelIdentical.getLayer("denseRight0").params(),modelToTune.getLayer("denseRight0").params());
    assertEquals(modelIdentical.getLayer("denseRight1").params(), modelToTune.getLayer("denseRight1").params());
    assertEquals(modelIdentical.getLayer("outRight").params(), modelToTune.getLayer("outRight").params());
    assertEquals(modelIdentical.getLayer("denseLeft0").params(), modelToTune.getLayer("denseLeft0").params());
    assertEquals(modelIdentical.getLayer("outLeft").params(), modelToTune.getLayer("outLeft").params());
    log.info(modelIdentical.summary());
    log.info(helper.unfrozenGraph().summary());
}
Also used : OutputLayer(org.deeplearning4j.nn.conf.layers.OutputLayer) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) MergeVertex(org.deeplearning4j.nn.conf.graph.MergeVertex) SubsetVertex(org.deeplearning4j.nn.conf.graph.SubsetVertex) DenseLayer(org.deeplearning4j.nn.conf.layers.DenseLayer) INDArray(org.nd4j.linalg.api.ndarray.INDArray) MultiDataSet(org.nd4j.linalg.dataset.MultiDataSet) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) Test(org.junit.Test)

Example 3 with SubsetVertex

use of org.deeplearning4j.nn.conf.graph.SubsetVertex in project deeplearning4j by deeplearning4j.

the class TestComputationGraphNetwork method printSummary.

@Test
public void printSummary() {
    NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().learningRate(0.1).activation(Activation.IDENTITY).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.SGD);
    ComputationGraphConfiguration conf = overallConf.graphBuilder().addInputs("inCentre", "inRight").addLayer("denseCentre0", new DenseLayer.Builder().nIn(10).nOut(9).build(), "inCentre").addLayer("denseCentre1", new DenseLayer.Builder().nIn(9).nOut(8).build(), "denseCentre0").addLayer("denseCentre2", new DenseLayer.Builder().nIn(8).nOut(7).build(), "denseCentre1").addLayer("denseCentre3", new DenseLayer.Builder().nIn(7).nOut(7).build(), "denseCentre2").addLayer("outCentre", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(7).nOut(4).build(), "denseCentre3").addVertex("subsetLeft", new SubsetVertex(0, 3), "denseCentre1").addLayer("denseLeft0", new DenseLayer.Builder().nIn(4).nOut(5).build(), "subsetLeft").addLayer("outLeft", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(5).nOut(6).build(), "denseLeft0").addLayer("denseRight", new DenseLayer.Builder().nIn(7).nOut(7).build(), "denseCentre2").addLayer("denseRight0", new DenseLayer.Builder().nIn(2).nOut(3).build(), "inRight").addVertex("mergeRight", new MergeVertex(), "denseRight", "denseRight0").addLayer("denseRight1", new DenseLayer.Builder().nIn(10).nOut(5).build(), "mergeRight").addLayer("outRight", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(5).nOut(5).build(), "denseRight1").setOutputs("outLeft", "outCentre", "outRight").build();
    ComputationGraph modelToTune = new ComputationGraph(conf);
    modelToTune.init();
    System.out.println(modelToTune.summary());
    ComputationGraph modelNow = new TransferLearning.GraphBuilder(modelToTune).setFeatureExtractor("denseCentre2").build();
    System.out.println(modelNow.summary());
}
Also used : MergeVertex(org.deeplearning4j.nn.conf.graph.MergeVertex) SubsetVertex(org.deeplearning4j.nn.conf.graph.SubsetVertex) Test(org.junit.Test)

Example 4 with SubsetVertex

use of org.deeplearning4j.nn.conf.graph.SubsetVertex in project deeplearning4j by deeplearning4j.

the class TransferLearningHelperTest method tesUnfrozenSubset.

@Test
public void tesUnfrozenSubset() {
    NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().learningRate(0.1).seed(124).activation(Activation.IDENTITY).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.SGD);
    /*
                             (inCentre)                        (inRight)
                                |                                |
                            denseCentre0                         |
                                |                                |
                 ,--------  denseCentre1                       denseRight0
                /               |                                |
        subsetLeft(0-3)    denseCentre2 ---- denseRight ----  mergeRight
              |                 |                                |
         denseLeft0        denseCentre3                        denseRight1
              |                 |                                |
          (outLeft)         (outCentre)                        (outRight)
        
         */
    ComputationGraphConfiguration conf = overallConf.graphBuilder().addInputs("inCentre", "inRight").addLayer("denseCentre0", new DenseLayer.Builder().nIn(10).nOut(9).build(), "inCentre").addLayer("denseCentre1", new DenseLayer.Builder().nIn(9).nOut(8).build(), "denseCentre0").addLayer("denseCentre2", new DenseLayer.Builder().nIn(8).nOut(7).build(), "denseCentre1").addLayer("denseCentre3", new DenseLayer.Builder().nIn(7).nOut(7).build(), "denseCentre2").addLayer("outCentre", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(7).nOut(4).build(), "denseCentre3").addVertex("subsetLeft", new SubsetVertex(0, 3), "denseCentre1").addLayer("denseLeft0", new DenseLayer.Builder().nIn(4).nOut(5).build(), "subsetLeft").addLayer("outLeft", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(5).nOut(6).build(), "denseLeft0").addLayer("denseRight", new DenseLayer.Builder().nIn(7).nOut(7).build(), "denseCentre2").addLayer("denseRight0", new DenseLayer.Builder().nIn(2).nOut(3).build(), "inRight").addVertex("mergeRight", new MergeVertex(), "denseRight", "denseRight0").addLayer("denseRight1", new DenseLayer.Builder().nIn(10).nOut(5).build(), "mergeRight").addLayer("outRight", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(5).nOut(5).build(), "denseRight1").setOutputs("outLeft", "outCentre", "outRight").build();
    ComputationGraph modelToTune = new ComputationGraph(conf);
    modelToTune.init();
    TransferLearningHelper helper = new TransferLearningHelper(modelToTune, "denseCentre2");
    ComputationGraph modelSubset = helper.unfrozenGraph();
    ComputationGraphConfiguration expectedConf = //inputs are in sorted order
    overallConf.graphBuilder().addInputs("denseCentre1", "denseCentre2", "inRight").addLayer("denseCentre3", new DenseLayer.Builder().nIn(7).nOut(7).build(), "denseCentre2").addLayer("outCentre", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(7).nOut(4).build(), "denseCentre3").addVertex("subsetLeft", new SubsetVertex(0, 3), "denseCentre1").addLayer("denseLeft0", new DenseLayer.Builder().nIn(4).nOut(5).build(), "subsetLeft").addLayer("outLeft", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(5).nOut(6).build(), "denseLeft0").addLayer("denseRight", new DenseLayer.Builder().nIn(7).nOut(7).build(), "denseCentre2").addLayer("denseRight0", new DenseLayer.Builder().nIn(2).nOut(3).build(), "inRight").addVertex("mergeRight", new MergeVertex(), "denseRight", "denseRight0").addLayer("denseRight1", new DenseLayer.Builder().nIn(10).nOut(5).build(), "mergeRight").addLayer("outRight", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(5).nOut(5).build(), "denseRight1").setOutputs("outLeft", "outCentre", "outRight").build();
    ComputationGraph expectedModel = new ComputationGraph(expectedConf);
    expectedModel.init();
    assertEquals(expectedConf.toJson(), modelSubset.getConfiguration().toJson());
}
Also used : OutputLayer(org.deeplearning4j.nn.conf.layers.OutputLayer) DenseLayer(org.deeplearning4j.nn.conf.layers.DenseLayer) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) MergeVertex(org.deeplearning4j.nn.conf.graph.MergeVertex) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) SubsetVertex(org.deeplearning4j.nn.conf.graph.SubsetVertex) Test(org.junit.Test)

Aggregations

MergeVertex (org.deeplearning4j.nn.conf.graph.MergeVertex)4 SubsetVertex (org.deeplearning4j.nn.conf.graph.SubsetVertex)4 Test (org.junit.Test)4 DenseLayer (org.deeplearning4j.nn.conf.layers.DenseLayer)3 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)2 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)2 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)2 ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)2 ElementWiseVertex (org.deeplearning4j.nn.conf.graph.ElementWiseVertex)1 ConvolutionLayer (org.deeplearning4j.nn.conf.layers.ConvolutionLayer)1 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1 MultiDataSet (org.nd4j.linalg.dataset.MultiDataSet)1