use of org.nd4j.jita.allocator.pointers.CudaPointer in project nd4j by deeplearning4j.
the class BasicContextPool method acquireContextForDevice.
@Override
public CudaContext acquireContextForDevice(Integer deviceId) {
/*
We should check, if we have context for this specific thread/device
If we don't have context for this thread - we should stick to one of existent contexts available at pool
*/
Long threadId = Thread.currentThread().getId();
if (!contextsPool.containsKey(threadId)) {
try {
// this is lockable thing, but since it locks once per thread initialization, performance impact won't be big
lock.acquire();
/*
if (!cuPool.containsKey(deviceId)) {
CUcontext cuContext = createNewContext(deviceId);
cuPool.put(deviceId, cuContext);
}
int result = JCudaDriver.cuCtxSetCurrent(cuPool.get(deviceId));
if (result != CUresult.CUDA_SUCCESS) {
throw new RuntimeException("Failed to set context on assigner");
}
*/
if (!contextsForDevices.containsKey(deviceId)) {
contextsForDevices.put(deviceId, new ConcurrentHashMap<Integer, CudaContext>());
}
// if we hadn't hit MAX_STREAMS_PER_DEVICE limit - we add new stream. Otherwise we use random one.
if (contextsForDevices.get(deviceId).size() < MAX_STREAMS_PER_DEVICE) {
log.debug("Creating new context...");
CudaContext context = createNewStream(deviceId);
getDeviceBuffers(context, deviceId);
if (contextsForDevices.get(deviceId).size() == 0) {
// if we have no contexts created - it's just awesome time to attach cuBLAS handle here
log.debug("Creating new cuBLAS handle for device [{}]...", deviceId);
cudaStream_t cublasStream = createNewStream(deviceId).getOldStream();
cublasHandle_t handle = createNewCublasHandle(cublasStream);
context.setHandle(handle);
context.setCublasStream(cublasStream);
cublasPool.put(deviceId, handle);
log.debug("Creating new cuSolver handle for device [{}]...", deviceId);
cudaStream_t solverStream = createNewStream(deviceId).getOldStream();
cusolverDnHandle_t solverhandle = createNewSolverHandle(solverStream);
context.setSolverHandle(solverhandle);
context.setSolverStream(solverStream);
solverPool.put(deviceId, solverhandle);
} else {
// just pick handle out there
log.debug("Reusing blas here...");
cublasHandle_t handle = cublasPool.get(deviceId);
context.setHandle(handle);
log.debug("Reusing solver here...");
cusolverDnHandle_t solverHandle = solverPool.get(deviceId);
context.setSolverHandle(solverHandle);
// TODO: actually we don't need this anymore
// cudaStream_t cublasStream = new cudaStream_t();
// JCublas2.cublasGetStream(handle, cublasStream);
// context.setCublasStream(cublasStream);
}
// we need this sync to finish memset
context.syncOldStream();
contextsPool.put(threadId, context);
contextsForDevices.get(deviceId).put(contextsForDevices.get(deviceId).size(), context);
return context;
} else {
Integer rand = RandomUtils.nextInt(0, MAX_STREAMS_PER_DEVICE);
log.debug("Reusing context: " + rand);
nativeOps.setDevice(new CudaPointer(deviceId));
CudaContext context = contextsForDevices.get(deviceId).get(rand);
contextsPool.put(threadId, context);
return context;
}
} catch (Exception e) {
throw new RuntimeException(e);
} finally {
lock.release();
}
}
return contextsPool.get(threadId);
}
use of org.nd4j.jita.allocator.pointers.CudaPointer in project nd4j by deeplearning4j.
the class LimitedContextPool method acquireContextForDevice.
@Override
public CudaContext acquireContextForDevice(Integer deviceId) {
long threadIdx = Thread.currentThread().getId();
CudaContext context = acquired.get(threadIdx);
if (context != null && deviceId == context.getDeviceId()) {
return context;
}
// log.info("Setting device to {}", deviceId);
nativeOps.setDevice(new CudaPointer(deviceId));
context = pool.get(deviceId).poll();
if (context != null) {
int col = RandomUtils.nextInt(0, collectors.size());
collectors.get(col);
GarbageResourceReference reference = new GarbageResourceReference(Thread.currentThread(), queueMap.get(col), context, deviceId.intValue());
context.attachReference(reference);
// Garba reference = new GarbageBufferReference((BaseDataBuffer) buffer, queueMap.get(bucketId), point);
// point.attachReference(reference);
acquired.put(threadIdx, context);
context.setDeviceId(deviceId);
return context;
} else {
do {
try {
Nd4j.getMemoryManager().invokeGc();
context = pool.get(deviceId).poll(1, TimeUnit.SECONDS);
if (context != null) {
int col = RandomUtils.nextInt(0, collectors.size());
collectors.get(col);
GarbageResourceReference reference = new GarbageResourceReference(Thread.currentThread(), queueMap.get(col), context, deviceId.intValue());
context.attachReference(reference);
acquired.put(threadIdx, context);
context.setDeviceId(deviceId);
} else {
if (currentPoolSize.get() < CudaEnvironment.getInstance().getConfiguration().getPoolSize() * 3) {
addResourcesToPool(16);
// there's possible race condition, but we don't really care
currentPoolSize.addAndGet(16);
} else {
log.warn("Can't allocate new context, sleeping...");
Nd4j.getMemoryManager().invokeGc();
try {
Thread.sleep(500);
} catch (Exception e) {
//
}
}
}
} catch (Exception e) {
throw new RuntimeException(e);
}
} while (context == null);
return context;
}
}
Aggregations