Search in sources :

Example 1 with JCublasNDArray

use of org.nd4j.linalg.jcublas.JCublasNDArray in project nd4j by deeplearning4j.

the class SynchronousFlowController method prepareDelayedMemory.

protected void prepareDelayedMemory(INDArray array) {
    if (configuration.getMemoryModel() == Configuration.MemoryModel.DELAYED) {
        AllocationPoint pointData = allocator.getAllocationPoint(array.shapeInfoDataBuffer());
        AllocationPoint pointShape = allocator.getAllocationPoint(array.shapeInfoDataBuffer());
        if (pointData.getAllocationStatus() != AllocationStatus.DEVICE)
            prepareDelayedMemory(array.data());
        if (pointShape.getAllocationStatus() == AllocationStatus.HOST) {
            DataBuffer oShape = array.shapeInfoDataBuffer();
            DataBuffer nShape = Nd4j.getConstantHandler().relocateConstantSpace(oShape);
            if (nShape == oShape)
                Nd4j.getConstantHandler().moveToConstantSpace(nShape);
            ((JCublasNDArray) array).setShapeInfoDataBuffer(nShape);
        }
    }
}
Also used : JCublasNDArray(org.nd4j.linalg.jcublas.JCublasNDArray) AllocationPoint(org.nd4j.jita.allocator.impl.AllocationPoint) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer)

Example 2 with JCublasNDArray

use of org.nd4j.linalg.jcublas.JCublasNDArray in project nd4j by deeplearning4j.

the class SynchronousFlowController method prepareActionAllWrite.

@Override
public CudaContext prepareActionAllWrite(INDArray... operands) {
    CudaContext context = (CudaContext) allocator.getDeviceContext().getContext();
    int cId = allocator.getDeviceId();
    for (INDArray operand : operands) {
        if (operand == null)
            continue;
        Nd4j.getCompressor().autoDecompress(operand);
        AllocationPoint pointData = allocator.getAllocationPoint(operand);
        AllocationPoint pointShape = allocator.getAllocationPoint(operand.shapeInfoDataBuffer());
        pointData.acquireLock();
        if (pointData.getDeviceId() != cId && pointData.getDeviceId() >= 0) {
            DataBuffer buffer = operand.data().originalDataBuffer() == null ? operand.data() : operand.data().originalDataBuffer();
            allocator.getMemoryHandler().relocateObject(buffer);
        }
        if (pointShape.getDeviceId() != cId && pointShape.getDeviceId() >= 0) {
            ((JCublasNDArray) operand).setShapeInfoDataBuffer(Nd4j.getConstantHandler().relocateConstantSpace(operand.shapeInfoDataBuffer()));
        }
        prepareDelayedMemory(operand);
        allocator.getAllocationPoint(operand).setCurrentContext(context);
    }
    return context;
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) JCublasNDArray(org.nd4j.linalg.jcublas.JCublasNDArray) CudaContext(org.nd4j.linalg.jcublas.context.CudaContext) AllocationPoint(org.nd4j.jita.allocator.impl.AllocationPoint) AllocationPoint(org.nd4j.jita.allocator.impl.AllocationPoint) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer)

Example 3 with JCublasNDArray

use of org.nd4j.linalg.jcublas.JCublasNDArray in project nd4j by deeplearning4j.

the class SynchronousFlowController method prepareAction.

@Override
public CudaContext prepareAction(INDArray result, INDArray... operands) {
    CudaContext context = (CudaContext) allocator.getDeviceContext().getContext();
    int cId = allocator.getDeviceId();
    if (result != null) {
        Nd4j.getCompressor().autoDecompress(result);
        prepareDelayedMemory(result);
        AllocationPoint pointData = allocator.getAllocationPoint(result);
        AllocationPoint pointShape = allocator.getAllocationPoint(result.shapeInfoDataBuffer());
        pointData.acquireLock();
        if (pointData.getDeviceId() != cId && pointData.getDeviceId() >= 0 && (!CudaEnvironment.getInstance().getConfiguration().isCrossDeviceAccessAllowed() || !NativeOpsHolder.getInstance().getDeviceNativeOps().isP2PAvailable())) {
            DataBuffer buffer = result.data().originalDataBuffer() == null ? result.data() : result.data().originalDataBuffer();
            allocator.getMemoryHandler().relocateObject(buffer);
        }
        if (pointShape.getDeviceId() != cId && pointShape.getDeviceId() >= 0) {
            ((JCublasNDArray) result).setShapeInfoDataBuffer(Nd4j.getConstantHandler().relocateConstantSpace(result.shapeInfoDataBuffer()));
        }
        allocator.getAllocationPoint(result).setCurrentContext(context);
    }
    for (INDArray operand : operands) {
        if (operand == null)
            continue;
        Nd4j.getCompressor().autoDecompress(operand);
        AllocationPoint pointData = allocator.getAllocationPoint(operand);
        AllocationPoint pointShape = allocator.getAllocationPoint(operand.shapeInfoDataBuffer());
        pointData.acquireLock();
        if (pointData.getDeviceId() != cId && pointData.getDeviceId() >= 0 && (!CudaEnvironment.getInstance().getConfiguration().isCrossDeviceAccessAllowed() || !NativeOpsHolder.getInstance().getDeviceNativeOps().isP2PAvailable())) {
            DataBuffer buffer = operand.data().originalDataBuffer() == null ? operand.data() : operand.data().originalDataBuffer();
            allocator.getMemoryHandler().relocateObject(buffer);
        }
        if (pointShape.getDeviceId() != cId && pointShape.getDeviceId() >= 0) {
            ((JCublasNDArray) operand).setShapeInfoDataBuffer(Nd4j.getConstantHandler().relocateConstantSpace(operand.shapeInfoDataBuffer()));
        }
        prepareDelayedMemory(operand);
        allocator.getAllocationPoint(operand).setCurrentContext(context);
    }
    return context;
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) JCublasNDArray(org.nd4j.linalg.jcublas.JCublasNDArray) CudaContext(org.nd4j.linalg.jcublas.context.CudaContext) AllocationPoint(org.nd4j.jita.allocator.impl.AllocationPoint) AllocationPoint(org.nd4j.jita.allocator.impl.AllocationPoint) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer)

Aggregations

AllocationPoint (org.nd4j.jita.allocator.impl.AllocationPoint)3 DataBuffer (org.nd4j.linalg.api.buffer.DataBuffer)3 JCublasNDArray (org.nd4j.linalg.jcublas.JCublasNDArray)3 INDArray (org.nd4j.linalg.api.ndarray.INDArray)2 CudaContext (org.nd4j.linalg.jcublas.context.CudaContext)2