use of org.nd4j.jita.allocator.garbage.GarbageResourceReference in project nd4j by deeplearning4j.
the class LimitedContextPool method acquireContextForDevice.
@Override
public CudaContext acquireContextForDevice(Integer deviceId) {
long threadIdx = Thread.currentThread().getId();
CudaContext context = acquired.get(threadIdx);
if (context != null && deviceId == context.getDeviceId()) {
return context;
}
// log.info("Setting device to {}", deviceId);
nativeOps.setDevice(new CudaPointer(deviceId));
context = pool.get(deviceId).poll();
if (context != null) {
int col = RandomUtils.nextInt(0, collectors.size());
collectors.get(col);
GarbageResourceReference reference = new GarbageResourceReference(Thread.currentThread(), queueMap.get(col), context, deviceId.intValue());
context.attachReference(reference);
// Garba reference = new GarbageBufferReference((BaseDataBuffer) buffer, queueMap.get(bucketId), point);
// point.attachReference(reference);
acquired.put(threadIdx, context);
context.setDeviceId(deviceId);
return context;
} else {
do {
try {
Nd4j.getMemoryManager().invokeGc();
context = pool.get(deviceId).poll(1, TimeUnit.SECONDS);
if (context != null) {
int col = RandomUtils.nextInt(0, collectors.size());
collectors.get(col);
GarbageResourceReference reference = new GarbageResourceReference(Thread.currentThread(), queueMap.get(col), context, deviceId.intValue());
context.attachReference(reference);
acquired.put(threadIdx, context);
context.setDeviceId(deviceId);
} else {
if (currentPoolSize.get() < CudaEnvironment.getInstance().getConfiguration().getPoolSize() * 3) {
addResourcesToPool(16);
// there's possible race condition, but we don't really care
currentPoolSize.addAndGet(16);
} else {
log.warn("Can't allocate new context, sleeping...");
Nd4j.getMemoryManager().invokeGc();
try {
Thread.sleep(500);
} catch (Exception e) {
//
}
}
}
} catch (Exception e) {
throw new RuntimeException(e);
}
} while (context == null);
return context;
}
}
Aggregations