Search in sources :

Example 1 with ElementWiseVertex

use of org.deeplearning4j.nn.conf.graph.ElementWiseVertex in project deeplearning4j by deeplearning4j.

the class ComputationGraphConfigurationTest method testJSONWithGraphNodes.

@Test
public void testJSONWithGraphNodes() {
    ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder().addInputs("input1", "input2").addLayer("cnn1", new ConvolutionLayer.Builder(2, 2).stride(2, 2).nIn(1).nOut(5).build(), "input1").addLayer("cnn2", new ConvolutionLayer.Builder(2, 2).stride(2, 2).nIn(1).nOut(5).build(), "input2").addVertex("merge1", new MergeVertex(), "cnn1", "cnn2").addVertex("subset1", new SubsetVertex(0, 1), "merge1").addLayer("dense1", new DenseLayer.Builder().nIn(20).nOut(5).build(), "subset1").addLayer("dense2", new DenseLayer.Builder().nIn(20).nOut(5).build(), "subset1").addVertex("add", new ElementWiseVertex(ElementWiseVertex.Op.Add), "dense1", "dense2").addLayer("out", new OutputLayer.Builder().nIn(1).nOut(1).build(), "add").setOutputs("out").build();
    String json = conf.toJson();
    System.out.println(json);
    ComputationGraphConfiguration conf2 = ComputationGraphConfiguration.fromJson(json);
    assertEquals(json, conf2.toJson());
    assertEquals(conf, conf2);
}
Also used : DenseLayer(org.deeplearning4j.nn.conf.layers.DenseLayer) ElementWiseVertex(org.deeplearning4j.nn.conf.graph.ElementWiseVertex) MergeVertex(org.deeplearning4j.nn.conf.graph.MergeVertex) ConvolutionLayer(org.deeplearning4j.nn.conf.layers.ConvolutionLayer) SubsetVertex(org.deeplearning4j.nn.conf.graph.SubsetVertex) Test(org.junit.Test)

Example 2 with ElementWiseVertex

use of org.deeplearning4j.nn.conf.graph.ElementWiseVertex in project deeplearning4j by deeplearning4j.

the class TestGraphNodes method testJSON.

@Test
public void testJSON() {
    //The config here is non-sense, but that doesn't matter for config -> json -> config test
    ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in").addVertex("v1", new ElementWiseVertex(ElementWiseVertex.Op.Add), "in").addVertex("v2", new org.deeplearning4j.nn.conf.graph.MergeVertex(), "in", "in").addVertex("v3", new PreprocessorVertex(new CnnToFeedForwardPreProcessor(1, 2, 1)), "in").addVertex("v4", new org.deeplearning4j.nn.conf.graph.SubsetVertex(0, 1), "in").addVertex("v5", new DuplicateToTimeSeriesVertex("in"), "in").addVertex("v6", new LastTimeStepVertex("in"), "in").addVertex("v7", new org.deeplearning4j.nn.conf.graph.StackVertex(), "in").addVertex("v8", new org.deeplearning4j.nn.conf.graph.UnstackVertex(0, 1), "in").addLayer("out", new OutputLayer.Builder().nIn(1).nOut(1).build(), "in").setOutputs("out").build();
    String json = conf.toJson();
    ComputationGraphConfiguration conf2 = ComputationGraphConfiguration.fromJson(json);
    assertEquals(conf, conf2);
}
Also used : PreprocessorVertex(org.deeplearning4j.nn.conf.graph.PreprocessorVertex) CnnToFeedForwardPreProcessor(org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor) ElementWiseVertex(org.deeplearning4j.nn.conf.graph.ElementWiseVertex) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) DuplicateToTimeSeriesVertex(org.deeplearning4j.nn.conf.graph.rnn.DuplicateToTimeSeriesVertex) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) LastTimeStepVertex(org.deeplearning4j.nn.conf.graph.rnn.LastTimeStepVertex) Test(org.junit.Test)

Aggregations

ElementWiseVertex (org.deeplearning4j.nn.conf.graph.ElementWiseVertex)2 Test (org.junit.Test)2 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)1 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)1 MergeVertex (org.deeplearning4j.nn.conf.graph.MergeVertex)1 PreprocessorVertex (org.deeplearning4j.nn.conf.graph.PreprocessorVertex)1 SubsetVertex (org.deeplearning4j.nn.conf.graph.SubsetVertex)1 DuplicateToTimeSeriesVertex (org.deeplearning4j.nn.conf.graph.rnn.DuplicateToTimeSeriesVertex)1 LastTimeStepVertex (org.deeplearning4j.nn.conf.graph.rnn.LastTimeStepVertex)1 ConvolutionLayer (org.deeplearning4j.nn.conf.layers.ConvolutionLayer)1 DenseLayer (org.deeplearning4j.nn.conf.layers.DenseLayer)1 CnnToFeedForwardPreProcessor (org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor)1