Search in sources :

Example 6 with BaseCudaDataBuffer

use of org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer in project nd4j by deeplearning4j.

the class CudaZeroHandler method getHostPointer.

/**
 * PLEASE NOTE: This method always returns pointer within OS memory space
 *
 * @param buffer
 * @return
 */
@Override
public org.bytedeco.javacpp.Pointer getHostPointer(DataBuffer buffer) {
    AllocationPoint dstPoint = ((BaseCudaDataBuffer) buffer).getAllocationPoint();
    // return pointer with offset if needed. length is specified for constructor compatibility purposes
    if (dstPoint.getPointers().getHostPointer() == null) {
        log.info("DevicePointer: " + dstPoint.getPointers().getDevicePointer());
        log.info("HostPointer: " + dstPoint.getPointers().getHostPointer());
        log.info("AllocStatus: " + dstPoint.getAllocationStatus());
        throw new RuntimeException("pointer is null");
    }
    // dstPoint.tickHostWrite();
    // dstPoint.tickHostRead();
    // log.info("Requesting host pointer for {}", buffer);
    // getCudaContext().syncOldStream();
    synchronizeThreadDevice(Thread.currentThread().getId(), dstPoint.getDeviceId(), dstPoint);
    CudaPointer p = new CudaPointer(dstPoint.getPointers().getHostPointer(), buffer.length(), (buffer.offset() * buffer.getElementSize()));
    switch(buffer.dataType()) {
        case DOUBLE:
            return p.asDoublePointer();
        case FLOAT:
            return p.asFloatPointer();
        case INT:
            return p.asIntPointer();
        case HALF:
            return p.asShortPointer();
        default:
            return p;
    }
}
Also used : BaseCudaDataBuffer(org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer) AllocationPoint(org.nd4j.jita.allocator.impl.AllocationPoint) CudaPointer(org.nd4j.jita.allocator.pointers.CudaPointer)

Example 7 with BaseCudaDataBuffer

use of org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer in project nd4j by deeplearning4j.

the class CudaZeroHandler method memcpyDevice.

@Override
public void memcpyDevice(DataBuffer dstBuffer, Pointer srcPointer, long length, long dstOffset, CudaContext context) {
    // log.info("Memcpy device: {} bytes ", length);
    AllocationPoint point = ((BaseCudaDataBuffer) dstBuffer).getAllocationPoint();
    Pointer dP = new CudaPointer((point.getPointers().getDevicePointer().address()) + dstOffset);
    if (nativeOps.memcpyAsync(dP, srcPointer, length, CudaConstants.cudaMemcpyDeviceToDevice, context.getOldStream()) == 0)
        throw new ND4JIllegalStateException("memcpyAsync failed");
    point.tickDeviceWrite();
}
Also used : BaseCudaDataBuffer(org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer) CudaPointer(org.nd4j.jita.allocator.pointers.CudaPointer) Pointer(org.bytedeco.javacpp.Pointer) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) AllocationPoint(org.nd4j.jita.allocator.impl.AllocationPoint) CudaPointer(org.nd4j.jita.allocator.pointers.CudaPointer)

Aggregations

AllocationPoint (org.nd4j.jita.allocator.impl.AllocationPoint)7 BaseCudaDataBuffer (org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer)7 CudaPointer (org.nd4j.jita.allocator.pointers.CudaPointer)6 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)5 Pointer (org.bytedeco.javacpp.Pointer)4 CudaContext (org.nd4j.linalg.jcublas.context.CudaContext)4 MemoryWorkspace (org.nd4j.linalg.api.memory.MemoryWorkspace)1