Search in sources :

Example 16 with NdIndexIterator

use of org.nd4j.linalg.api.iter.NdIndexIterator in project nd4j by deeplearning4j.

the class StaticShapeTests method testBufferToIntShapeStrideMethods.

@Test
public void testBufferToIntShapeStrideMethods() {
    // Specifically: Shape.shape(IntBuffer), Shape.shape(DataBuffer)
    // .isRowVectorShape(DataBuffer), .isRowVectorShape(IntBuffer)
    // Shape.size(DataBuffer,int), Shape.size(IntBuffer,int)
    // Also: Shape.stride(IntBuffer), Shape.stride(DataBuffer)
    // Shape.stride(DataBuffer,int), Shape.stride(IntBuffer,int)
    List<List<Pair<INDArray, String>>> lists = new ArrayList<>();
    lists.add(NDArrayCreationUtil.getAllTestMatricesWithShape(3, 4, 12345));
    lists.add(NDArrayCreationUtil.getAllTestMatricesWithShape(1, 4, 12345));
    lists.add(NDArrayCreationUtil.getAllTestMatricesWithShape(3, 1, 12345));
    lists.add(NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, 3, 4, 5));
    lists.add(NDArrayCreationUtil.getAll4dTestArraysWithShape(12345, 3, 4, 5, 6));
    lists.add(NDArrayCreationUtil.getAll4dTestArraysWithShape(12345, 3, 1, 5, 1));
    lists.add(NDArrayCreationUtil.getAll5dTestArraysWithShape(12345, 3, 4, 5, 6, 7));
    lists.add(NDArrayCreationUtil.getAll6dTestArraysWithShape(12345, 3, 4, 5, 6, 7, 8));
    int[][] shapes = new int[][] { { 3, 4 }, { 1, 4 }, { 3, 1 }, { 3, 4, 5 }, { 3, 4, 5, 6 }, { 3, 1, 5, 1 }, { 3, 4, 5, 6, 7 }, { 3, 4, 5, 6, 7, 8 } };
    for (int i = 0; i < shapes.length; i++) {
        List<Pair<INDArray, String>> list = lists.get(i);
        int[] shape = shapes[i];
        for (Pair<INDArray, String> p : list) {
            INDArray arr = p.getFirst();
            assertArrayEquals(shape, arr.shape());
            int[] thisStride = arr.stride();
            IntBuffer ib = arr.shapeInfo();
            DataBuffer db = arr.shapeInfoDataBuffer();
            // Check shape calculation
            assertEquals(shape.length, Shape.rank(ib));
            assertEquals(shape.length, Shape.rank(db));
            assertArrayEquals(shape, Shape.shape(ib));
            assertArrayEquals(shape, Shape.shape(db));
            for (int j = 0; j < shape.length; j++) {
                assertEquals(shape[j], Shape.size(ib, j));
                assertEquals(shape[j], Shape.size(db, j));
                assertEquals(thisStride[j], Shape.stride(ib, j));
                assertEquals(thisStride[j], Shape.stride(db, j));
            }
            // Check base offset
            assertEquals(Shape.offset(ib), Shape.offset(db));
            // Check offset calculation:
            NdIndexIterator iter = new NdIndexIterator(shape);
            while (iter.hasNext()) {
                int[] next = iter.next();
                long offset1 = Shape.getOffset(ib, next);
                assertEquals(offset1, Shape.getOffset(db, next));
                switch(shape.length) {
                    case 2:
                        assertEquals(offset1, Shape.getOffset(ib, next[0], next[1]));
                        assertEquals(offset1, Shape.getOffset(db, next[0], next[1]));
                        break;
                    case 3:
                        assertEquals(offset1, Shape.getOffset(ib, next[0], next[1], next[2]));
                        assertEquals(offset1, Shape.getOffset(db, next[0], next[1], next[2]));
                        break;
                    case 4:
                        assertEquals(offset1, Shape.getOffset(ib, next[0], next[1], next[2], next[3]));
                        assertEquals(offset1, Shape.getOffset(db, next[0], next[1], next[2], next[3]));
                        break;
                    case 5:
                    case 6:
                        // No 5 and 6d getOffset overloads
                        break;
                    default:
                        throw new RuntimeException();
                }
            }
        }
    }
}
Also used : NdIndexIterator(org.nd4j.linalg.api.iter.NdIndexIterator) ArrayList(java.util.ArrayList) INDArray(org.nd4j.linalg.api.ndarray.INDArray) IntBuffer(java.nio.IntBuffer) ArrayList(java.util.ArrayList) List(java.util.List) Pair(org.nd4j.linalg.primitives.Pair) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer) Test(org.junit.Test) BaseNd4jTest(org.nd4j.linalg.BaseNd4jTest)

Aggregations

NdIndexIterator (org.nd4j.linalg.api.iter.NdIndexIterator)16 INDArray (org.nd4j.linalg.api.ndarray.INDArray)11 Test (org.junit.Test)7 BaseNd4jTest (org.nd4j.linalg.BaseNd4jTest)4 Assign (org.nd4j.linalg.api.ops.impl.transforms.Assign)3 ArrayList (java.util.ArrayList)2 List (java.util.List)2 DataBuffer (org.nd4j.linalg.api.buffer.DataBuffer)2 Field (java.lang.reflect.Field)1 IntBuffer (java.nio.IntBuffer)1 NoSuchElementException (java.util.NoSuchElementException)1 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)1 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)1 UniformDistribution (org.deeplearning4j.nn.conf.distribution.UniformDistribution)1 DenseLayer (org.deeplearning4j.nn.conf.layers.DenseLayer)1 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)1 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)1 DifferentialFunction (org.nd4j.autodiff.functions.DifferentialFunction)1 SDVariable (org.nd4j.autodiff.samediff.SDVariable)1 IMax (org.nd4j.linalg.api.ops.impl.indexaccum.IMax)1