Search in sources :

Example 6 with MemoryWorkspace

use of org.nd4j.linalg.api.memory.MemoryWorkspace in project nd4j by deeplearning4j.

the class WorkspaceProviderTests method testReallocate3.

@Test
public void testReallocate3() throws Exception {
    MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(reallocateUnspecifiedConfiguration, "WS_1");
    for (int i = 1; i <= 10; i++) {
        try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(reallocateUnspecifiedConfiguration, "WS_1")) {
            INDArray array = Nd4j.create(100 * i);
        }
        if (i == 3) {
            workspace.initializeWorkspace();
            assertEquals("Failed on iteration " + i, 100 * i * Nd4j.sizeOfDataType(), workspace.getCurrentSize());
        }
    }
    log.info("-----------------------------");
    for (int i = 10; i > 0; i--) {
        try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(reallocateUnspecifiedConfiguration, "WS_1")) {
            INDArray array = Nd4j.create(100 * i);
        }
    }
    workspace.initializeWorkspace();
    assertEquals("Failed on final", 100 * 10 * Nd4j.sizeOfDataType(), workspace.getCurrentSize());
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) MemoryWorkspace(org.nd4j.linalg.api.memory.MemoryWorkspace) Test(org.junit.Test) BaseNd4jTest(org.nd4j.linalg.BaseNd4jTest)

Example 7 with MemoryWorkspace

use of org.nd4j.linalg.api.memory.MemoryWorkspace in project nd4j by deeplearning4j.

the class WorkspaceProviderTests method testNewWorkspace1.

@Test
public void testNewWorkspace1() throws Exception {
    MemoryWorkspace workspace1 = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread();
    assertNotEquals(null, workspace1);
    MemoryWorkspace workspace2 = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread();
    assertEquals(workspace1, workspace2);
}
Also used : MemoryWorkspace(org.nd4j.linalg.api.memory.MemoryWorkspace) Test(org.junit.Test) BaseNd4jTest(org.nd4j.linalg.BaseNd4jTest)

Example 8 with MemoryWorkspace

use of org.nd4j.linalg.api.memory.MemoryWorkspace in project nd4j by deeplearning4j.

the class WorkspaceProviderTests method testWorkspacesSerde3.

@Test
public void testWorkspacesSerde3() throws Exception {
    INDArray array = Nd4j.create(10).assign(1.0);
    INDArray restored = null;
    ByteArrayOutputStream bos = new ByteArrayOutputStream();
    DataOutputStream dos = new DataOutputStream(bos);
    Nd4j.write(array, dos);
    try (Nd4jWorkspace workspace = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfiguration, "WS_1")) {
        try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
            workspace.enableDebug(true);
            ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray());
            DataInputStream dis = new DataInputStream(bis);
            restored = Nd4j.read(dis);
            assertEquals(0, workspace.getHostOffset());
            assertEquals(array.length(), restored.length());
            assertEquals(1.0f, restored.meanNumber().floatValue(), 1.0f);
            // we want to ensure it's the same cached shapeInfo used here
            assertTrue(array.shapeInfoDataBuffer() == restored.shapeInfoDataBuffer());
        }
    }
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) ByteArrayInputStream(java.io.ByteArrayInputStream) DataOutputStream(java.io.DataOutputStream) ByteArrayOutputStream(java.io.ByteArrayOutputStream) MemoryWorkspace(org.nd4j.linalg.api.memory.MemoryWorkspace) DataInputStream(java.io.DataInputStream) Nd4jWorkspace(org.nd4j.linalg.memory.abstracts.Nd4jWorkspace) Test(org.junit.Test) BaseNd4jTest(org.nd4j.linalg.BaseNd4jTest)

Example 9 with MemoryWorkspace

use of org.nd4j.linalg.api.memory.MemoryWorkspace in project nd4j by deeplearning4j.

the class CudaZeroHandler method relocateObject.

