use of org.nd4j.linalg.cache.ArrayDescriptor in project nd4j by deeplearning4j.
the class ProtectedCudaConstantHandler method getConstantBuffer.
/**
* This method returns DataBuffer with contant equal to input array.
*
* PLEASE NOTE: This method assumes that you'll never ever change values within result DataBuffer
*
* @param array
* @return
*/
@Override
public DataBuffer getConstantBuffer(float[] array) {
// logger.info("getConstantBuffer(float[]) called");
ArrayDescriptor descriptor = new ArrayDescriptor(array);
Integer deviceId = AtomicAllocator.getInstance().getDeviceId();
ensureMaps(deviceId);
if (!buffersCache.get(deviceId).containsKey(descriptor)) {
// we create new databuffer
// logger.info("Creating new constant buffer...");
DataBuffer buffer = Nd4j.createBufferDetached(array);
if (constantOffsets.get(deviceId).get() + (array.length * Nd4j.sizeOfDataType()) < MAX_CONSTANT_LENGTH) {
buffer.setConstant(true);
// now we move data to constant memory, and keep happy
moveToConstantSpace(buffer);
buffersCache.get(deviceId).put(descriptor, buffer);
bytes.addAndGet(array.length * Nd4j.sizeOfDataType());
}
return buffer;
}
return buffersCache.get(deviceId).get(descriptor);
}
use of org.nd4j.linalg.cache.ArrayDescriptor in project nd4j by deeplearning4j.
the class ProtectedCudaConstantHandler method getConstantBuffer.
/**
* This method returns DataBuffer with contant equal to input array.
*
* PLEASE NOTE: This method assumes that you'll never ever change values within result DataBuffer
*
* @param array
* @return
*/
@Override
public DataBuffer getConstantBuffer(double[] array) {
// logger.info("getConstantBuffer(double[]) called: {}", Arrays.toString(array));
ArrayDescriptor descriptor = new ArrayDescriptor(array);
Integer deviceId = AtomicAllocator.getInstance().getDeviceId();
ensureMaps(deviceId);
if (!buffersCache.get(deviceId).containsKey(descriptor)) {
// we create new databuffer
// logger.info("Creating new constant buffer...");
DataBuffer buffer = Nd4j.createBufferDetached(array);
if (constantOffsets.get(deviceId).get() + (array.length * Nd4j.sizeOfDataType()) < MAX_CONSTANT_LENGTH) {
buffer.setConstant(true);
// now we move data to constant memory, and keep happy
moveToConstantSpace(buffer);
buffersCache.get(deviceId).put(descriptor, buffer);
bytes.addAndGet(array.length * Nd4j.sizeOfDataType());
}
return buffer;
}
return buffersCache.get(deviceId).get(descriptor);
}
use of org.nd4j.linalg.cache.ArrayDescriptor in project nd4j by deeplearning4j.
the class BasicTADManagerTest method testArrayDesriptor1.
@Test
public void testArrayDesriptor1() throws Exception {
ArrayDescriptor descriptor1 = new ArrayDescriptor(new int[] { 2, 3, 4 });
ArrayDescriptor descriptor2 = new ArrayDescriptor(new int[] { 2, 4, 3 });
ArrayDescriptor descriptor3 = new ArrayDescriptor(new int[] { 3, 2, 4 });
ArrayDescriptor descriptor4 = new ArrayDescriptor(new int[] { 4, 2, 3 });
ArrayDescriptor descriptor5 = new ArrayDescriptor(new int[] { 4, 3, 2 });
assertNotEquals(descriptor1, descriptor2);
assertNotEquals(descriptor2, descriptor3);
assertNotEquals(descriptor3, descriptor4);
assertNotEquals(descriptor4, descriptor5);
assertNotEquals(descriptor1, descriptor3);
assertNotEquals(descriptor1, descriptor4);
assertNotEquals(descriptor1, descriptor5);
assertNotEquals(descriptor2, descriptor4);
assertNotEquals(descriptor2, descriptor5);
}
Aggregations