Search in sources :

Example 46 with MemoryWorkspace

use of org.nd4j.linalg.api.memory.MemoryWorkspace in project nd4j by deeplearning4j.

the class BasicWorkspaceTests method testOutOfScope1.

@Test
public void testOutOfScope1() throws Exception {
    try (Nd4jWorkspace wsOne = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfig, "EXT")) {
        INDArray array1 = Nd4j.create(new float[] { 1f, 2f, 3f, 4f, 5f });
        long reqMemory = 5 * Nd4j.sizeOfDataType();
        assertEquals(reqMemory + reqMemory % 8, wsOne.getHostOffset());
        INDArray array2;
        try (MemoryWorkspace workspace = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
            array2 = Nd4j.create(new float[] { 1f, 2f, 3f, 4f, 5f });
        }
        assertFalse(array2.isAttached());
        log.info("Current workspace: {}", Nd4j.getMemoryManager().getCurrentWorkspace());
        assertTrue(wsOne == Nd4j.getMemoryManager().getCurrentWorkspace());
        INDArray array3 = Nd4j.create(new float[] { 1f, 2f, 3f, 4f, 5f });
        reqMemory = 5 * Nd4j.sizeOfDataType();
        assertEquals((reqMemory + reqMemory % 8) * 2, wsOne.getHostOffset());
        array1.addi(array2);
        assertEquals(30.0f, array1.sumNumber().floatValue(), 0.01f);
    }
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) MemoryWorkspace(org.nd4j.linalg.api.memory.MemoryWorkspace) Nd4jWorkspace(org.nd4j.linalg.memory.abstracts.Nd4jWorkspace) Test(org.junit.Test) BaseNd4jTest(org.nd4j.linalg.BaseNd4jTest)

Example 47 with MemoryWorkspace

use of org.nd4j.linalg.api.memory.MemoryWorkspace in project nd4j by deeplearning4j.

the class EndlessWorkspaceTests method testPerf1.

@Test
public void testPerf1() throws Exception {
    Nd4j.getWorkspaceManager().setDefaultWorkspaceConfiguration(WorkspaceConfiguration.builder().initialSize(50000L).build());
    MemoryWorkspace ws = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("WS_1");
    INDArray tmp = Nd4j.create(64 * 64 + 1);
    // Nd4j.getMemoryManager().togglePeriodicGc(true);
    List<Long> results = new ArrayList<>();
    List<Long> resultsOp = new ArrayList<>();
    for (int i = 0; i < 1000000; i++) {
        long time1 = System.nanoTime();
        long time3 = 0;
        long time4 = 0;
        // MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace("WS_1");
        try (MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace("WS_1")) {
            INDArray array = Nd4j.createUninitialized(64 * 64 + 1);
            INDArray arrayx = Nd4j.createUninitialized(64 * 64 + 1);
            time3 = System.nanoTime();
            arrayx.addi(1.01);
            time4 = System.nanoTime();
        }
        // workspace.notifyScopeLeft();
        long time2 = System.nanoTime();
        results.add(time2 - time1);
        resultsOp.add(time4 - time3);
    }
    Collections.sort(results);
    Collections.sort(resultsOp);
    int pos = (int) (results.size() * 0.9);
    log.info("Block: {} ns; Op: {} ns;", results.get(pos), resultsOp.get(pos));
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) AtomicLong(java.util.concurrent.atomic.AtomicLong) ArrayList(java.util.ArrayList) MemoryWorkspace(org.nd4j.linalg.api.memory.MemoryWorkspace) Test(org.junit.Test) BaseNd4jTest(org.nd4j.linalg.BaseNd4jTest)

Example 48 with MemoryWorkspace

use of org.nd4j.linalg.api.memory.MemoryWorkspace in project nd4j by deeplearning4j.

the class EndlessWorkspaceTests method endlessTest1.

/**
 * This test checks for allocations within single workspace, without any spills
 *
 * @throws Exception
 */
@Test
public void endlessTest1() throws Exception {
    Nd4j.getWorkspaceManager().setDefaultWorkspaceConfiguration(WorkspaceConfiguration.builder().initialSize(100 * 1024L * 1024L).build());
    Nd4j.getMemoryManager().togglePeriodicGc(false);
    AtomicLong counter = new AtomicLong(0);
    while (true) {
        try (MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace()) {
            long time1 = System.nanoTime();
            INDArray array = Nd4j.create(1024 * 1024);
            long time2 = System.nanoTime();
            array.addi(1.0f);
            assertEquals(1.0f, array.meanNumber().floatValue(), 0.1f);
            if (counter.incrementAndGet() % 1000 == 0)
                log.info("{} iterations passed... Allocation time: {} ns", counter.get(), time2 - time1);
        }
    }
}
Also used : AtomicLong(java.util.concurrent.atomic.AtomicLong) INDArray(org.nd4j.linalg.api.ndarray.INDArray) MemoryWorkspace(org.nd4j.linalg.api.memory.MemoryWorkspace) Test(org.junit.Test) BaseNd4jTest(org.nd4j.linalg.BaseNd4jTest)

