Search in sources :

Example 1 with CudaWorkspace

use of org.nd4j.jita.workspace.CudaWorkspace in project nd4j by deeplearning4j.

the class AtomicAllocator method allocateMemory.

/**
 * This method allocates required chunk of memory in specific location
 * <p>
 * PLEASE NOTE: Do not use this method, unless you're 100% sure what you're doing
 *
 * @param requiredMemory
 * @param location
 */
@Override
public AllocationPoint allocateMemory(DataBuffer buffer, AllocationShape requiredMemory, AllocationStatus location, boolean initialize) {
    AllocationPoint point = new AllocationPoint();
    useTracker.set(System.currentTimeMillis());
    // we use these longs as tracking codes for memory tracking
    Long allocId = objectsTracker.getAndIncrement();
    // point.attachBuffer(buffer);
    point.setObjectId(allocId);
    point.setShape(requiredMemory);
    /*
        if (buffer instanceof CudaIntDataBuffer) {
            buffer.setConstant(true);
            point.setConstant(true);
        }
        */
    int numBuckets = configuration.getNumberOfGcThreads();
    int bucketId = RandomUtils.nextInt(0, numBuckets);
    GarbageBufferReference reference = new GarbageBufferReference((BaseDataBuffer) buffer, queueMap.get(bucketId), point);
    point.attachReference(reference);
    point.setDeviceId(-1);
    if (buffer.isAttached()) {
        long reqMem = AllocationUtils.getRequiredMemory(requiredMemory);
        // log.info("Allocating {} bytes from attached memory...", reqMem);
        // workaround for init order
        getMemoryHandler().getCudaContext();
        point.setDeviceId(Nd4j.getAffinityManager().getDeviceForCurrentThread());
        CudaWorkspace workspace = (CudaWorkspace) Nd4j.getMemoryManager().getCurrentWorkspace();
        PointersPair pair = new PointersPair();
        PagedPointer ptrDev = workspace.alloc(reqMem, MemoryKind.DEVICE, requiredMemory.getDataType(), initialize);
        PagedPointer ptrHost = workspace.alloc(reqMem, MemoryKind.HOST, requiredMemory.getDataType(), initialize);
        pair.setHostPointer(ptrHost);
        if (ptrDev != null) {
            pair.setDevicePointer(ptrDev);
            point.setAllocationStatus(AllocationStatus.DEVICE);
        } else {
            pair.setDevicePointer(ptrHost);
            point.setAllocationStatus(AllocationStatus.HOST);
        }
        // if (!ptrDev.isLeaked())
        point.setAttached(true);
        point.setPointers(pair);
    } else {
        // we stay naive on PointersPair, we just don't know on this level, which pointers are set. MemoryHandler will be used for that
        PointersPair pair = memoryHandler.alloc(location, point, requiredMemory, initialize);
        point.setPointers(pair);
    }
    allocationsMap.put(allocId, point);
    return point;
}
Also used : PointersPair(org.nd4j.jita.allocator.pointers.PointersPair) AtomicLong(java.util.concurrent.atomic.AtomicLong) CudaWorkspace(org.nd4j.jita.workspace.CudaWorkspace) PagedPointer(org.nd4j.linalg.api.memory.pointers.PagedPointer) GarbageBufferReference(org.nd4j.jita.allocator.garbage.GarbageBufferReference)

Aggregations

AtomicLong (java.util.concurrent.atomic.AtomicLong)1 GarbageBufferReference (org.nd4j.jita.allocator.garbage.GarbageBufferReference)1 PointersPair (org.nd4j.jita.allocator.pointers.PointersPair)1 CudaWorkspace (org.nd4j.jita.workspace.CudaWorkspace)1 PagedPointer (org.nd4j.linalg.api.memory.pointers.PagedPointer)1