Search in sources :

Example 6 with AllocationPoint

use of org.nd4j.jita.allocator.impl.AllocationPoint in project nd4j by deeplearning4j.

the class SynchronousFlowController method registerAction.

public void registerAction(CudaContext context, INDArray result, INDArray... operands) {
    if (result == null)
        return;
    AllocationPoint point = allocator.getAllocationPoint(result);
    point.tickDeviceWrite();
    eventsProvider.storeEvent(point.getLastWriteEvent());
    point.setLastWriteEvent(eventsProvider.getEvent());
    point.getLastWriteEvent().register(context.getOldStream());
    point.releaseLock();
    for (INDArray operand : operands) {
        if (operand == null)
            continue;
        AllocationPoint pointOperand = allocator.getAllocationPoint(operand);
        pointOperand.releaseLock();
        eventsProvider.storeEvent(pointOperand.getLastReadEvent());
        pointOperand.setLastReadEvent(eventsProvider.getEvent());
        pointOperand.getLastReadEvent().register(context.getOldStream());
    }
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) AllocationPoint(org.nd4j.jita.allocator.impl.AllocationPoint)

Example 7 with AllocationPoint

use of org.nd4j.jita.allocator.impl.AllocationPoint in project nd4j by deeplearning4j.

the class CudaZeroHandler method memcpySpecial.

/**
 * Special memcpy version, addressing shapeInfoDataBuffer copies
 *
 * PLEASE NOTE: Blocking H->H, Async H->D
 *
 * @param dstBuffer
 * @param srcPointer
 * @param length
 * @param dstOffset
 */
@Override
public void memcpySpecial(DataBuffer dstBuffer, Pointer srcPointer, long length, long dstOffset) {
    // log.info("Memcpy special: {} bytes ", length);
    CudaContext context = getCudaContext();
    AllocationPoint point = ((BaseCudaDataBuffer) dstBuffer).getAllocationPoint();
    // context.syncOldStream();
    Pointer dP = new CudaPointer((point.getPointers().getHostPointer().address()) + dstOffset);
    if (nativeOps.memcpyAsync(dP, srcPointer, length, CudaConstants.cudaMemcpyHostToHost, context.getOldStream()) == 0)
        throw new ND4JIllegalStateException("memcpyAsync failed");
    if (point.getAllocationStatus() == AllocationStatus.DEVICE) {
        Pointer rDP = new CudaPointer(point.getPointers().getDevicePointer().address() + dstOffset);
        if (nativeOps.memcpyAsync(rDP, dP, length, CudaConstants.cudaMemcpyHostToDevice, context.getOldStream()) == 0)
            throw new ND4JIllegalStateException("memcpyAsync failed");
        context.syncOldStream();
    }
    context.syncOldStream();
    point.tickDeviceWrite();
// point.tickHostRead();
}
Also used : CudaContext(org.nd4j.linalg.jcublas.context.CudaContext) BaseCudaDataBuffer(org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer) CudaPointer(org.nd4j.jita.allocator.pointers.CudaPointer) Pointer(org.bytedeco.javacpp.Pointer) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) AllocationPoint(org.nd4j.jita.allocator.impl.AllocationPoint) CudaPointer(org.nd4j.jita.allocator.pointers.CudaPointer)

Example 8 with AllocationPoint

use of org.nd4j.jita.allocator.impl.AllocationPoint in project nd4j by deeplearning4j.

the class CudaZeroHandler method memcpy.

/**
 *  Synchronous version of memcpy.
 *
 * @param dstBuffer
 * @param srcBuffer
 */