Example 49 with MemoryWorkspace

use of org.nd4j.linalg.api.memory.MemoryWorkspace in project nd4j by deeplearning4j.

the class EndlessWorkspaceTests method endlessTest6.

@Test
public void endlessTest6() throws Exception {
    Nd4j.getMemoryManager().togglePeriodicGc(false);
    WorkspaceConfiguration wsConf = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L).policyLearning(LearningPolicy.NONE).build();
    final AtomicLong cnt = new AtomicLong(0);
    while (true) {
        try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(wsConf, "PEW-PEW")) {
            INDArray array = Nd4j.create(new float[] { 1f, 2f, 3f, 4f, 5f });
        }
        if (cnt.incrementAndGet() % 1000000 == 0)
            log.info("TotalBytes: {}", Pointer.totalBytes());
    }
}
Also used : AtomicLong(java.util.concurrent.atomic.AtomicLong) INDArray(org.nd4j.linalg.api.ndarray.INDArray) WorkspaceConfiguration(org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration) MemoryWorkspace(org.nd4j.linalg.api.memory.MemoryWorkspace) Test(org.junit.Test) BaseNd4jTest(org.nd4j.linalg.BaseNd4jTest)

Example 50 with MemoryWorkspace

use of org.nd4j.linalg.api.memory.MemoryWorkspace in project nd4j by deeplearning4j.

the class EndlessWorkspaceTests method endlessTest3.

/**
 * This endless test checks for nested workspaces and cross-workspace memory use
 *
 * @throws Exception
 */
@Test
public void endlessTest3() throws Exception {
    Nd4j.getWorkspaceManager().setDefaultWorkspaceConfiguration(WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L).build());
    Nd4j.getMemoryManager().togglePeriodicGc(false);
    AtomicLong counter = new AtomicLong(0);
    while (true) {
        try (MemoryWorkspace workspace1 = Nd4j.getWorkspaceManager().getAndActivateWorkspace("WS_1")) {
            INDArray array1 = Nd4j.create(2 * 1024 * 1024);
            array1.assign(1.0);
            try (MemoryWorkspace workspace2 = Nd4j.getWorkspaceManager().getAndActivateWorkspace("WS_2")) {
                INDArray array2 = Nd4j.create(2 * 1024 * 1024);
                array2.assign(1.0);
                array1.addi(array2);
                assertEquals(2.0f, array1.meanNumber().floatValue(), 0.01);
                if (counter.incrementAndGet() % 1000 == 0) {
                    log.info("{} iterations passed...", counter.get());
                    System.gc();
                }
            }
        }
    }
}
Also used : AtomicLong(java.util.concurrent.atomic.AtomicLong) INDArray(org.nd4j.linalg.api.ndarray.INDArray) MemoryWorkspace(org.nd4j.linalg.api.memory.MemoryWorkspace) Test(org.junit.Test) BaseNd4jTest(org.nd4j.linalg.BaseNd4jTest)

Aggregations

MemoryWorkspace (org.nd4j.linalg.api.memory.MemoryWorkspace)62 Test (org.junit.Test)39 BaseNd4jTest (org.nd4j.linalg.BaseNd4jTest)35 INDArray (org.nd4j.linalg.api.ndarray.INDArray)35 Nd4jWorkspace (org.nd4j.linalg.memory.abstracts.Nd4jWorkspace)18 WorkspaceConfiguration (org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration)14 DataBuffer (org.nd4j.linalg.api.buffer.DataBuffer)6 AtomicLong (java.util.concurrent.atomic.AtomicLong)5 AllocationPoint (org.nd4j.jita.allocator.impl.AllocationPoint)4 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)3 CudaContext (org.nd4j.linalg.jcublas.context.CudaContext)3 ByteArrayInputStream (java.io.ByteArrayInputStream)2 ByteArrayOutputStream (java.io.ByteArrayOutputStream)2 DataInputStream (java.io.DataInputStream)2 DataOutputStream (java.io.DataOutputStream)2 File (java.io.File)2 ArrayList (java.util.ArrayList)2 Ignore (org.junit.Ignore)2 IOException (java.io.IOException)1 CopyOnWriteArrayList (java.util.concurrent.CopyOnWriteArrayList)1