use of org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer in project nd4j by deeplearning4j.
the class CudaZeroHandler method getHostPointer.
/**
* PLEASE NOTE: This method always returns pointer within OS memory space
*
* @param buffer
* @return
*/
@Override
public org.bytedeco.javacpp.Pointer getHostPointer(DataBuffer buffer) {
AllocationPoint dstPoint = ((BaseCudaDataBuffer) buffer).getAllocationPoint();
// return pointer with offset if needed. length is specified for constructor compatibility purposes
if (dstPoint.getPointers().getHostPointer() == null) {
log.info("DevicePointer: " + dstPoint.getPointers().getDevicePointer());
log.info("HostPointer: " + dstPoint.getPointers().getHostPointer());
log.info("AllocStatus: " + dstPoint.getAllocationStatus());
throw new RuntimeException("pointer is null");
}
// dstPoint.tickHostWrite();
// dstPoint.tickHostRead();
// log.info("Requesting host pointer for {}", buffer);
// getCudaContext().syncOldStream();
synchronizeThreadDevice(Thread.currentThread().getId(), dstPoint.getDeviceId(), dstPoint);
CudaPointer p = new CudaPointer(dstPoint.getPointers().getHostPointer(), buffer.length(), (buffer.offset() * buffer.getElementSize()));
switch(buffer.dataType()) {
case DOUBLE:
return p.asDoublePointer();
case FLOAT:
return p.asFloatPointer();
case INT:
return p.asIntPointer();
case HALF:
return p.asShortPointer();
default:
return p;
}
}
use of org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer in project nd4j by deeplearning4j.
the class CudaZeroHandler method memcpyDevice.
@Override
public void memcpyDevice(DataBuffer dstBuffer, Pointer srcPointer, long length, long dstOffset, CudaContext context) {
// log.info("Memcpy device: {} bytes ", length);
AllocationPoint point = ((BaseCudaDataBuffer) dstBuffer).getAllocationPoint();
Pointer dP = new CudaPointer((point.getPointers().getDevicePointer().address()) + dstOffset);
if (nativeOps.memcpyAsync(dP, srcPointer, length, CudaConstants.cudaMemcpyDeviceToDevice, context.getOldStream()) == 0)
throw new ND4JIllegalStateException("memcpyAsync failed");
point.tickDeviceWrite();
}
Aggregations