Search in sources :

Example 1 with ArrayDescriptor

use of org.nd4j.linalg.cache.ArrayDescriptor in project nd4j by deeplearning4j.

the class ProtectedCudaConstantHandler method ensureMaps.

private void ensureMaps(Integer deviceId) {
    if (!buffersCache.containsKey(deviceId)) {
        if (flowController == null)
            flowController = AtomicAllocator.getInstance().getFlowController();
        try {
            synchronized (this) {
                if (!buffersCache.containsKey(deviceId)) {
                    // TODO: this op call should be checked
                    // nativeOps.setDevice(new CudaPointer(deviceId));
                    buffersCache.put(deviceId, new ConcurrentHashMap<ArrayDescriptor, DataBuffer>());
                    constantOffsets.put(deviceId, new AtomicLong(0));
                    deviceLocks.put(deviceId, new Semaphore(1));
                    Pointer cAddr = NativeOpsHolder.getInstance().getDeviceNativeOps().getConstantSpace();
                    // logger.info("constant pointer: {}", cAddr.address() );
                    deviceAddresses.put(deviceId, cAddr);
                }
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
}
Also used : AtomicLong(java.util.concurrent.atomic.AtomicLong) ArrayDescriptor(org.nd4j.linalg.cache.ArrayDescriptor) CudaPointer(org.nd4j.jita.allocator.pointers.CudaPointer) Pointer(org.bytedeco.javacpp.Pointer) Semaphore(java.util.concurrent.Semaphore) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer) CudaIntDataBuffer(org.nd4j.linalg.jcublas.buffer.CudaIntDataBuffer) CudaHalfDataBuffer(org.nd4j.linalg.jcublas.buffer.CudaHalfDataBuffer) CudaFloatDataBuffer(org.nd4j.linalg.jcublas.buffer.CudaFloatDataBuffer) CudaDoubleDataBuffer(org.nd4j.linalg.jcublas.buffer.CudaDoubleDataBuffer)

Example 2 with ArrayDescriptor

use of org.nd4j.linalg.cache.ArrayDescriptor in project nd4j by deeplearning4j.

the class ProtectedCudaConstantHandler method getConstantBuffer.

/**
 * This method returns DataBuffer with contant equal to input array.
 *
 * PLEASE NOTE: This method assumes that you'll never ever change values within result DataBuffer
 *
 * @param array
 * @return
 */
@Override
public DataBuffer getConstantBuffer(int[] array) {
    // logger.info("getConstantBuffer(int[]) called");
    ArrayDescriptor descriptor = new ArrayDescriptor(array);
    Integer deviceId = AtomicAllocator.getInstance().getDeviceId();
    ensureMaps(deviceId);
    if (!buffersCache.get(deviceId).containsKey(descriptor)) {
        // we create new databuffer
        // logger.info("Creating new constant buffer...");
        DataBuffer buffer = Nd4j.createBufferDetached(array);
        if (constantOffsets.get(deviceId).get() + (array.length * 4) < MAX_CONSTANT_LENGTH) {
            buffer.setConstant(true);
            // now we move data to constant memory, and keep happy
            moveToConstantSpace(buffer);
            buffersCache.get(deviceId).put(descriptor, buffer);
            bytes.addAndGet(array.length * 4);
        }
        return buffer;
    }
    return buffersCache.get(deviceId).get(descriptor);
}
Also used : ArrayDescriptor(org.nd4j.linalg.cache.ArrayDescriptor) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer) CudaIntDataBuffer(org.nd4j.linalg.jcublas.buffer.CudaIntDataBuffer) CudaHalfDataBuffer(org.nd4j.linalg.jcublas.buffer.CudaHalfDataBuffer) CudaFloatDataBuffer(org.nd4j.linalg.jcublas.buffer.CudaFloatDataBuffer) CudaDoubleDataBuffer(org.nd4j.linalg.jcublas.buffer.CudaDoubleDataBuffer)

Example 3 with ArrayDescriptor

use of org.nd4j.linalg.cache.ArrayDescriptor in project nd4j by deeplearning4j.

the class ConstantBuffersCache method getConstantBuffer.

@Override
public DataBuffer getConstantBuffer(double[] array) {
    ArrayDescriptor descriptor = new ArrayDescriptor(array);
    if (!buffersCache.containsKey(descriptor)) {
        DataBuffer buffer = Nd4j.createBufferDetached(array);
        if (counter.get() < MAX_ENTRIES) {
            counter.incrementAndGet();
            buffersCache.put(descriptor, buffer);
            bytes.addAndGet(array.length * Nd4j.sizeOfDataType());
        }
        return buffer;
    }
    return buffersCache.get(descriptor);
}
Also used : ArrayDescriptor(org.nd4j.linalg.cache.ArrayDescriptor) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer)

Example 4 with ArrayDescriptor

use of org.nd4j.linalg.cache.ArrayDescriptor in project nd4j by deeplearning4j.

the class ConstantBuffersCache method getConstantBuffer.

@Override
public DataBuffer getConstantBuffer(int[] array) {
    ArrayDescriptor descriptor = new ArrayDescriptor(array);
    if (!buffersCache.containsKey(descriptor)) {
        DataBuffer buffer = Nd4j.createBufferDetached(array);
        // we always allow int arrays with length < 3. 99.9% it's just dimension array. we don't want to recreate them over and over
        if (counter.get() < MAX_ENTRIES || array.length < 4) {
            counter.incrementAndGet();
            buffersCache.put(descriptor, buffer);
            bytes.addAndGet(array.length * 4);
        }
        return buffer;
    }
    return buffersCache.get(descriptor);
}
Also used : ArrayDescriptor(org.nd4j.linalg.cache.ArrayDescriptor) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer)

Example 5 with ArrayDescriptor

use of org.nd4j.linalg.cache.ArrayDescriptor in project nd4j by deeplearning4j.

the class ConstantBuffersCache method getConstantBuffer.

@Override
public DataBuffer getConstantBuffer(float[] array) {
    ArrayDescriptor descriptor = new ArrayDescriptor(array);
    if (!buffersCache.containsKey(descriptor)) {
        DataBuffer buffer = Nd4j.createBufferDetached(array);
        if (counter.get() < MAX_ENTRIES) {
            counter.incrementAndGet();
            buffersCache.put(descriptor, buffer);
            bytes.addAndGet(array.length * Nd4j.sizeOfDataType());
        }
        return buffer;
    }
    return buffersCache.get(descriptor);
}
Also used : ArrayDescriptor(org.nd4j.linalg.cache.ArrayDescriptor) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer)

Aggregations

ArrayDescriptor (org.nd4j.linalg.cache.ArrayDescriptor)8 DataBuffer (org.nd4j.linalg.api.buffer.DataBuffer)7 CudaDoubleDataBuffer (org.nd4j.linalg.jcublas.buffer.CudaDoubleDataBuffer)4 CudaFloatDataBuffer (org.nd4j.linalg.jcublas.buffer.CudaFloatDataBuffer)4 CudaHalfDataBuffer (org.nd4j.linalg.jcublas.buffer.CudaHalfDataBuffer)4 CudaIntDataBuffer (org.nd4j.linalg.jcublas.buffer.CudaIntDataBuffer)4 Semaphore (java.util.concurrent.Semaphore)1 AtomicLong (java.util.concurrent.atomic.AtomicLong)1 Pointer (org.bytedeco.javacpp.Pointer)1 Test (org.junit.Test)1 CudaPointer (org.nd4j.jita.allocator.pointers.CudaPointer)1 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)1