Search in sources :

Example 56 with MemoryWorkspace

use of org.nd4j.linalg.api.memory.MemoryWorkspace in project nd4j by deeplearning4j.

the class WorkspaceProviderTests method testMultithreading1.

@Test
public void testMultithreading1() throws Exception {
    final List<MemoryWorkspace> workspaces = new CopyOnWriteArrayList<>();
    Nd4j.getWorkspaceManager().setDefaultWorkspaceConfiguration(basicConfiguration);
    Thread[] threads = new Thread[20];
    for (int x = 0; x < threads.length; x++) {
        threads[x] = new Thread(new Runnable() {

            @Override
            public void run() {
                MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread();
                workspaces.add(workspace);
            }
        });
        threads[x].start();
    }
    for (int x = 0; x < threads.length; x++) {
        threads[x].join();
    }
    for (int x = 0; x < threads.length; x++) {
        for (int y = 0; y < threads.length; y++) {
            if (x == y)
                continue;
            assertFalse(workspaces.get(x) == workspaces.get(y));
        }
    }
    assertNull(Nd4j.getMemoryManager().getCurrentWorkspace());
}
Also used : MemoryWorkspace(org.nd4j.linalg.api.memory.MemoryWorkspace) CopyOnWriteArrayList(java.util.concurrent.CopyOnWriteArrayList) Test(org.junit.Test) BaseNd4jTest(org.nd4j.linalg.BaseNd4jTest)

Example 57 with MemoryWorkspace

use of org.nd4j.linalg.api.memory.MemoryWorkspace in project nd4j by deeplearning4j.

the class WorkspaceProviderTests method testNestedWorkspaces9.

@Test
public void testNestedWorkspaces9() throws Exception {
    for (int x = 1; x < 10; x++) {
        try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(delayedConfiguration, "WS_1")) {
            INDArray array = Nd4j.create(100 * x);
        }
    }
    Nd4jWorkspace workspace = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(delayedConfiguration, "WS_1");
    workspace.initializeWorkspace();
    assertEquals(300 * Nd4j.sizeOfDataType(), workspace.getCurrentSize());
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) MemoryWorkspace(org.nd4j.linalg.api.memory.MemoryWorkspace) Nd4jWorkspace(org.nd4j.linalg.memory.abstracts.Nd4jWorkspace) Test(org.junit.Test) BaseNd4jTest(org.nd4j.linalg.BaseNd4jTest)

Example 58 with MemoryWorkspace

use of org.nd4j.linalg.api.memory.MemoryWorkspace in project nd4j by deeplearning4j.

the class WorkspaceProviderTests method testNestedWorkspaces6.

@Test
public void testNestedWorkspaces6() throws Exception {
    try (Nd4jWorkspace wsExternal = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getAndActivateWorkspace(firstConfiguration, "External")) {
        INDArray array1 = Nd4j.create(10);
        INDArray array2 = null;
        INDArray array3 = null;
        INDArray array4 = null;
        try (Nd4jWorkspace wsFeedForward = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getAndActivateWorkspace(firstConfiguration, "FeedForward")) {
            array2 = Nd4j.create(10);
            assertEquals(true, array2.isAttached());
            try (Nd4jWorkspace borrowed = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("External").notifyScopeBorrowed()) {
                array3 = Nd4j.create(10);
                assertTrue(wsExternal == array3.data().getParentWorkspace());
            }
            assertEquals(true, array3.isAttached());
            try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
                array4 = Nd4j.create(10);
            }
            assertEquals(false, array4.isAttached());
        }
        assertEquals(0, wsExternal.getCurrentSize());
        log.info("------");
    }
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) MemoryWorkspace(org.nd4j.linalg.api.memory.MemoryWorkspace) Nd4jWorkspace(org.nd4j.linalg.memory.abstracts.Nd4jWorkspace) Test(org.junit.Test) BaseNd4jTest(org.nd4j.linalg.BaseNd4jTest)

Example 59 with MemoryWorkspace

use of org.nd4j.linalg.api.memory.MemoryWorkspace in project nd4j by deeplearning4j.

the class WorkspaceProviderTests method testVariableInput1.

