Search in sources :

Example 86 with ND4JIllegalStateException

use of org.nd4j.linalg.exception.ND4JIllegalStateException in project nd4j by deeplearning4j.

the class CudaExecutioner method exec.

@Override
public INDArray exec(Accumulation op, int... dimension) {
    long st = profilingHookIn(op);
    checkForCompression(op);
    validateDataType(Nd4j.dataType(), op);
    Arrays.sort(dimension);
    validateDataType(Nd4j.dataType(), op);
    if (extraz.get() == null)
        extraz.set(new PointerPointer(32));
    int[] maxShape = Shape.getMaxShape(op.x(), op.y());
    for (int i = 0; i < dimension.length; i++) if (dimension[i] >= maxShape.length && dimension[i] != Integer.MAX_VALUE)
        throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(dimension) + " contains element that higher then rank of op.X: [" + op.x().rank() + "]");
    for (int i = 0; i < dimension.length; i++) {
        if (dimension[i] < 0)
            dimension[i] += op.x().rank();
    }
    // do op along all dimensions
    if (dimension.length == op.x().rank())
        dimension = new int[] { Integer.MAX_VALUE };
    int[] retShape;
    if (Shape.wholeArrayDimension(dimension))
        retShape = new int[] { 1, 1 };
    else
        retShape = ArrayUtil.removeIndex(maxShape, dimension);
    // ensure vector is proper shape
    if (retShape.length == 1) {
        if (dimension[0] == 0)
            retShape = new int[] { 1, retShape[0] };
        else
            retShape = new int[] { retShape[0], 1 };
    } else if (retShape.length == 0) {
        retShape = new int[] { 1, 1 };
    }
    if (op.x().isVector() && op.x().length() == ArrayUtil.prod(retShape) && ArrayUtil.prodLong(retShape) > 1 && op.y() == null)
        return op.noOp();
    INDArray ret = null;
    if (op.z() == null || op.z() == op.x()) {
        if (op.isComplexAccumulation()) {
            int xT = op.x().tensorssAlongDimension(dimension);
            int yT = op.y().tensorssAlongDimension(dimension);
            ret = Nd4j.create(xT, yT);
        } else {
            if (op.y() != null) {
                val xT = op.x().tensorAlongDimension(0, dimension).lengthLong();
                val yT = op.y().lengthLong();
                if (xT != yT)
                    throw new ND4JIllegalStateException("Number of TADs along dimension doesn't match");
            }
            if (0.0 + Math.abs(op.zeroDouble()) <= Nd4j.EPS_THRESHOLD) {
                ret = Nd4j.zeros(retShape);
            } else {
                if (op.x().data().dataType() == DataBuffer.Type.DOUBLE)
                    ret = Nd4j.valueArrayOf(retShape, op.zeroDouble());
                else if (op.x().data().dataType() == DataBuffer.Type.FLOAT)
                    ret = Nd4j.valueArrayOf(retShape, op.zeroFloat());
                else if (op.x().data().dataType() == DataBuffer.Type.HALF)
                    ret = Nd4j.valueArrayOf(retShape, op.zeroHalf());
            }
        }
        op.setZ(ret);
    } else {
        // compare length
        if (op.z().lengthLong() != ArrayUtil.prodLong(retShape))
            throw new ND4JIllegalStateException("Shape of target array for reduction [" + Arrays.toString(op.z().shape()) + "] doesn't match expected [" + Arrays.toString(retShape) + "]");
        if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            op.z().assign(op.zeroDouble());
        } else if (op.x().data().dataType() == DataBuffer.Type.FLOAT) {
            op.z().assign(op.zeroFloat());
        } else if (op.x().data().dataType() == DataBuffer.Type.HALF) {
            op.z().assign(op.zeroHalf());
        }
        ret = op.z();
    }
    naiveExec(op, dimension);
    profilingHookOut(op, st);
    return op.z();
}
Also used : lombok.val(lombok.val) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) AllocationPoint(org.nd4j.jita.allocator.impl.AllocationPoint)

Example 87 with ND4JIllegalStateException

use of org.nd4j.linalg.exception.ND4JIllegalStateException in project nd4j by deeplearning4j.

the class CudaExecutioner method thresholdDecode.

