use of org.nd4j.linalg.api.ops.impl.meta.ReduceMetaOp in project nd4j by deeplearning4j.
the class CudaGridExecutioner method exec.
@Override
public void exec(MetaOp op) {
if (extraz.get() == null)
extraz.set(new PointerPointer(32));
prepareGrid(op);
GridPointers first = op.getGridDescriptor().getGridPointers().get(0);
GridPointers second = op.getGridDescriptor().getGridPointers().get(1);
// we need to use it only for first op, since for MetaOps second op shares the same X & Z by definition
CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(first.getOpZ(), first.getOpY());
// AtomicAllocator.getInstance().getFlowController().prepareAction(second.getOpX(), second.getOpY(), second.getOpZ());
// CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext();
PointerPointer extras = extraz.get().put(null, context.getOldStream());
double scalarA = 0.0;
double scalarB = 0.0;
if (op.getFirstOp() instanceof ScalarOp)
scalarA = ((ScalarOp) op.getFirstOp()).scalar().doubleValue();
if (op.getSecondOp() instanceof ScalarOp)
scalarB = ((ScalarOp) op.getSecondOp()).scalar().doubleValue();
// logger.info("FirstOp: {}, SecondOp: {}", op.getFirstOp().getClass().getSimpleName(), op.getSecondOp().getClass().getSimpleName());
/*
TODO: launch can be either strided, or shapeInfo-based, it doesn't really matters for us.
We just need to pass all pointers.
TODO: obviously, execMetaPredicateElementwiseFloat should be renamed to execMetaPredicateStridedFloat
*/
// FIXME: this is bad hack, reconsider this one
GridPointers yGrid = first;
if (op.getSecondOp().y() != null) {
yGrid = second;
}
if (op instanceof PredicateMetaOp || op instanceof InvertedPredicateMetaOp) {
if (first.getDtype() == DataBuffer.Type.FLOAT) {
if (yGrid.getYOrder() == yGrid.getXOrder() && yGrid.getXStride() >= 1 && yGrid.getYStride() >= 1) {
nativeOps.execMetaPredicateStridedFloat(extras, first.getType().ordinal(), first.getOpNum(), second.getType().ordinal(), second.getOpNum(), first.getXLength(), // can be null
(FloatPointer) first.getX(), // can be null
first.getXStride(), // can be null
(FloatPointer) yGrid.getY(), // cane be -1
yGrid.getYStride(), (FloatPointer) second.getZ(), second.getZStride(), (FloatPointer) first.getExtraArgs(), (FloatPointer) second.getExtraArgs(), (float) scalarA, (float) scalarB);
} else {
nativeOps.execMetaPredicateShapeFloat(extras, first.getType().ordinal(), first.getOpNum(), second.getType().ordinal(), second.getOpNum(), first.getXLength(), (FloatPointer) first.getX(), (IntPointer) first.getXShapeInfo(), // can be null
(FloatPointer) yGrid.getY(), // cane be -1
(IntPointer) yGrid.getYShapeInfo(), (FloatPointer) second.getZ(), (IntPointer) second.getZShapeInfo(), (FloatPointer) first.getExtraArgs(), (FloatPointer) second.getExtraArgs(), (float) scalarA, (float) scalarB);
}
} else if (first.getDtype() == DataBuffer.Type.DOUBLE) {
if (yGrid.getYOrder() == yGrid.getXOrder() && yGrid.getXStride() >= 1 && yGrid.getYStride() >= 1) {
nativeOps.execMetaPredicateStridedDouble(extras, first.getType().ordinal(), first.getOpNum(), second.getType().ordinal(), second.getOpNum(), first.getXLength(), // can be null
(DoublePointer) first.getX(), // can be null
first.getXStride(), // can be null
(DoublePointer) yGrid.getY(), // cane be -1
yGrid.getYStride(), (DoublePointer) second.getZ(), second.getZStride(), (DoublePointer) first.getExtraArgs(), (DoublePointer) second.getExtraArgs(), scalarA, scalarB);
} else {
nativeOps.execMetaPredicateShapeDouble(extras, first.getType().ordinal(), first.getOpNum(), second.getType().ordinal(), second.getOpNum(), first.getXLength(), (DoublePointer) first.getX(), (IntPointer) first.getXShapeInfo(), // can be null
(DoublePointer) yGrid.getY(), // cane be -1
(IntPointer) yGrid.getYShapeInfo(), (DoublePointer) second.getZ(), (IntPointer) second.getZShapeInfo(), (DoublePointer) first.getExtraArgs(), (DoublePointer) second.getExtraArgs(), scalarA, scalarB);
}
} else {
if (yGrid.getYOrder() == yGrid.getXOrder() && yGrid.getXStride() >= 1 && yGrid.getYStride() >= 1) {
nativeOps.execMetaPredicateStridedHalf(extras, first.getType().ordinal(), first.getOpNum(), second.getType().ordinal(), second.getOpNum(), first.getXLength(), // can be null
(ShortPointer) first.getX(), // can be null
first.getXStride(), // can be null
(ShortPointer) yGrid.getY(), // cane be -1
yGrid.getYStride(), (ShortPointer) second.getZ(), second.getZStride(), (ShortPointer) first.getExtraArgs(), (ShortPointer) second.getExtraArgs(), (float) scalarA, (float) scalarB);
} else {
nativeOps.execMetaPredicateShapeHalf(extras, first.getType().ordinal(), first.getOpNum(), second.getType().ordinal(), second.getOpNum(), first.getXLength(), (ShortPointer) first.getX(), (IntPointer) first.getXShapeInfo(), // can be null
(ShortPointer) yGrid.getY(), // cane be -1
(IntPointer) yGrid.getYShapeInfo(), (ShortPointer) second.getZ(), (IntPointer) second.getZShapeInfo(), (ShortPointer) first.getExtraArgs(), (ShortPointer) second.getExtraArgs(), (float) scalarA, (float) scalarB);
}
}
} else if (op instanceof ReduceMetaOp) {
if (first.getDtype() == DataBuffer.Type.FLOAT) {
nativeOps.execMetaPredicateReduceFloat(extras, first.getType().ordinal(), first.getOpNum(), second.getType().ordinal(), second.getOpNum(), (FloatPointer) first.getX(), (IntPointer) first.getXShapeInfo(), (FloatPointer) second.getY(), (IntPointer) second.getYShapeInfo(), (FloatPointer) second.getZ(), (IntPointer) second.getZShapeInfo(), (IntPointer) second.getDimensions(), second.getDimensionsLength(), (IntPointer) second.getTadShape(), new LongPointerWrapper(second.getTadOffsets()), (FloatPointer) first.getExtraArgs(), (FloatPointer) second.getExtraArgs(), (float) scalarA, 0.0f, false);
}
}
AtomicAllocator.getInstance().getFlowController().registerAction(context, first.getOpZ(), first.getOpY());
// AtomicAllocator.getInstance().getFlowController().registerAction(context, second.getOpX(), second.getOpY(), second.getOpZ());
}
use of org.nd4j.linalg.api.ops.impl.meta.ReduceMetaOp in project nd4j by deeplearning4j.
the class MetaOpTests method testPredicateReduce1.
/**
* Scalar + reduce along dimension
*
* @throws Exception
*/
@Test
public void testPredicateReduce1() throws Exception {
CudaGridExecutioner executioner = new CudaGridExecutioner();
INDArray arrayX = Nd4j.create(5, 5);
INDArray exp = Nd4j.create(new float[] { 2f, 2f, 2f, 2f, 2f });
ScalarAdd opA = new ScalarAdd(arrayX, 2.0f);
Max opB = new Max(arrayX);
OpDescriptor a = new OpDescriptor(opA);
OpDescriptor b = new OpDescriptor(opB, new int[] { 1 });
executioner.buildZ(opB, b.getDimensions());
ReduceMetaOp metaOp = new ReduceMetaOp(a, b);
executioner.prepareGrid(metaOp);
executioner.exec(metaOp);
INDArray result = opB.z();
assertNotEquals(null, result);
assertEquals(exp, result);
}
Aggregations