use of org.nd4j.linalg.api.iter.NdIndexIterator in project nd4j by deeplearning4j.
the class StaticShapeTests method testBufferToIntShapeStrideMethods.
@Test
public void testBufferToIntShapeStrideMethods() {
// Specifically: Shape.shape(IntBuffer), Shape.shape(DataBuffer)
// .isRowVectorShape(DataBuffer), .isRowVectorShape(IntBuffer)
// Shape.size(DataBuffer,int), Shape.size(IntBuffer,int)
// Also: Shape.stride(IntBuffer), Shape.stride(DataBuffer)
// Shape.stride(DataBuffer,int), Shape.stride(IntBuffer,int)
List<List<Pair<INDArray, String>>> lists = new ArrayList<>();
lists.add(NDArrayCreationUtil.getAllTestMatricesWithShape(3, 4, 12345));
lists.add(NDArrayCreationUtil.getAllTestMatricesWithShape(1, 4, 12345));
lists.add(NDArrayCreationUtil.getAllTestMatricesWithShape(3, 1, 12345));
lists.add(NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, 3, 4, 5));
lists.add(NDArrayCreationUtil.getAll4dTestArraysWithShape(12345, 3, 4, 5, 6));
lists.add(NDArrayCreationUtil.getAll4dTestArraysWithShape(12345, 3, 1, 5, 1));
lists.add(NDArrayCreationUtil.getAll5dTestArraysWithShape(12345, 3, 4, 5, 6, 7));
lists.add(NDArrayCreationUtil.getAll6dTestArraysWithShape(12345, 3, 4, 5, 6, 7, 8));
int[][] shapes = new int[][] { { 3, 4 }, { 1, 4 }, { 3, 1 }, { 3, 4, 5 }, { 3, 4, 5, 6 }, { 3, 1, 5, 1 }, { 3, 4, 5, 6, 7 }, { 3, 4, 5, 6, 7, 8 } };
for (int i = 0; i < shapes.length; i++) {
List<Pair<INDArray, String>> list = lists.get(i);
int[] shape = shapes[i];
for (Pair<INDArray, String> p : list) {
INDArray arr = p.getFirst();
assertArrayEquals(shape, arr.shape());
int[] thisStride = arr.stride();
IntBuffer ib = arr.shapeInfo();
DataBuffer db = arr.shapeInfoDataBuffer();
// Check shape calculation
assertEquals(shape.length, Shape.rank(ib));
assertEquals(shape.length, Shape.rank(db));
assertArrayEquals(shape, Shape.shape(ib));
assertArrayEquals(shape, Shape.shape(db));
for (int j = 0; j < shape.length; j++) {
assertEquals(shape[j], Shape.size(ib, j));
assertEquals(shape[j], Shape.size(db, j));
assertEquals(thisStride[j], Shape.stride(ib, j));
assertEquals(thisStride[j], Shape.stride(db, j));
}
// Check base offset
assertEquals(Shape.offset(ib), Shape.offset(db));
// Check offset calculation:
NdIndexIterator iter = new NdIndexIterator(shape);
while (iter.hasNext()) {
int[] next = iter.next();
long offset1 = Shape.getOffset(ib, next);
assertEquals(offset1, Shape.getOffset(db, next));
switch(shape.length) {
case 2:
assertEquals(offset1, Shape.getOffset(ib, next[0], next[1]));
assertEquals(offset1, Shape.getOffset(db, next[0], next[1]));
break;
case 3:
assertEquals(offset1, Shape.getOffset(ib, next[0], next[1], next[2]));
assertEquals(offset1, Shape.getOffset(db, next[0], next[1], next[2]));
break;
case 4:
assertEquals(offset1, Shape.getOffset(ib, next[0], next[1], next[2], next[3]));
assertEquals(offset1, Shape.getOffset(db, next[0], next[1], next[2], next[3]));
break;
case 5:
case 6:
// No 5 and 6d getOffset overloads
break;
default:
throw new RuntimeException();
}
}
}
}
}
Aggregations