Search in sources :

Example 26 with CublasPointer

use of org.nd4j.linalg.jcublas.CublasPointer in project nd4j by deeplearning4j.

the class JcublasLevel3 method dsyrk.

@Override
protected void dsyrk(char Order, char Uplo, char Trans, int N, int K, double alpha, INDArray A, int lda, double beta, INDArray C, int ldc) {
    if (Nd4j.dataType() != DataBuffer.Type.DOUBLE)
        logger.warn("DOUBLE syrk called");
    Nd4j.getExecutioner().push();
    CudaContext ctx = allocator.getFlowController().prepareAction(C, A);
    CublasPointer aPointer = new CublasPointer(A, ctx);
    CublasPointer cPointer = new CublasPointer(C, ctx);
    cublasHandle_t handle = ctx.getHandle();
    synchronized (handle) {
        cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream()));
        cublasDsyrk_v2(new cublasContext(handle), convertUplo(Uplo), Trans, N, K, new DoublePointer(alpha), (DoublePointer) aPointer.getDevicePointer(), lda, new DoublePointer(beta), (DoublePointer) cPointer.getDevicePointer(), ldc);
    }
    allocator.registerAction(ctx, C, A);
    OpExecutionerUtil.checkForAny(C);
}
Also used : org.nd4j.jita.allocator.pointers.cuda.cublasHandle_t(org.nd4j.jita.allocator.pointers.cuda.cublasHandle_t) CudaContext(org.nd4j.linalg.jcublas.context.CudaContext) DoublePointer(org.bytedeco.javacpp.DoublePointer) CublasPointer(org.nd4j.linalg.jcublas.CublasPointer)

Example 27 with CublasPointer

use of org.nd4j.linalg.jcublas.CublasPointer in project nd4j by deeplearning4j.

the class JcublasLevel3 method strsm.

@Override
protected void strsm(char Order, char Side, char Uplo, char TransA, char Diag, int M, int N, float alpha, INDArray A, int lda, INDArray B, int ldb) {
    if (Nd4j.dataType() != DataBuffer.Type.FLOAT)
        logger.warn("FLOAT trsm called");
    Nd4j.getExecutioner().push();
    CudaContext ctx = allocator.getFlowController().prepareAction(B, A);
    CublasPointer aPointer = new CublasPointer(A, ctx);
    CublasPointer bPointer = new CublasPointer(B, ctx);
    cublasHandle_t handle = ctx.getHandle();
    synchronized (handle) {
        cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream()));
        cublasStrsm_v2(new cublasContext(handle), convertSideMode(Side), convertUplo(Uplo), convertTranspose(TransA), convertDiag(Diag), M, N, new FloatPointer(alpha), (FloatPointer) aPointer.getDevicePointer(), lda, (FloatPointer) bPointer.getDevicePointer(), ldb);
    }
    allocator.registerAction(ctx, B, A);
    OpExecutionerUtil.checkForAny(B);
}
Also used : org.nd4j.jita.allocator.pointers.cuda.cublasHandle_t(org.nd4j.jita.allocator.pointers.cuda.cublasHandle_t) FloatPointer(org.bytedeco.javacpp.FloatPointer) CudaContext(org.nd4j.linalg.jcublas.context.CudaContext) CublasPointer(org.nd4j.linalg.jcublas.CublasPointer)

Example 28 with CublasPointer

use of org.nd4j.linalg.jcublas.CublasPointer in project nd4j by deeplearning4j.

the class JcublasLevel3 method sgemm.

@Override
protected void sgemm(char Order, char TransA, char TransB, int M, int N, int K, float alpha, INDArray A, int lda, INDArray B, int ldb, float beta, INDArray C, int ldc) {
    // B = Shape.toOffsetZero(B);
    if (Nd4j.dataType() != DataBuffer.Type.FLOAT)
        logger.warn("FLOAT gemm called");
    Nd4j.getExecutioner().push();
    CudaContext ctx = allocator.getFlowController().prepareAction(C, A, B);
    CublasPointer cAPointer = new CublasPointer(A, ctx);
    CublasPointer cBPointer = new CublasPointer(B, ctx);
    CublasPointer cCPointer = new CublasPointer(C, ctx);
    cublasHandle_t handle = ctx.getHandle();
    synchronized (handle) {
        cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream()));
        cublasSgemm_v2(new cublasContext(handle), convertTranspose(TransA), convertTranspose(TransB), M, N, K, new FloatPointer(alpha), (FloatPointer) cAPointer.getDevicePointer(), lda, (FloatPointer) cBPointer.getDevicePointer(), ldb, new FloatPointer(beta), (FloatPointer) cCPointer.getDevicePointer(), ldc);
    }
    allocator.registerAction(ctx, C, A, B);
    OpExecutionerUtil.checkForAny(C);
}
Also used : org.nd4j.jita.allocator.pointers.cuda.cublasHandle_t(org.nd4j.jita.allocator.pointers.cuda.cublasHandle_t) FloatPointer(org.bytedeco.javacpp.FloatPointer) CudaContext(org.nd4j.linalg.jcublas.context.CudaContext) CublasPointer(org.nd4j.linalg.jcublas.CublasPointer)

Example 29 with CublasPointer

use of org.nd4j.linalg.jcublas.CublasPointer in project nd4j by deeplearning4j.

the class JcublasLevel3 method ssyrk.

