Search in sources :

Example 1 with JCudaBuffer

use of org.nd4j.linalg.jcublas.buffer.JCudaBuffer in project nd4j by deeplearning4j.

the class CudaArgs method argsAndReference.

/**
 * @param context
 * @param kernelParams
 * @return
 */
public static ArgsAndReferences argsAndReference(CudaContext context, Object... kernelParams) {
    // Map<Object, Object> idMap = new IdentityHashMap<>();
    Object[] kernelParameters = new Object[kernelParams.length];
    // List<CublasPointer> pointersToFree = new ArrayList<>();
    Multimap<INDArray, CublasPointer> arrayToPointer = ArrayListMultimap.create();
    for (int i = 0; i < kernelParams.length; i++) {
        Object arg = kernelParams[i];
        // If the instance is a JCudaBuffer we should assign it to the device
        if (arg instanceof JCudaBuffer) {
            JCudaBuffer buffer = (JCudaBuffer) arg;
            // if (!idMap.containsKey(buffer)) {
            CublasPointer pointerToFree = new CublasPointer(buffer, context);
            kernelParameters[i] = pointerToFree.getDevicePointer();
        // pointersToFree.add(pointerToFree);
        // idMap.put(buffer, pointerToFree.getPointer());
        // } else {
        // Pointer pointer = (Pointer) idMap.get(buffer);
        // kernelParameters[i] = pointer;
        // }
        } else if (arg instanceof INDArray) {
            INDArray array = (INDArray) arg;
            // array.norm2(0);
            // if (!idMap.containsKey(array)) {
            CublasPointer pointerToFree = new CublasPointer(array, context);
            kernelParameters[i] = pointerToFree.getDevicePointer();
            // pointersToFree.add(pointerToFree);
            arrayToPointer.put(array, pointerToFree);
        // idMap.put(array, pointerToFree.getPointer());
        // } else {
        // Pointer pointer = (Pointer) idMap.get(array);
        // kernelParameters[i] = pointer;
        // }
        } else {
            kernelParameters[i] = arg;
        }
    }
    return new ArgsAndReferences(kernelParameters, arrayToPointer);
// return new ArgsAndReferences(kernelParameters,idMap,pointersToFree,arrayToPointer);
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) JCudaBuffer(org.nd4j.linalg.jcublas.buffer.JCudaBuffer) CublasPointer(org.nd4j.linalg.jcublas.CublasPointer)

Aggregations

INDArray (org.nd4j.linalg.api.ndarray.INDArray)1 CublasPointer (org.nd4j.linalg.jcublas.CublasPointer)1 JCudaBuffer (org.nd4j.linalg.jcublas.buffer.JCudaBuffer)1