use of org.nd4j.linalg.api.ops.grid.GridPointers in project nd4j by deeplearning4j.
the class GridExecutionerTest method testOpPointerizeReduce1.
/**
* Reduce along dimensions
*
* @throws Exception
*/
@Test
public void testOpPointerizeReduce1() throws Exception {
CudaGridExecutioner executioner = new CudaGridExecutioner();
INDArray array = Nd4j.create(10, 10);
Sum opA = new Sum(array);
// we need exec here, to init Op.Z for specific dimension
executioner.exec(opA, 1);
GridPointers pointers = executioner.pointerizeOp(opA, 1);
assertEquals(opA.opNum(), pointers.getOpNum());
assertEquals(Op.Type.REDUCE, pointers.getType());
CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext();
Pointer x = AtomicAllocator.getInstance().getPointer(array, context);
Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(array.shapeInfoDataBuffer(), context);
Pointer z = AtomicAllocator.getInstance().getPointer(opA.z(), context);
Pointer zShapeInfo = AtomicAllocator.getInstance().getPointer(opA.z().shapeInfoDataBuffer(), context);
DataBuffer dimBuff = Nd4j.getConstantHandler().getConstantBuffer(new int[] { 1 });
Pointer ptrBuff = AtomicAllocator.getInstance().getPointer(dimBuff, context);
assertEquals(x, pointers.getX());
assertEquals(null, pointers.getY());
assertNotEquals(null, pointers.getZ());
assertEquals(z, pointers.getZ());
assertEquals(10, opA.z().length());
assertEquals(10, pointers.getZLength());
/* // We dont really care about EWS here, since we're testing TAD-based operation
assertEquals(1, pointers.getXStride());
assertEquals(-1, pointers.getYStride());
assertEquals(1, pointers.getZStride());
*/
assertEquals(xShapeInfo, pointers.getXShapeInfo());
assertEquals(null, pointers.getYShapeInfo());
assertEquals(zShapeInfo, pointers.getZShapeInfo());
assertEquals(ptrBuff, pointers.getDimensions());
assertEquals(1, pointers.getDimensionsLength());
assertNotEquals(null, pointers.getTadShape());
assertNotEquals(null, pointers.getTadOffsets());
assertEquals(null, pointers.getExtraArgs());
}
Aggregations