use of org.nd4j.jita.workspace.CudaWorkspace in project nd4j by deeplearning4j.
the class AtomicAllocator method allocateMemory.
/**
* This method allocates required chunk of memory in specific location
* <p>
* PLEASE NOTE: Do not use this method, unless you're 100% sure what you're doing
*
* @param requiredMemory
* @param location
*/
@Override
public AllocationPoint allocateMemory(DataBuffer buffer, AllocationShape requiredMemory, AllocationStatus location, boolean initialize) {
AllocationPoint point = new AllocationPoint();
useTracker.set(System.currentTimeMillis());
// we use these longs as tracking codes for memory tracking
Long allocId = objectsTracker.getAndIncrement();
// point.attachBuffer(buffer);
point.setObjectId(allocId);
point.setShape(requiredMemory);
/*
if (buffer instanceof CudaIntDataBuffer) {
buffer.setConstant(true);
point.setConstant(true);
}
*/
int numBuckets = configuration.getNumberOfGcThreads();
int bucketId = RandomUtils.nextInt(0, numBuckets);
GarbageBufferReference reference = new GarbageBufferReference((BaseDataBuffer) buffer, queueMap.get(bucketId), point);
point.attachReference(reference);
point.setDeviceId(-1);
if (buffer.isAttached()) {
long reqMem = AllocationUtils.getRequiredMemory(requiredMemory);
// log.info("Allocating {} bytes from attached memory...", reqMem);
// workaround for init order
getMemoryHandler().getCudaContext();
point.setDeviceId(Nd4j.getAffinityManager().getDeviceForCurrentThread());
CudaWorkspace workspace = (CudaWorkspace) Nd4j.getMemoryManager().getCurrentWorkspace();
PointersPair pair = new PointersPair();
PagedPointer ptrDev = workspace.alloc(reqMem, MemoryKind.DEVICE, requiredMemory.getDataType(), initialize);
PagedPointer ptrHost = workspace.alloc(reqMem, MemoryKind.HOST, requiredMemory.getDataType(), initialize);
pair.setHostPointer(ptrHost);
if (ptrDev != null) {
pair.setDevicePointer(ptrDev);
point.setAllocationStatus(AllocationStatus.DEVICE);
} else {
pair.setDevicePointer(ptrHost);
point.setAllocationStatus(AllocationStatus.HOST);
}
// if (!ptrDev.isLeaked())
point.setAttached(true);
point.setPointers(pair);
} else {
// we stay naive on PointersPair, we just don't know on this level, which pointers are set. MemoryHandler will be used for that
PointersPair pair = memoryHandler.alloc(location, point, requiredMemory, initialize);
point.setPointers(pair);
}
allocationsMap.put(allocId, point);
return point;
}
Aggregations