Search in sources :

Example 1 with GraphVertex

use of org.deeplearning4j.nn.graph.vertex.GraphVertex in project deeplearning4j by deeplearning4j.

the class TestGraphNodes method testDuplicateToTimeSeriesVertex.

@Test
public void testDuplicateToTimeSeriesVertex() {
    ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in2d", "in3d").addVertex("duplicateTS", new DuplicateToTimeSeriesVertex("in3d"), "in2d").addLayer("out", new OutputLayer.Builder().nIn(1).nOut(1).build(), "duplicateTS").setOutputs("out").build();
    ComputationGraph graph = new ComputationGraph(conf);
    graph.init();
    INDArray in2d = Nd4j.rand(3, 5);
    INDArray in3d = Nd4j.rand(new int[] { 3, 2, 7 });
    graph.setInputs(in2d, in3d);
    INDArray expOut = Nd4j.zeros(3, 5, 7);
    for (int i = 0; i < 7; i++) {
        expOut.put(new INDArrayIndex[] { NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(i) }, in2d);
    }
    GraphVertex gv = graph.getVertex("duplicateTS");
    gv.setInputs(in2d);
    INDArray outFwd = gv.doForward(true);
    assertEquals(expOut, outFwd);
    INDArray expOutBackward = expOut.sum(2);
    gv.setEpsilon(expOut);
    INDArray outBwd = gv.doBackward(false).getSecond()[0];
    assertEquals(expOutBackward, outBwd);
    String json = conf.toJson();
    ComputationGraphConfiguration conf2 = ComputationGraphConfiguration.fromJson(json);
    assertEquals(conf, conf2);
}
Also used : GraphVertex(org.deeplearning4j.nn.graph.vertex.GraphVertex) DuplicateToTimeSeriesVertex(org.deeplearning4j.nn.conf.graph.rnn.DuplicateToTimeSeriesVertex) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) Test(org.junit.Test)

Example 2 with GraphVertex

use of org.deeplearning4j.nn.graph.vertex.GraphVertex in project deeplearning4j by deeplearning4j.

the class TestGraphNodes method testSubsetNode.

@Test
public void testSubsetNode() {
    Nd4j.getRandom().setSeed(12345);
    GraphVertex subset = new SubsetVertex(null, "", -1, 4, 7);
    INDArray in = Nd4j.rand(5, 10);
    subset.setInputs(in);
    INDArray out = subset.doForward(false);
    assertEquals(in.get(NDArrayIndex.all(), NDArrayIndex.interval(4, 7, true)), out);
    subset.setEpsilon(out);
    INDArray backward = subset.doBackward(false).getSecond()[0];
    assertEquals(Nd4j.zeros(5, 4), backward.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 3, true)));
    assertEquals(out, backward.get(NDArrayIndex.all(), NDArrayIndex.interval(4, 7, true)));
    assertEquals(Nd4j.zeros(5, 2), backward.get(NDArrayIndex.all(), NDArrayIndex.interval(8, 9, true)));
    //Test same for CNNs:
    in = Nd4j.rand(new int[] { 5, 10, 3, 3 });
    subset.setInputs(in);
    out = subset.doForward(false);
    assertEquals(in.get(NDArrayIndex.all(), NDArrayIndex.interval(4, 7, true), NDArrayIndex.all(), NDArrayIndex.all()), out);
    subset.setEpsilon(out);
    backward = subset.doBackward(false).getSecond()[0];
    assertEquals(Nd4j.zeros(5, 4, 3, 3), backward.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 3, true), NDArrayIndex.all(), NDArrayIndex.all()));
    assertEquals(out, backward.get(NDArrayIndex.all(), NDArrayIndex.interval(4, 7, true), NDArrayIndex.all(), NDArrayIndex.all()));
    assertEquals(Nd4j.zeros(5, 2, 3, 3), backward.get(NDArrayIndex.all(), NDArrayIndex.interval(8, 9, true), NDArrayIndex.all(), NDArrayIndex.all()));
}
Also used : GraphVertex(org.deeplearning4j.nn.graph.vertex.GraphVertex) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Test(org.junit.Test)

Example 3 with GraphVertex

use of org.deeplearning4j.nn.graph.vertex.GraphVertex in project deeplearning4j by deeplearning4j.

the class TestGraphNodes method testStackNode.

