Search in sources :

Example 1 with GarbageResourceReference

use of org.nd4j.jita.allocator.garbage.GarbageResourceReference in project nd4j by deeplearning4j.

the class LimitedContextPool method acquireContextForDevice.

@Override
public CudaContext acquireContextForDevice(Integer deviceId) {
    long threadIdx = Thread.currentThread().getId();
    CudaContext context = acquired.get(threadIdx);
    if (context != null && deviceId == context.getDeviceId()) {
        return context;
    }
    // log.info("Setting device to {}", deviceId);
    nativeOps.setDevice(new CudaPointer(deviceId));
    context = pool.get(deviceId).poll();
    if (context != null) {
        int col = RandomUtils.nextInt(0, collectors.size());
        collectors.get(col);
        GarbageResourceReference reference = new GarbageResourceReference(Thread.currentThread(), queueMap.get(col), context, deviceId.intValue());
        context.attachReference(reference);
        // Garba reference = new GarbageBufferReference((BaseDataBuffer) buffer, queueMap.get(bucketId), point);
        // point.attachReference(reference);
        acquired.put(threadIdx, context);
        context.setDeviceId(deviceId);
        return context;
    } else {
        do {
            try {
                Nd4j.getMemoryManager().invokeGc();
                context = pool.get(deviceId).poll(1, TimeUnit.SECONDS);
                if (context != null) {
                    int col = RandomUtils.nextInt(0, collectors.size());
                    collectors.get(col);
                    GarbageResourceReference reference = new GarbageResourceReference(Thread.currentThread(), queueMap.get(col), context, deviceId.intValue());
                    context.attachReference(reference);
                    acquired.put(threadIdx, context);
                    context.setDeviceId(deviceId);
                } else {
                    if (currentPoolSize.get() < CudaEnvironment.getInstance().getConfiguration().getPoolSize() * 3) {
                        addResourcesToPool(16);
                        // there's possible race condition, but we don't really care
                        currentPoolSize.addAndGet(16);
                    } else {
                        log.warn("Can't allocate new context, sleeping...");
                        Nd4j.getMemoryManager().invokeGc();
                        try {
                            Thread.sleep(500);
                        } catch (Exception e) {
                        // 
                        }
                    }
                }
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        } while (context == null);
        return context;
    }
}
Also used : GarbageResourceReference(org.nd4j.jita.allocator.garbage.GarbageResourceReference) CudaContext(org.nd4j.linalg.jcublas.context.CudaContext) CudaPointer(org.nd4j.jita.allocator.pointers.CudaPointer)

Aggregations

GarbageResourceReference (org.nd4j.jita.allocator.garbage.GarbageResourceReference)1 CudaPointer (org.nd4j.jita.allocator.pointers.CudaPointer)1 CudaContext (org.nd4j.linalg.jcublas.context.CudaContext)1