@Override
public INDArray thresholdDecode(INDArray encoded, INDArray target) {
    DataBuffer buffer = encoded.data();
    if (buffer.dataType() != DataBuffer.Type.INT)
        throw new UnsupportedOperationException();
    long compressedLength = buffer.getInt(0);
    long originalLength = buffer.getInt(1);
    if (target.lengthLong() != originalLength)
        throw new ND4JIllegalStateException("originalLength [" + originalLength + "] stored in encoded array doesn't match target length [" + target.lengthLong() + "]");
    DataBuffer result = target.data();
    CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext();
    if (extraz.get() == null)
        extraz.set(new PointerPointer(32));
    PointerPointer extras = extraz.get().put(1, context.getOldStream());
    if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
        nativeOps.decodeThresholdFloat(extras, AtomicAllocator.getInstance().getPointer(buffer), compressedLength, (FloatPointer) AtomicAllocator.getInstance().getPointer(result));
    } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
        nativeOps.decodeThresholdDouble(extras, AtomicAllocator.getInstance().getPointer(buffer), compressedLength, (DoublePointer) AtomicAllocator.getInstance().getPointer(result));
    } else if (Nd4j.dataType() == DataBuffer.Type.HALF) {
        nativeOps.decodeThresholdHalf(extras, AtomicAllocator.getInstance().getPointer(buffer), compressedLength, (ShortPointer) AtomicAllocator.getInstance().getPointer(result));
    }
    AtomicAllocator.getInstance().getAllocationPoint(result).tickDeviceWrite();
    return target;
}
Also used : CudaContext(org.nd4j.linalg.jcublas.context.CudaContext) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer) BaseDataBuffer(org.nd4j.linalg.api.buffer.BaseDataBuffer)

Example 88 with ND4JIllegalStateException

use of org.nd4j.linalg.exception.ND4JIllegalStateException in project nd4j by deeplearning4j.

the class CudaExecutioner method invoke.

protected CudaContext invoke(ScalarOp op) {
    long st = profilingHookIn(op);
    checkForCompression(op);
    validateDataType(Nd4j.dataType(), op);
    if (op.x().length() != op.z().length())
        throw new ND4JIllegalStateException("op.X length should be equal to op.Y length: [" + Arrays.toString(op.x().shapeInfoDataBuffer().asInt()) + "] != [" + Arrays.toString(op.z().shapeInfoDataBuffer().asInt()) + "]");
    if (extraz.get() == null)
        extraz.set(new PointerPointer(32));
    if (CudaEnvironment.getInstance().getConfiguration().isDebug())
        lastOp.set(op.opName());
    if (op.getDimension() != null) {
        intercept(op, op.getDimension());
        return null;
    }
    CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y());
    Pointer hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
    Pointer hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());
    Pointer x = AtomicAllocator.getInstance().getPointer(op.x(), context);
    Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context);
    Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(), context) : null;
    Pointer z = AtomicAllocator.getInstance().getPointer(op.z(), context);
    Pointer zShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context);
    PointerPointer xShapeInfoHostPointer = extraz.get().put(AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, null, null);
    if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
        if (op.x().elementWiseStride() >= 1 && op.z().ordering() == op.x().ordering()) {
            nativeOps.execScalarDouble(xShapeInfoHostPointer, op.opNum(), (DoublePointer) x, op.x().elementWiseStride(), (DoublePointer) z, op.z().elementWiseStride(), op.scalar().doubleValue(), (DoublePointer) extraArgs, op.n());
        } else {
            nativeOps.execScalarDouble(xShapeInfoHostPointer, op.opNum(), (DoublePointer) x, (IntPointer) xShapeInfo, (DoublePointer) z, (IntPointer) zShapeInfo, op.scalar().doubleValue(), (DoublePointer) extraArgs);
        }
    } else if (op.x().data().dataType() == DataBuffer.Type.FLOAT) {
        if (op.x().elementWiseStride() >= 1 && op.z().ordering() == op.x().ordering()) {
            nativeOps.execScalarFloat(xShapeInfoHostPointer, op.opNum(), (FloatPointer) x, op.x().elementWiseStride(), (FloatPointer) z, op.z().elementWiseStride(), op.scalar().floatValue(), (FloatPointer) extraArgs, op.n());
        } else {
            nativeOps.execScalarFloat(xShapeInfoHostPointer, op.opNum(), (FloatPointer) x, (IntPointer) xShapeInfo, (FloatPointer) z, (IntPointer) zShapeInfo, op.scalar().floatValue(), (FloatPointer) extraArgs);
        }
    } else {
        if (op.x().elementWiseStride() >= 1 && op.z().ordering() == op.x().ordering()) {
            nativeOps.execScalarHalf(xShapeInfoHostPointer, op.opNum(), (ShortPointer) x, op.x().elementWiseStride(), (ShortPointer) z, op.z().elementWiseStride(), op.scalar().floatValue(), (ShortPointer) extraArgs, op.n());
        } else {
            nativeOps.execScalarHalf(xShapeInfoHostPointer, op.opNum(), (ShortPointer) x, (IntPointer) xShapeInfo, (ShortPointer) z, (IntPointer) zShapeInfo, op.scalar().floatValue(), (ShortPointer) extraArgs);
        }
    }
    AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
    profilingHookOut(op, st);
    return null;
}
Also used : CudaContext(org.nd4j.linalg.jcublas.context.CudaContext) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) CudaPointer(org.nd4j.jita.allocator.pointers.CudaPointer) PagedPointer(org.nd4j.linalg.api.memory.pointers.PagedPointer)

