use of org.nd4j.linalg.api.ops.impl.meta.InvertedPredicateMetaOp in project nd4j by deeplearning4j.
the class CudaGridExecutioner method processAsGridOp.
protected void processAsGridOp(Op op, int... dimension) {
/*
We have multiple options here:
1) Op has no relation to lastOp
2) Op has SOME relation to lastOp
3) Op is supposed to blocking
So we either should append this op to future GridOp, form MetaOp, or immediately execute everything
But we don't expect this method called for blocking ops ever, so it's either
*/
// CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y());
OpDescriptor last = lastOp.get();
if (last != null) {
MetaType type = getMetaOpType(op, dimension);
lastOp.remove();
switch(type) {
case NOT_APPLICABLE:
{
/*
If we can't form MetaOp with new Op here, we should move lastOp to GridOp queue, and update lastOp with current Op
*/
dequeueOp(last);
pushToGrid(last, false);
// || op instanceof ScalarOp
if ((op instanceof TransformOp && op.y() != null) && onCurrentDeviceXYZ(op)) {
enqueueOp(new OpDescriptor(op, dimension));
} else {
pushToGrid(new OpDescriptor(op, dimension), false);
}
}
break;
case PREDICATE:
{
MetaOp metaOp = new PredicateMetaOp(last, new OpDescriptor(op, dimension));
pushToGrid(new OpDescriptor(metaOp), false);
}
break;
case INVERTED_PREDICATE:
{
OpDescriptor currentOp = new OpDescriptor(op, dimension);
// logger.info("Calling for Meta: {}+{}", last.getOp().getClass().getSimpleName(), currentOp.getOp().getClass().getSimpleName());
dequeueOp(last);
dequeueOp(currentOp);
MetaOp metaOp = new InvertedPredicateMetaOp(last, currentOp);
pushToGrid(new OpDescriptor(metaOp), false);
}
break;
case POSTULATE:
{
MetaOp metaOp = new PostulateMetaOp(last, new OpDescriptor(op, dimension));
pushToGrid(new OpDescriptor(metaOp), false);
}
break;
default:
throw new UnsupportedOperationException("Not supported MetaType: [" + type + "]");
}
} else {
// && Nd4j.dataType() != DataBuffer.Type.HALF
if ((op instanceof TransformOp && op.y() != null && onCurrentDeviceXYZ(op))) {
enqueueOp(new OpDescriptor(op, dimension));
} else {
pushToGrid(new OpDescriptor(op, dimension), false);
}
}
// AtomicAllocator.getInstance().getFlowController().registerAction(context, op.z(), op.x(), op.y());
// return op;
}
use of org.nd4j.linalg.api.ops.impl.meta.InvertedPredicateMetaOp in project nd4j by deeplearning4j.
the class CudaGridExecutioner method exec.
@Override
public void exec(MetaOp op) {
if (extraz.get() == null)
extraz.set(new PointerPointer(32));
prepareGrid(op);
GridPointers first = op.getGridDescriptor().getGridPointers().get(0);
GridPointers second = op.getGridDescriptor().getGridPointers().get(1);
// we need to use it only for first op, since for MetaOps second op shares the same X & Z by definition
CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(first.getOpZ(), first.getOpY());
// AtomicAllocator.getInstance().getFlowController().prepareAction(second.getOpX(), second.getOpY(), second.getOpZ());
// CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext();
PointerPointer extras = extraz.get().put(null, context.getOldStream());
double scalarA = 0.0;
double scalarB = 0.0;
if (op.getFirstOp() instanceof ScalarOp)
scalarA = ((ScalarOp) op.getFirstOp()).scalar().doubleValue();
if (op.getSecondOp() instanceof ScalarOp)
scalarB = ((ScalarOp) op.getSecondOp()).scalar().doubleValue();
// logger.info("FirstOp: {}, SecondOp: {}", op.getFirstOp().getClass().getSimpleName(), op.getSecondOp().getClass().getSimpleName());
/*
TODO: launch can be either strided, or shapeInfo-based, it doesn't really matters for us.
We just need to pass all pointers.
TODO: obviously, execMetaPredicateElementwiseFloat should be renamed to execMetaPredicateStridedFloat
*/
// FIXME: this is bad hack, reconsider this one
GridPointers yGrid = first;
if (op.getSecondOp().y() != null) {
yGrid = second;
}
if (op instanceof PredicateMetaOp || op instanceof InvertedPredicateMetaOp) {
if (first.getDtype() == DataBuffer.Type.FLOAT) {
if (yGrid.getYOrder() == yGrid.getXOrder() && yGrid.getXStride() >= 1 && yGrid.getYStride() >= 1) {
nativeOps.execMetaPredicateStridedFloat(extras, first.getType().ordinal(), first.getOpNum(), second.getType().ordinal(), second.getOpNum(), first.getXLength(), // can be null
(FloatPointer) first.getX(), // can be null
first.getXStride(), // can be null
(FloatPointer) yGrid.getY(), // cane be -1
yGrid.getYStride(), (FloatPointer) second.getZ(), second.getZStride(), (FloatPointer) first.getExtraArgs(), (FloatPointer) second.getExtraArgs(), (float) scalarA, (float) scalarB);
} else {
nativeOps.execMetaPredicateShapeFloat(extras, first.getType().ordinal(), first.getOpNum(), second.getType().ordinal(), second.getOpNum(), first.getXLength(), (FloatPointer) first.getX(), (IntPointer) first.getXShapeInfo(), // can be null
(FloatPointer) yGrid.getY(), // cane be -1
(IntPointer) yGrid.getYShapeInfo(), (FloatPointer) second.getZ(), (IntPointer) second.getZShapeInfo(), (FloatPointer) first.getExtraArgs(), (FloatPointer) second.getExtraArgs(), (float) scalarA, (float) scalarB);
}
} else if (first.getDtype() == DataBuffer.Type.DOUBLE) {
if (yGrid.getYOrder() == yGrid.getXOrder() && yGrid.getXStride() >= 1 && yGrid.getYStride() >= 1) {
nativeOps.execMetaPredicateStridedDouble(extras, first.getType().ordinal(), first.getOpNum(), second.getType().ordinal(), second.getOpNum(), first.getXLength(), // can be null
(DoublePointer) first.getX(), // can be null
first.getXStride(), // can be null
(DoublePointer) yGrid.getY(), // cane be -1
yGrid.getYStride(), (DoublePointer) second.getZ(), second.getZStride(), (DoublePointer) first.getExtraArgs(), (DoublePointer) second.getExtraArgs(), scalarA, scalarB);
} else {
nativeOps.execMetaPredicateShapeDouble(extras, first.getType().ordinal(), first.getOpNum(), second.getType().ordinal(), second.getOpNum(), first.getXLength(), (DoublePointer) first.getX(), (IntPointer) first.getXShapeInfo(), // can be null
(DoublePointer) yGrid.getY(), // cane be -1
(IntPointer) yGrid.getYShapeInfo(), (DoublePointer) second.getZ(), (IntPointer) second.getZShapeInfo(), (DoublePointer) first.getExtraArgs(), (DoublePointer) second.getExtraArgs(), scalarA, scalarB);
}
} else {
if (yGrid.getYOrder() == yGrid.getXOrder() && yGrid.getXStride() >= 1 && yGrid.getYStride() >= 1) {
nativeOps.execMetaPredicateStridedHalf(extras, first.getType().ordinal(), first.getOpNum(), second.getType().ordinal(), second.getOpNum(), first.getXLength(), // can be null
(ShortPointer) first.getX(), // can be null
first.getXStride(), // can be null
(ShortPointer) yGrid.getY(), // cane be -1
yGrid.getYStride(), (ShortPointer) second.getZ(), second.getZStride(), (ShortPointer) first.getExtraArgs(), (ShortPointer) second.getExtraArgs(), (float) scalarA, (float) scalarB);
} else {
nativeOps.execMetaPredicateShapeHalf(extras, first.getType().ordinal(), first.getOpNum(), second.getType().ordinal(), second.getOpNum(), first.getXLength(), (ShortPointer) first.getX(), (IntPointer) first.getXShapeInfo(), // can be null
(ShortPointer) yGrid.getY(), // cane be -1
(IntPointer) yGrid.getYShapeInfo(), (ShortPointer) second.getZ(), (IntPointer) second.getZShapeInfo(), (ShortPointer) first.getExtraArgs(), (ShortPointer) second.getExtraArgs(), (float) scalarA, (float) scalarB);
}
}
} else if (op instanceof ReduceMetaOp) {
if (first.getDtype() == DataBuffer.Type.FLOAT) {
nativeOps.execMetaPredicateReduceFloat(extras, first.getType().ordinal(), first.getOpNum(), second.getType().ordinal(), second.getOpNum(), (FloatPointer) first.getX(), (IntPointer) first.getXShapeInfo(), (FloatPointer) second.getY(), (IntPointer) second.getYShapeInfo(), (FloatPointer) second.getZ(), (IntPointer) second.getZShapeInfo(), (IntPointer) second.getDimensions(), second.getDimensionsLength(), (IntPointer) second.getTadShape(), new LongPointerWrapper(second.getTadOffsets()), (FloatPointer) first.getExtraArgs(), (FloatPointer) second.getExtraArgs(), (float) scalarA, 0.0f, false);
}
}
AtomicAllocator.getInstance().getFlowController().registerAction(context, first.getOpZ(), first.getOpY());
// AtomicAllocator.getInstance().getFlowController().registerAction(context, second.getOpX(), second.getOpY(), second.getOpZ());
}
Aggregations