Search in sources :

Example 1 with MemoryWorkspace

use of org.nd4j.linalg.api.memory.MemoryWorkspace in project nd4j by deeplearning4j.

the class BasicWorkspaceManager method destroyAllWorkspacesForCurrentThread.

/**
 * This method destroys all workspaces allocated in current thread
 */
@Override
public void destroyAllWorkspacesForCurrentThread() {
    ensureThreadExistense();
    List<MemoryWorkspace> workspaces = new ArrayList<>();
    workspaces.addAll(backingMap.get().values());
    for (MemoryWorkspace workspace : workspaces) {
        destroyWorkspace(workspace);
    }
    System.gc();
}
Also used : ArrayList(java.util.ArrayList) MemoryWorkspace(org.nd4j.linalg.api.memory.MemoryWorkspace)

Example 2 with MemoryWorkspace

use of org.nd4j.linalg.api.memory.MemoryWorkspace in project nd4j by deeplearning4j.

the class BasicWorkspaceManager method destroyWorkspace.

/**
 * This method destroy default workspace, if any
 */
@Override
public void destroyWorkspace() {
    ensureThreadExistense();
    MemoryWorkspace workspace = backingMap.get().get(MemoryWorkspace.DEFAULT_ID);
    // if (workspace != null)
    // workspace.destroyWorkspace();
    backingMap.get().remove(MemoryWorkspace.DEFAULT_ID);
}
Also used : MemoryWorkspace(org.nd4j.linalg.api.memory.MemoryWorkspace)

Example 3 with MemoryWorkspace

use of org.nd4j.linalg.api.memory.MemoryWorkspace in project nd4j by deeplearning4j.

the class WorkspaceProviderTests method testWorkspacesSerde1.

@Test
public void testWorkspacesSerde1() throws Exception {
    int[] shape = new int[] { 17, 57, 79 };
    INDArray array = Nd4j.create(shape).assign(1.0);
    INDArray restored = null;
    ByteArrayOutputStream bos = new ByteArrayOutputStream();
    DataOutputStream dos = new DataOutputStream(bos);
    Nd4j.write(array, dos);
    try (MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace(bigConfiguration, "WS_1")) {
        ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray());
        DataInputStream dis = new DataInputStream(bis);
        restored = Nd4j.read(dis);
        assertEquals(array.length(), restored.length());
        assertEquals(1.0f, restored.meanNumber().floatValue(), 1.0f);
        // we want to ensure it's the same cached shapeInfo used here
        assertTrue(array.shapeInfoDataBuffer() == restored.shapeInfoDataBuffer());
    }
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) ByteArrayInputStream(java.io.ByteArrayInputStream) DataOutputStream(java.io.DataOutputStream) ByteArrayOutputStream(java.io.ByteArrayOutputStream) MemoryWorkspace(org.nd4j.linalg.api.memory.MemoryWorkspace) DataInputStream(java.io.DataInputStream) Test(org.junit.Test) BaseNd4jTest(org.nd4j.linalg.BaseNd4jTest)

Example 4 with MemoryWorkspace

use of org.nd4j.linalg.api.memory.MemoryWorkspace in project nd4j by deeplearning4j.

the class WorkspaceProviderTests method testCircularLearning1.

@Test
public void testCircularLearning1() throws Exception {
    INDArray array1;
    INDArray array2;
    for (int i = 0; i < 2; i++) {
        try (MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace(circularConfiguration, "WSX")) {
            array1 = Nd4j.create(10).assign(1);
        }
        Nd4jWorkspace workspace = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(circularConfiguration, "WSX");
        assertEquals(10 * 1024 * 1024L, workspace.getCurrentSize());
        log.info("Current step number: {}", workspace.getStepNumber());
        if (i == 0)
            assertEquals(0, workspace.getHostOffset());
        else if (i == 1)
            assertEquals(workspace.getInitialBlockSize(), workspace.getHostOffset());
    }
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) MemoryWorkspace(org.nd4j.linalg.api.memory.MemoryWorkspace) Nd4jWorkspace(org.nd4j.linalg.memory.abstracts.Nd4jWorkspace) Test(org.junit.Test) BaseNd4jTest(org.nd4j.linalg.BaseNd4jTest)

Example 5 with MemoryWorkspace

use of org.nd4j.linalg.api.memory.MemoryWorkspace in project nd4j by deeplearning4j.

the class WorkspaceProviderTests method testNestedWorkspaces7.

@Test
public void testNestedWorkspaces7() throws Exception {
    try (Nd4jWorkspace wsExternal = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfiguration, "External")) {
        INDArray array1 = Nd4j.create(10);
        INDArray array2 = null;
        INDArray array3 = null;
        INDArray array4 = null;
        INDArray array5 = null;
        try (Nd4jWorkspace wsFeedForward = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfiguration, "FeedForward")) {
            array2 = Nd4j.create(10);
            assertEquals(true, array2.isAttached());
            try (Nd4jWorkspace borrowed = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("External").notifyScopeBorrowed()) {
                array3 = Nd4j.create(10);
                assertTrue(wsExternal == array3.data().getParentWorkspace());
                try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
                    array4 = Nd4j.create(10);
                }
                array5 = Nd4j.create(10);
                log.info("Workspace5: {}", array5.data().getParentWorkspace());
                assertTrue(null == array4.data().getParentWorkspace());
                assertFalse(array4.isAttached());
                assertTrue(wsExternal == array5.data().getParentWorkspace());
            }
            assertEquals(true, array3.isAttached());
            assertEquals(false, array4.isAttached());
            assertEquals(true, array5.isAttached());
        }
    }
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) MemoryWorkspace(org.nd4j.linalg.api.memory.MemoryWorkspace) Nd4jWorkspace(org.nd4j.linalg.memory.abstracts.Nd4jWorkspace) Test(org.junit.Test) BaseNd4jTest(org.nd4j.linalg.BaseNd4jTest)

Aggregations

MemoryWorkspace (org.nd4j.linalg.api.memory.MemoryWorkspace)62 Test (org.junit.Test)39 BaseNd4jTest (org.nd4j.linalg.BaseNd4jTest)35 INDArray (org.nd4j.linalg.api.ndarray.INDArray)35 Nd4jWorkspace (org.nd4j.linalg.memory.abstracts.Nd4jWorkspace)18 WorkspaceConfiguration (org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration)14 DataBuffer (org.nd4j.linalg.api.buffer.DataBuffer)6 AtomicLong (java.util.concurrent.atomic.AtomicLong)5 AllocationPoint (org.nd4j.jita.allocator.impl.AllocationPoint)4 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)3 CudaContext (org.nd4j.linalg.jcublas.context.CudaContext)3 ByteArrayInputStream (java.io.ByteArrayInputStream)2 ByteArrayOutputStream (java.io.ByteArrayOutputStream)2 DataInputStream (java.io.DataInputStream)2 DataOutputStream (java.io.DataOutputStream)2 File (java.io.File)2 ArrayList (java.util.ArrayList)2 Ignore (org.junit.Ignore)2 IOException (java.io.IOException)1 CopyOnWriteArrayList (java.util.concurrent.CopyOnWriteArrayList)1