Example 89 with ND4JIllegalStateException

use of org.nd4j.linalg.exception.ND4JIllegalStateException in project nd4j by deeplearning4j.

the class CudaExecutioner method exec.

@Override
public INDArray exec(BroadcastOp op, int... dimension) {
    long st = profilingHookIn(op);
    checkForCompression(op);
    validateDataType(Nd4j.dataType(), op);
    if (extraz.get() == null)
        extraz.set(new PointerPointer(32));
    Arrays.sort(dimension);
    for (int i = 0; i < dimension.length; i++) if (dimension[i] >= op.x().rank() && dimension[i] != Integer.MAX_VALUE)
        throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(dimension) + " contains element that higher then rank of op.X: [" + op.x().rank() + "]");
    CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y());
    if (CudaEnvironment.getInstance().getConfiguration().isDebug())
        lastOp.set(op.opName());
    Pointer hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
    Pointer hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());
    Pointer x = AtomicAllocator.getInstance().getPointer(op.x(), context);
    Pointer y = AtomicAllocator.getInstance().getPointer(op.y(), context);
    Pointer z = AtomicAllocator.getInstance().getPointer(op.z(), context);
    Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context);
    Pair<DataBuffer, DataBuffer> tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), dimension);
    Pointer hostTadShapeInfo = AddressRetriever.retrieveHostPointer(tadBuffers.getFirst());
    Pointer devTadShapeInfo = AtomicAllocator.getInstance().getPointer(tadBuffers.getFirst(), context);
    DataBuffer offsets = tadBuffers.getSecond();
    Pointer devTadOffsets = AtomicAllocator.getInstance().getPointer(offsets, context);
    Pointer devTadShapeInfoZ = null;
    Pointer devTadOffsetsZ = null;
    // that's the place where we're going to have second TAD in place
    Pair<DataBuffer, DataBuffer> tadBuffersZ = tadManager.getTADOnlyShapeInfo(op.z(), dimension);
    devTadShapeInfoZ = AtomicAllocator.getInstance().getPointer(tadBuffersZ.getFirst(), context);
    devTadOffsetsZ = AtomicAllocator.getInstance().getPointer(tadBuffersZ.getSecond(), context);
    // }
    // extraz.get().put
    // new PointerPointer
    PointerPointer xShapeInfoHostPointer = extraz.get().put(AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(), context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(), hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets, devTadShapeInfoZ, devTadOffsetsZ);
    // Pointer dimensionPointer = AtomicAllocator.getInstance().getPointer(Nd4j.createBuffer(dimension), context);
    Pointer dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension), context);
    if (op.x().data().dataType() == DataBuffer.Type.DOUBLE) {
        nativeOps.execBroadcastDouble(xShapeInfoHostPointer, op.opNum(), (DoublePointer) x, (IntPointer) xShapeInfo, (DoublePointer) y, (IntPointer) AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), (DoublePointer) z, (IntPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (IntPointer) dimensionPointer, dimension.length);
    } else if (op.x().data().dataType() == DataBuffer.Type.FLOAT) {
        nativeOps.execBroadcastFloat(xShapeInfoHostPointer, op.opNum(), (FloatPointer) x, (IntPointer) xShapeInfo, (FloatPointer) y, (IntPointer) AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), (FloatPointer) z, (IntPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (IntPointer) dimensionPointer, dimension.length);
    } else {
        nativeOps.execBroadcastHalf(xShapeInfoHostPointer, op.opNum(), (ShortPointer) x, (IntPointer) xShapeInfo, (ShortPointer) y, (IntPointer) AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context), (ShortPointer) z, (IntPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), (IntPointer) dimensionPointer, dimension.length);
    }
    AtomicAllocator.getInstance().registerAction(context, op.z(), op.x(), op.y());
    profilingHookOut(op, st);
    return op.z();
}
Also used : CudaContext(org.nd4j.linalg.jcublas.context.CudaContext) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) CudaPointer(org.nd4j.jita.allocator.pointers.CudaPointer) PagedPointer(org.nd4j.linalg.api.memory.pointers.PagedPointer) AllocationPoint(org.nd4j.jita.allocator.impl.AllocationPoint) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer) BaseDataBuffer(org.nd4j.linalg.api.buffer.BaseDataBuffer)

