Search in sources :

Example 1 with OpStatus

use of org.nd4j.linalg.api.ops.executioner.OpStatus in project nd4j by deeplearning4j.

the class NativeOpExecutioner method exec.

/**
 * This method executes given CustomOp
 *
 * PLEASE NOTE: You're responsible for input/output validation
 * @param op
 */
public void exec(@NonNull CustomOp op) {
    if (op.numOutputArguments() == 0 && !op.isInplaceCall())
        throw new ND4JIllegalStateException("Op name " + op.opName() + " failed to execute. You can't execute non-inplace CustomOp without outputs being specified");
    val name = op.opName().toLowerCase();
    val hash = op.opHash();
    val inputShapes = getInputShapes(op.numInputArguments());
    val inputBuffers = getInputBuffers(op.numInputArguments());
    int cnt = 0;
    val inputArgs = op.inputArguments();
    for (val in : inputArgs) {
        if (in == null) {
            throw new NullPointerException("Input argument is null");
        }
        inputBuffers.put(cnt, in.data().addressPointer());
        inputShapes.put(cnt++, in.shapeInfoDataBuffer().addressPointer());
    }
    val outputArgs = op.outputArguments();
    for (int i = 0; i < outputArgs.length; i++) {
        if (outputArgs[i] == null)
            throw new ND4JIllegalStateException("Op output arguments must not be null!");
    }
    val outputShapes = getOutputShapes(op.numOutputArguments());
    val outputBuffers = getOutputBuffers(op.numOutputArguments());
    cnt = 0;
    for (val out : outputArgs) {
        outputBuffers.put(cnt, out.data().addressPointer());
        outputShapes.put(cnt++, out.shapeInfoDataBuffer().addressPointer());
    }
    val iArgs = op.numIArguments() > 0 ? new IntPointer(op.numIArguments()) : null;
    cnt = 0;
    val iArgs1 = op.iArgs();
    for (val i : iArgs1) iArgs.put(cnt++, i);
    if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
        val tArgs = op.numTArguments() > 0 ? new FloatPointer(op.numTArguments()) : null;
        val tArgs1 = op.tArgs();
        cnt = 0;
        for (val t : tArgs1) tArgs.put(cnt++, (float) t);
        val status = OpStatus.byNumber(loop.execCustomOpFloat(null, hash, inputBuffers, inputShapes, op.numInputArguments(), outputBuffers, outputShapes, op.numOutputArguments(), tArgs, op.numTArguments(), iArgs, op.numIArguments(), op.isInplaceCall()));
        if (status != OpStatus.ND4J_STATUS_OK)
            throw new ND4JIllegalStateException("Op execution failed: " + status);
    } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
        val tArgs = op.numTArguments() > 0 ? getDoublePointerFrom(tArgsPointer, op.numTArguments()) : null;
        val tArgs1 = op.tArgs();
        cnt = 0;
        for (val t : tArgs1) tArgs.put(cnt++, t);
        val t = op.numInputArguments();
        OpStatus status = OpStatus.ND4J_STATUS_OK;
        try {
            status = OpStatus.byNumber(loop.execCustomOpDouble(null, hash, inputBuffers, inputShapes, op.numInputArguments(), outputBuffers, outputShapes, op.numOutputArguments(), tArgs, op.numTArguments(), iArgs, op.numIArguments(), op.isInplaceCall()));
        } catch (Exception e) {
            log.error("Failed to execute. Please see above message (printed out from c++) for a possible cause of error.");
            throw e;
        }
    } else if (Nd4j.dataType() == DataBuffer.Type.HALF) {
        val tArgs = op.numTArguments() > 0 ? getShortPointerFrom(halfArgsPointer, op.numTArguments()) : null;
        cnt = 0;
        val tArgs1 = op.tArgs();
        for (val t : tArgs1) tArgs.put(cnt++, ArrayUtil.toHalf(t));
        val status = OpStatus.byNumber(loop.execCustomOpHalf(null, hash, inputBuffers, inputShapes, op.numInputArguments(), outputBuffers, outputShapes, op.numOutputArguments(), tArgs, op.numTArguments(), iArgs, op.numIArguments(), op.isInplaceCall()));
        if (status != OpStatus.ND4J_STATUS_OK)
            throw new ND4JIllegalStateException("Op execution failed: " + status);
    }
}
Also used : lombok.val(lombok.val) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) OpStatus(org.nd4j.linalg.api.ops.executioner.OpStatus) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException)

Aggregations

lombok.val (lombok.val)1 OpStatus (org.nd4j.linalg.api.ops.executioner.OpStatus)1 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)1