@Override
protected void ssyrk(char Order, char Uplo, char Trans, int N, int K, float alpha, INDArray A, int lda, float beta, INDArray C, int ldc) {
    if (Nd4j.dataType() != DataBuffer.Type.FLOAT)
        logger.warn("FLOAT syrk called");
    Nd4j.getExecutioner().push();
    CudaContext ctx = allocator.getFlowController().prepareAction(C, A);
    CublasPointer aPointer = new CublasPointer(A, ctx);
    CublasPointer cPointer = new CublasPointer(C, ctx);
    cublasHandle_t handle = ctx.getHandle();
    synchronized (handle) {
        cublasSetStream_v2(new cublasContext(handle), new CUstream_st(ctx.getOldStream()));
        cublasSsyrk_v2(new cublasContext(handle), convertUplo(Uplo), convertTranspose(Trans), N, K, new FloatPointer(alpha), (FloatPointer) aPointer.getDevicePointer(), lda, new FloatPointer(beta), (FloatPointer) cPointer.getDevicePointer(), ldc);
    }
    allocator.registerAction(ctx, C, A);
    OpExecutionerUtil.checkForAny(C);
}
Also used : org.nd4j.jita.allocator.pointers.cuda.cublasHandle_t(org.nd4j.jita.allocator.pointers.cuda.cublasHandle_t) FloatPointer(org.bytedeco.javacpp.FloatPointer) CudaContext(org.nd4j.linalg.jcublas.context.CudaContext) CublasPointer(org.nd4j.linalg.jcublas.CublasPointer)

Example 30 with CublasPointer

use of org.nd4j.linalg.jcublas.CublasPointer in project nd4j by deeplearning4j.

the class CudaArgs method argsAndReference.

/**
 * @param context
 * @param kernelParams
 * @return
 */
public static ArgsAndReferences argsAndReference(CudaContext context, Object... kernelParams) {
    // Map<Object, Object> idMap = new IdentityHashMap<>();
    Object[] kernelParameters = new Object[kernelParams.length];
    // List<CublasPointer> pointersToFree = new ArrayList<>();
    Multimap<INDArray, CublasPointer> arrayToPointer = ArrayListMultimap.create();
    for (int i = 0; i < kernelParams.length; i++) {
        Object arg = kernelParams[i];
        // If the instance is a JCudaBuffer we should assign it to the device
        if (arg instanceof JCudaBuffer) {
            JCudaBuffer buffer = (JCudaBuffer) arg;
            // if (!idMap.containsKey(buffer)) {
            CublasPointer pointerToFree = new CublasPointer(buffer, context);
            kernelParameters[i] = pointerToFree.getDevicePointer();
        // pointersToFree.add(pointerToFree);
        // idMap.put(buffer, pointerToFree.getPointer());
        // } else {
        // Pointer pointer = (Pointer) idMap.get(buffer);
        // kernelParameters[i] = pointer;
        // }
        } else if (arg instanceof INDArray) {
            INDArray array = (INDArray) arg;
            // array.norm2(0);
            // if (!idMap.containsKey(array)) {
            CublasPointer pointerToFree = new CublasPointer(array, context);
            kernelParameters[i] = pointerToFree.getDevicePointer();
            // pointersToFree.add(pointerToFree);
            arrayToPointer.put(array, pointerToFree);
        // idMap.put(array, pointerToFree.getPointer());
        // } else {
        // Pointer pointer = (Pointer) idMap.get(array);
        // kernelParameters[i] = pointer;
        // }
        } else {
            kernelParameters[i] = arg;
        }
    }
    return new ArgsAndReferences(kernelParameters, arrayToPointer);
// return new ArgsAndReferences(kernelParameters,idMap,pointersToFree,arrayToPointer);
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) JCudaBuffer(org.nd4j.linalg.jcublas.buffer.JCudaBuffer) CublasPointer(org.nd4j.linalg.jcublas.CublasPointer)

Aggregations

CublasPointer (org.nd4j.linalg.jcublas.CublasPointer)36 CudaContext (org.nd4j.linalg.jcublas.context.CudaContext)35 org.nd4j.jita.allocator.pointers.cuda.cublasHandle_t (org.nd4j.jita.allocator.pointers.cuda.cublasHandle_t)25 CUstream_st (org.bytedeco.javacpp.cuda.CUstream_st)24 DoublePointer (org.bytedeco.javacpp.DoublePointer)20 FloatPointer (org.bytedeco.javacpp.FloatPointer)19 IntPointer (org.bytedeco.javacpp.IntPointer)12 INDArray (org.nd4j.linalg.api.ndarray.INDArray)11 Pointer (org.bytedeco.javacpp.Pointer)10 CudaPointer (org.nd4j.jita.allocator.pointers.CudaPointer)10 org.nd4j.jita.allocator.pointers.cuda.cusolverDnHandle_t (org.nd4j.jita.allocator.pointers.cuda.cusolverDnHandle_t)10 DataBuffer (org.nd4j.linalg.api.buffer.DataBuffer)10 GridExecutioner (org.nd4j.linalg.api.ops.executioner.GridExecutioner)10 BlasException (org.nd4j.linalg.api.blas.BlasException)8 INDArrayIndex (org.nd4j.linalg.indexing.INDArrayIndex)4 ShortPointer (org.bytedeco.javacpp.ShortPointer)1 JCudaBuffer (org.nd4j.linalg.jcublas.buffer.JCudaBuffer)1