use of org.nd4j.linalg.api.shape.ShapeDescriptor in project nd4j by deeplearning4j.
the class DirectShapeInfoProvider method createShapeInformation.
@Override
public Pair<DataBuffer, int[]> createShapeInformation(int[] shape, int[] stride, long offset, int elementWiseStride, char order) {
// We enforce offset to 0 in shapeBuffer, since we need it for cache efficiency + we don't actually use offset value @ native side
offset = 0;
ShapeDescriptor descriptor = new ShapeDescriptor(shape, stride, offset, elementWiseStride, order);
if (!shapeCache.containsKey(descriptor)) {
if (counter.get() < MAX_ENTRIES) {
synchronized (this) {
if (!shapeCache.containsKey(descriptor)) {
counter.incrementAndGet();
Pair<DataBuffer, int[]> buffer = super.createShapeInformation(shape, stride, offset, elementWiseStride, order);
shapeCache.put(descriptor, buffer);
bytes.addAndGet(buffer.getFirst().length() * 4 * 2);
return buffer;
} else
return shapeCache.get(descriptor);
}
} else {
return super.createShapeInformation(shape, stride, offset, elementWiseStride, order);
}
}
return shapeCache.get(descriptor);
}
use of org.nd4j.linalg.api.shape.ShapeDescriptor in project nd4j by deeplearning4j.
the class ProtectedCudaShapeInfoProviderTest method testPurge2.
@Test
public void testPurge2() throws Exception {
INDArray arrayA = Nd4j.create(10, 10);
DataBuffer shapeInfoA = arrayA.shapeInfoDataBuffer();
INDArray arrayE = Nd4j.create(10, 10);
DataBuffer shapeInfoE = arrayE.shapeInfoDataBuffer();
int[] arrayShapeA = shapeInfoA.asInt();
assertTrue(shapeInfoA == shapeInfoE);
ShapeDescriptor descriptor = new ShapeDescriptor(arrayA.shape(), arrayA.stride(), 0, arrayA.elementWiseStride(), arrayA.ordering());
ConstantProtector protector = ConstantProtector.getInstance();
AllocationPoint pointA = AtomicAllocator.getInstance().getAllocationPoint(arrayA.shapeInfoDataBuffer());
assertEquals(true, protector.containsDataBuffer(0, descriptor));
// //////////////////////////////////
Nd4j.getMemoryManager().purgeCaches();
// //////////////////////////////////
assertEquals(false, protector.containsDataBuffer(0, descriptor));
INDArray arrayB = Nd4j.create(10, 10);
DataBuffer shapeInfoB = arrayB.shapeInfoDataBuffer();
assertFalse(shapeInfoA == shapeInfoB);
AllocationPoint pointB = AtomicAllocator.getInstance().getAllocationPoint(arrayB.shapeInfoDataBuffer());
assertArrayEquals(arrayShapeA, shapeInfoB.asInt());
// pointers should be equal, due to offsets reset
assertEquals(pointA.getPointers().getDevicePointer().address(), pointB.getPointers().getDevicePointer().address());
}
use of org.nd4j.linalg.api.shape.ShapeDescriptor in project nd4j by deeplearning4j.
the class ProtectedCudaShapeInfoProviderTest method testPurge1.
@Test
public void testPurge1() throws Exception {
INDArray array = Nd4j.create(10, 10);
ProtectedCudaShapeInfoProvider provider = (ProtectedCudaShapeInfoProvider) ProtectedCudaShapeInfoProvider.getInstance();
assertEquals(true, provider.protector.containsDataBuffer(0, new ShapeDescriptor(array.shape(), array.stride(), 0, array.elementWiseStride(), array.ordering())));
Nd4j.getMemoryManager().purgeCaches();
assertEquals(false, provider.protector.containsDataBuffer(0, new ShapeDescriptor(array.shape(), array.stride(), 0, array.elementWiseStride(), array.ordering())));
// INDArray array2 = Nd4j.create(10, 10);
}
use of org.nd4j.linalg.api.shape.ShapeDescriptor in project nd4j by deeplearning4j.
the class ProtectedCudaShapeInfoProvider method createShapeInformation.
@Override
public Pair<DataBuffer, int[]> createShapeInformation(int[] shape, int[] stride, long offset, int elementWiseStride, char order) {
// We enforce offset to 0 in shapeBuffer, since we need it for cache efficiency + we don't actually use offset value @ native side
offset = 0;
Integer deviceId = AtomicAllocator.getInstance().getDeviceId();
ShapeDescriptor descriptor = new ShapeDescriptor(shape, stride, offset, elementWiseStride, order);
if (!protector.containsDataBuffer(deviceId, descriptor)) {
Pair<DataBuffer, int[]> buffer = null;
synchronized (this) {
if (!protector.containsDataBuffer(deviceId, descriptor)) {
// log.info("Cache miss: {}", descriptor);
buffer = super.createShapeInformation(shape, stride, offset, elementWiseStride, order);
buffer.getFirst().setConstant(true);
if (CudaEnvironment.getInstance().getConfiguration().getMemoryModel() == Configuration.MemoryModel.IMMEDIATE) {
Nd4j.getConstantHandler().moveToConstantSpace(buffer.getFirst());
}
// deviceCache.get(deviceId).put(descriptor, buffer);
protector.persistDataBuffer(deviceId, descriptor, buffer);
bytes.addAndGet(buffer.getFirst().length() * 4 * 2);
cacheMiss.incrementAndGet();
} else {
buffer = protector.getDataBuffer(deviceId, descriptor);
}
}
return buffer;
} else {
// log.info("Cache hit: {}", descriptor);
cacheHit.incrementAndGet();
}
// deviceCache.get(deviceId).get(descriptor);
return protector.getDataBuffer(deviceId, descriptor);
}
Aggregations