Search in sources :

Example 1 with AllocationPoint

use of org.nd4j.jita.allocator.impl.AllocationPoint in project nd4j by deeplearning4j.

the class SynchronousFlowController method registerActionAllWrite.

@Override
public void registerActionAllWrite(CudaContext context, INDArray... operands) {
    for (INDArray operand : operands) {
        if (operand == null)
            continue;
        AllocationPoint pointOperand = allocator.getAllocationPoint(operand);
        pointOperand.tickDeviceWrite();
        eventsProvider.storeEvent(pointOperand.getLastWriteEvent());
        pointOperand.setLastWriteEvent(eventsProvider.getEvent());
        pointOperand.getLastWriteEvent().register(context.getOldStream());
        pointOperand.releaseLock();
    }
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) AllocationPoint(org.nd4j.jita.allocator.impl.AllocationPoint)

Example 2 with AllocationPoint

use of org.nd4j.jita.allocator.impl.AllocationPoint in project nd4j by deeplearning4j.

the class SynchronousFlowController method prepareDelayedMemory.

protected void prepareDelayedMemory(INDArray array) {
    if (configuration.getMemoryModel() == Configuration.MemoryModel.DELAYED) {
        AllocationPoint pointData = allocator.getAllocationPoint(array.shapeInfoDataBuffer());
        AllocationPoint pointShape = allocator.getAllocationPoint(array.shapeInfoDataBuffer());
        if (pointData.getAllocationStatus() != AllocationStatus.DEVICE)
            prepareDelayedMemory(array.data());
        if (pointShape.getAllocationStatus() == AllocationStatus.HOST) {
            DataBuffer oShape = array.shapeInfoDataBuffer();
            DataBuffer nShape = Nd4j.getConstantHandler().relocateConstantSpace(oShape);
            if (nShape == oShape)
                Nd4j.getConstantHandler().moveToConstantSpace(nShape);
            ((JCublasNDArray) array).setShapeInfoDataBuffer(nShape);
        }
    }
}
Also used : JCublasNDArray(org.nd4j.linalg.jcublas.JCublasNDArray) AllocationPoint(org.nd4j.jita.allocator.impl.AllocationPoint) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer)

Example 3 with AllocationPoint

use of org.nd4j.jita.allocator.impl.AllocationPoint in project nd4j by deeplearning4j.

the class SynchronousFlowController method prepareActionAllWrite.

@Override
public CudaContext prepareActionAllWrite(INDArray... operands) {
    CudaContext context = (CudaContext) allocator.getDeviceContext().getContext();
    int cId = allocator.getDeviceId();
    for (INDArray operand : operands) {
        if (operand == null)
            continue;
        Nd4j.getCompressor().autoDecompress(operand);
        AllocationPoint pointData = allocator.getAllocationPoint(operand);
        AllocationPoint pointShape = allocator.getAllocationPoint(operand.shapeInfoDataBuffer());
        pointData.acquireLock();
        if (pointData.getDeviceId() != cId && pointData.getDeviceId() >= 0) {
            DataBuffer buffer = operand.data().originalDataBuffer() == null ? operand.data() : operand.data().originalDataBuffer();
            allocator.getMemoryHandler().relocateObject(buffer);
        }
        if (pointShape.getDeviceId() != cId && pointShape.getDeviceId() >= 0) {
            ((JCublasNDArray) operand).setShapeInfoDataBuffer(Nd4j.getConstantHandler().relocateConstantSpace(operand.shapeInfoDataBuffer()));
        }
        prepareDelayedMemory(operand);
        allocator.getAllocationPoint(operand).setCurrentContext(context);
    }
    return context;
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) JCublasNDArray(org.nd4j.linalg.jcublas.JCublasNDArray) CudaContext(org.nd4j.linalg.jcublas.context.CudaContext) AllocationPoint(org.nd4j.jita.allocator.impl.AllocationPoint) AllocationPoint(org.nd4j.jita.allocator.impl.AllocationPoint) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer)

Example 4 with AllocationPoint

use of org.nd4j.jita.allocator.impl.AllocationPoint in project nd4j by deeplearning4j.

the class SynchronousFlowController method registerAction.

@Override
public void registerAction(CudaContext context, AllocationPoint result, AllocationPoint... operands) {
    eventsProvider.storeEvent(result.getLastWriteEvent());
    result.setLastWriteEvent(eventsProvider.getEvent());
    result.getLastWriteEvent().register(context.getOldStream());
    result.releaseLock();
    for (AllocationPoint operand : operands) {
        eventsProvider.storeEvent(operand.getLastReadEvent());
        operand.setLastReadEvent(eventsProvider.getEvent());
        operand.getLastReadEvent().register(context.getOldStream());
        operand.releaseLock();
    }
// context.syncOldStream();
}
Also used : AllocationPoint(org.nd4j.jita.allocator.impl.AllocationPoint)

Example 5 with AllocationPoint

use of org.nd4j.jita.allocator.impl.AllocationPoint in project nd4j by deeplearning4j.

the class SynchronousFlowController method prepareAction.

@Override
public CudaContext prepareAction(AllocationPoint result, AllocationPoint... operands) {
    CudaContext context = (CudaContext) allocator.getDeviceContext().getContext();
    if (result != null) {
        result.acquireLock();
        result.setCurrentContext(context);
    }
    for (AllocationPoint operand : operands) {
        if (operand == null)
            continue;
        operand.acquireLock();
        operand.setCurrentContext(context);
    }
    return context;
}
Also used : CudaContext(org.nd4j.linalg.jcublas.context.CudaContext) AllocationPoint(org.nd4j.jita.allocator.impl.AllocationPoint)

Aggregations

AllocationPoint (org.nd4j.jita.allocator.impl.AllocationPoint)67 INDArray (org.nd4j.linalg.api.ndarray.INDArray)33 Test (org.junit.Test)31 CudaContext (org.nd4j.linalg.jcublas.context.CudaContext)24 CudaPointer (org.nd4j.jita.allocator.pointers.CudaPointer)15 DataBuffer (org.nd4j.linalg.api.buffer.DataBuffer)11 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)11 AtomicAllocator (org.nd4j.jita.allocator.impl.AtomicAllocator)7 BaseCudaDataBuffer (org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer)7 Pointer (org.bytedeco.javacpp.Pointer)6 AllocationShape (org.nd4j.jita.allocator.impl.AllocationShape)5 PointersPair (org.nd4j.jita.allocator.pointers.PointersPair)5 MemoryWorkspace (org.nd4j.linalg.api.memory.MemoryWorkspace)4 JCublasNDArray (org.nd4j.linalg.jcublas.JCublasNDArray)3 CudaDoubleDataBuffer (org.nd4j.linalg.jcublas.buffer.CudaDoubleDataBuffer)3 CompressedDataBuffer (org.nd4j.linalg.compression.CompressedDataBuffer)2 DeviceLocalNDArray (org.nd4j.linalg.util.DeviceLocalNDArray)2 DataInputStream (java.io.DataInputStream)1 DataOutputStream (java.io.DataOutputStream)1 FileInputStream (java.io.FileInputStream)1