use of org.nd4j.linalg.primitives.Pair in project nd4j by deeplearning4j.
the class ConcatTestsC method testConcat3dv2.
@Test
@Ignore
public void testConcat3dv2() {
INDArray first = Nd4j.linspace(1, 24, 24).reshape('c', 2, 3, 4);
INDArray second = Nd4j.linspace(24, 35, 12).reshape('c', 1, 3, 4);
INDArray third = Nd4j.linspace(36, 47, 12).reshape('c', 1, 3, 4);
// ConcatV2, dim 0
INDArray exp = Nd4j.create(2 + 1 + 1, 3, 4);
exp.put(new INDArrayIndex[] { NDArrayIndex.interval(0, 2), NDArrayIndex.all(), NDArrayIndex.all() }, first);
exp.put(new INDArrayIndex[] { NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.all() }, second);
exp.put(new INDArrayIndex[] { NDArrayIndex.point(3), NDArrayIndex.all(), NDArrayIndex.all() }, third);
List<Pair<INDArray, String>> firsts = NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, 2, 3, 4);
List<Pair<INDArray, String>> seconds = NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, 1, 3, 4);
List<Pair<INDArray, String>> thirds = NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, 1, 3, 4);
for (Pair<INDArray, String> f : firsts) {
for (Pair<INDArray, String> s : seconds) {
for (Pair<INDArray, String> t : thirds) {
INDArray f2 = f.getFirst().assign(first);
INDArray s2 = s.getFirst().assign(second);
INDArray t2 = t.getFirst().assign(third);
INDArray concat0 = Nd4j.concat(0, f2, s2, t2);
if (!exp.equals(concat0)) {
concat0 = Nd4j.concat(0, f2, s2, t2);
}
assertEquals(exp, concat0);
}
}
}
// ConcatV2, dim 1
second = Nd4j.linspace(24, 31, 8).reshape('c', 2, 1, 4);
third = Nd4j.linspace(32, 47, 16).reshape('c', 2, 2, 4);
exp = Nd4j.create(2, 3 + 1 + 2, 4);
exp.put(new INDArrayIndex[] { NDArrayIndex.all(), NDArrayIndex.interval(0, 3), NDArrayIndex.all() }, first);
exp.put(new INDArrayIndex[] { NDArrayIndex.all(), NDArrayIndex.point(3), NDArrayIndex.all() }, second);
exp.put(new INDArrayIndex[] { NDArrayIndex.all(), NDArrayIndex.interval(4, 6), NDArrayIndex.all() }, third);
firsts = NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, 2, 3, 4);
seconds = NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, 2, 1, 4);
thirds = NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, 2, 2, 4);
for (Pair<INDArray, String> f : firsts) {
for (Pair<INDArray, String> s : seconds) {
for (Pair<INDArray, String> t : thirds) {
INDArray f2 = f.getFirst().assign(first);
INDArray s2 = s.getFirst().assign(second);
INDArray t2 = t.getFirst().assign(third);
INDArray concat1 = Nd4j.concat(1, f2, s2, t2);
assertEquals(exp, concat1);
}
}
}
// ConcatV2, dim 2
second = Nd4j.linspace(24, 35, 12).reshape('c', 2, 3, 2);
third = Nd4j.linspace(36, 41, 6).reshape('c', 2, 3, 1);
exp = Nd4j.create(2, 3, 4 + 2 + 1);
exp.put(new INDArrayIndex[] { NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, 4) }, first);
exp.put(new INDArrayIndex[] { NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(4, 6) }, second);
exp.put(new INDArrayIndex[] { NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(6) }, third);
firsts = NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, 2, 3, 4);
seconds = NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, 2, 3, 2);
thirds = NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, 2, 3, 1);
for (Pair<INDArray, String> f : firsts) {
for (Pair<INDArray, String> s : seconds) {
for (Pair<INDArray, String> t : thirds) {
INDArray f2 = f.getFirst().assign(first);
INDArray s2 = s.getFirst().assign(second);
INDArray t2 = t.getFirst().assign(third);
INDArray concat2 = Nd4j.concat(2, f2, s2, t2);
assertEquals(exp, concat2);
}
}
}
}
Aggregations