Search in sources :

Example 6 with GridPointers

use of org.nd4j.linalg.api.ops.grid.GridPointers in project nd4j by deeplearning4j.

the class GridExecutionerTest method testOpPointerizeReduce1.

/**
 * Reduce along dimensions
 *
 * @throws Exception
 */
@Test
public void testOpPointerizeReduce1() throws Exception {
    CudaGridExecutioner executioner = new CudaGridExecutioner();
    INDArray array = Nd4j.create(10, 10);
    Sum opA = new Sum(array);
    // we need exec here, to init Op.Z for specific dimension
    executioner.exec(opA, 1);
    GridPointers pointers = executioner.pointerizeOp(opA, 1);
    assertEquals(opA.opNum(), pointers.getOpNum());
    assertEquals(Op.Type.REDUCE, pointers.getType());
    CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext();
    Pointer x = AtomicAllocator.getInstance().getPointer(array, context);
    Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(array.shapeInfoDataBuffer(), context);
    Pointer z = AtomicAllocator.getInstance().getPointer(opA.z(), context);
    Pointer zShapeInfo = AtomicAllocator.getInstance().getPointer(opA.z().shapeInfoDataBuffer(), context);
    DataBuffer dimBuff = Nd4j.getConstantHandler().getConstantBuffer(new int[] { 1 });
    Pointer ptrBuff = AtomicAllocator.getInstance().getPointer(dimBuff, context);
    assertEquals(x, pointers.getX());
    assertEquals(null, pointers.getY());
    assertNotEquals(null, pointers.getZ());
    assertEquals(z, pointers.getZ());
    assertEquals(10, opA.z().length());
    assertEquals(10, pointers.getZLength());
    /*      // We dont really care about EWS here, since we're testing TAD-based operation

        assertEquals(1, pointers.getXStride());
        assertEquals(-1, pointers.getYStride());
        assertEquals(1, pointers.getZStride());
*/
    assertEquals(xShapeInfo, pointers.getXShapeInfo());
    assertEquals(null, pointers.getYShapeInfo());
    assertEquals(zShapeInfo, pointers.getZShapeInfo());
    assertEquals(ptrBuff, pointers.getDimensions());
    assertEquals(1, pointers.getDimensionsLength());
    assertNotEquals(null, pointers.getTadShape());
    assertNotEquals(null, pointers.getTadOffsets());
    assertEquals(null, pointers.getExtraArgs());
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) GridPointers(org.nd4j.linalg.api.ops.grid.GridPointers) CudaContext(org.nd4j.linalg.jcublas.context.CudaContext) Sum(org.nd4j.linalg.api.ops.impl.accum.Sum) Pointer(org.bytedeco.javacpp.Pointer) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer) Test(org.junit.Test)

Aggregations

GridPointers (org.nd4j.linalg.api.ops.grid.GridPointers)6 CudaContext (org.nd4j.linalg.jcublas.context.CudaContext)5 Pointer (org.bytedeco.javacpp.Pointer)3 Test (org.junit.Test)3 DataBuffer (org.nd4j.linalg.api.buffer.DataBuffer)3 INDArray (org.nd4j.linalg.api.ndarray.INDArray)3 Sum (org.nd4j.linalg.api.ops.impl.accum.Sum)2 AtomicAllocator (org.nd4j.jita.allocator.impl.AtomicAllocator)1 InvertedPredicateMetaOp (org.nd4j.linalg.api.ops.impl.meta.InvertedPredicateMetaOp)1 PredicateMetaOp (org.nd4j.linalg.api.ops.impl.meta.PredicateMetaOp)1 ReduceMetaOp (org.nd4j.linalg.api.ops.impl.meta.ReduceMetaOp)1 ScalarMultiplication (org.nd4j.linalg.api.ops.impl.scalar.ScalarMultiplication)1 LongPointerWrapper (org.nd4j.nativeblas.LongPointerWrapper)1