Search in sources :

Example 1 with ShapeDescriptor

use of org.nd4j.linalg.api.shape.ShapeDescriptor in project nd4j by deeplearning4j.

the class DirectShapeInfoProvider method createShapeInformation.

@Override
public Pair<DataBuffer, int[]> createShapeInformation(int[] shape, int[] stride, long offset, int elementWiseStride, char order) {
    // We enforce offset to 0 in shapeBuffer, since we need it for cache efficiency + we don't actually use offset value @ native side
    offset = 0;
    ShapeDescriptor descriptor = new ShapeDescriptor(shape, stride, offset, elementWiseStride, order);
    if (!shapeCache.containsKey(descriptor)) {
        if (counter.get() < MAX_ENTRIES) {
            synchronized (this) {
                if (!shapeCache.containsKey(descriptor)) {
                    counter.incrementAndGet();
                    Pair<DataBuffer, int[]> buffer = super.createShapeInformation(shape, stride, offset, elementWiseStride, order);
                    shapeCache.put(descriptor, buffer);
                    bytes.addAndGet(buffer.getFirst().length() * 4 * 2);
                    return buffer;
                } else
                    return shapeCache.get(descriptor);
            }
        } else {
            return super.createShapeInformation(shape, stride, offset, elementWiseStride, order);
        }
    }
    return shapeCache.get(descriptor);
}
Also used : ShapeDescriptor(org.nd4j.linalg.api.shape.ShapeDescriptor) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer)

Example 2 with ShapeDescriptor

use of org.nd4j.linalg.api.shape.ShapeDescriptor in project nd4j by deeplearning4j.

the class ProtectedCudaShapeInfoProviderTest method testPurge2.

@Test
public void testPurge2() throws Exception {
    INDArray arrayA = Nd4j.create(10, 10);
    DataBuffer shapeInfoA = arrayA.shapeInfoDataBuffer();
    INDArray arrayE = Nd4j.create(10, 10);
    DataBuffer shapeInfoE = arrayE.shapeInfoDataBuffer();
    int[] arrayShapeA = shapeInfoA.asInt();
    assertTrue(shapeInfoA == shapeInfoE);
    ShapeDescriptor descriptor = new ShapeDescriptor(arrayA.shape(), arrayA.stride(), 0, arrayA.elementWiseStride(), arrayA.ordering());
    ConstantProtector protector = ConstantProtector.getInstance();
    AllocationPoint pointA = AtomicAllocator.getInstance().getAllocationPoint(arrayA.shapeInfoDataBuffer());
    assertEquals(true, protector.containsDataBuffer(0, descriptor));
    // //////////////////////////////////
    Nd4j.getMemoryManager().purgeCaches();
    // //////////////////////////////////
    assertEquals(false, protector.containsDataBuffer(0, descriptor));
    INDArray arrayB = Nd4j.create(10, 10);
    DataBuffer shapeInfoB = arrayB.shapeInfoDataBuffer();
    assertFalse(shapeInfoA == shapeInfoB);
    AllocationPoint pointB = AtomicAllocator.getInstance().getAllocationPoint(arrayB.shapeInfoDataBuffer());
    assertArrayEquals(arrayShapeA, shapeInfoB.asInt());
    // pointers should be equal, due to offsets reset
    assertEquals(pointA.getPointers().getDevicePointer().address(), pointB.getPointers().getDevicePointer().address());
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) AllocationPoint(org.nd4j.jita.allocator.impl.AllocationPoint) ShapeDescriptor(org.nd4j.linalg.api.shape.ShapeDescriptor) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer) Test(org.junit.Test)

Example 3 with ShapeDescriptor

use of org.nd4j.linalg.api.shape.ShapeDescriptor in project nd4j by deeplearning4j.

the class ProtectedCudaShapeInfoProviderTest method testPurge1.

@Test
public void testPurge1() throws Exception {
    INDArray array = Nd4j.create(10, 10);
    ProtectedCudaShapeInfoProvider provider = (ProtectedCudaShapeInfoProvider) ProtectedCudaShapeInfoProvider.getInstance();
    assertEquals(true, provider.protector.containsDataBuffer(0, new ShapeDescriptor(array.shape(), array.stride(), 0, array.elementWiseStride(), array.ordering())));
    Nd4j.getMemoryManager().purgeCaches();
    assertEquals(false, provider.protector.containsDataBuffer(0, new ShapeDescriptor(array.shape(), array.stride(), 0, array.elementWiseStride(), array.ordering())));
// INDArray array2 = Nd4j.create(10, 10);
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) ShapeDescriptor(org.nd4j.linalg.api.shape.ShapeDescriptor) Test(org.junit.Test)

Example 4 with ShapeDescriptor

use of org.nd4j.linalg.api.shape.ShapeDescriptor in project nd4j by deeplearning4j.

the class ProtectedCudaShapeInfoProvider method createShapeInformation.

@Override
public Pair<DataBuffer, int[]> createShapeInformation(int[] shape, int[] stride, long offset, int elementWiseStride, char order) {
    // We enforce offset to 0 in shapeBuffer, since we need it for cache efficiency + we don't actually use offset value @ native side
    offset = 0;
    Integer deviceId = AtomicAllocator.getInstance().getDeviceId();
    ShapeDescriptor descriptor = new ShapeDescriptor(shape, stride, offset, elementWiseStride, order);
    if (!protector.containsDataBuffer(deviceId, descriptor)) {
        Pair<DataBuffer, int[]> buffer = null;
        synchronized (this) {
            if (!protector.containsDataBuffer(deviceId, descriptor)) {
                // log.info("Cache miss: {}", descriptor);
                buffer = super.createShapeInformation(shape, stride, offset, elementWiseStride, order);
                buffer.getFirst().setConstant(true);
                if (CudaEnvironment.getInstance().getConfiguration().getMemoryModel() == Configuration.MemoryModel.IMMEDIATE) {
                    Nd4j.getConstantHandler().moveToConstantSpace(buffer.getFirst());
                }
                // deviceCache.get(deviceId).put(descriptor, buffer);
                protector.persistDataBuffer(deviceId, descriptor, buffer);
                bytes.addAndGet(buffer.getFirst().length() * 4 * 2);
                cacheMiss.incrementAndGet();
            } else {
                buffer = protector.getDataBuffer(deviceId, descriptor);
            }
        }
        return buffer;
    } else {
        // log.info("Cache hit: {}", descriptor);
        cacheHit.incrementAndGet();
    }
    // deviceCache.get(deviceId).get(descriptor);
    return protector.getDataBuffer(deviceId, descriptor);
}
Also used : ShapeDescriptor(org.nd4j.linalg.api.shape.ShapeDescriptor) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer)

Aggregations

ShapeDescriptor (org.nd4j.linalg.api.shape.ShapeDescriptor)4 DataBuffer (org.nd4j.linalg.api.buffer.DataBuffer)3 Test (org.junit.Test)2 INDArray (org.nd4j.linalg.api.ndarray.INDArray)2 AllocationPoint (org.nd4j.jita.allocator.impl.AllocationPoint)1