Search in sources :

Example 46 with CudaPointer

use of org.nd4j.jita.allocator.pointers.CudaPointer in project nd4j by deeplearning4j.

the class BasicContextPool method acquireContextForDevice.

@Override
public CudaContext acquireContextForDevice(Integer deviceId) {
    /*
            We should check, if we have context for this specific thread/device
            If we don't have context for this thread - we should stick to one of existent contexts available at pool
         */
    Long threadId = Thread.currentThread().getId();
    if (!contextsPool.containsKey(threadId)) {
        try {
            // this is lockable thing, but since it locks once per thread initialization, performance impact won't be big
            lock.acquire();
            /*
                if (!cuPool.containsKey(deviceId)) {
                    CUcontext cuContext = createNewContext(deviceId);
                    cuPool.put(deviceId, cuContext);
                }
                
                int result = JCudaDriver.cuCtxSetCurrent(cuPool.get(deviceId));
                if (result != CUresult.CUDA_SUCCESS) {
                    throw new RuntimeException("Failed to set context on assigner");
                }
                */
            if (!contextsForDevices.containsKey(deviceId)) {
                contextsForDevices.put(deviceId, new ConcurrentHashMap<Integer, CudaContext>());
            }
            // if we hadn't hit MAX_STREAMS_PER_DEVICE limit - we add new stream. Otherwise we use random one.
            if (contextsForDevices.get(deviceId).size() < MAX_STREAMS_PER_DEVICE) {
                log.debug("Creating new context...");
                CudaContext context = createNewStream(deviceId);
                getDeviceBuffers(context, deviceId);
                if (contextsForDevices.get(deviceId).size() == 0) {
                    // if we have no contexts created - it's just awesome time to attach cuBLAS handle here
                    log.debug("Creating new cuBLAS handle for device [{}]...", deviceId);
                    cudaStream_t cublasStream = createNewStream(deviceId).getOldStream();
                    cublasHandle_t handle = createNewCublasHandle(cublasStream);
                    context.setHandle(handle);
                    context.setCublasStream(cublasStream);
                    cublasPool.put(deviceId, handle);
                    log.debug("Creating new cuSolver handle for device [{}]...", deviceId);
                    cudaStream_t solverStream = createNewStream(deviceId).getOldStream();
                    cusolverDnHandle_t solverhandle = createNewSolverHandle(solverStream);
                    context.setSolverHandle(solverhandle);
                    context.setSolverStream(solverStream);
                    solverPool.put(deviceId, solverhandle);
                } else {
                    // just pick handle out there
                    log.debug("Reusing blas here...");
                    cublasHandle_t handle = cublasPool.get(deviceId);
                    context.setHandle(handle);
                    log.debug("Reusing solver here...");
                    cusolverDnHandle_t solverHandle = solverPool.get(deviceId);
                    context.setSolverHandle(solverHandle);
                // TODO: actually we don't need this anymore
                // cudaStream_t cublasStream = new cudaStream_t();
                // JCublas2.cublasGetStream(handle, cublasStream);
                // context.setCublasStream(cublasStream);
                }
                // we need this sync to finish memset
                context.syncOldStream();
                contextsPool.put(threadId, context);
                contextsForDevices.get(deviceId).put(contextsForDevices.get(deviceId).size(), context);
                return context;
            } else {
                Integer rand = RandomUtils.nextInt(0, MAX_STREAMS_PER_DEVICE);
                log.debug("Reusing context: " + rand);
                nativeOps.setDevice(new CudaPointer(deviceId));
                CudaContext context = contextsForDevices.get(deviceId).get(rand);
                contextsPool.put(threadId, context);
                return context;
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        } finally {
            lock.release();
        }
    }
    return contextsPool.get(threadId);
}
Also used : org.nd4j.jita.allocator.pointers.cuda.cusolverDnHandle_t(org.nd4j.jita.allocator.pointers.cuda.cusolverDnHandle_t) org.nd4j.jita.allocator.pointers.cuda.cudaStream_t(org.nd4j.jita.allocator.pointers.cuda.cudaStream_t) org.nd4j.jita.allocator.pointers.cuda.cublasHandle_t(org.nd4j.jita.allocator.pointers.cuda.cublasHandle_t) CudaContext(org.nd4j.linalg.jcublas.context.CudaContext) CudaPointer(org.nd4j.jita.allocator.pointers.CudaPointer)

Example 47 with CudaPointer

use of org.nd4j.jita.allocator.pointers.CudaPointer in project nd4j by deeplearning4j.

the class LimitedContextPool method acquireContextForDevice.

@Override
public CudaContext acquireContextForDevice(Integer deviceId) {
    long threadIdx = Thread.currentThread().getId();
    CudaContext context = acquired.get(threadIdx);
    if (context != null && deviceId == context.getDeviceId()) {
        return context;
    }
    // log.info("Setting device to {}", deviceId);
    nativeOps.setDevice(new CudaPointer(deviceId));
    context = pool.get(deviceId).poll();
    if (context != null) {
        int col = RandomUtils.nextInt(0, collectors.size());
        collectors.get(col);
        GarbageResourceReference reference = new GarbageResourceReference(Thread.currentThread(), queueMap.get(col), context, deviceId.intValue());
        context.attachReference(reference);
        // Garba reference = new GarbageBufferReference((BaseDataBuffer) buffer, queueMap.get(bucketId), point);
        // point.attachReference(reference);
        acquired.put(threadIdx, context);
        context.setDeviceId(deviceId);
        return context;
    } else {
        do {
            try {
                Nd4j.getMemoryManager().invokeGc();
                context = pool.get(deviceId).poll(1, TimeUnit.SECONDS);
                if (context != null) {
                    int col = RandomUtils.nextInt(0, collectors.size());
                    collectors.get(col);
                    GarbageResourceReference reference = new GarbageResourceReference(Thread.currentThread(), queueMap.get(col), context, deviceId.intValue());
                    context.attachReference(reference);
                    acquired.put(threadIdx, context);
                    context.setDeviceId(deviceId);
                } else {
                    if (currentPoolSize.get() < CudaEnvironment.getInstance().getConfiguration().getPoolSize() * 3) {
                        addResourcesToPool(16);
                        // there's possible race condition, but we don't really care
                        currentPoolSize.addAndGet(16);
                    } else {
                        log.warn("Can't allocate new context, sleeping...");
                        Nd4j.getMemoryManager().invokeGc();
                        try {
                            Thread.sleep(500);
                        } catch (Exception e) {
                        // 
                        }
                    }
                }
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        } while (context == null);
        return context;
    }
}
Also used : GarbageResourceReference(org.nd4j.jita.allocator.garbage.GarbageResourceReference) CudaContext(org.nd4j.linalg.jcublas.context.CudaContext) CudaPointer(org.nd4j.jita.allocator.pointers.CudaPointer)

Aggregations

CudaPointer (org.nd4j.jita.allocator.pointers.CudaPointer)47 CudaContext (org.nd4j.linalg.jcublas.context.CudaContext)27 AllocationPoint (org.nd4j.jita.allocator.impl.AllocationPoint)20 Pointer (org.bytedeco.javacpp.Pointer)18 DataBuffer (org.nd4j.linalg.api.buffer.DataBuffer)18 INDArray (org.nd4j.linalg.api.ndarray.INDArray)15 org.nd4j.jita.allocator.pointers.cuda.cusolverDnHandle_t (org.nd4j.jita.allocator.pointers.cuda.cusolverDnHandle_t)12 GridExecutioner (org.nd4j.linalg.api.ops.executioner.GridExecutioner)11 DoublePointer (org.bytedeco.javacpp.DoublePointer)10 FloatPointer (org.bytedeco.javacpp.FloatPointer)10 IntPointer (org.bytedeco.javacpp.IntPointer)10 CUstream_st (org.bytedeco.javacpp.cuda.CUstream_st)10 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)10 CublasPointer (org.nd4j.linalg.jcublas.CublasPointer)10 BlasException (org.nd4j.linalg.api.blas.BlasException)8 BaseCudaDataBuffer (org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer)7 AllocationShape (org.nd4j.jita.allocator.impl.AllocationShape)4 AtomicAllocator (org.nd4j.jita.allocator.impl.AtomicAllocator)4 BaseDataBuffer (org.nd4j.linalg.api.buffer.BaseDataBuffer)4 INDArrayIndex (org.nd4j.linalg.indexing.INDArrayIndex)4