@Override
public void memcpy(DataBuffer dstBuffer, DataBuffer srcBuffer) {
    // log.info("Buffer MemCpy called");
    // log.info("Memcpy buffer: {} bytes ", dstBuffer.length() * dstBuffer.getElementSize());
    CudaContext context = getCudaContext();
    AllocationPoint dstPoint = ((BaseCudaDataBuffer) dstBuffer).getAllocationPoint();
    AllocationPoint srcPoint = ((BaseCudaDataBuffer) srcBuffer).getAllocationPoint();
    Pointer dP = new CudaPointer(dstPoint.getPointers().getHostPointer().address());
    Pointer sP = null;
    if (srcPoint.getAllocationStatus() == AllocationStatus.DEVICE) {
        sP = new CudaPointer(srcPoint.getPointers().getDevicePointer().address());
        /*
            JCuda.cudaMemcpyAsync(
                    dP,
                    sP,
                    srcBuffer.length(),
                    cudaMemcpyKind.cudaMemcpyHostToDevice,
                    context.getOldStream()
            );*/
        if (nativeOps.memcpyAsync(dP, sP, srcBuffer.length() * srcBuffer.getElementSize(), CudaConstants.cudaMemcpyHostToDevice, context.getOldStream()) == 0) {
            throw new ND4JIllegalStateException("memcpyAsync failed");
        }
    } else {
        sP = new CudaPointer(srcPoint.getPointers().getHostPointer().address());
        /*
            JCuda.cudaMemcpyAsync(
                    dP,
                    sP,
                    srcBuffer.length(),
                    cudaMemcpyKind.cudaMemcpyHostToDevice,
                    context.getOldStream()
            );*/
        if (nativeOps.memcpyAsync(dP, sP, srcBuffer.length() * srcBuffer.getElementSize(), CudaConstants.cudaMemcpyHostToDevice, context.getOldStream()) == 0) {
            throw new ND4JIllegalStateException("memcpyAsync failed");
        }
    }
    if (dstPoint.getAllocationStatus() == AllocationStatus.DEVICE) {
        Pointer rDP = new CudaPointer(dstPoint.getPointers().getDevicePointer().address());
        /*
            JCuda.cudaMemcpyAsync(
                    rDP,
                    dP,
                    srcBuffer.length(),
                    cudaMemcpyKind.cudaMemcpyHostToDevice,
                    context.getOldStream()
            );*/
        if (nativeOps.memcpyAsync(rDP, dP, srcBuffer.length() * srcBuffer.getElementSize(), CudaConstants.cudaMemcpyHostToDevice, context.getOldStream()) == 0) {
            throw new ND4JIllegalStateException("memcpyAsync failed");
        }
    }
    dstPoint.tickDeviceWrite();
    // it has to be blocking call
    context.syncOldStream();
}
Also used : CudaContext(org.nd4j.linalg.jcublas.context.CudaContext) BaseCudaDataBuffer(org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer) CudaPointer(org.nd4j.jita.allocator.pointers.CudaPointer) Pointer(org.bytedeco.javacpp.Pointer) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) AllocationPoint(org.nd4j.jita.allocator.impl.AllocationPoint) CudaPointer(org.nd4j.jita.allocator.pointers.CudaPointer)

Example 9 with AllocationPoint

use of org.nd4j.jita.allocator.impl.AllocationPoint in project nd4j by deeplearning4j.

the class CudaZeroHandler method getDevicePointer.

/**
 * PLEASE NOTE: Specific implementation, on systems without special devices can return HostPointer here
 *
 * @param buffer
 * @return
 */
