use of org.nd4j.linalg.cache.ArrayDescriptor in project nd4j by deeplearning4j.
the class ProtectedCudaConstantHandler method ensureMaps.
private void ensureMaps(Integer deviceId) {
if (!buffersCache.containsKey(deviceId)) {
if (flowController == null)
flowController = AtomicAllocator.getInstance().getFlowController();
try {
synchronized (this) {
if (!buffersCache.containsKey(deviceId)) {
// TODO: this op call should be checked
// nativeOps.setDevice(new CudaPointer(deviceId));
buffersCache.put(deviceId, new ConcurrentHashMap<ArrayDescriptor, DataBuffer>());
constantOffsets.put(deviceId, new AtomicLong(0));
deviceLocks.put(deviceId, new Semaphore(1));
Pointer cAddr = NativeOpsHolder.getInstance().getDeviceNativeOps().getConstantSpace();
// logger.info("constant pointer: {}", cAddr.address() );
deviceAddresses.put(deviceId, cAddr);
}
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}
use of org.nd4j.linalg.cache.ArrayDescriptor in project nd4j by deeplearning4j.
the class ProtectedCudaConstantHandler method getConstantBuffer.
/**
* This method returns DataBuffer with contant equal to input array.
*
* PLEASE NOTE: This method assumes that you'll never ever change values within result DataBuffer
*
* @param array
* @return
*/
@Override
public DataBuffer getConstantBuffer(int[] array) {
// logger.info("getConstantBuffer(int[]) called");
ArrayDescriptor descriptor = new ArrayDescriptor(array);
Integer deviceId = AtomicAllocator.getInstance().getDeviceId();
ensureMaps(deviceId);
if (!buffersCache.get(deviceId).containsKey(descriptor)) {
// we create new databuffer
// logger.info("Creating new constant buffer...");
DataBuffer buffer = Nd4j.createBufferDetached(array);
if (constantOffsets.get(deviceId).get() + (array.length * 4) < MAX_CONSTANT_LENGTH) {
buffer.setConstant(true);
// now we move data to constant memory, and keep happy
moveToConstantSpace(buffer);
buffersCache.get(deviceId).put(descriptor, buffer);
bytes.addAndGet(array.length * 4);
}
return buffer;
}
return buffersCache.get(deviceId).get(descriptor);
}
use of org.nd4j.linalg.cache.ArrayDescriptor in project nd4j by deeplearning4j.
the class ConstantBuffersCache method getConstantBuffer.
@Override
public DataBuffer getConstantBuffer(double[] array) {
ArrayDescriptor descriptor = new ArrayDescriptor(array);
if (!buffersCache.containsKey(descriptor)) {
DataBuffer buffer = Nd4j.createBufferDetached(array);
if (counter.get() < MAX_ENTRIES) {
counter.incrementAndGet();
buffersCache.put(descriptor, buffer);
bytes.addAndGet(array.length * Nd4j.sizeOfDataType());
}
return buffer;
}
return buffersCache.get(descriptor);
}
use of org.nd4j.linalg.cache.ArrayDescriptor in project nd4j by deeplearning4j.
the class ConstantBuffersCache method getConstantBuffer.
@Override
public DataBuffer getConstantBuffer(int[] array) {
ArrayDescriptor descriptor = new ArrayDescriptor(array);
if (!buffersCache.containsKey(descriptor)) {
DataBuffer buffer = Nd4j.createBufferDetached(array);
// we always allow int arrays with length < 3. 99.9% it's just dimension array. we don't want to recreate them over and over
if (counter.get() < MAX_ENTRIES || array.length < 4) {
counter.incrementAndGet();
buffersCache.put(descriptor, buffer);
bytes.addAndGet(array.length * 4);
}
return buffer;
}
return buffersCache.get(descriptor);
}
use of org.nd4j.linalg.cache.ArrayDescriptor in project nd4j by deeplearning4j.
the class ConstantBuffersCache method getConstantBuffer.
@Override
public DataBuffer getConstantBuffer(float[] array) {
ArrayDescriptor descriptor = new ArrayDescriptor(array);
if (!buffersCache.containsKey(descriptor)) {
DataBuffer buffer = Nd4j.createBufferDetached(array);
if (counter.get() < MAX_ENTRIES) {
counter.incrementAndGet();
buffersCache.put(descriptor, buffer);
bytes.addAndGet(array.length * Nd4j.sizeOfDataType());
}
return buffer;
}
return buffersCache.get(descriptor);
}
Aggregations