@Test
public void testStackNode() {
    Nd4j.getRandom().setSeed(12345);
    GraphVertex unstack = new StackVertex(null, "", -1);
    INDArray in1 = Nd4j.rand(5, 2);
    INDArray in2 = Nd4j.rand(5, 2);
    INDArray in3 = Nd4j.rand(5, 2);
    unstack.setInputs(in1, in2, in3);
    INDArray out = unstack.doForward(false);
    assertEquals(in1, out.get(NDArrayIndex.interval(0, 5), NDArrayIndex.all()));
    assertEquals(in2, out.get(NDArrayIndex.interval(5, 10), NDArrayIndex.all()));
    assertEquals(in3, out.get(NDArrayIndex.interval(10, 15), NDArrayIndex.all()));
    unstack.setEpsilon(out);
    Pair<Gradient, INDArray[]> b = unstack.doBackward(false);
    assertEquals(in1, b.getSecond()[0]);
    assertEquals(in2, b.getSecond()[1]);
    assertEquals(in3, b.getSecond()[2]);
}
Also used : Gradient(org.deeplearning4j.nn.gradient.Gradient) GraphVertex(org.deeplearning4j.nn.graph.vertex.GraphVertex) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Test(org.junit.Test)

Example 4 with GraphVertex

use of org.deeplearning4j.nn.graph.vertex.GraphVertex in project deeplearning4j by deeplearning4j.

the class TestGraphNodes method testUnstackNode.

@Test
public void testUnstackNode() {
    Nd4j.getRandom().setSeed(12345);
    GraphVertex unstack0 = new UnstackVertex(null, "", -1, 0, 3);
    GraphVertex unstack1 = new UnstackVertex(null, "", -1, 1, 3);
    GraphVertex unstack2 = new UnstackVertex(null, "", -1, 2, 3);
    INDArray in = Nd4j.rand(15, 2);
    unstack0.setInputs(in);
    unstack1.setInputs(in);
    unstack2.setInputs(in);
    INDArray out0 = unstack0.doForward(false);
    INDArray out1 = unstack1.doForward(false);
    INDArray out2 = unstack2.doForward(false);
    assertEquals(in.get(NDArrayIndex.interval(0, 5), NDArrayIndex.all()), out0);
    assertEquals(in.get(NDArrayIndex.interval(5, 10), NDArrayIndex.all()), out1);
    assertEquals(in.get(NDArrayIndex.interval(10, 15), NDArrayIndex.all()), out2);
    unstack0.setEpsilon(out0);
    unstack1.setEpsilon(out1);
    unstack2.setEpsilon(out2);
    INDArray backward0 = unstack0.doBackward(false).getSecond()[0];
    INDArray backward1 = unstack1.doBackward(false).getSecond()[0];
    INDArray backward2 = unstack2.doBackward(false).getSecond()[0];
    assertEquals(out0, backward0.get(NDArrayIndex.interval(0, 5), NDArrayIndex.all()));
    assertEquals(Nd4j.zeros(5, 2), backward0.get(NDArrayIndex.interval(5, 10), NDArrayIndex.all()));
    assertEquals(Nd4j.zeros(5, 2), backward0.get(NDArrayIndex.interval(10, 15), NDArrayIndex.all()));
    assertEquals(Nd4j.zeros(5, 2), backward1.get(NDArrayIndex.interval(0, 5), NDArrayIndex.all()));
    assertEquals(out1, backward1.get(NDArrayIndex.interval(5, 10), NDArrayIndex.all()));
    assertEquals(Nd4j.zeros(5, 2), backward1.get(NDArrayIndex.interval(10, 15), NDArrayIndex.all()));
    assertEquals(Nd4j.zeros(5, 2), backward2.get(NDArrayIndex.interval(0, 5), NDArrayIndex.all()));
    assertEquals(Nd4j.zeros(5, 2), backward2.get(NDArrayIndex.interval(5, 10), NDArrayIndex.all()));
    assertEquals(out2, backward2.get(NDArrayIndex.interval(10, 15), NDArrayIndex.all()));
    //Test same for CNNs:
    in = Nd4j.rand(new int[] { 15, 10, 3, 3 });
    unstack0.setInputs(in);
    unstack1.setInputs(in);
    unstack2.setInputs(in);
    out0 = unstack0.doForward(false);
    out1 = unstack1.doForward(false);
    out2 = unstack2.doForward(false);
    assertEquals(in.get(NDArrayIndex.interval(0, 5), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()), out0);
    assertEquals(in.get(NDArrayIndex.interval(5, 10), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()), out1);
    assertEquals(in.get(NDArrayIndex.interval(10, 15), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()), out2);
    unstack0.setEpsilon(out0);
    unstack1.setEpsilon(out1);
    unstack2.setEpsilon(out2);
    backward0 = unstack0.doBackward(false).getSecond()[0];
    backward1 = unstack1.doBackward(false).getSecond()[0];
    backward2 = unstack2.doBackward(false).getSecond()[0];
    assertEquals(out0, backward0.get(NDArrayIndex.interval(0, 5), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()));
    assertEquals(Nd4j.zeros(5, 10, 3, 3), backward0.get(NDArrayIndex.interval(5, 10), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()));
    assertEquals(Nd4j.zeros(5, 10, 3, 3), backward0.get(NDArrayIndex.interval(10, 15), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()));
    assertEquals(Nd4j.zeros(5, 10, 3, 3), backward1.get(NDArrayIndex.interval(0, 5), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()));
    assertEquals(out1, backward1.get(NDArrayIndex.interval(5, 10), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()));
    assertEquals(Nd4j.zeros(5, 10, 3, 3), backward1.get(NDArrayIndex.interval(10, 15), NDArrayIndex.all()));
    assertEquals(Nd4j.zeros(5, 10, 3, 3), backward2.get(NDArrayIndex.interval(0, 5), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()));
    assertEquals(Nd4j.zeros(5, 10, 3, 3), backward2.get(NDArrayIndex.interval(5, 10), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()));
    assertEquals(out2, backward2.get(NDArrayIndex.interval(10, 15), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()));
}
Also used : GraphVertex(org.deeplearning4j.nn.graph.vertex.GraphVertex) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Test(org.junit.Test)

