use of org.nd4j.linalg.memory.abstracts.Nd4jWorkspace in project nd4j by deeplearning4j.
the class BasicWorkspaceTests method testLoop4.
@Test
public void testLoop4() throws Exception {
Nd4jWorkspace workspace = (Nd4jWorkspace) Nd4j.getWorkspaceManager().createNewWorkspace(loopFirstConfig);
Nd4j.getMemoryManager().setCurrentWorkspace(workspace);
assertNotEquals(null, Nd4j.getMemoryManager().getCurrentWorkspace());
assertEquals(0, workspace.getHostOffset());
try (MemoryWorkspace cW = workspace.notifyScopeEntered()) {
INDArray array1 = Nd4j.create(100);
INDArray array2 = Nd4j.create(100);
}
assertEquals(0, workspace.getHostOffset());
assertEquals(200 * Nd4j.sizeOfDataType(), workspace.getCurrentSize());
try (MemoryWorkspace cW = workspace.notifyScopeEntered()) {
INDArray array1 = Nd4j.create(100);
assertEquals(100 * Nd4j.sizeOfDataType(), workspace.getHostOffset());
}
assertEquals(0, workspace.getHostOffset());
}
use of org.nd4j.linalg.memory.abstracts.Nd4jWorkspace in project nd4j by deeplearning4j.
the class BasicWorkspaceTests method testAllocation4.
@Test
public void testAllocation4() throws Exception {
WorkspaceConfiguration failConfig = WorkspaceConfiguration.builder().initialSize(1024 * 1024).maxSize(1024 * 1024).overallocationLimit(0.1).policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.FIRST_LOOP).policyMirroring(MirroringPolicy.FULL).policySpill(SpillPolicy.FAIL).build();
Nd4jWorkspace workspace = (Nd4jWorkspace) Nd4j.getWorkspaceManager().createNewWorkspace(failConfig);
Nd4j.getMemoryManager().setCurrentWorkspace(workspace);
assertNotEquals(null, Nd4j.getMemoryManager().getCurrentWorkspace());
assertEquals(0, workspace.getHostOffset());
INDArray array = Nd4j.create(new int[] { 1, 5 }, 'c');
// checking if allocation actually happened
long reqMem = 5 * Nd4j.sizeOfDataType();
assertEquals(reqMem + reqMem % 8, workspace.getHostOffset());
try {
INDArray array2 = Nd4j.create(10000000);
assertTrue(false);
} catch (ND4JIllegalStateException e) {
assertTrue(true);
}
assertEquals(reqMem + reqMem % 8, workspace.getHostOffset());
INDArray array2 = Nd4j.create(new int[] { 1, 5 }, 'c');
assertEquals((reqMem + reqMem % 8) * 2, workspace.getHostOffset());
}
use of org.nd4j.linalg.memory.abstracts.Nd4jWorkspace in project nd4j by deeplearning4j.
the class BasicWorkspaceTests method testLeverageTo2.
@Test
public void testLeverageTo2() throws Exception {
try (Nd4jWorkspace wsOne = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getAndActivateWorkspace(loopOverTimeConfig, "EXT")) {
INDArray array1 = Nd4j.create(new float[] { 1f, 2f, 3f, 4f, 5f });
INDArray array3 = null;
try (Nd4jWorkspace wsTwo = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfig, "INT")) {
INDArray array2 = Nd4j.create(new float[] { 1f, 2f, 3f, 4f, 5f });
long reqMemory = 5 * Nd4j.sizeOfDataType();
array3 = array2.leverageTo("EXT");
assertEquals(0, wsOne.getCurrentSize());
assertEquals(15f, array3.sumNumber().floatValue(), 0.01f);
}
try (Nd4jWorkspace wsTwo = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfig, "INT")) {
INDArray array2 = Nd4j.create(100);
}
assertEquals(15f, array3.sumNumber().floatValue(), 0.01f);
}
}
use of org.nd4j.linalg.memory.abstracts.Nd4jWorkspace in project nd4j by deeplearning4j.
the class BasicWorkspaceTests method testIsAttached1.
@Test
public void testIsAttached1() {
try (Nd4jWorkspace wsI = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getAndActivateWorkspace(loopFirstConfig, "ITER")) {
INDArray array = Nd4j.create(100);
assertTrue(array.isAttached());
}
INDArray array = Nd4j.create(100);
assertFalse(array.isAttached());
}
use of org.nd4j.linalg.memory.abstracts.Nd4jWorkspace in project nd4j by deeplearning4j.
the class BasicWorkspaceTests method testNoShape1.
@Test
public void testNoShape1() {
int outDepth = 50;
int miniBatch = 64;
int outH = 8;
int outW = 8;
try (Nd4jWorkspace wsI = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfig, "ITER")) {
INDArray delta = Nd4j.create(new int[] { 50, 64, 8, 8 }, new int[] { 64, 3200, 8, 1 }, 'c');
delta = delta.permute(1, 0, 2, 3);
assertArrayEquals(new int[] { 64, 50, 8, 8 }, delta.shape());
assertArrayEquals(new int[] { 3200, 64, 8, 1 }, delta.stride());
INDArray delta2d = Shape.newShapeNoCopy(delta, new int[] { outDepth, miniBatch * outH * outW }, false);
assertNotNull(delta2d);
}
}
Aggregations