Example 90 with ND4JIllegalStateException

use of org.nd4j.linalg.exception.ND4JIllegalStateException in project nd4j by deeplearning4j.

the class CudaExecutioner method exec.

/**
 * This method executes given CustomOp
 *
 * PLEASE NOTE: You're responsible for input/output validation
 * PLEASE NOTE: right now this operations are executing on CPU
 * @param op
 */
public void exec(CustomOp op) {
    Nd4j.getExecutioner().commit();
    if (op.opName().equalsIgnoreCase("im2col")) {
        val dtype = Nd4j.dataType();
        val xArr = op.inputArguments()[0];
        val zArr = op.outputArguments()[0];
        CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(zArr, xArr);
        if (extraz.get() == null)
            extraz.set(new PointerPointer(32));
        PointerPointer xShapeHost = // 0
        extraz.get().put(// 0
        AddressRetriever.retrieveHostPointer(xArr.shapeInfoDataBuffer()), // 1
        context.getOldStream(), // 2
        AtomicAllocator.getInstance().getDeviceIdPointer(), // 3
        context.getBufferAllocation(), // 4
        context.getBufferReduction(), // 5
        context.getBufferScalar(), context.getBufferSpecial(), null, AddressRetriever.retrieveHostPointer(zArr.shapeInfoDataBuffer()));
        val x = AtomicAllocator.getInstance().getPointer(xArr, context);
        val z = AtomicAllocator.getInstance().getPointer(zArr, context);
        val xShape = AtomicAllocator.getInstance().getPointer(xArr.shapeInfoDataBuffer(), context);
        val zShape = AtomicAllocator.getInstance().getPointer(zArr.shapeInfoDataBuffer(), context);
        double zeroPad = 0.0;
        if (op.tArgs() != null && op.tArgs().length > 0) {
            zeroPad = op.tArgs()[0];
        }
        val extrass = new double[] { op.iArgs()[0], op.iArgs()[1], op.iArgs()[2], op.iArgs()[3], op.iArgs()[4], op.iArgs()[5], op.iArgs()[6], op.iArgs()[7], op.iArgs()[8], zeroPad };
        val extraArgsBuff = Nd4j.getConstantHandler().getConstantBuffer(extrass);
        val extraArgs = AtomicAllocator.getInstance().getPointer(extraArgsBuff, context);
        if (dtype == DataBuffer.Type.DOUBLE) {
            nativeOps.execTransformDouble(xShapeHost, 37, (DoublePointer) x, (IntPointer) xShape, (DoublePointer) z, (IntPointer) zShape, (DoublePointer) extraArgs);
        } else if (dtype == DataBuffer.Type.FLOAT) {
            nativeOps.execTransformFloat(xShapeHost, 37, (FloatPointer) x, (IntPointer) xShape, (FloatPointer) z, (IntPointer) zShape, (FloatPointer) extraArgs);
        } else if (dtype == DataBuffer.Type.HALF) {
            nativeOps.execTransformHalf(xShapeHost, 37, (ShortPointer) x, (IntPointer) xShape, (ShortPointer) z, (IntPointer) zShape, (ShortPointer) extraArgs);
        }
        // AtomicAllocator.getInstance().getAllocationPoint(zArr).tickDeviceWrite();
        AtomicAllocator.getInstance().getFlowController().registerAction(context, zArr, xArr);
        return;
    } else if (op.opName().equalsIgnoreCase("col2im")) {
        val dtype = Nd4j.dataType();
        val xArr = op.inputArguments()[0];
        val zArr = op.outputArguments()[0];
        CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(zArr, xArr);
        if (extraz.get() == null)
            extraz.set(new PointerPointer(32));
        PointerPointer xShapeHost = // 0
        extraz.get().put(// 0
        AddressRetriever.retrieveHostPointer(xArr.shapeInfoDataBuffer()), // 1
        context.getOldStream(), // 2
        AtomicAllocator.getInstance().getDeviceIdPointer(), // 3
        context.getBufferAllocation(), // 4
        context.getBufferReduction(), // 5
        context.getBufferScalar(), context.getBufferSpecial(), null, AddressRetriever.retrieveHostPointer(zArr.shapeInfoDataBuffer()));
        val x = AtomicAllocator.getInstance().getPointer(xArr, context);
        val z = AtomicAllocator.getInstance().getPointer(zArr, context);
        val xShape = AtomicAllocator.getInstance().getPointer(xArr.shapeInfoDataBuffer(), context);
        val zShape = AtomicAllocator.getInstance().getPointer(zArr.shapeInfoDataBuffer(), context);
        val extrass = new double[] { op.iArgs()[0], op.iArgs()[1], op.iArgs()[2], op.iArgs()[3], op.iArgs()[4], op.iArgs()[5], op.iArgs()[6], op.iArgs()[7] };
        val extraArgsBuff = Nd4j.getConstantHandler().getConstantBuffer(extrass);
        val extraArgs = AtomicAllocator.getInstance().getPointer(extraArgsBuff, context);
        if (dtype == DataBuffer.Type.DOUBLE) {
            nativeOps.execTransformDouble(xShapeHost, 36, (DoublePointer) x, (IntPointer) xShape, (DoublePointer) z, (IntPointer) zShape, (DoublePointer) extraArgs);
        } else if (dtype == DataBuffer.Type.FLOAT) {
            nativeOps.execTransformFloat(xShapeHost, 36, (FloatPointer) x, (IntPointer) xShape, (FloatPointer) z, (IntPointer) zShape, (FloatPointer) extraArgs);
        } else if (dtype == DataBuffer.Type.HALF) {
            nativeOps.execTransformHalf(xShapeHost, 36, (ShortPointer) x, (IntPointer) xShape, (ShortPointer) z, (IntPointer) zShape, (ShortPointer) extraArgs);
        }
        // AtomicAllocator.getInstance().getAllocationPoint(zArr).tickDeviceWrite();
        AtomicAllocator.getInstance().getFlowController().registerAction(context, zArr, xArr);
        return;
    } else if (op.opName().equalsIgnoreCase("pooling2d")) {
        val dtype = Nd4j.dataType();
        val xArr = op.inputArguments()[0];
        val zArr = op.outputArguments()[0];
        CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(zArr, xArr);
        if (extraz.get() == null)
            extraz.set(new PointerPointer(32));
        PointerPointer xShapeHost = // 0
        extraz.get().put(// 0
        AddressRetriever.retrieveHostPointer(xArr.shapeInfoDataBuffer()), // 1
        context.getOldStream(), // 2
        AtomicAllocator.getInstance().getDeviceIdPointer(), // 3
        context.getBufferAllocation(), // 4
        context.getBufferReduction(), // 5
        context.getBufferScalar(), context.getBufferSpecial(), null, AddressRetriever.retrieveHostPointer(zArr.shapeInfoDataBuffer()));
        val x = AtomicAllocator.getInstance().getPointer(xArr, context);
        val z = AtomicAllocator.getInstance().getPointer(zArr, context);
        val xShape = AtomicAllocator.getInstance().getPointer(xArr.shapeInfoDataBuffer(), context);
        val zShape = AtomicAllocator.getInstance().getPointer(zArr.shapeInfoDataBuffer(), context);
        val extrass = new double[] { op.iArgs()[0], op.iArgs()[1], op.iArgs()[2], op.iArgs()[3], op.iArgs()[4], op.iArgs()[5], op.iArgs()[6], op.iArgs()[7], op.iArgs()[8] };
        val extraArgsBuff = Nd4j.getConstantHandler().getConstantBuffer(extrass);
        val extraArgs = AtomicAllocator.getInstance().getPointer(extraArgsBuff, context);
        if (dtype == DataBuffer.Type.DOUBLE) {
            nativeOps.execTransformDouble(xShapeHost, 71, (DoublePointer) x, (IntPointer) xShape, (DoublePointer) z, (IntPointer) zShape, (DoublePointer) extraArgs);
        } else if (dtype == DataBuffer.Type.FLOAT) {
            nativeOps.execTransformFloat(xShapeHost, 71, (FloatPointer) x, (IntPointer) xShape, (FloatPointer) z, (IntPointer) zShape, (FloatPointer) extraArgs);
        } else if (dtype == DataBuffer.Type.HALF) {
            nativeOps.execTransformHalf(xShapeHost, 71, (ShortPointer) x, (IntPointer) xShape, (ShortPointer) z, (IntPointer) zShape, (ShortPointer) extraArgs);
        }
        // AtomicAllocator.getInstance().getAllocationPoint(zArr).tickDeviceWrite();
        AtomicAllocator.getInstance().getFlowController().registerAction(context, zArr, xArr);
        return;
    }
    Nd4j.getExecutioner().commit();
    CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext();
    if (extraz.get() == null)
        extraz.set(new PointerPointer(32));
    PointerPointer extras = extraz.get().put(new CudaPointer(1), context.getOldStream(), context.getBufferScalar(), context.getBufferReduction());
    val outputArgs = op.outputArguments();
    val inputArgs = op.inputArguments();
    if (outputArgs.length == 0 && !op.isInplaceCall())
        throw new ND4JIllegalStateException("You can't execute non-inplace CustomOp without outputs being specified");
    val lc = op.opName().toLowerCase();
    val hash = op.opHash();
    val inputShapes = new PointerPointer<>(inputArgs.length * 2);
    val inputBuffers = new PointerPointer<>(inputArgs.length * 2);
    int cnt = 0;
    for (val in : inputArgs) {
        val hp = AtomicAllocator.getInstance().getHostPointer(in.shapeInfoDataBuffer());
        inputBuffers.put(cnt, AtomicAllocator.getInstance().getHostPointer(in));
        inputShapes.put(cnt, hp);
        val dp = AtomicAllocator.getInstance().getPointer(in.shapeInfoDataBuffer(), context);
        inputBuffers.put(cnt + inputArgs.length, AtomicAllocator.getInstance().getPointer(in, context));
        inputShapes.put(cnt + inputArgs.length, dp);
        if (op.isInplaceCall())
            AtomicAllocator.getInstance().getAllocationPoint(in).tickHostWrite();
        cnt++;
    }
    val outputShapes = new PointerPointer<>(outputArgs.length * 2);
    val outputBuffers = new PointerPointer<>(outputArgs.length * 2);
    cnt = 0;
    for (val out : outputArgs) {
        outputBuffers.put(cnt, AtomicAllocator.getInstance().getHostPointer(out));
        outputShapes.put(cnt, AtomicAllocator.getInstance().getHostPointer(out.shapeInfoDataBuffer()));
        outputBuffers.put(cnt + outputArgs.length, AtomicAllocator.getInstance().getPointer(out, context));
        outputShapes.put(cnt + outputArgs.length, AtomicAllocator.getInstance().getPointer(out.shapeInfoDataBuffer(), context));
        AtomicAllocator.getInstance().getAllocationPoint(out).tickHostWrite();
        cnt++;
    }
    if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
        val tArgs = op.tArgs().length > 0 ? new FloatPointer(op.tArgs().length) : null;
        val iArgs = op.iArgs().length > 0 ? new IntPointer(op.iArgs().length) : null;
        cnt = 0;
        for (val t : op.tArgs()) tArgs.put(cnt++, (float) t);
        cnt = 0;
        for (val i : op.iArgs()) iArgs.put(cnt++, i);
        val status = OpStatus.byNumber(nativeOps.execCustomOpFloat(extras, hash, inputBuffers, inputShapes, inputArgs.length, outputBuffers, outputShapes, outputArgs.length, tArgs, op.tArgs().length, iArgs, op.iArgs().length, op.isInplaceCall()));
        if (status != OpStatus.ND4J_STATUS_OK)
            throw new ND4JIllegalStateException("Op execution failed: " + status);
    } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
        val tArgs = op.tArgs().length > 0 ? new DoublePointer(op.tArgs().length) : null;
        val iArgs = op.iArgs().length > 0 ? new IntPointer(op.iArgs().length) : null;
        cnt = 0;
        for (val t : op.tArgs()) tArgs.put(cnt++, t);
        for (val i : op.iArgs()) iArgs.put(cnt++, i);
        val status = OpStatus.byNumber(nativeOps.execCustomOpDouble(extras, hash, inputBuffers, inputShapes, inputArgs.length, outputBuffers, outputShapes, outputArgs.length, tArgs, op.tArgs().length, iArgs, op.iArgs().length, op.isInplaceCall()));
        if (status != OpStatus.ND4J_STATUS_OK)
            throw new ND4JIllegalStateException("Op execution failed: " + status);
    } else if (Nd4j.dataType() == DataBuffer.Type.HALF) {
        val tArgs = op.tArgs().length > 0 ? new ShortPointer(op.tArgs().length) : null;
        val iArgs = op.iArgs().length > 0 ? new IntPointer(op.iArgs().length) : null;
        cnt = 0;
        for (val t : op.tArgs()) tArgs.put(cnt++, ArrayUtil.toHalf((float) t));
        cnt = 0;
        for (val i : op.iArgs()) iArgs.put(cnt++, i);
        val status = OpStatus.byNumber(nativeOps.execCustomOpHalf(extras, hash, inputBuffers, inputShapes, inputArgs.length, outputBuffers, outputShapes, outputArgs.length, tArgs, op.tArgs().length, iArgs, op.iArgs().length, op.isInplaceCall()));
        if (status != OpStatus.ND4J_STATUS_OK)
            throw new ND4JIllegalStateException("Op execution failed: " + status);
    }
