Search in sources :

Example 1 with ReduceMetaOp

use of org.nd4j.linalg.api.ops.impl.meta.ReduceMetaOp in project nd4j by deeplearning4j.

the class CudaGridExecutioner method exec.

@Override
public void exec(MetaOp op) {
    if (extraz.get() == null)
        extraz.set(new PointerPointer(32));
    prepareGrid(op);
    GridPointers first = op.getGridDescriptor().getGridPointers().get(0);
    GridPointers second = op.getGridDescriptor().getGridPointers().get(1);
    // we need to use it only for first op, since for MetaOps second op shares the same X & Z by definition
    CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(first.getOpZ(), first.getOpY());
    // AtomicAllocator.getInstance().getFlowController().prepareAction(second.getOpX(), second.getOpY(), second.getOpZ());
    // CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext();
    PointerPointer extras = extraz.get().put(null, context.getOldStream());
    double scalarA = 0.0;
    double scalarB = 0.0;
    if (op.getFirstOp() instanceof ScalarOp)
        scalarA = ((ScalarOp) op.getFirstOp()).scalar().doubleValue();
    if (op.getSecondOp() instanceof ScalarOp)
        scalarB = ((ScalarOp) op.getSecondOp()).scalar().doubleValue();
    // logger.info("FirstOp: {}, SecondOp: {}", op.getFirstOp().getClass().getSimpleName(), op.getSecondOp().getClass().getSimpleName());
    /*
            TODO: launch can be either strided, or shapeInfo-based, it doesn't really matters for us.
            We just need to pass all pointers.
        
            TODO: obviously, execMetaPredicateElementwiseFloat should be renamed to execMetaPredicateStridedFloat
         */
    // FIXME: this is bad hack, reconsider this one
    GridPointers yGrid = first;
    if (op.getSecondOp().y() != null) {
        yGrid = second;
    }
    if (op instanceof PredicateMetaOp || op instanceof InvertedPredicateMetaOp) {
        if (first.getDtype() == DataBuffer.Type.FLOAT) {
            if (yGrid.getYOrder() == yGrid.getXOrder() && yGrid.getXStride() >= 1 && yGrid.getYStride() >= 1) {
                nativeOps.execMetaPredicateStridedFloat(extras, first.getType().ordinal(), first.getOpNum(), second.getType().ordinal(), second.getOpNum(), first.getXLength(), // can be null
                (FloatPointer) first.getX(), // can be null
                first.getXStride(), // can be null
                (FloatPointer) yGrid.getY(), // cane be -1
                yGrid.getYStride(), (FloatPointer) second.getZ(), second.getZStride(), (FloatPointer) first.getExtraArgs(), (FloatPointer) second.getExtraArgs(), (float) scalarA, (float) scalarB);
            } else {
                nativeOps.execMetaPredicateShapeFloat(extras, first.getType().ordinal(), first.getOpNum(), second.getType().ordinal(), second.getOpNum(), first.getXLength(), (FloatPointer) first.getX(), (IntPointer) first.getXShapeInfo(), // can be null
                (FloatPointer) yGrid.getY(), // cane be -1
                (IntPointer) yGrid.getYShapeInfo(), (FloatPointer) second.getZ(), (IntPointer) second.getZShapeInfo(), (FloatPointer) first.getExtraArgs(), (FloatPointer) second.getExtraArgs(), (float) scalarA, (float) scalarB);
            }
        } else if (first.getDtype() == DataBuffer.Type.DOUBLE) {
            if (yGrid.getYOrder() == yGrid.getXOrder() && yGrid.getXStride() >= 1 && yGrid.getYStride() >= 1) {
                nativeOps.execMetaPredicateStridedDouble(extras, first.getType().ordinal(), first.getOpNum(), second.getType().ordinal(), second.getOpNum(), first.getXLength(), // can be null
                (DoublePointer) first.getX(), // can be null
                first.getXStride(), // can be null
                (DoublePointer) yGrid.getY(), // cane be -1
                yGrid.getYStride(), (DoublePointer) second.getZ(), second.getZStride(), (DoublePointer) first.getExtraArgs(), (DoublePointer) second.getExtraArgs(), scalarA, scalarB);
            } else {
                nativeOps.execMetaPredicateShapeDouble(extras, first.getType().ordinal(), first.getOpNum(), second.getType().ordinal(), second.getOpNum(), first.getXLength(), (DoublePointer) first.getX(), (IntPointer) first.getXShapeInfo(), // can be null
                (DoublePointer) yGrid.getY(), // cane be -1
                (IntPointer) yGrid.getYShapeInfo(), (DoublePointer) second.getZ(), (IntPointer) second.getZShapeInfo(), (DoublePointer) first.getExtraArgs(), (DoublePointer) second.getExtraArgs(), scalarA, scalarB);
            }
        } else {
            if (yGrid.getYOrder() == yGrid.getXOrder() && yGrid.getXStride() >= 1 && yGrid.getYStride() >= 1) {
                nativeOps.execMetaPredicateStridedHalf(extras, first.getType().ordinal(), first.getOpNum(), second.getType().ordinal(), second.getOpNum(), first.getXLength(), // can be null
                (ShortPointer) first.getX(), // can be null
                first.getXStride(), // can be null
                (ShortPointer) yGrid.getY(), // cane be -1
                yGrid.getYStride(), (ShortPointer) second.getZ(), second.getZStride(), (ShortPointer) first.getExtraArgs(), (ShortPointer) second.getExtraArgs(), (float) scalarA, (float) scalarB);
            } else {
                nativeOps.execMetaPredicateShapeHalf(extras, first.getType().ordinal(), first.getOpNum(), second.getType().ordinal(), second.getOpNum(), first.getXLength(), (ShortPointer) first.getX(), (IntPointer) first.getXShapeInfo(), // can be null
                (ShortPointer) yGrid.getY(), // cane be -1
                (IntPointer) yGrid.getYShapeInfo(), (ShortPointer) second.getZ(), (IntPointer) second.getZShapeInfo(), (ShortPointer) first.getExtraArgs(), (ShortPointer) second.getExtraArgs(), (float) scalarA, (float) scalarB);
            }
        }
    } else if (op instanceof ReduceMetaOp) {
        if (first.getDtype() == DataBuffer.Type.FLOAT) {
            nativeOps.execMetaPredicateReduceFloat(extras, first.getType().ordinal(), first.getOpNum(), second.getType().ordinal(), second.getOpNum(), (FloatPointer) first.getX(), (IntPointer) first.getXShapeInfo(), (FloatPointer) second.getY(), (IntPointer) second.getYShapeInfo(), (FloatPointer) second.getZ(), (IntPointer) second.getZShapeInfo(), (IntPointer) second.getDimensions(), second.getDimensionsLength(), (IntPointer) second.getTadShape(), new LongPointerWrapper(second.getTadOffsets()), (FloatPointer) first.getExtraArgs(), (FloatPointer) second.getExtraArgs(), (float) scalarA, 0.0f, false);
        }
    }
    AtomicAllocator.getInstance().getFlowController().registerAction(context, first.getOpZ(), first.getOpY());
// AtomicAllocator.getInstance().getFlowController().registerAction(context, second.getOpX(), second.getOpY(), second.getOpZ());
}
Also used : ReduceMetaOp(org.nd4j.linalg.api.ops.impl.meta.ReduceMetaOp) GridPointers(org.nd4j.linalg.api.ops.grid.GridPointers) CudaContext(org.nd4j.linalg.jcublas.context.CudaContext) InvertedPredicateMetaOp(org.nd4j.linalg.api.ops.impl.meta.InvertedPredicateMetaOp) PredicateMetaOp(org.nd4j.linalg.api.ops.impl.meta.PredicateMetaOp) InvertedPredicateMetaOp(org.nd4j.linalg.api.ops.impl.meta.InvertedPredicateMetaOp) LongPointerWrapper(org.nd4j.nativeblas.LongPointerWrapper)

