Search in sources :

Example 66 with Pair

use of org.nd4j.linalg.primitives.Pair in project nd4j by deeplearning4j.

the class ConcatTestsC method testConcat3dv2.

@Test
@Ignore
public void testConcat3dv2() {
    INDArray first = Nd4j.linspace(1, 24, 24).reshape('c', 2, 3, 4);
    INDArray second = Nd4j.linspace(24, 35, 12).reshape('c', 1, 3, 4);
    INDArray third = Nd4j.linspace(36, 47, 12).reshape('c', 1, 3, 4);
    // ConcatV2, dim 0
    INDArray exp = Nd4j.create(2 + 1 + 1, 3, 4);
    exp.put(new INDArrayIndex[] { NDArrayIndex.interval(0, 2), NDArrayIndex.all(), NDArrayIndex.all() }, first);
    exp.put(new INDArrayIndex[] { NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.all() }, second);
    exp.put(new INDArrayIndex[] { NDArrayIndex.point(3), NDArrayIndex.all(), NDArrayIndex.all() }, third);
    List<Pair<INDArray, String>> firsts = NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, 2, 3, 4);
    List<Pair<INDArray, String>> seconds = NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, 1, 3, 4);
    List<Pair<INDArray, String>> thirds = NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, 1, 3, 4);
    for (Pair<INDArray, String> f : firsts) {
        for (Pair<INDArray, String> s : seconds) {
            for (Pair<INDArray, String> t : thirds) {
                INDArray f2 = f.getFirst().assign(first);
                INDArray s2 = s.getFirst().assign(second);
                INDArray t2 = t.getFirst().assign(third);
                INDArray concat0 = Nd4j.concat(0, f2, s2, t2);
                if (!exp.equals(concat0)) {
                    concat0 = Nd4j.concat(0, f2, s2, t2);
                }
                assertEquals(exp, concat0);
            }
        }
    }
    // ConcatV2, dim 1
    second = Nd4j.linspace(24, 31, 8).reshape('c', 2, 1, 4);
    third = Nd4j.linspace(32, 47, 16).reshape('c', 2, 2, 4);
    exp = Nd4j.create(2, 3 + 1 + 2, 4);
    exp.put(new INDArrayIndex[] { NDArrayIndex.all(), NDArrayIndex.interval(0, 3), NDArrayIndex.all() }, first);
    exp.put(new INDArrayIndex[] { NDArrayIndex.all(), NDArrayIndex.point(3), NDArrayIndex.all() }, second);
    exp.put(new INDArrayIndex[] { NDArrayIndex.all(), NDArrayIndex.interval(4, 6), NDArrayIndex.all() }, third);
    firsts = NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, 2, 3, 4);
    seconds = NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, 2, 1, 4);
    thirds = NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, 2, 2, 4);
    for (Pair<INDArray, String> f : firsts) {
        for (Pair<INDArray, String> s : seconds) {
            for (Pair<INDArray, String> t : thirds) {
                INDArray f2 = f.getFirst().assign(first);
                INDArray s2 = s.getFirst().assign(second);
                INDArray t2 = t.getFirst().assign(third);
                INDArray concat1 = Nd4j.concat(1, f2, s2, t2);
                assertEquals(exp, concat1);
            }
        }
    }
    // ConcatV2, dim 2
    second = Nd4j.linspace(24, 35, 12).reshape('c', 2, 3, 2);
    third = Nd4j.linspace(36, 41, 6).reshape('c', 2, 3, 1);
    exp = Nd4j.create(2, 3, 4 + 2 + 1);
    exp.put(new INDArrayIndex[] { NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, 4) }, first);
    exp.put(new INDArrayIndex[] { NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(4, 6) }, second);
    exp.put(new INDArrayIndex[] { NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(6) }, third);
    firsts = NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, 2, 3, 4);
    seconds = NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, 2, 3, 2);
    thirds = NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, 2, 3, 1);
    for (Pair<INDArray, String> f : firsts) {
        for (Pair<INDArray, String> s : seconds) {
            for (Pair<INDArray, String> t : thirds) {
                INDArray f2 = f.getFirst().assign(first);
                INDArray s2 = s.getFirst().assign(second);
                INDArray t2 = t.getFirst().assign(third);
                INDArray concat2 = Nd4j.concat(2, f2, s2, t2);
                assertEquals(exp, concat2);
            }
        }
    }
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) Pair(org.nd4j.linalg.primitives.Pair) Ignore(org.junit.Ignore) Test(org.junit.Test) BaseNd4jTest(org.nd4j.linalg.BaseNd4jTest)

Aggregations

Pair (org.nd4j.linalg.primitives.Pair)66 INDArray (org.nd4j.linalg.api.ndarray.INDArray)63 Test (org.junit.Test)24 BaseNd4jTest (org.nd4j.linalg.BaseNd4jTest)9 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)5 ArrayList (java.util.ArrayList)4 DataBuffer (org.nd4j.linalg.api.buffer.DataBuffer)4 Ignore (org.junit.Ignore)3 List (java.util.List)2 RealMatrix (org.apache.commons.math3.linear.RealMatrix)2 IntPointer (org.bytedeco.javacpp.IntPointer)2 Pointer (org.bytedeco.javacpp.Pointer)2 OpExecutionerUtil (org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil)2 LongPointerWrapper (org.nd4j.nativeblas.LongPointerWrapper)2 IntBuffer (java.nio.IntBuffer)1 Random (java.util.Random)1 AtomicBoolean (java.util.concurrent.atomic.AtomicBoolean)1 Array2DRowRealMatrix (org.apache.commons.math3.linear.Array2DRowRealMatrix)1 BlockRealMatrix (org.apache.commons.math3.linear.BlockRealMatrix)1 LUDecomposition (org.apache.commons.math3.linear.LUDecomposition)1