use of org.nd4j.linalg.jcublas.buffer.JCudaBuffer in project nd4j by deeplearning4j.
the class CudaArgs method argsAndReference.
/**
* @param context
* @param kernelParams
* @return
*/
public static ArgsAndReferences argsAndReference(CudaContext context, Object... kernelParams) {
// Map<Object, Object> idMap = new IdentityHashMap<>();
Object[] kernelParameters = new Object[kernelParams.length];
// List<CublasPointer> pointersToFree = new ArrayList<>();
Multimap<INDArray, CublasPointer> arrayToPointer = ArrayListMultimap.create();
for (int i = 0; i < kernelParams.length; i++) {
Object arg = kernelParams[i];
// If the instance is a JCudaBuffer we should assign it to the device
if (arg instanceof JCudaBuffer) {
JCudaBuffer buffer = (JCudaBuffer) arg;
// if (!idMap.containsKey(buffer)) {
CublasPointer pointerToFree = new CublasPointer(buffer, context);
kernelParameters[i] = pointerToFree.getDevicePointer();
// pointersToFree.add(pointerToFree);
// idMap.put(buffer, pointerToFree.getPointer());
// } else {
// Pointer pointer = (Pointer) idMap.get(buffer);
// kernelParameters[i] = pointer;
// }
} else if (arg instanceof INDArray) {
INDArray array = (INDArray) arg;
// array.norm2(0);
// if (!idMap.containsKey(array)) {
CublasPointer pointerToFree = new CublasPointer(array, context);
kernelParameters[i] = pointerToFree.getDevicePointer();
// pointersToFree.add(pointerToFree);
arrayToPointer.put(array, pointerToFree);
// idMap.put(array, pointerToFree.getPointer());
// } else {
// Pointer pointer = (Pointer) idMap.get(array);
// kernelParameters[i] = pointer;
// }
} else {
kernelParameters[i] = arg;
}
}
return new ArgsAndReferences(kernelParameters, arrayToPointer);
// return new ArgsAndReferences(kernelParameters,idMap,pointersToFree,arrayToPointer);
}
Aggregations