use of org.nd4j.linalg.api.memory.MemoryWorkspace in project nd4j by deeplearning4j.
the class BasicWorkspaceManager method destroyAllWorkspacesForCurrentThread.
/**
* This method destroys all workspaces allocated in current thread
*/
@Override
public void destroyAllWorkspacesForCurrentThread() {
ensureThreadExistense();
List<MemoryWorkspace> workspaces = new ArrayList<>();
workspaces.addAll(backingMap.get().values());
for (MemoryWorkspace workspace : workspaces) {
destroyWorkspace(workspace);
}
System.gc();
}
use of org.nd4j.linalg.api.memory.MemoryWorkspace in project nd4j by deeplearning4j.
the class BasicWorkspaceManager method destroyWorkspace.
/**
* This method destroy default workspace, if any
*/
@Override
public void destroyWorkspace() {
ensureThreadExistense();
MemoryWorkspace workspace = backingMap.get().get(MemoryWorkspace.DEFAULT_ID);
// if (workspace != null)
// workspace.destroyWorkspace();
backingMap.get().remove(MemoryWorkspace.DEFAULT_ID);
}
use of org.nd4j.linalg.api.memory.MemoryWorkspace in project nd4j by deeplearning4j.
the class WorkspaceProviderTests method testWorkspacesSerde1.
@Test
public void testWorkspacesSerde1() throws Exception {
int[] shape = new int[] { 17, 57, 79 };
INDArray array = Nd4j.create(shape).assign(1.0);
INDArray restored = null;
ByteArrayOutputStream bos = new ByteArrayOutputStream();
DataOutputStream dos = new DataOutputStream(bos);
Nd4j.write(array, dos);
try (MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace(bigConfiguration, "WS_1")) {
ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray());
DataInputStream dis = new DataInputStream(bis);
restored = Nd4j.read(dis);
assertEquals(array.length(), restored.length());
assertEquals(1.0f, restored.meanNumber().floatValue(), 1.0f);
// we want to ensure it's the same cached shapeInfo used here
assertTrue(array.shapeInfoDataBuffer() == restored.shapeInfoDataBuffer());
}
}
use of org.nd4j.linalg.api.memory.MemoryWorkspace in project nd4j by deeplearning4j.
the class WorkspaceProviderTests method testCircularLearning1.
@Test
public void testCircularLearning1() throws Exception {
INDArray array1;
INDArray array2;
for (int i = 0; i < 2; i++) {
try (MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace(circularConfiguration, "WSX")) {
array1 = Nd4j.create(10).assign(1);
}
Nd4jWorkspace workspace = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(circularConfiguration, "WSX");
assertEquals(10 * 1024 * 1024L, workspace.getCurrentSize());
log.info("Current step number: {}", workspace.getStepNumber());
if (i == 0)
assertEquals(0, workspace.getHostOffset());
else if (i == 1)
assertEquals(workspace.getInitialBlockSize(), workspace.getHostOffset());
}
}
use of org.nd4j.linalg.api.memory.MemoryWorkspace in project nd4j by deeplearning4j.
the class WorkspaceProviderTests method testNestedWorkspaces7.
@Test
public void testNestedWorkspaces7() throws Exception {
try (Nd4jWorkspace wsExternal = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfiguration, "External")) {
INDArray array1 = Nd4j.create(10);
INDArray array2 = null;
INDArray array3 = null;
INDArray array4 = null;
INDArray array5 = null;
try (Nd4jWorkspace wsFeedForward = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfiguration, "FeedForward")) {
array2 = Nd4j.create(10);
assertEquals(true, array2.isAttached());
try (Nd4jWorkspace borrowed = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("External").notifyScopeBorrowed()) {
array3 = Nd4j.create(10);
assertTrue(wsExternal == array3.data().getParentWorkspace());
try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
array4 = Nd4j.create(10);
}
array5 = Nd4j.create(10);
log.info("Workspace5: {}", array5.data().getParentWorkspace());
assertTrue(null == array4.data().getParentWorkspace());
assertFalse(array4.isAttached());
assertTrue(wsExternal == array5.data().getParentWorkspace());
}
assertEquals(true, array3.isAttached());
assertEquals(false, array4.isAttached());
assertEquals(true, array5.isAttached());
}
}
}
Aggregations