@Test
public void testVariableInput1() throws Exception {
    Nd4jWorkspace workspace = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(adsiConfiguration, "ADSI");
    INDArray array1 = null;
    INDArray array2 = null;
    try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(adsiConfiguration, "ADSI")) {
        // we allocate first element smaller then subsequent;
        array1 = Nd4j.create(8, 128, 100);
    }
    long requiredMemory = 8 * 128 * 100 * Nd4j.sizeOfDataType();
    long shiftedSize = ((long) (requiredMemory * 1.3)) + (8 - (((long) (requiredMemory * 1.3)) % 8));
    assertEquals(shiftedSize, workspace.getInitialBlockSize());
    assertEquals(shiftedSize * 4, workspace.getCurrentSize());
    assertEquals(0, workspace.getHostOffset());
    assertEquals(0, workspace.getDeviceOffset());
    assertEquals(1, workspace.getCyclesCount());
    assertEquals(0, workspace.getStepNumber());
    try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(adsiConfiguration, "ADSI")) {
        // allocating same shape
        array1 = Nd4j.create(8, 128, 100);
    }
    assertEquals(workspace.getInitialBlockSize(), workspace.getHostOffset());
    assertEquals(workspace.getInitialBlockSize(), workspace.getDeviceOffset());
    assertEquals(2, workspace.getCyclesCount());
    assertEquals(0, workspace.getStepNumber());
    try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(adsiConfiguration, "ADSI")) {
        // allocating bigger shape
        array1 = Nd4j.create(8, 128, 200).assign(1.0);
    }
    // offsets should be intact, allocation happened as pinned
    assertEquals(workspace.getInitialBlockSize(), workspace.getHostOffset());
    assertEquals(workspace.getInitialBlockSize(), workspace.getDeviceOffset());
    assertEquals(1, workspace.getNumberOfPinnedAllocations());
    assertEquals(3, workspace.getCyclesCount());
    assertEquals(0, workspace.getStepNumber());
    try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(adsiConfiguration, "ADSI")) {
        // allocating same shape
        array1 = Nd4j.create(8, 128, 100);
    }
    assertEquals(2, workspace.getNumberOfPinnedAllocations());
    assertEquals(0, workspace.getStepNumber());
    assertEquals(4, workspace.getCyclesCount());
    try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(adsiConfiguration, "ADSI")) {
        // allocating same shape
        array1 = Nd4j.create(8, 128, 100);
    }
    assertEquals(3, workspace.getNumberOfPinnedAllocations());
    assertEquals(1, workspace.getStepNumber());
    assertEquals(5, workspace.getCyclesCount());
    for (int i = 0; i < 12; i++) {
        try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(adsiConfiguration, "ADSI")) {
            // allocating same shape
            array1 = Nd4j.create(8, 128, 100);
        }
    }
    // Now we know that workspace was reallocated and offset was shifted to the end of workspace
    assertEquals(4, workspace.getStepNumber());
    requiredMemory = 8 * 128 * 200 * Nd4j.sizeOfDataType();
    shiftedSize = ((long) (requiredMemory * 1.3)) + (8 - (((long) (requiredMemory * 1.3)) % 8));
    assertEquals(shiftedSize * 4, workspace.getCurrentSize());
    assertEquals(workspace.getCurrentSize(), workspace.getHostOffset());
    assertEquals(workspace.getCurrentSize(), workspace.getDeviceOffset());
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) MemoryWorkspace(org.nd4j.linalg.api.memory.MemoryWorkspace) Nd4jWorkspace(org.nd4j.linalg.memory.abstracts.Nd4jWorkspace) Test(org.junit.Test) BaseNd4jTest(org.nd4j.linalg.BaseNd4jTest)

Example 60 with MemoryWorkspace

use of org.nd4j.linalg.api.memory.MemoryWorkspace in project nd4j by deeplearning4j.

the class WorkspaceProviderTests method testReallocate1.

@Test
public void testReallocate1() throws Exception {
    try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(reallocateConfiguration, "WS_1")) {
        INDArray array = Nd4j.create(100);
    }
    Nd4jWorkspace workspace = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(reallocateConfiguration, "WS_1");
    workspace.initializeWorkspace();
    assertEquals(100 * Nd4j.sizeOfDataType(), workspace.getCurrentSize());
    try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(reallocateConfiguration, "WS_1")) {
        INDArray array = Nd4j.create(1000);
    }
    assertEquals(1000 * Nd4j.sizeOfDataType(), workspace.getMaxCycleAllocations());
    workspace.initializeWorkspace();
    assertEquals(1000 * Nd4j.sizeOfDataType(), workspace.getCurrentSize());
    // now we're working on reallocated array, that should be able to hold >100 elements
    try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(reallocateConfiguration, "WS_1")) {
        INDArray array = Nd4j.create(500).assign(1.0);
        assertEquals(1.0, array.meanNumber().doubleValue(), 0.01);
    }
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) MemoryWorkspace(org.nd4j.linalg.api.memory.MemoryWorkspace) Nd4jWorkspace(org.nd4j.linalg.memory.abstracts.Nd4jWorkspace) Test(org.junit.Test) BaseNd4jTest(org.nd4j.linalg.BaseNd4jTest)

Aggregations

MemoryWorkspace (org.nd4j.linalg.api.memory.MemoryWorkspace)62 Test (org.junit.Test)39 BaseNd4jTest (org.nd4j.linalg.BaseNd4jTest)35 INDArray (org.nd4j.linalg.api.ndarray.INDArray)35 Nd4jWorkspace (org.nd4j.linalg.memory.abstracts.Nd4jWorkspace)18 WorkspaceConfiguration (org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration)14 DataBuffer (org.nd4j.linalg.api.buffer.DataBuffer)6 AtomicLong (java.util.concurrent.atomic.AtomicLong)5 AllocationPoint (org.nd4j.jita.allocator.impl.AllocationPoint)4 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)3 CudaContext (org.nd4j.linalg.jcublas.context.CudaContext)3 ByteArrayInputStream (java.io.ByteArrayInputStream)2 ByteArrayOutputStream (java.io.ByteArrayOutputStream)2 DataInputStream (java.io.DataInputStream)2 DataOutputStream (java.io.DataOutputStream)2 File (java.io.File)2 ArrayList (java.util.ArrayList)2 Ignore (org.junit.Ignore)2 IOException (java.io.IOException)1 CopyOnWriteArrayList (java.util.concurrent.CopyOnWriteArrayList)1