// AtomicAllocator.getInstance().getFlowController().prepareActionAllWrite(op.outputArguments());
}
Also used : lombok.val(lombok.val) CudaContext(org.nd4j.linalg.jcublas.context.CudaContext) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) CudaPointer(org.nd4j.jita.allocator.pointers.CudaPointer) AllocationPoint(org.nd4j.jita.allocator.impl.AllocationPoint)

Aggregations

ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)116 lombok.val (lombok.val)26 INDArray (org.nd4j.linalg.api.ndarray.INDArray)23 CudaContext (org.nd4j.linalg.jcublas.context.CudaContext)21 AllocationPoint (org.nd4j.jita.allocator.impl.AllocationPoint)19 DataBuffer (org.nd4j.linalg.api.buffer.DataBuffer)17 CudaPointer (org.nd4j.jita.allocator.pointers.CudaPointer)15 PagedPointer (org.nd4j.linalg.api.memory.pointers.PagedPointer)12 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)8 BaseDataBuffer (org.nd4j.linalg.api.buffer.BaseDataBuffer)7 IComplexNDArray (org.nd4j.linalg.api.complex.IComplexNDArray)6 Pointer (org.bytedeco.javacpp.Pointer)5 ArrayList (java.util.ArrayList)4 DifferentialFunction (org.nd4j.autodiff.functions.DifferentialFunction)4 Aeron (io.aeron.Aeron)3 FragmentAssembler (io.aeron.FragmentAssembler)3 MediaDriver (io.aeron.driver.MediaDriver)3 AtomicBoolean (java.util.concurrent.atomic.AtomicBoolean)3 Slf4j (lombok.extern.slf4j.Slf4j)3 CloseHelper (org.agrona.CloseHelper)3