use of org.nd4j.linalg.api.ops.executioner.OpStatus in project nd4j by deeplearning4j.
the class NativeOpExecutioner method exec.
/**
* This method executes given CustomOp
*
* PLEASE NOTE: You're responsible for input/output validation
* @param op
*/
public void exec(@NonNull CustomOp op) {
if (op.numOutputArguments() == 0 && !op.isInplaceCall())
throw new ND4JIllegalStateException("Op name " + op.opName() + " failed to execute. You can't execute non-inplace CustomOp without outputs being specified");
val name = op.opName().toLowerCase();
val hash = op.opHash();
val inputShapes = getInputShapes(op.numInputArguments());
val inputBuffers = getInputBuffers(op.numInputArguments());
int cnt = 0;
val inputArgs = op.inputArguments();
for (val in : inputArgs) {
if (in == null) {
throw new NullPointerException("Input argument is null");
}
inputBuffers.put(cnt, in.data().addressPointer());
inputShapes.put(cnt++, in.shapeInfoDataBuffer().addressPointer());
}
val outputArgs = op.outputArguments();
for (int i = 0; i < outputArgs.length; i++) {
if (outputArgs[i] == null)
throw new ND4JIllegalStateException("Op output arguments must not be null!");
}
val outputShapes = getOutputShapes(op.numOutputArguments());
val outputBuffers = getOutputBuffers(op.numOutputArguments());
cnt = 0;
for (val out : outputArgs) {
outputBuffers.put(cnt, out.data().addressPointer());
outputShapes.put(cnt++, out.shapeInfoDataBuffer().addressPointer());
}
val iArgs = op.numIArguments() > 0 ? new IntPointer(op.numIArguments()) : null;
cnt = 0;
val iArgs1 = op.iArgs();
for (val i : iArgs1) iArgs.put(cnt++, i);
if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
val tArgs = op.numTArguments() > 0 ? new FloatPointer(op.numTArguments()) : null;
val tArgs1 = op.tArgs();
cnt = 0;
for (val t : tArgs1) tArgs.put(cnt++, (float) t);
val status = OpStatus.byNumber(loop.execCustomOpFloat(null, hash, inputBuffers, inputShapes, op.numInputArguments(), outputBuffers, outputShapes, op.numOutputArguments(), tArgs, op.numTArguments(), iArgs, op.numIArguments(), op.isInplaceCall()));
if (status != OpStatus.ND4J_STATUS_OK)
throw new ND4JIllegalStateException("Op execution failed: " + status);
} else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
val tArgs = op.numTArguments() > 0 ? getDoublePointerFrom(tArgsPointer, op.numTArguments()) : null;
val tArgs1 = op.tArgs();
cnt = 0;
for (val t : tArgs1) tArgs.put(cnt++, t);
val t = op.numInputArguments();
OpStatus status = OpStatus.ND4J_STATUS_OK;
try {
status = OpStatus.byNumber(loop.execCustomOpDouble(null, hash, inputBuffers, inputShapes, op.numInputArguments(), outputBuffers, outputShapes, op.numOutputArguments(), tArgs, op.numTArguments(), iArgs, op.numIArguments(), op.isInplaceCall()));
} catch (Exception e) {
log.error("Failed to execute. Please see above message (printed out from c++) for a possible cause of error.");
throw e;
}
} else if (Nd4j.dataType() == DataBuffer.Type.HALF) {
val tArgs = op.numTArguments() > 0 ? getShortPointerFrom(halfArgsPointer, op.numTArguments()) : null;
cnt = 0;
val tArgs1 = op.tArgs();
for (val t : tArgs1) tArgs.put(cnt++, ArrayUtil.toHalf(t));
val status = OpStatus.byNumber(loop.execCustomOpHalf(null, hash, inputBuffers, inputShapes, op.numInputArguments(), outputBuffers, outputShapes, op.numOutputArguments(), tArgs, op.numTArguments(), iArgs, op.numIArguments(), op.isInplaceCall()));
if (status != OpStatus.ND4J_STATUS_OK)
throw new ND4JIllegalStateException("Op execution failed: " + status);
}
}
Aggregations