@Override
public org.bytedeco.javacpp.Pointer getDevicePointer(DataBuffer buffer, CudaContext context) {
    // TODO: It would be awesome to get rid of typecasting here
    // getCudaContext().syncOldStream();
    AllocationPoint dstPoint = ((BaseCudaDataBuffer) buffer).getAllocationPoint();
    // here's the place, where we do care about promotion. but we only care about promotion of original  buffers
    if (dstPoint.getAllocationStatus() == AllocationStatus.HOST && buffer.offset() == 0 && 1 < 0) {
        if (dstPoint.getDeviceTicks() > configuration.getMinimumRelocationThreshold()) {
            // at this point we know, that this request is done withing some existent context
            long requiredMemory = AllocationUtils.getRequiredMemory(dstPoint.getShape());
            if (deviceMemoryTracker.reserveAllocationIfPossible(Thread.currentThread().getId(), getDeviceId(), requiredMemory) && pingDeviceForFreeMemory(getDeviceId(), requiredMemory)) {
                // so, memory is reserved
                promoteObject(buffer);
            }
        }
    }
    // if that's device state, we probably might want to update device memory state
    if (dstPoint.getAllocationStatus() == AllocationStatus.DEVICE) {
        if (!dstPoint.isActualOnDeviceSide()) {
            // log.info("Relocating to GPU");
            relocate(AllocationStatus.HOST, AllocationStatus.DEVICE, dstPoint, dstPoint.getShape(), context);
        } else {
        // log.info("Buffer is actual on device side: " + dstPoint.getShape());
        }
    }
    // else log.info("Not on [DEVICE]");
    // we update memory use counter, to announce that it's somehow used on device
    dstPoint.tickDeviceRead();
    // return pointer with offset if needed. length is specified for constructor compatibility purposes
    CudaPointer p = new CudaPointer(dstPoint.getPointers().getDevicePointer(), buffer.length(), (buffer.offset() * buffer.getElementSize()));
    switch(buffer.dataType()) {
        case DOUBLE:
            return p.asDoublePointer();
        case FLOAT:
            return p.asFloatPointer();
        case INT:
            return p.asIntPointer();
        case HALF:
            return p.asShortPointer();
        default:
            return p;
    }
}
Also used : BaseCudaDataBuffer(org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer) AllocationPoint(org.nd4j.jita.allocator.impl.AllocationPoint) CudaPointer(org.nd4j.jita.allocator.pointers.CudaPointer)

Example 10 with AllocationPoint

use of org.nd4j.jita.allocator.impl.AllocationPoint in project nd4j by deeplearning4j.

the class CudaZeroHandler method relocateObject.

