Search in sources :

Example 6 with PointersPair

use of org.nd4j.linalg.api.memory.pointers.PointersPair in project nd4j by deeplearning4j.

the class Nd4jWorkspace method alloc.

public PagedPointer alloc(long requiredMemory, MemoryKind kind, DataBuffer.Type type, boolean initialize) {
    /*
            just two options here:
            1) reqMem + hostOffset < totalSize, we just return pointer + offset
            2) go for either external spilled, or pinned allocation
         */
    // we enforce 8 byte alignment to ensure CUDA doesn't blame us
    long div = requiredMemory % 8;
    if (div != 0)
        requiredMemory += div;
    long numElements = requiredMemory / Nd4j.sizeOfDataType(type);
    // shortcut made to skip workspace
    if (!isUsed.get()) {
        if (disabledCounter.incrementAndGet() % 10 == 0)
            log.warn("Workspace was turned off, and wasn't enabled after {} allocations", disabledCounter.get());
        PagedPointer pointer = new PagedPointer(memoryManager.allocate(requiredMemory, MemoryKind.HOST, initialize), numElements);
        externalAllocations.add(new PointersPair(pointer, null));
        return pointer;
    }
    /*
            Trimmed mode is possible for cyclic workspace mode. Used in AsyncDataSetIterator, MQ, etc.
            Basically idea is simple: if one of datasets coming out of iterator has size higher then expected - we should reallocate workspace to match this size.
            So, we switch to trimmed mode, and all allocations will be "pinned", and eventually workspace will be reallocated.
         */
    boolean trimmer = (workspaceConfiguration.getPolicyReset() == ResetPolicy.ENDOFBUFFER_REACHED && requiredMemory + cycleAllocations.get() > initialBlockSize.get() && initialBlockSize.get() > 0) || trimmedMode.get();
    if (trimmer && workspaceConfiguration.getPolicySpill() == SpillPolicy.REALLOCATE && !trimmedMode.get()) {
        trimmedMode.set(true);
        trimmedStep.set(stepsCount.get());
    }
    // if size is enough - allocate from workspace
    if (hostOffset.get() + requiredMemory <= currentSize.get() && !trimmer) {
        // just alignment to 8 bytes
        cycleAllocations.addAndGet(requiredMemory);
        long prevOffset = hostOffset.getAndAdd(requiredMemory);
        deviceOffset.set(hostOffset.get());
        PagedPointer ptr = workspace.getHostPointer().withOffset(prevOffset, numElements);
        if (isDebug.get())
            log.info("Workspace [{}]: Allocating array of {} bytes, capacity of {} elements, prevOffset: {}; currentOffset: {}; address: {}", id, requiredMemory, numElements, prevOffset, hostOffset.get(), ptr.address());
        if (initialize)
            Pointer.memset(ptr, 0, requiredMemory);
        return ptr;
    } else {
        // in case of circular mode - we just reset offsets, and start from the beginning of the workspace
        if (workspaceConfiguration.getPolicyReset() == ResetPolicy.ENDOFBUFFER_REACHED && currentSize.get() > 0 && !trimmer) {
            reset();
            resetPlanned.set(true);
            return alloc(requiredMemory, kind, type, initialize);
        }
        // updating respective counters
        if (!trimmer)
            spilledAllocationsSize.addAndGet(requiredMemory);
        else
            pinnedAllocationsSize.addAndGet(requiredMemory);
        if (isDebug.get())
            log.info("Workspace [{}]: step: {}, spilled  {} bytes, capacity of {} elements", id, stepsCount.get(), requiredMemory, numElements);
        switch(workspaceConfiguration.getPolicySpill()) {
            case REALLOCATE:
            case EXTERNAL:
                cycleAllocations.addAndGet(requiredMemory);
                if (!trimmer) {
                    externalCount.incrementAndGet();
                    PagedPointer pointer = new PagedPointer(memoryManager.allocate(requiredMemory, MemoryKind.HOST, initialize), numElements);
                    externalAllocations.add(new PointersPair(pointer, null));
                    return pointer;
                } else {
                    pinnedCount.incrementAndGet();
                    PagedPointer pointer = new PagedPointer(memoryManager.allocate(requiredMemory, MemoryKind.HOST, initialize), numElements);
                    pinnedAllocations.add(new PointersPair(stepsCount.get(), requiredMemory, pointer, null));
                    return pointer;
                }
            case FAIL:
            default:
                {
                    throw new ND4JIllegalStateException("Can't allocate memory: Workspace is full");
                }
        }
    }
}
Also used : PointersPair(org.nd4j.linalg.api.memory.pointers.PointersPair) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) PagedPointer(org.nd4j.linalg.api.memory.pointers.PagedPointer)

Aggregations

PointersPair (org.nd4j.linalg.api.memory.pointers.PointersPair)6 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)3 PagedPointer (org.nd4j.linalg.api.memory.pointers.PagedPointer)2 AllocationShape (org.nd4j.jita.allocator.impl.AllocationShape)1 CudaContext (org.nd4j.linalg.jcublas.context.CudaContext)1 NativeOps (org.nd4j.nativeblas.NativeOps)1