Example 2 with ReduceMetaOp

use of org.nd4j.linalg.api.ops.impl.meta.ReduceMetaOp in project nd4j by deeplearning4j.

the class MetaOpTests method testPredicateReduce1.

/**
 * Scalar + reduce along dimension
 *
 * @throws Exception
 */
@Test
public void testPredicateReduce1() throws Exception {
    CudaGridExecutioner executioner = new CudaGridExecutioner();
    INDArray arrayX = Nd4j.create(5, 5);
    INDArray exp = Nd4j.create(new float[] { 2f, 2f, 2f, 2f, 2f });
    ScalarAdd opA = new ScalarAdd(arrayX, 2.0f);
    Max opB = new Max(arrayX);
    OpDescriptor a = new OpDescriptor(opA);
    OpDescriptor b = new OpDescriptor(opB, new int[] { 1 });
    executioner.buildZ(opB, b.getDimensions());
    ReduceMetaOp metaOp = new ReduceMetaOp(a, b);
    executioner.prepareGrid(metaOp);
    executioner.exec(metaOp);
    INDArray result = opB.z();
    assertNotEquals(null, result);
    assertEquals(exp, result);
}
Also used : ReduceMetaOp(org.nd4j.linalg.api.ops.impl.meta.ReduceMetaOp) ScalarAdd(org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Max(org.nd4j.linalg.api.ops.impl.accum.Max) OpDescriptor(org.nd4j.linalg.api.ops.grid.OpDescriptor) Test(org.junit.Test)

Aggregations

ReduceMetaOp (org.nd4j.linalg.api.ops.impl.meta.ReduceMetaOp)2 Test (org.junit.Test)1 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1 GridPointers (org.nd4j.linalg.api.ops.grid.GridPointers)1 OpDescriptor (org.nd4j.linalg.api.ops.grid.OpDescriptor)1 Max (org.nd4j.linalg.api.ops.impl.accum.Max)1 InvertedPredicateMetaOp (org.nd4j.linalg.api.ops.impl.meta.InvertedPredicateMetaOp)1 PredicateMetaOp (org.nd4j.linalg.api.ops.impl.meta.PredicateMetaOp)1 ScalarAdd (org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd)1 CudaContext (org.nd4j.linalg.jcublas.context.CudaContext)1 LongPointerWrapper (org.nd4j.nativeblas.LongPointerWrapper)1