use of org.nd4j.linalg.jcublas.JCublasNDArray in project nd4j by deeplearning4j.
the class SynchronousFlowController method prepareDelayedMemory.
protected void prepareDelayedMemory(INDArray array) {
if (configuration.getMemoryModel() == Configuration.MemoryModel.DELAYED) {
AllocationPoint pointData = allocator.getAllocationPoint(array.shapeInfoDataBuffer());
AllocationPoint pointShape = allocator.getAllocationPoint(array.shapeInfoDataBuffer());
if (pointData.getAllocationStatus() != AllocationStatus.DEVICE)
prepareDelayedMemory(array.data());
if (pointShape.getAllocationStatus() == AllocationStatus.HOST) {
DataBuffer oShape = array.shapeInfoDataBuffer();
DataBuffer nShape = Nd4j.getConstantHandler().relocateConstantSpace(oShape);
if (nShape == oShape)
Nd4j.getConstantHandler().moveToConstantSpace(nShape);
((JCublasNDArray) array).setShapeInfoDataBuffer(nShape);
}
}
}
use of org.nd4j.linalg.jcublas.JCublasNDArray in project nd4j by deeplearning4j.
the class SynchronousFlowController method prepareActionAllWrite.
@Override
public CudaContext prepareActionAllWrite(INDArray... operands) {
CudaContext context = (CudaContext) allocator.getDeviceContext().getContext();
int cId = allocator.getDeviceId();
for (INDArray operand : operands) {
if (operand == null)
continue;
Nd4j.getCompressor().autoDecompress(operand);
AllocationPoint pointData = allocator.getAllocationPoint(operand);
AllocationPoint pointShape = allocator.getAllocationPoint(operand.shapeInfoDataBuffer());
pointData.acquireLock();
if (pointData.getDeviceId() != cId && pointData.getDeviceId() >= 0) {
DataBuffer buffer = operand.data().originalDataBuffer() == null ? operand.data() : operand.data().originalDataBuffer();
allocator.getMemoryHandler().relocateObject(buffer);
}
if (pointShape.getDeviceId() != cId && pointShape.getDeviceId() >= 0) {
((JCublasNDArray) operand).setShapeInfoDataBuffer(Nd4j.getConstantHandler().relocateConstantSpace(operand.shapeInfoDataBuffer()));
}
prepareDelayedMemory(operand);
allocator.getAllocationPoint(operand).setCurrentContext(context);
}
return context;
}
use of org.nd4j.linalg.jcublas.JCublasNDArray in project nd4j by deeplearning4j.
the class SynchronousFlowController method prepareAction.
@Override
public CudaContext prepareAction(INDArray result, INDArray... operands) {
CudaContext context = (CudaContext) allocator.getDeviceContext().getContext();
int cId = allocator.getDeviceId();
if (result != null) {
Nd4j.getCompressor().autoDecompress(result);
prepareDelayedMemory(result);
AllocationPoint pointData = allocator.getAllocationPoint(result);
AllocationPoint pointShape = allocator.getAllocationPoint(result.shapeInfoDataBuffer());
pointData.acquireLock();
if (pointData.getDeviceId() != cId && pointData.getDeviceId() >= 0 && (!CudaEnvironment.getInstance().getConfiguration().isCrossDeviceAccessAllowed() || !NativeOpsHolder.getInstance().getDeviceNativeOps().isP2PAvailable())) {
DataBuffer buffer = result.data().originalDataBuffer() == null ? result.data() : result.data().originalDataBuffer();
allocator.getMemoryHandler().relocateObject(buffer);
}
if (pointShape.getDeviceId() != cId && pointShape.getDeviceId() >= 0) {
((JCublasNDArray) result).setShapeInfoDataBuffer(Nd4j.getConstantHandler().relocateConstantSpace(result.shapeInfoDataBuffer()));
}
allocator.getAllocationPoint(result).setCurrentContext(context);
}
for (INDArray operand : operands) {
if (operand == null)
continue;
Nd4j.getCompressor().autoDecompress(operand);
AllocationPoint pointData = allocator.getAllocationPoint(operand);
AllocationPoint pointShape = allocator.getAllocationPoint(operand.shapeInfoDataBuffer());
pointData.acquireLock();
if (pointData.getDeviceId() != cId && pointData.getDeviceId() >= 0 && (!CudaEnvironment.getInstance().getConfiguration().isCrossDeviceAccessAllowed() || !NativeOpsHolder.getInstance().getDeviceNativeOps().isP2PAvailable())) {
DataBuffer buffer = operand.data().originalDataBuffer() == null ? operand.data() : operand.data().originalDataBuffer();
allocator.getMemoryHandler().relocateObject(buffer);
}
if (pointShape.getDeviceId() != cId && pointShape.getDeviceId() >= 0) {
((JCublasNDArray) operand).setShapeInfoDataBuffer(Nd4j.getConstantHandler().relocateConstantSpace(operand.shapeInfoDataBuffer()));
}
prepareDelayedMemory(operand);
allocator.getAllocationPoint(operand).setCurrentContext(context);
}
return context;
}
Aggregations