Example 5 with GraphVertex

use of org.deeplearning4j.nn.graph.vertex.GraphVertex in project deeplearning4j by deeplearning4j.

the class TestGraphNodes method testMergeNode.

@Test
public void testMergeNode() {
    Nd4j.getRandom().setSeed(12345);
    GraphVertex mergeNode = new MergeVertex(null, "", -1);
    INDArray first = Nd4j.linspace(0, 11, 12).reshape(3, 4);
    INDArray second = Nd4j.linspace(0, 17, 18).reshape(3, 6).addi(100);
    mergeNode.setInputs(first, second);
    INDArray out = mergeNode.doForward(false);
    assertArrayEquals(new int[] { 3, 10 }, out.shape());
    assertEquals(first, out.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 4)));
    assertEquals(second, out.get(NDArrayIndex.all(), NDArrayIndex.interval(4, 10)));
    mergeNode.setEpsilon(out);
    INDArray[] backward = mergeNode.doBackward(false).getSecond();
    assertEquals(first, backward[0]);
    assertEquals(second, backward[1]);
}
Also used : GraphVertex(org.deeplearning4j.nn.graph.vertex.GraphVertex) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Test(org.junit.Test)

Aggregations

GraphVertex (org.deeplearning4j.nn.graph.vertex.GraphVertex)22 INDArray (org.nd4j.linalg.api.ndarray.INDArray)19 VertexIndices (org.deeplearning4j.nn.graph.vertex.VertexIndices)9 Test (org.junit.Test)9 Layer (org.deeplearning4j.nn.api.Layer)8 IOutputLayer (org.deeplearning4j.nn.api.layers.IOutputLayer)8 FeedForwardLayer (org.deeplearning4j.nn.conf.layers.FeedForwardLayer)8 FrozenLayer (org.deeplearning4j.nn.layers.FrozenLayer)7 RecurrentLayer (org.deeplearning4j.nn.api.layers.RecurrentLayer)6 Gradient (org.deeplearning4j.nn.gradient.Gradient)4 ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)4 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)4 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)2 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)2 BaseOutputLayer (org.deeplearning4j.nn.conf.layers.BaseOutputLayer)2 SubsamplingLayer (org.deeplearning4j.nn.conf.layers.SubsamplingLayer)2 Pair (org.deeplearning4j.berkeley.Pair)1 Triple (org.deeplearning4j.berkeley.Triple)1 MaskState (org.deeplearning4j.nn.api.MaskState)1 DuplicateToTimeSeriesVertex (org.deeplearning4j.nn.conf.graph.rnn.DuplicateToTimeSeriesVertex)1