Search in sources :

Example 1 with LongBuffer

use of org.nd4j.linalg.api.buffer.LongBuffer in project nd4j by deeplearning4j.

the class CpuTADManager method getTADOnlyShapeInfo.

@Override
public Pair<DataBuffer, DataBuffer> getTADOnlyShapeInfo(INDArray array, int[] dimension) {
    if (dimension != null && dimension.length > 1)
        Arrays.sort(dimension);
    if (dimension == null || dimension.length >= 1 && dimension[0] == Integer.MAX_VALUE) {
        return new Pair<>(array.shapeInfoDataBuffer(), null);
    } else {
        TadDescriptor descriptor = new TadDescriptor(array, dimension);
        if (!cache.containsKey(descriptor)) {
            int dimensionLength = dimension.length;
            // FIXME: this is fast triage, remove it later
            // dimensionLength <= 1 ? 2 : dimensionLength;
            int targetRank = array.rank();
            long offsetLength;
            long tadLength = 1;
            for (int i = 0; i < dimensionLength; i++) {
                tadLength *= array.shape()[dimension[i]];
            }
            offsetLength = array.lengthLong() / tadLength;
            DataBuffer outputBuffer = new IntBuffer(targetRank * 2 + 4);
            DataBuffer offsetsBuffer = new LongBuffer(offsetLength);
            DataBuffer dimensionBuffer = constantHandler.getConstantBuffer(dimension);
            Pointer dimensionPointer = dimensionBuffer.addressPointer();
            Pointer xShapeInfo = array.shapeInfoDataBuffer().addressPointer();
            Pointer targetPointer = outputBuffer.addressPointer();
            Pointer offsetsPointer = offsetsBuffer.addressPointer();
            nativeOps.tadOnlyShapeInfo((IntPointer) xShapeInfo, (IntPointer) dimensionPointer, dimension.length, (IntPointer) targetPointer, new LongPointerWrapper(offsetsPointer));
            // If the line below will be uncommented, shapes from JVM will be used on native side
            // outputBuffer = array.tensorAlongDimension(0, dimension).shapeInfoDataBuffer();
            Pair<DataBuffer, DataBuffer> pair = new Pair<>(outputBuffer, offsetsBuffer);
            if (counter.get() < MAX_ENTRIES) {
                counter.incrementAndGet();
                cache.put(descriptor, pair);
                bytes.addAndGet((outputBuffer.length() * 4) + (offsetsBuffer.length() * 8));
            }
            return pair;
        }
        return cache.get(descriptor);
    }
}
Also used : LongBuffer(org.nd4j.linalg.api.buffer.LongBuffer) IntBuffer(org.nd4j.linalg.api.buffer.IntBuffer) LongPointerWrapper(org.nd4j.nativeblas.LongPointerWrapper) IntPointer(org.bytedeco.javacpp.IntPointer) Pointer(org.bytedeco.javacpp.Pointer) TadDescriptor(org.nd4j.linalg.cache.TadDescriptor) Pair(org.nd4j.linalg.primitives.Pair) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer)

Aggregations

IntPointer (org.bytedeco.javacpp.IntPointer)1 Pointer (org.bytedeco.javacpp.Pointer)1 DataBuffer (org.nd4j.linalg.api.buffer.DataBuffer)1 IntBuffer (org.nd4j.linalg.api.buffer.IntBuffer)1 LongBuffer (org.nd4j.linalg.api.buffer.LongBuffer)1 TadDescriptor (org.nd4j.linalg.cache.TadDescriptor)1 Pair (org.nd4j.linalg.primitives.Pair)1 LongPointerWrapper (org.nd4j.nativeblas.LongPointerWrapper)1