use of org.deeplearning4j.nn.graph.vertex.GraphVertex in project deeplearning4j by deeplearning4j.
the class TestGraphNodes method testDuplicateToTimeSeriesVertex.
@Test
public void testDuplicateToTimeSeriesVertex() {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in2d", "in3d").addVertex("duplicateTS", new DuplicateToTimeSeriesVertex("in3d"), "in2d").addLayer("out", new OutputLayer.Builder().nIn(1).nOut(1).build(), "duplicateTS").setOutputs("out").build();
ComputationGraph graph = new ComputationGraph(conf);
graph.init();
INDArray in2d = Nd4j.rand(3, 5);
INDArray in3d = Nd4j.rand(new int[] { 3, 2, 7 });
graph.setInputs(in2d, in3d);
INDArray expOut = Nd4j.zeros(3, 5, 7);
for (int i = 0; i < 7; i++) {
expOut.put(new INDArrayIndex[] { NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(i) }, in2d);
}
GraphVertex gv = graph.getVertex("duplicateTS");
gv.setInputs(in2d);
INDArray outFwd = gv.doForward(true);
assertEquals(expOut, outFwd);
INDArray expOutBackward = expOut.sum(2);
gv.setEpsilon(expOut);
INDArray outBwd = gv.doBackward(false).getSecond()[0];
assertEquals(expOutBackward, outBwd);
String json = conf.toJson();
ComputationGraphConfiguration conf2 = ComputationGraphConfiguration.fromJson(json);
assertEquals(conf, conf2);
}
use of org.deeplearning4j.nn.graph.vertex.GraphVertex in project deeplearning4j by deeplearning4j.
the class TestGraphNodes method testSubsetNode.
@Test
public void testSubsetNode() {
Nd4j.getRandom().setSeed(12345);
GraphVertex subset = new SubsetVertex(null, "", -1, 4, 7);
INDArray in = Nd4j.rand(5, 10);
subset.setInputs(in);
INDArray out = subset.doForward(false);
assertEquals(in.get(NDArrayIndex.all(), NDArrayIndex.interval(4, 7, true)), out);
subset.setEpsilon(out);
INDArray backward = subset.doBackward(false).getSecond()[0];
assertEquals(Nd4j.zeros(5, 4), backward.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 3, true)));
assertEquals(out, backward.get(NDArrayIndex.all(), NDArrayIndex.interval(4, 7, true)));
assertEquals(Nd4j.zeros(5, 2), backward.get(NDArrayIndex.all(), NDArrayIndex.interval(8, 9, true)));
//Test same for CNNs:
in = Nd4j.rand(new int[] { 5, 10, 3, 3 });
subset.setInputs(in);
out = subset.doForward(false);
assertEquals(in.get(NDArrayIndex.all(), NDArrayIndex.interval(4, 7, true), NDArrayIndex.all(), NDArrayIndex.all()), out);
subset.setEpsilon(out);
backward = subset.doBackward(false).getSecond()[0];
assertEquals(Nd4j.zeros(5, 4, 3, 3), backward.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 3, true), NDArrayIndex.all(), NDArrayIndex.all()));
assertEquals(out, backward.get(NDArrayIndex.all(), NDArrayIndex.interval(4, 7, true), NDArrayIndex.all(), NDArrayIndex.all()));
assertEquals(Nd4j.zeros(5, 2, 3, 3), backward.get(NDArrayIndex.all(), NDArrayIndex.interval(8, 9, true), NDArrayIndex.all(), NDArrayIndex.all()));
}
use of org.deeplearning4j.nn.graph.vertex.GraphVertex in project deeplearning4j by deeplearning4j.
the class TestGraphNodes method testStackNode.
@Test
public void testStackNode() {
Nd4j.getRandom().setSeed(12345);
GraphVertex unstack = new StackVertex(null, "", -1);
INDArray in1 = Nd4j.rand(5, 2);
INDArray in2 = Nd4j.rand(5, 2);
INDArray in3 = Nd4j.rand(5, 2);
unstack.setInputs(in1, in2, in3);
INDArray out = unstack.doForward(false);
assertEquals(in1, out.get(NDArrayIndex.interval(0, 5), NDArrayIndex.all()));
assertEquals(in2, out.get(NDArrayIndex.interval(5, 10), NDArrayIndex.all()));
assertEquals(in3, out.get(NDArrayIndex.interval(10, 15), NDArrayIndex.all()));
unstack.setEpsilon(out);
Pair<Gradient, INDArray[]> b = unstack.doBackward(false);
assertEquals(in1, b.getSecond()[0]);
assertEquals(in2, b.getSecond()[1]);
assertEquals(in3, b.getSecond()[2]);
}
use of org.deeplearning4j.nn.graph.vertex.GraphVertex in project deeplearning4j by deeplearning4j.
the class TestGraphNodes method testUnstackNode.
@Test
public void testUnstackNode() {
Nd4j.getRandom().setSeed(12345);
GraphVertex unstack0 = new UnstackVertex(null, "", -1, 0, 3);
GraphVertex unstack1 = new UnstackVertex(null, "", -1, 1, 3);
GraphVertex unstack2 = new UnstackVertex(null, "", -1, 2, 3);
INDArray in = Nd4j.rand(15, 2);
unstack0.setInputs(in);
unstack1.setInputs(in);
unstack2.setInputs(in);
INDArray out0 = unstack0.doForward(false);
INDArray out1 = unstack1.doForward(false);
INDArray out2 = unstack2.doForward(false);
assertEquals(in.get(NDArrayIndex.interval(0, 5), NDArrayIndex.all()), out0);
assertEquals(in.get(NDArrayIndex.interval(5, 10), NDArrayIndex.all()), out1);
assertEquals(in.get(NDArrayIndex.interval(10, 15), NDArrayIndex.all()), out2);
unstack0.setEpsilon(out0);
unstack1.setEpsilon(out1);
unstack2.setEpsilon(out2);
INDArray backward0 = unstack0.doBackward(false).getSecond()[0];
INDArray backward1 = unstack1.doBackward(false).getSecond()[0];
INDArray backward2 = unstack2.doBackward(false).getSecond()[0];
assertEquals(out0, backward0.get(NDArrayIndex.interval(0, 5), NDArrayIndex.all()));
assertEquals(Nd4j.zeros(5, 2), backward0.get(NDArrayIndex.interval(5, 10), NDArrayIndex.all()));
assertEquals(Nd4j.zeros(5, 2), backward0.get(NDArrayIndex.interval(10, 15), NDArrayIndex.all()));
assertEquals(Nd4j.zeros(5, 2), backward1.get(NDArrayIndex.interval(0, 5), NDArrayIndex.all()));
assertEquals(out1, backward1.get(NDArrayIndex.interval(5, 10), NDArrayIndex.all()));
assertEquals(Nd4j.zeros(5, 2), backward1.get(NDArrayIndex.interval(10, 15), NDArrayIndex.all()));
assertEquals(Nd4j.zeros(5, 2), backward2.get(NDArrayIndex.interval(0, 5), NDArrayIndex.all()));
assertEquals(Nd4j.zeros(5, 2), backward2.get(NDArrayIndex.interval(5, 10), NDArrayIndex.all()));
assertEquals(out2, backward2.get(NDArrayIndex.interval(10, 15), NDArrayIndex.all()));
//Test same for CNNs:
in = Nd4j.rand(new int[] { 15, 10, 3, 3 });
unstack0.setInputs(in);
unstack1.setInputs(in);
unstack2.setInputs(in);
out0 = unstack0.doForward(false);
out1 = unstack1.doForward(false);
out2 = unstack2.doForward(false);
assertEquals(in.get(NDArrayIndex.interval(0, 5), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()), out0);
assertEquals(in.get(NDArrayIndex.interval(5, 10), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()), out1);
assertEquals(in.get(NDArrayIndex.interval(10, 15), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()), out2);
unstack0.setEpsilon(out0);
unstack1.setEpsilon(out1);
unstack2.setEpsilon(out2);
backward0 = unstack0.doBackward(false).getSecond()[0];
backward1 = unstack1.doBackward(false).getSecond()[0];
backward2 = unstack2.doBackward(false).getSecond()[0];
assertEquals(out0, backward0.get(NDArrayIndex.interval(0, 5), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()));
assertEquals(Nd4j.zeros(5, 10, 3, 3), backward0.get(NDArrayIndex.interval(5, 10), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()));
assertEquals(Nd4j.zeros(5, 10, 3, 3), backward0.get(NDArrayIndex.interval(10, 15), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()));
assertEquals(Nd4j.zeros(5, 10, 3, 3), backward1.get(NDArrayIndex.interval(0, 5), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()));
assertEquals(out1, backward1.get(NDArrayIndex.interval(5, 10), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()));
assertEquals(Nd4j.zeros(5, 10, 3, 3), backward1.get(NDArrayIndex.interval(10, 15), NDArrayIndex.all()));
assertEquals(Nd4j.zeros(5, 10, 3, 3), backward2.get(NDArrayIndex.interval(0, 5), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()));
assertEquals(Nd4j.zeros(5, 10, 3, 3), backward2.get(NDArrayIndex.interval(5, 10), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()));
assertEquals(out2, backward2.get(NDArrayIndex.interval(10, 15), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()));
}
use of org.deeplearning4j.nn.graph.vertex.GraphVertex in project deeplearning4j by deeplearning4j.
the class TestGraphNodes method testMergeNode.
@Test
public void testMergeNode() {
Nd4j.getRandom().setSeed(12345);
GraphVertex mergeNode = new MergeVertex(null, "", -1);
INDArray first = Nd4j.linspace(0, 11, 12).reshape(3, 4);
INDArray second = Nd4j.linspace(0, 17, 18).reshape(3, 6).addi(100);
mergeNode.setInputs(first, second);
INDArray out = mergeNode.doForward(false);
assertArrayEquals(new int[] { 3, 10 }, out.shape());
assertEquals(first, out.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 4)));
assertEquals(second, out.get(NDArrayIndex.all(), NDArrayIndex.interval(4, 10)));
mergeNode.setEpsilon(out);
INDArray[] backward = mergeNode.doBackward(false).getSecond();
assertEquals(first, backward[0]);
assertEquals(second, backward[1]);
}
Aggregations