@Override
public synchronized void relocateObject(DataBuffer buffer) {
    AllocationPoint dstPoint = AtomicAllocator.getInstance().getAllocationPoint(buffer);
    // we don't relocate non-DEVICE buffers (i.e HOST or CONSTANT)
    if (dstPoint.getAllocationStatus() != AllocationStatus.DEVICE)
        return;
    int deviceId = getDeviceId();
    if (dstPoint.getDeviceId() >= 0 && dstPoint.getDeviceId() == deviceId) {
        return;
    }
    // FIXME: cross-thread access, might cause problems
    if (!dstPoint.isActualOnHostSide())
        AtomicAllocator.getInstance().synchronizeHostData(buffer);
    if (!dstPoint.isActualOnHostSide())
        throw new RuntimeException("Buffer synchronization failed");
    if (buffer.isAttached() || dstPoint.isAttached()) {
        // if this buffer is Attached, we just relocate to new workspace
        MemoryWorkspace workspace = Nd4j.getMemoryManager().getCurrentWorkspace();
        if (workspace == null) {
            // if we're out of workspace, we should mark our buffer as detached, so gc will pick it up eventually
            alloc(AllocationStatus.DEVICE, dstPoint, dstPoint.getShape(), false);
            CudaContext context = getCudaContext();
            if (nativeOps.memcpyAsync(dstPoint.getDevicePointer(), dstPoint.getHostPointer(), buffer.length() * buffer.getElementSize(), 1, context.getSpecialStream()) == 0)
                throw new ND4JIllegalStateException("memcpyAsync failed");
            context.syncSpecialStream();
            // updating host pointer now
            alloc(AllocationStatus.HOST, dstPoint, dstPoint.getShape(), false);
            // marking it as detached
            dstPoint.setAttached(false);
            // marking it as proper on device
            dstPoint.tickHostRead();
            dstPoint.tickDeviceWrite();
        } else {
            // this call will automagically take care of workspaces, so it'll be either
            // log.info("Relocating to deviceId [{}], workspace [{}]...", deviceId, workspace.getId());
            BaseCudaDataBuffer nBuffer = (BaseCudaDataBuffer) Nd4j.createBuffer(buffer.length());
            Nd4j.getMemoryManager().memcpy(nBuffer, buffer);
            dstPoint.getPointers().setDevicePointer(nBuffer.getAllocationPoint().getDevicePointer());
            dstPoint.getPointers().setHostPointer(nBuffer.getAllocationPoint().getHostPointer());
            dstPoint.setDeviceId(deviceId);
            dstPoint.tickDeviceRead();
            dstPoint.tickHostRead();
        }
        return;
    }
    if (buffer.isConstant()) {
        // we can't relocate or modify buffers
        throw new RuntimeException("Can't relocateObject() for constant buffer");
    } else {
        // log.info("Free relocateObject: deviceId: {}, pointer: {}", deviceId, dstPoint.getPointers().getDevicePointer().address());
        memoryProvider.free(dstPoint);
        deviceMemoryTracker.subFromAllocation(Thread.currentThread().getId(), dstPoint.getDeviceId(), AllocationUtils.getRequiredMemory(dstPoint.getShape()));
        // we replace original device pointer with new one
        alloc(AllocationStatus.DEVICE, dstPoint, dstPoint.getShape(), false);
        // log.info("Pointer after alloc: {}", dstPoint.getPointers().getDevicePointer().address());
        CudaContext context = getCudaContext();
        if (nativeOps.memcpyAsync(dstPoint.getDevicePointer(), dstPoint.getHostPointer(), buffer.length() * buffer.getElementSize(), 1, context.getSpecialStream()) == 0)
            throw new ND4JIllegalStateException("memcpyAsync failed");
        context.syncSpecialStream();
        dstPoint.tickDeviceRead();
        dstPoint.tickHostRead();
    }
}
Also used : CudaContext(org.nd4j.linalg.jcublas.context.CudaContext) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) BaseCudaDataBuffer(org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer) AllocationPoint(org.nd4j.jita.allocator.impl.AllocationPoint) MemoryWorkspace(org.nd4j.linalg.api.memory.MemoryWorkspace) AllocationPoint(org.nd4j.jita.allocator.impl.AllocationPoint)

Example 10 with MemoryWorkspace

use of org.nd4j.linalg.api.memory.MemoryWorkspace in project nd4j by deeplearning4j.

the class DoubleDataBufferTest method testReallocationWorkspace.

@Test
public void testReallocationWorkspace() {
    WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L).policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.NONE).build();
    MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace(initialConfig, "SOME_ID");
    DataBuffer buffer = Nd4j.createBuffer(new double[] { 1, 2, 3, 4 });
    double[] old = buffer.asDouble();
    assertTrue(buffer.isAttached());
    assertEquals(4, buffer.capacity());
    buffer.reallocate(6);
    assertEquals(6, buffer.capacity());
    assertArrayEquals(old, buffer.asDouble(), 1e-1);
    workspace.close();
}
Also used : WorkspaceConfiguration(org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration) MemoryWorkspace(org.nd4j.linalg.api.memory.MemoryWorkspace) Test(org.junit.Test) BaseNd4jTest(org.nd4j.linalg.BaseNd4jTest)

Aggregations

MemoryWorkspace (org.nd4j.linalg.api.memory.MemoryWorkspace)62 Test (org.junit.Test)39 BaseNd4jTest (org.nd4j.linalg.BaseNd4jTest)35 INDArray (org.nd4j.linalg.api.ndarray.INDArray)35 Nd4jWorkspace (org.nd4j.linalg.memory.abstracts.Nd4jWorkspace)18 WorkspaceConfiguration (org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration)14 DataBuffer (org.nd4j.linalg.api.buffer.DataBuffer)6 AtomicLong (java.util.concurrent.atomic.AtomicLong)5 AllocationPoint (org.nd4j.jita.allocator.impl.AllocationPoint)4 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)3 CudaContext (org.nd4j.linalg.jcublas.context.CudaContext)3 ByteArrayInputStream (java.io.ByteArrayInputStream)2 ByteArrayOutputStream (java.io.ByteArrayOutputStream)2 DataInputStream (java.io.DataInputStream)2 DataOutputStream (java.io.DataOutputStream)2 File (java.io.File)2 ArrayList (java.util.ArrayList)2 Ignore (org.junit.Ignore)2 IOException (java.io.IOException)1 CopyOnWriteArrayList (java.util.concurrent.CopyOnWriteArrayList)1