use of org.nd4j.linalg.memory.abstracts.Nd4jWorkspace in project nd4j by deeplearning4j.
the class BasicWorkspaceTests method testIsAttached3.
@Test
public void testIsAttached3() {
INDArray array = Nd4j.create(100);
try (Nd4jWorkspace wsI = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfig, "ITER")) {
INDArray arrayL = array.leverageTo("ITER");
assertFalse(array.isAttached());
assertFalse(arrayL.isAttached());
}
INDArray array2 = Nd4j.create(100);
assertFalse(array.isAttached());
assertFalse(array2.isAttached());
}
use of org.nd4j.linalg.memory.abstracts.Nd4jWorkspace in project nd4j by deeplearning4j.
the class BasicWorkspaceTests method testAllocation2.
@Test
public void testAllocation2() throws Exception {
Nd4jWorkspace workspace = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfig, "testAllocation2");
Nd4j.getMemoryManager().setCurrentWorkspace(workspace);
assertNotEquals(null, Nd4j.getMemoryManager().getCurrentWorkspace());
assertEquals(0, workspace.getHostOffset());
INDArray array = Nd4j.create(5);
// checking if allocation actually happened
long reqMem = 5 * Nd4j.sizeOfDataType();
assertEquals(reqMem + reqMem % 8, workspace.getHostOffset());
array.assign(1.0f);
assertEquals(5, array.sumNumber().doubleValue(), 0.01);
workspace.close();
}
use of org.nd4j.linalg.memory.abstracts.Nd4jWorkspace in project nd4j by deeplearning4j.
the class BasicWorkspaceTests method testOverallocation3.
@Test
public void testOverallocation3() throws Exception {
WorkspaceConfiguration overallocationConfig = WorkspaceConfiguration.builder().initialSize(0).maxSize(10 * 1024 * 1024).overallocationLimit(1.0).policyAllocation(AllocationPolicy.OVERALLOCATE).policyLearning(LearningPolicy.OVER_TIME).policyMirroring(MirroringPolicy.FULL).policySpill(SpillPolicy.EXTERNAL).build();
Nd4jWorkspace workspace = (Nd4jWorkspace) Nd4j.getWorkspaceManager().createNewWorkspace(overallocationConfig);
Nd4j.getMemoryManager().setCurrentWorkspace(workspace);
assertEquals(0, workspace.getCurrentSize());
for (int x = 10; x <= 100; x += 10) {
try (MemoryWorkspace cW = workspace.notifyScopeEntered()) {
INDArray array = Nd4j.create(x);
}
}
assertEquals(0, workspace.getCurrentSize());
workspace.initializeWorkspace();
// should be 800 = 100 elements * 4 bytes per element * 2 as overallocation coefficient
assertEquals(200 * Nd4j.sizeOfDataType(), workspace.getCurrentSize());
}
use of org.nd4j.linalg.memory.abstracts.Nd4jWorkspace in project nd4j by deeplearning4j.
the class BasicWorkspaceTests method testLeverage3.
@Test
public void testLeverage3() throws Exception {
try (Nd4jWorkspace wsOne = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfig, "EXT")) {
INDArray array = null;
try (Nd4jWorkspace wsTwo = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfig, "INT")) {
INDArray matrix = Nd4j.create(32, 1, 40);
INDArray view = matrix.tensorAlongDimension(0, 1, 2);
view.assign(1.0f);
assertEquals(40.0f, matrix.sumNumber().floatValue(), 0.01f);
assertEquals(40.0f, view.sumNumber().floatValue(), 0.01f);
array = view.leverageTo("EXT");
}
assertEquals(40.0f, array.sumNumber().floatValue(), 0.01f);
}
}
use of org.nd4j.linalg.memory.abstracts.Nd4jWorkspace in project nd4j by deeplearning4j.
the class BasicWorkspaceTests method testOutOfScope1.
@Test
public void testOutOfScope1() throws Exception {
try (Nd4jWorkspace wsOne = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfig, "EXT")) {
INDArray array1 = Nd4j.create(new float[] { 1f, 2f, 3f, 4f, 5f });
long reqMemory = 5 * Nd4j.sizeOfDataType();
assertEquals(reqMemory + reqMemory % 8, wsOne.getHostOffset());
INDArray array2;
try (MemoryWorkspace workspace = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
array2 = Nd4j.create(new float[] { 1f, 2f, 3f, 4f, 5f });
}
assertFalse(array2.isAttached());
log.info("Current workspace: {}", Nd4j.getMemoryManager().getCurrentWorkspace());
assertTrue(wsOne == Nd4j.getMemoryManager().getCurrentWorkspace());
INDArray array3 = Nd4j.create(new float[] { 1f, 2f, 3f, 4f, 5f });
reqMemory = 5 * Nd4j.sizeOfDataType();
assertEquals((reqMemory + reqMemory % 8) * 2, wsOne.getHostOffset());
array1.addi(array2);
assertEquals(30.0f, array1.sumNumber().floatValue(), 0.01f);
}
}
Aggregations