use of org.nd4j.linalg.memory.abstracts.Nd4jWorkspace in project nd4j by deeplearning4j.
the class WorkspaceProviderTests method testNestedWorkspacesOverlap2.
@Test
public void testNestedWorkspacesOverlap2() throws Exception {
Nd4j.getWorkspaceManager().setDefaultWorkspaceConfiguration(basicConfiguration);
assertFalse(Nd4j.getWorkspaceManager().checkIfWorkspaceExists("WS1"));
assertFalse(Nd4j.getWorkspaceManager().checkIfWorkspaceExists("WS2"));
try (Nd4jWorkspace ws1 = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("WS1").notifyScopeEntered()) {
INDArray array = Nd4j.create(new float[] { 6f, 3f, 1f, 9f, 21f });
INDArray array3 = null;
long reqMem = 5 * Nd4j.sizeOfDataType();
assertEquals(reqMem + reqMem % 8, ws1.getHostOffset());
try (Nd4jWorkspace ws2 = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("WS2").notifyScopeEntered()) {
INDArray array2 = Nd4j.create(new float[] { 1f, 2f, 3f, 4f, 5f });
reqMem = 5 * Nd4j.sizeOfDataType();
assertEquals(reqMem + reqMem % 8, ws1.getHostOffset());
assertEquals(reqMem + reqMem % 8, ws2.getHostOffset());
try (Nd4jWorkspace ws3 = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("WS1").notifyScopeBorrowed()) {
assertTrue(ws1 == ws3);
assertTrue(ws1 == Nd4j.getMemoryManager().getCurrentWorkspace());
array3 = array2.unsafeDuplication();
assertTrue(ws1 == array3.data().getParentWorkspace());
assertEquals(reqMem + reqMem % 8, ws2.getHostOffset());
assertEquals((reqMem + reqMem % 8) * 2, ws1.getHostOffset());
}
log.info("Current workspace: {}", Nd4j.getMemoryManager().getCurrentWorkspace());
assertTrue(ws2 == Nd4j.getMemoryManager().getCurrentWorkspace());
assertEquals(reqMem + reqMem % 8, ws2.getHostOffset());
assertEquals((reqMem + reqMem % 8) * 2, ws1.getHostOffset());
assertEquals(15f, array3.sumNumber().floatValue(), 0.01f);
}
}
log.info("------");
assertNull(Nd4j.getMemoryManager().getCurrentWorkspace());
}
use of org.nd4j.linalg.memory.abstracts.Nd4jWorkspace in project nd4j by deeplearning4j.
the class WorkspaceProviderTests method testUnboundedLoop2.
/**
* This simple test checks for over-time learning with coefficient applied
*
* @throws Exception
*/
@Test
public void testUnboundedLoop2() throws Exception {
WorkspaceConfiguration configuration = WorkspaceConfiguration.builder().initialSize(0).policyReset(ResetPolicy.ENDOFBUFFER_REACHED).policyAllocation(AllocationPolicy.OVERALLOCATE).overallocationLimit(4.0).policyLearning(LearningPolicy.OVER_TIME).cyclesBeforeInitialization(5).build();
Nd4jWorkspace ws1 = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(configuration, "ITER");
long requiredMemory = 100 * Nd4j.sizeOfDataType();
long shiftedSize = ((long) (requiredMemory * 1.3)) + (8 - (((long) (requiredMemory * 1.3)) % 8));
for (int x = 0; x < 100; x++) {
try (Nd4jWorkspace wsI = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(configuration, "ITER").notifyScopeEntered()) {
INDArray array = Nd4j.create(100);
}
// only checking after workspace is initialized
if (x > 4) {
assertEquals(shiftedSize, ws1.getInitialBlockSize());
assertEquals(5 * shiftedSize, ws1.getCurrentSize());
} else if (x < 4) {
// we're making sure we're not initialize early
assertEquals("Failed on iteration " + x, 0, ws1.getCurrentSize());
}
}
// maximum allocation amount is 100 elements during learning, and additional coefficient is 4.0. result is workspace of 500 elements
assertEquals(5 * shiftedSize, ws1.getCurrentSize());
assertNull(Nd4j.getMemoryManager().getCurrentWorkspace());
}
use of org.nd4j.linalg.memory.abstracts.Nd4jWorkspace in project nd4j by deeplearning4j.
the class WorkspaceProviderTests method testCircularLearning1.
@Test
public void testCircularLearning1() throws Exception {
INDArray array1;
INDArray array2;
for (int i = 0; i < 2; i++) {
try (MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace(circularConfiguration, "WSX")) {
array1 = Nd4j.create(10).assign(1);
}
Nd4jWorkspace workspace = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(circularConfiguration, "WSX");
assertEquals(10 * 1024 * 1024L, workspace.getCurrentSize());
log.info("Current step number: {}", workspace.getStepNumber());
if (i == 0)
assertEquals(0, workspace.getHostOffset());
else if (i == 1)
assertEquals(workspace.getInitialBlockSize(), workspace.getHostOffset());
}
}
use of org.nd4j.linalg.memory.abstracts.Nd4jWorkspace in project nd4j by deeplearning4j.
the class WorkspaceProviderTests method testNestedWorkspaces7.
@Test
public void testNestedWorkspaces7() throws Exception {
try (Nd4jWorkspace wsExternal = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfiguration, "External")) {
INDArray array1 = Nd4j.create(10);
INDArray array2 = null;
INDArray array3 = null;
INDArray array4 = null;
INDArray array5 = null;
try (Nd4jWorkspace wsFeedForward = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfiguration, "FeedForward")) {
array2 = Nd4j.create(10);
assertEquals(true, array2.isAttached());
try (Nd4jWorkspace borrowed = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("External").notifyScopeBorrowed()) {
array3 = Nd4j.create(10);
assertTrue(wsExternal == array3.data().getParentWorkspace());
try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
array4 = Nd4j.create(10);
}
array5 = Nd4j.create(10);
log.info("Workspace5: {}", array5.data().getParentWorkspace());
assertTrue(null == array4.data().getParentWorkspace());
assertFalse(array4.isAttached());
assertTrue(wsExternal == array5.data().getParentWorkspace());
}
assertEquals(true, array3.isAttached());
assertEquals(false, array4.isAttached());
assertEquals(true, array5.isAttached());
}
}
}
use of org.nd4j.linalg.memory.abstracts.Nd4jWorkspace in project nd4j by deeplearning4j.
the class WorkspaceProviderTests method testNestedWorkspacesOverlap1.
@Test
public void testNestedWorkspacesOverlap1() throws Exception {
Nd4j.getWorkspaceManager().setDefaultWorkspaceConfiguration(basicConfiguration);
try (Nd4jWorkspace ws1 = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("WS1").notifyScopeEntered()) {
INDArray array = Nd4j.create(new float[] { 1f, 2f, 3f, 4f, 5f });
long reqMem = 5 * Nd4j.sizeOfDataType();
assertEquals(reqMem + reqMem % 8, ws1.getHostOffset());
try (Nd4jWorkspace ws2 = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("WS2").notifyScopeEntered()) {
INDArray array2 = Nd4j.create(new float[] { 1f, 2f, 3f, 4f, 5f });
reqMem = 5 * Nd4j.sizeOfDataType();
assertEquals(reqMem + reqMem % 8, ws1.getHostOffset());
assertEquals(reqMem + reqMem % 8, ws2.getHostOffset());
try (Nd4jWorkspace ws3 = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("WS1").notifyScopeBorrowed()) {
assertTrue(ws1 == ws3);
INDArray array3 = Nd4j.create(new float[] { 1f, 2f, 3f, 4f, 5f });
assertEquals(reqMem + reqMem % 8, ws2.getHostOffset());
assertEquals((reqMem + reqMem % 8) * 2, ws1.getHostOffset());
}
}
}
assertNull(Nd4j.getMemoryManager().getCurrentWorkspace());
}
Aggregations