use of org.deeplearning4j.nn.conf.graph.ElementWiseVertex in project deeplearning4j by deeplearning4j.
the class ComputationGraphConfigurationTest method testJSONWithGraphNodes.
@Test
public void testJSONWithGraphNodes() {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder().addInputs("input1", "input2").addLayer("cnn1", new ConvolutionLayer.Builder(2, 2).stride(2, 2).nIn(1).nOut(5).build(), "input1").addLayer("cnn2", new ConvolutionLayer.Builder(2, 2).stride(2, 2).nIn(1).nOut(5).build(), "input2").addVertex("merge1", new MergeVertex(), "cnn1", "cnn2").addVertex("subset1", new SubsetVertex(0, 1), "merge1").addLayer("dense1", new DenseLayer.Builder().nIn(20).nOut(5).build(), "subset1").addLayer("dense2", new DenseLayer.Builder().nIn(20).nOut(5).build(), "subset1").addVertex("add", new ElementWiseVertex(ElementWiseVertex.Op.Add), "dense1", "dense2").addLayer("out", new OutputLayer.Builder().nIn(1).nOut(1).build(), "add").setOutputs("out").build();
String json = conf.toJson();
System.out.println(json);
ComputationGraphConfiguration conf2 = ComputationGraphConfiguration.fromJson(json);
assertEquals(json, conf2.toJson());
assertEquals(conf, conf2);
}
use of org.deeplearning4j.nn.conf.graph.ElementWiseVertex in project deeplearning4j by deeplearning4j.
the class TestGraphNodes method testJSON.
@Test
public void testJSON() {
//The config here is non-sense, but that doesn't matter for config -> json -> config test
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in").addVertex("v1", new ElementWiseVertex(ElementWiseVertex.Op.Add), "in").addVertex("v2", new org.deeplearning4j.nn.conf.graph.MergeVertex(), "in", "in").addVertex("v3", new PreprocessorVertex(new CnnToFeedForwardPreProcessor(1, 2, 1)), "in").addVertex("v4", new org.deeplearning4j.nn.conf.graph.SubsetVertex(0, 1), "in").addVertex("v5", new DuplicateToTimeSeriesVertex("in"), "in").addVertex("v6", new LastTimeStepVertex("in"), "in").addVertex("v7", new org.deeplearning4j.nn.conf.graph.StackVertex(), "in").addVertex("v8", new org.deeplearning4j.nn.conf.graph.UnstackVertex(0, 1), "in").addLayer("out", new OutputLayer.Builder().nIn(1).nOut(1).build(), "in").setOutputs("out").build();
String json = conf.toJson();
ComputationGraphConfiguration conf2 = ComputationGraphConfiguration.fromJson(json);
assertEquals(conf, conf2);
}
Aggregations