Search in sources :

Example 6 with TADManager

use of org.nd4j.linalg.cache.TADManager in project nd4j by deeplearning4j.

the class CpuNDArrayFactory method pullRows.

/**
 * This method produces concatenated array, that consist from tensors, fetched from source array, against some dimension and specified indexes
 *
 * @param source          source tensor
 * @param sourceDimension dimension of source tensor
 * @param indexes         indexes from source array
 * @return
 */
@Override
public INDArray pullRows(INDArray source, int sourceDimension, int[] indexes, char order) {
    if (indexes == null || indexes.length < 1)
        throw new IllegalStateException("Indexes can't be null or zero-length");
    int[] shape = null;
    if (sourceDimension == 1)
        shape = new int[] { indexes.length, source.shape()[sourceDimension] };
    else if (sourceDimension == 0)
        shape = new int[] { source.shape()[sourceDimension], indexes.length };
    else
        throw new UnsupportedOperationException("2D input is expected");
    INDArray ret = Nd4j.createUninitialized(shape, order);
    Nd4j.getCompressor().autoDecompress(source);
    PointerPointer dummy = new PointerPointer(new Pointer[] { null });
    TADManager tadManager = Nd4j.getExecutioner().getTADManager();
    Pair<DataBuffer, DataBuffer> tadBuffers = tadManager.getTADOnlyShapeInfo(source, new int[] { sourceDimension });
    Pair<DataBuffer, DataBuffer> zTadBuffers = tadManager.getTADOnlyShapeInfo(ret, new int[] { sourceDimension });
    Pointer hostTadShapeInfo = tadBuffers.getFirst().addressPointer();
    Pointer zTadShapeInfo = zTadBuffers.getFirst().addressPointer();
    IntPointer pIndex = new IntPointer(indexes);
    DataBuffer offsets = tadBuffers.getSecond();
    Pointer hostTadOffsets = offsets == null ? null : offsets.addressPointer();
    DataBuffer zOffsets = zTadBuffers.getSecond();
    Pointer zTadOffsets = zOffsets == null ? null : zOffsets.addressPointer();
    if (ret.data().dataType() == DataBuffer.Type.DOUBLE) {
        nativeOps.pullRowsDouble(dummy, (DoublePointer) source.data().addressPointer(), (IntPointer) source.shapeInfoDataBuffer().addressPointer(), (DoublePointer) ret.data().addressPointer(), (IntPointer) ret.shapeInfoDataBuffer().addressPointer(), indexes.length, pIndex, (IntPointer) hostTadShapeInfo, new LongPointerWrapper(hostTadOffsets), (IntPointer) zTadShapeInfo, new LongPointerWrapper(zTadOffsets));
    } else if (ret.data().dataType() == DataBuffer.Type.FLOAT) {
        nativeOps.pullRowsFloat(dummy, (FloatPointer) source.data().addressPointer(), (IntPointer) source.shapeInfoDataBuffer().addressPointer(), (FloatPointer) ret.data().addressPointer(), (IntPointer) ret.shapeInfoDataBuffer().addressPointer(), indexes.length, pIndex, (IntPointer) hostTadShapeInfo, new LongPointerWrapper(hostTadOffsets), (IntPointer) zTadShapeInfo, new LongPointerWrapper(zTadOffsets));
    } else {
        nativeOps.pullRowsHalf(dummy, (ShortPointer) source.data().addressPointer(), (IntPointer) source.shapeInfoDataBuffer().addressPointer(), (ShortPointer) ret.data().addressPointer(), (IntPointer) ret.shapeInfoDataBuffer().addressPointer(), indexes.length, pIndex, (IntPointer) hostTadShapeInfo, new LongPointerWrapper(hostTadOffsets), (IntPointer) zTadShapeInfo, new LongPointerWrapper(zTadOffsets));
    }
    return ret;
}
Also used : ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) INDArray(org.nd4j.linalg.api.ndarray.INDArray) LongPointerWrapper(org.nd4j.nativeblas.LongPointerWrapper) TADManager(org.nd4j.linalg.cache.TADManager) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer) CompressedDataBuffer(org.nd4j.linalg.compression.CompressedDataBuffer)

Example 7 with TADManager

use of org.nd4j.linalg.cache.TADManager in project nd4j by deeplearning4j.

the class DelayedMemoryTest method testDelayedTAD1.

@Test
public void testDelayedTAD1() throws Exception {
    TADManager tadManager = new DeviceTADManager();
    INDArray array = Nd4j.create(128, 256);
    Pair<DataBuffer, DataBuffer> tadBuffers = tadManager.getTADOnlyShapeInfo(array, new int[] { 0 });
    DataBuffer tadBuffer = tadBuffers.getFirst();
    DataBuffer offBuffer = tadBuffers.getSecond();
    AllocationPoint pointTad = AtomicAllocator.getInstance().getAllocationPoint(tadBuffer);
    AllocationPoint pointOff = AtomicAllocator.getInstance().getAllocationPoint(offBuffer);
    assertEquals(AllocationStatus.CONSTANT, pointTad.getAllocationStatus());
    assertEquals(AllocationStatus.DEVICE, pointOff.getAllocationStatus());
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) AllocationPoint(org.nd4j.jita.allocator.impl.AllocationPoint) TADManager(org.nd4j.linalg.cache.TADManager) DeviceTADManager(org.nd4j.jita.allocator.tad.DeviceTADManager) DeviceTADManager(org.nd4j.jita.allocator.tad.DeviceTADManager) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer) Test(org.junit.Test)

Aggregations

DataBuffer (org.nd4j.linalg.api.buffer.DataBuffer)7 INDArray (org.nd4j.linalg.api.ndarray.INDArray)7 TADManager (org.nd4j.linalg.cache.TADManager)7 CompressedDataBuffer (org.nd4j.linalg.compression.CompressedDataBuffer)5 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)4 CudaContext (org.nd4j.linalg.jcublas.context.CudaContext)4 AllocationPoint (org.nd4j.jita.allocator.impl.AllocationPoint)3 AtomicAllocator (org.nd4j.jita.allocator.impl.AtomicAllocator)3 CudaPointer (org.nd4j.jita.allocator.pointers.CudaPointer)3 CudaDoubleDataBuffer (org.nd4j.linalg.jcublas.buffer.CudaDoubleDataBuffer)3 CudaIntDataBuffer (org.nd4j.linalg.jcublas.buffer.CudaIntDataBuffer)3 Test (org.junit.Test)2 GridExecutioner (org.nd4j.linalg.api.ops.executioner.GridExecutioner)2 LongPointerWrapper (org.nd4j.nativeblas.LongPointerWrapper)2 DeviceTADManager (org.nd4j.jita.allocator.tad.DeviceTADManager)1 Pair (org.nd4j.linalg.primitives.Pair)1