Search in sources :

Example 11 with BlasException

use of org.nd4j.linalg.api.blas.BlasException in project nd4j by deeplearning4j.

the class CpuLapack method dgeqrf.

@Override
public void dgeqrf(int M, int N, INDArray A, INDArray R, INDArray INFO) {
    INDArray tau = Nd4j.create(N);
    int status = LAPACKE_dgeqrf(getColumnOrder(A), M, N, (DoublePointer) A.data().addressPointer(), getLda(A), (DoublePointer) tau.data().addressPointer());
    if (status != 0) {
        throw new BlasException("Failed to execute dgeqrf", status);
    }
    // Copy R ( upper part of Q ) into result
    if (R != null) {
        R.assign(A.get(NDArrayIndex.interval(0, A.columns()), NDArrayIndex.all()));
        INDArrayIndex[] ix = new INDArrayIndex[2];
        for (int i = 1; i < Math.min(A.rows(), A.columns()); i++) {
            ix[0] = NDArrayIndex.point(i);
            ix[1] = NDArrayIndex.interval(0, i);
            R.put(ix, 0);
        }
    }
    status = LAPACKE_dorgqr(getColumnOrder(A), M, N, N, (DoublePointer) A.data().addressPointer(), getLda(A), (DoublePointer) tau.data().addressPointer());
    if (status != 0) {
        throw new BlasException("Failed to execute dorgqr", status);
    }
}
Also used : BlasException(org.nd4j.linalg.api.blas.BlasException) INDArray(org.nd4j.linalg.api.ndarray.INDArray) INDArrayIndex(org.nd4j.linalg.indexing.INDArrayIndex) DoublePointer(org.bytedeco.javacpp.DoublePointer)

Example 12 with BlasException

use of org.nd4j.linalg.api.blas.BlasException in project nd4j by deeplearning4j.

the class CpuLapack method sgesvd.

// =========================
// U S V' DECOMP  (aka SVD)
@Override
public void sgesvd(byte jobu, byte jobvt, int M, int N, INDArray A, INDArray S, INDArray U, INDArray VT, INDArray INFO) {
    INDArray superb = Nd4j.create(M < N ? M : N);
    int status = LAPACKE_sgesvd(getColumnOrder(A), jobu, jobvt, M, N, (FloatPointer) A.data().addressPointer(), getLda(A), (FloatPointer) S.data().addressPointer(), U == null ? null : (FloatPointer) U.data().addressPointer(), U == null ? 1 : getLda(U), VT == null ? null : (FloatPointer) VT.data().addressPointer(), VT == null ? 1 : getLda(VT), (FloatPointer) superb.data().addressPointer());
    if (status != 0) {
        throw new BlasException("Failed to execute sgesvd", status);
    }
}
Also used : BlasException(org.nd4j.linalg.api.blas.BlasException) INDArray(org.nd4j.linalg.api.ndarray.INDArray) FloatPointer(org.bytedeco.javacpp.FloatPointer)

Aggregations

BlasException (org.nd4j.linalg.api.blas.BlasException)12 INDArray (org.nd4j.linalg.api.ndarray.INDArray)12 DoublePointer (org.bytedeco.javacpp.DoublePointer)10 FloatPointer (org.bytedeco.javacpp.FloatPointer)10 IntPointer (org.bytedeco.javacpp.IntPointer)8 Pointer (org.bytedeco.javacpp.Pointer)8 CUstream_st (org.bytedeco.javacpp.cuda.CUstream_st)8 CudaPointer (org.nd4j.jita.allocator.pointers.CudaPointer)8 org.nd4j.jita.allocator.pointers.cuda.cusolverDnHandle_t (org.nd4j.jita.allocator.pointers.cuda.cusolverDnHandle_t)8 DataBuffer (org.nd4j.linalg.api.buffer.DataBuffer)8 GridExecutioner (org.nd4j.linalg.api.ops.executioner.GridExecutioner)8 CublasPointer (org.nd4j.linalg.jcublas.CublasPointer)8 CudaContext (org.nd4j.linalg.jcublas.context.CudaContext)8 INDArrayIndex (org.nd4j.linalg.indexing.INDArrayIndex)6