Search in sources :

Example 1 with TadDescriptor

use of org.nd4j.linalg.cache.TadDescriptor in project nd4j by deeplearning4j.

the class CpuTADManager method getTADOnlyShapeInfo.

@Override
public Pair<DataBuffer, DataBuffer> getTADOnlyShapeInfo(INDArray array, int[] dimension) {
    if (dimension != null && dimension.length > 1)
        Arrays.sort(dimension);
    if (dimension == null || dimension.length >= 1 && dimension[0] == Integer.MAX_VALUE) {
        return new Pair<>(array.shapeInfoDataBuffer(), null);
    } else {
        TadDescriptor descriptor = new TadDescriptor(array, dimension);
        if (!cache.containsKey(descriptor)) {
            int dimensionLength = dimension.length;
            // FIXME: this is fast triage, remove it later
            // dimensionLength <= 1 ? 2 : dimensionLength;
            int targetRank = array.rank();
            long offsetLength;
            long tadLength = 1;
            for (int i = 0; i < dimensionLength; i++) {
                tadLength *= array.shape()[dimension[i]];
            }
            offsetLength = array.lengthLong() / tadLength;
            DataBuffer outputBuffer = new IntBuffer(targetRank * 2 + 4);
            DataBuffer offsetsBuffer = new LongBuffer(offsetLength);
            DataBuffer dimensionBuffer = constantHandler.getConstantBuffer(dimension);
            Pointer dimensionPointer = dimensionBuffer.addressPointer();
            Pointer xShapeInfo = array.shapeInfoDataBuffer().addressPointer();
            Pointer targetPointer = outputBuffer.addressPointer();
            Pointer offsetsPointer = offsetsBuffer.addressPointer();
            nativeOps.tadOnlyShapeInfo((IntPointer) xShapeInfo, (IntPointer) dimensionPointer, dimension.length, (IntPointer) targetPointer, new LongPointerWrapper(offsetsPointer));
            // If the line below will be uncommented, shapes from JVM will be used on native side
            // outputBuffer = array.tensorAlongDimension(0, dimension).shapeInfoDataBuffer();
            Pair<DataBuffer, DataBuffer> pair = new Pair<>(outputBuffer, offsetsBuffer);
            if (counter.get() < MAX_ENTRIES) {
                counter.incrementAndGet();
                cache.put(descriptor, pair);
                bytes.addAndGet((outputBuffer.length() * 4) + (offsetsBuffer.length() * 8));
            }
            return pair;
        }
        return cache.get(descriptor);
    }
}
Also used : LongBuffer(org.nd4j.linalg.api.buffer.LongBuffer) IntBuffer(org.nd4j.linalg.api.buffer.IntBuffer) LongPointerWrapper(org.nd4j.nativeblas.LongPointerWrapper) IntPointer(org.bytedeco.javacpp.IntPointer) Pointer(org.bytedeco.javacpp.Pointer) TadDescriptor(org.nd4j.linalg.cache.TadDescriptor) Pair(org.nd4j.linalg.primitives.Pair) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer)

Example 2 with TadDescriptor

use of org.nd4j.linalg.cache.TadDescriptor in project nd4j by deeplearning4j.

the class DeviceTADManager method getTADOnlyShapeInfo.

@Override
public Pair<DataBuffer, DataBuffer> getTADOnlyShapeInfo(INDArray array, int[] dimension) {
    /*
            so, we check, if we have things cached.
            If we don't - we just create new TAD shape, and push it to constant memory
        */
    if (dimension != null && dimension.length > 1)
        Arrays.sort(dimension);
    Integer deviceId = AtomicAllocator.getInstance().getDeviceId();
    // log.info("Requested TAD for device [{}], dimensions: [{}]", deviceId, Arrays.toString(dimension));
    // extract the dimensions and shape buffer for comparison
    TadDescriptor descriptor = new TadDescriptor(array, dimension);
    if (!tadCache.get(deviceId).containsKey(descriptor)) {
        log.trace("Creating new TAD...");
        // create the TAD with the shape information and corresponding offsets
        // note that we use native code to get access to the shape information.
        Pair<DataBuffer, DataBuffer> buffers = super.getTADOnlyShapeInfo(array, dimension);
        /**
         * Store the buffers in constant memory.
         * The main implementation of this is cuda right now.
         *
         * Explanation from: http://cuda-programming.blogspot.jp/2013/01/what-is-constant-memory-in-cuda.html
         * The CUDA language makes available another kind of memory known as constant memory. As the opName may indicate, we use constant memory for data that will not change over the course of a kernel execution.
         *
         *             Why Constant Memory?
         *
         *             NVIDIA hardware provides 64KB of constant memory that
         *             it treats differently than it treats standard global memory. In some situations,
         *             using constant memory rather than global memory will reduce the required memory bandwidth.
         *
         *             NOTE HERE FOR US: We use 48kb of it using these methods.
         *
         *             Note also that we use the {@link AtomicAllocator} which is the cuda memory manager
         *             for moving the current host space data buffer to constant memory.
         *
         *             We do this for device access to shape information.
         */
        if (buffers.getFirst() != array.shapeInfoDataBuffer())
            AtomicAllocator.getInstance().moveToConstant(buffers.getFirst());
        /**
         * @see {@link org.nd4j.jita.constant.ProtectedCudaConstantHandler}
         */
        if (buffers.getSecond() != null)
            AtomicAllocator.getInstance().moveToConstant(buffers.getSecond());
        // so, at this point we have buffer valid on host side.
        // And we just need to replace DevicePointer with constant pointer
        tadCache.get(deviceId).put(descriptor, buffers);
        bytes.addAndGet((buffers.getFirst().length() * 4));
        if (buffers.getSecond() != null)
            bytes.addAndGet(buffers.getSecond().length() * 8);
        log.trace("Using TAD from cache...");
    }
    return tadCache.get(deviceId).get(descriptor);
}
Also used : TadDescriptor(org.nd4j.linalg.cache.TadDescriptor) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer)

Aggregations

DataBuffer (org.nd4j.linalg.api.buffer.DataBuffer)2 TadDescriptor (org.nd4j.linalg.cache.TadDescriptor)2 IntPointer (org.bytedeco.javacpp.IntPointer)1 Pointer (org.bytedeco.javacpp.Pointer)1 IntBuffer (org.nd4j.linalg.api.buffer.IntBuffer)1 LongBuffer (org.nd4j.linalg.api.buffer.LongBuffer)1 Pair (org.nd4j.linalg.primitives.Pair)1 LongPointerWrapper (org.nd4j.nativeblas.LongPointerWrapper)1