@Override
public synchronized void relocateObject(DataBuffer buffer) {
    AllocationPoint dstPoint = AtomicAllocator.getInstance().getAllocationPoint(buffer);
    // we don't relocate non-DEVICE buffers (i.e HOST or CONSTANT)
    if (dstPoint.getAllocationStatus() != AllocationStatus.DEVICE)
        return;
    int deviceId = getDeviceId();
    if (dstPoint.getDeviceId() >= 0 && dstPoint.getDeviceId() == deviceId) {
        return;
    }
    // FIXME: cross-thread access, might cause problems
    if (!dstPoint.isActualOnHostSide())
        AtomicAllocator.getInstance().synchronizeHostData(buffer);
    if (!dstPoint.isActualOnHostSide())
        throw new RuntimeException("Buffer synchronization failed");
    if (buffer.isAttached() || dstPoint.isAttached()) {
        // if this buffer is Attached, we just relocate to new workspace
        MemoryWorkspace workspace = Nd4j.getMemoryManager().getCurrentWorkspace();
        if (workspace == null) {
            // if we're out of workspace, we should mark our buffer as detached, so gc will pick it up eventually
            alloc(AllocationStatus.DEVICE, dstPoint, dstPoint.getShape(), false);
            CudaContext context = getCudaContext();
            if (nativeOps.memcpyAsync(dstPoint.getDevicePointer(), dstPoint.getHostPointer(), buffer.length() * buffer.getElementSize(), 1, context.getSpecialStream()) == 0)
                throw new ND4JIllegalStateException("memcpyAsync failed");
            context.syncSpecialStream();
            // updating host pointer now
            alloc(AllocationStatus.HOST, dstPoint, dstPoint.getShape(), false);
            // marking it as detached
            dstPoint.setAttached(false);
            // marking it as proper on device
            dstPoint.tickHostRead();
            dstPoint.tickDeviceWrite();
        } else {
            // this call will automagically take care of workspaces, so it'll be either
            // log.info("Relocating to deviceId [{}], workspace [{}]...", deviceId, workspace.getId());
            BaseCudaDataBuffer nBuffer = (BaseCudaDataBuffer) Nd4j.createBuffer(buffer.length());
            Nd4j.getMemoryManager().memcpy(nBuffer, buffer);
            dstPoint.getPointers().setDevicePointer(nBuffer.getAllocationPoint().getDevicePointer());
            dstPoint.getPointers().setHostPointer(nBuffer.getAllocationPoint().getHostPointer());
            dstPoint.setDeviceId(deviceId);
            dstPoint.tickDeviceRead();
            dstPoint.tickHostRead();
        }
        return;
    }
    if (buffer.isConstant()) {
        // we can't relocate or modify buffers
        throw new RuntimeException("Can't relocateObject() for constant buffer");
    } else {
        // log.info("Free relocateObject: deviceId: {}, pointer: {}", deviceId, dstPoint.getPointers().getDevicePointer().address());
        memoryProvider.free(dstPoint);
        deviceMemoryTracker.subFromAllocation(Thread.currentThread().getId(), dstPoint.getDeviceId(), AllocationUtils.getRequiredMemory(dstPoint.getShape()));
        // we replace original device pointer with new one
        alloc(AllocationStatus.DEVICE, dstPoint, dstPoint.getShape(), false);
        // log.info("Pointer after alloc: {}", dstPoint.getPointers().getDevicePointer().address());
        CudaContext context = getCudaContext();
        if (nativeOps.memcpyAsync(dstPoint.getDevicePointer(), dstPoint.getHostPointer(), buffer.length() * buffer.getElementSize(), 1, context.getSpecialStream()) == 0)
            throw new ND4JIllegalStateException("memcpyAsync failed");
        context.syncSpecialStream();
        dstPoint.tickDeviceRead();
        dstPoint.tickHostRead();
    }
}
Also used : CudaContext(org.nd4j.linalg.jcublas.context.CudaContext) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) BaseCudaDataBuffer(org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer) AllocationPoint(org.nd4j.jita.allocator.impl.AllocationPoint) MemoryWorkspace(org.nd4j.linalg.api.memory.MemoryWorkspace) AllocationPoint(org.nd4j.jita.allocator.impl.AllocationPoint)

Aggregations

AllocationPoint (org.nd4j.jita.allocator.impl.AllocationPoint)67 INDArray (org.nd4j.linalg.api.ndarray.INDArray)33 Test (org.junit.Test)31 CudaContext (org.nd4j.linalg.jcublas.context.CudaContext)24 CudaPointer (org.nd4j.jita.allocator.pointers.CudaPointer)15 DataBuffer (org.nd4j.linalg.api.buffer.DataBuffer)11 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)11 AtomicAllocator (org.nd4j.jita.allocator.impl.AtomicAllocator)7 BaseCudaDataBuffer (org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer)7 Pointer (org.bytedeco.javacpp.Pointer)6 AllocationShape (org.nd4j.jita.allocator.impl.AllocationShape)5 PointersPair (org.nd4j.jita.allocator.pointers.PointersPair)5 MemoryWorkspace (org.nd4j.linalg.api.memory.MemoryWorkspace)4 JCublasNDArray (org.nd4j.linalg.jcublas.JCublasNDArray)3 CudaDoubleDataBuffer (org.nd4j.linalg.jcublas.buffer.CudaDoubleDataBuffer)3 CompressedDataBuffer (org.nd4j.linalg.compression.CompressedDataBuffer)2 DeviceLocalNDArray (org.nd4j.linalg.util.DeviceLocalNDArray)2 DataInputStream (java.io.DataInputStream)1 DataOutputStream (java.io.DataOutputStream)1 FileInputStream (java.io.FileInputStream)1