use of org.nd4j.linalg.cache.TADManager in project nd4j by deeplearning4j.
the class CpuNDArrayFactory method pullRows.
/**
* This method produces concatenated array, that consist from tensors, fetched from source array, against some dimension and specified indexes
*
* @param source source tensor
* @param sourceDimension dimension of source tensor
* @param indexes indexes from source array
* @return
*/
@Override
public INDArray pullRows(INDArray source, int sourceDimension, int[] indexes, char order) {
if (indexes == null || indexes.length < 1)
throw new IllegalStateException("Indexes can't be null or zero-length");
int[] shape = null;
if (sourceDimension == 1)
shape = new int[] { indexes.length, source.shape()[sourceDimension] };
else if (sourceDimension == 0)
shape = new int[] { source.shape()[sourceDimension], indexes.length };
else
throw new UnsupportedOperationException("2D input is expected");
INDArray ret = Nd4j.createUninitialized(shape, order);
Nd4j.getCompressor().autoDecompress(source);
PointerPointer dummy = new PointerPointer(new Pointer[] { null });
TADManager tadManager = Nd4j.getExecutioner().getTADManager();
Pair<DataBuffer, DataBuffer> tadBuffers = tadManager.getTADOnlyShapeInfo(source, new int[] { sourceDimension });
Pair<DataBuffer, DataBuffer> zTadBuffers = tadManager.getTADOnlyShapeInfo(ret, new int[] { sourceDimension });
Pointer hostTadShapeInfo = tadBuffers.getFirst().addressPointer();
Pointer zTadShapeInfo = zTadBuffers.getFirst().addressPointer();
IntPointer pIndex = new IntPointer(indexes);
DataBuffer offsets = tadBuffers.getSecond();
Pointer hostTadOffsets = offsets == null ? null : offsets.addressPointer();
DataBuffer zOffsets = zTadBuffers.getSecond();
Pointer zTadOffsets = zOffsets == null ? null : zOffsets.addressPointer();
if (ret.data().dataType() == DataBuffer.Type.DOUBLE) {
nativeOps.pullRowsDouble(dummy, (DoublePointer) source.data().addressPointer(), (IntPointer) source.shapeInfoDataBuffer().addressPointer(), (DoublePointer) ret.data().addressPointer(), (IntPointer) ret.shapeInfoDataBuffer().addressPointer(), indexes.length, pIndex, (IntPointer) hostTadShapeInfo, new LongPointerWrapper(hostTadOffsets), (IntPointer) zTadShapeInfo, new LongPointerWrapper(zTadOffsets));
} else if (ret.data().dataType() == DataBuffer.Type.FLOAT) {
nativeOps.pullRowsFloat(dummy, (FloatPointer) source.data().addressPointer(), (IntPointer) source.shapeInfoDataBuffer().addressPointer(), (FloatPointer) ret.data().addressPointer(), (IntPointer) ret.shapeInfoDataBuffer().addressPointer(), indexes.length, pIndex, (IntPointer) hostTadShapeInfo, new LongPointerWrapper(hostTadOffsets), (IntPointer) zTadShapeInfo, new LongPointerWrapper(zTadOffsets));
} else {
nativeOps.pullRowsHalf(dummy, (ShortPointer) source.data().addressPointer(), (IntPointer) source.shapeInfoDataBuffer().addressPointer(), (ShortPointer) ret.data().addressPointer(), (IntPointer) ret.shapeInfoDataBuffer().addressPointer(), indexes.length, pIndex, (IntPointer) hostTadShapeInfo, new LongPointerWrapper(hostTadOffsets), (IntPointer) zTadShapeInfo, new LongPointerWrapper(zTadOffsets));
}
return ret;
}
use of org.nd4j.linalg.cache.TADManager in project nd4j by deeplearning4j.
the class DelayedMemoryTest method testDelayedTAD1.
@Test
public void testDelayedTAD1() throws Exception {
TADManager tadManager = new DeviceTADManager();
INDArray array = Nd4j.create(128, 256);
Pair<DataBuffer, DataBuffer> tadBuffers = tadManager.getTADOnlyShapeInfo(array, new int[] { 0 });
DataBuffer tadBuffer = tadBuffers.getFirst();
DataBuffer offBuffer = tadBuffers.getSecond();
AllocationPoint pointTad = AtomicAllocator.getInstance().getAllocationPoint(tadBuffer);
AllocationPoint pointOff = AtomicAllocator.getInstance().getAllocationPoint(offBuffer);
assertEquals(AllocationStatus.CONSTANT, pointTad.getAllocationStatus());
assertEquals(AllocationStatus.DEVICE, pointOff.getAllocationStatus());
}
Aggregations