Search in sources :

Example 6 with CudaDoubleDataBuffer

use of org.nd4j.linalg.jcublas.buffer.CudaDoubleDataBuffer in project nd4j by deeplearning4j.

the class JCublasNDArrayFactory method convertDataEx.

/*
    public DataBuffer convertToHalfs(DataBuffer buffer) {
        DataBuffer halfsBuffer = new CudaHalfDataBuffer(buffer.length());
    
        AtomicAllocator allocator = AtomicAllocator.getInstance();
    
        AllocationPoint pointSrc = allocator.getAllocationPoint(buffer);
        AllocationPoint pointDst = allocator.getAllocationPoint(halfsBuffer);
    
        CudaContext context =  allocator.getFlowController().prepareAction(pointDst, pointSrc);
    
        PointerPointer extras = new PointerPointer(
                null, // not used for conversion
                context.getOldStream(),
                AtomicAllocator.getInstance().getDeviceIdPointer());
    
        Pointer x = AtomicAllocator.getInstance().getPointer(buffer, context);
        Pointer z = AtomicAllocator.getInstance().getPointer(halfsBuffer, context);
    
        if (buffer.dataType() == DataBuffer.Type.FLOAT) {
            NativeOpsHolder.getInstance().getDeviceNativeOps().convertFloatsToHalfs(extras, x, (int) buffer.length(), z);
            pointDst.tickDeviceWrite();
        } else if (buffer.dataType() == DataBuffer.Type.DOUBLE) {
            NativeOpsHolder.getInstance().getDeviceNativeOps().convertDoublesToHalfs(extras, x, (int) buffer.length(), z);
            pointDst.tickDeviceWrite();
        } else if (buffer.dataType() == DataBuffer.Type.HALF) {
            log.info("Buffer is already HALF-precision");
            return buffer;
        } else {
            throw new UnsupportedOperationException("Conversion INT->HALF isn't supported yet.");
        }
    
        allocator.getFlowController().registerAction(context, pointDst, pointSrc);
    
        return halfsBuffer;
    }
    
    public DataBuffer restoreFromHalfs(DataBuffer buffer) {
        if (buffer.dataType() != DataBuffer.Type.HALF)
            throw new IllegalStateException("Input DataBuffer should contain Halfs");
    
        DataBuffer outputBuffer = null;
    
    
    
        if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
            outputBuffer = new CudaFloatDataBuffer(buffer.length());
    
        } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
            outputBuffer = new CudaDoubleDataBuffer(buffer.length());
    
        } else throw new UnsupportedOperationException("DataType ["+Nd4j.dataType()+"] isn't supported yet");
    
        AtomicAllocator allocator = AtomicAllocator.getInstance();
    
        AllocationPoint pointSrc = allocator.getAllocationPoint(buffer);
        AllocationPoint pointDst = allocator.getAllocationPoint(outputBuffer);
    
        CudaContext context =  allocator.getFlowController().prepareAction(pointDst, pointSrc);
    
        PointerPointer extras = new PointerPointer(
                null, // not used for conversion
                context.getOldStream(),
                AtomicAllocator.getInstance().getDeviceIdPointer());
    
        Pointer x = AtomicAllocator.getInstance().getPointer(buffer, context);
        Pointer z = AtomicAllocator.getInstance().getPointer(outputBuffer, context);
    
        if (Nd4j.dataType() == DataBuffer.Type.FLOAT) {
            NativeOpsHolder.getInstance().getDeviceNativeOps().convertHalfsToFloats(extras, x, (int) buffer.length(), z);
            pointDst.tickDeviceWrite();
        } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
            NativeOpsHolder.getInstance().getDeviceNativeOps().convertHalfsToDoubles(extras, x, (int) buffer.length(), z);
            pointDst.tickDeviceWrite();
        } else if (Nd4j.dataType() == DataBuffer.Type.HALF) {
            log.info("Buffer is already HALF-precision");
            return buffer;
        }
    
        allocator.getFlowController().registerAction(context, pointDst, pointSrc);
    
        return outputBuffer;
    }
    */
/**
 * This method converts Single/Double precision databuffer to Half-precision databuffer
 *
 * @param typeSrc
 * @param source
 * @param typeDst @return
 */
@Override
public INDArray convertDataEx(DataBuffer.TypeEx typeSrc, INDArray source, DataBuffer.TypeEx typeDst) {
    if (source.isView())
        throw new UnsupportedOperationException("Impossible to compress View. Consider using dup() before. ");
    DataBuffer buffer = convertDataEx(typeSrc, source.data(), typeDst);
    source.setData(buffer);
    if (buffer instanceof CompressedDataBuffer)
        source.markAsCompressed(true);
    else
        source.markAsCompressed(false);
    return source;
}
Also used : CompressedDataBuffer(org.nd4j.linalg.compression.CompressedDataBuffer) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer) CudaIntDataBuffer(org.nd4j.linalg.jcublas.buffer.CudaIntDataBuffer) CompressedDataBuffer(org.nd4j.linalg.compression.CompressedDataBuffer) CudaDoubleDataBuffer(org.nd4j.linalg.jcublas.buffer.CudaDoubleDataBuffer)

Example 7 with CudaDoubleDataBuffer

use of org.nd4j.linalg.jcublas.buffer.CudaDoubleDataBuffer in project nd4j by deeplearning4j.

the class AllocationUtils method getPointersBuffer.

public static DataBuffer getPointersBuffer(long[] pointers) {
    CudaDoubleDataBuffer tempX = new CudaDoubleDataBuffer(pointers.length);
    AtomicAllocator.getInstance().memcpyBlocking(tempX, new LongPointer(pointers), pointers.length * 8, 0);
    return tempX;
}
Also used : LongPointer(org.bytedeco.javacpp.LongPointer) CudaDoubleDataBuffer(org.nd4j.linalg.jcublas.buffer.CudaDoubleDataBuffer)

Aggregations

CudaDoubleDataBuffer (org.nd4j.linalg.jcublas.buffer.CudaDoubleDataBuffer)7 AllocationPoint (org.nd4j.jita.allocator.impl.AllocationPoint)5 CudaContext (org.nd4j.linalg.jcublas.context.CudaContext)5 AtomicAllocator (org.nd4j.jita.allocator.impl.AtomicAllocator)4 CudaPointer (org.nd4j.jita.allocator.pointers.CudaPointer)4 DataBuffer (org.nd4j.linalg.api.buffer.DataBuffer)4 CompressedDataBuffer (org.nd4j.linalg.compression.CompressedDataBuffer)4 CudaIntDataBuffer (org.nd4j.linalg.jcublas.buffer.CudaIntDataBuffer)4 INDArray (org.nd4j.linalg.api.ndarray.INDArray)3 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)3 TADManager (org.nd4j.linalg.cache.TADManager)2 LongPointer (org.bytedeco.javacpp.LongPointer)1 GridExecutioner (org.nd4j.linalg.api.ops.executioner.GridExecutioner)1 LongPointerWrapper (org.nd4j.nativeblas.LongPointerWrapper)1