Search in sources :

Example 31 with MemoryWorkspace

use of org.nd4j.linalg.api.memory.MemoryWorkspace in project nd4j by deeplearning4j.

the class BaseNDArray method detach.

/**
 * This metod detaches INDArray from Workspace, returning copy. Basically it's dup() into new memory chunk.
 * <p>
 * PLEASE NOTE: If this INDArray instance is NOT attached - it will be returned unmodified.
 *
 * @return
 */
@Override
public INDArray detach() {
    if (!isAttached())
        return this;
    Nd4j.getExecutioner().commit();
    /*
         two options here
         1) we're within some workspace
         2) we're out of any workspace
        */
    if (Nd4j.getMemoryManager().getCurrentWorkspace() == null) {
        if (!isView()) {
            Nd4j.getExecutioner().commit();
            DataBuffer buffer = Nd4j.createBuffer(this.lengthLong(), false);
            Nd4j.getMemoryManager().memcpy(buffer, this.data());
            return Nd4j.createArrayFromShapeBuffer(buffer, this.shapeInfoDataBuffer());
        } else {
            INDArray copy = Nd4j.createUninitialized(this.shape(), this.ordering());
            copy.assign(this);
            Nd4j.getExecutioner().commit();
            return copy;
        }
    } else {
        MemoryWorkspace workspace = Nd4j.getMemoryManager().getCurrentWorkspace();
        Nd4j.getMemoryManager().setCurrentWorkspace(null);
        INDArray copy = null;
        if (!isView()) {
            Nd4j.getExecutioner().commit();
            DataBuffer buffer = Nd4j.createBuffer(this.lengthLong(), false);
            // Pointer.memcpy(buffer.pointer(), this.data.pointer(), this.lengthLong() * Nd4j.sizeOfDataType(this.data.dataType()));
            Nd4j.getMemoryManager().memcpy(buffer, this.data());
            // this.dup(this.ordering());
            copy = Nd4j.createArrayFromShapeBuffer(buffer, this.shapeInfoDataBuffer());
        } else {
            copy = Nd4j.createUninitialized(this.shape(), this.ordering());
            copy.assign(this);
            Nd4j.getExecutioner().commit();
        }
        Nd4j.getMemoryManager().setCurrentWorkspace(workspace);
        return copy;
    }
}
Also used : MemoryWorkspace(org.nd4j.linalg.api.memory.MemoryWorkspace) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer)

Example 32 with MemoryWorkspace

use of org.nd4j.linalg.api.memory.MemoryWorkspace in project nd4j by deeplearning4j.

the class BaseNDArray method migrate.

/**
 * This method pulls this INDArray into current Workspace, or optionally detaches if no workspace is present.<br>
 * That is:<br>
 * If current workspace is present/active, INDArray is migrated to it.<br>
 * If no current workspace is present/active, one of two things occur:
 * 1. If detachOnNoWs arg is true: if there is no current workspace, INDArray is detached
 * 2. If detachOnNoWs arg is false: this INDArray is returned as-is (no-op) - equivalent to {@link #migrate()}
 *
 * @param detachOnNoWs If true: detach on no WS. If false and no workspace: return this.
 * @return Migrated INDArray
 */
@Override
public INDArray migrate(boolean detachOnNoWs) {
    MemoryWorkspace current = Nd4j.getMemoryManager().getCurrentWorkspace();
    if (current == null) {
        if (detachOnNoWs) {
            return detach();
        } else {
            return this;
        }
    }
    INDArray copy = null;
    if (!this.isView()) {
        Nd4j.getExecutioner().commit();
        DataBuffer buffer = Nd4j.createBuffer(this.lengthLong(), false);
        Nd4j.getMemoryManager().memcpy(buffer, this.data());
        copy = Nd4j.createArrayFromShapeBuffer(buffer, this.shapeInfoDataBuffer());
    } else {
        copy = this.dup(this.ordering());
        Nd4j.getExecutioner().commit();
    }
    return copy;
}
Also used : MemoryWorkspace(org.nd4j.linalg.api.memory.MemoryWorkspace) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer)

Example 33 with MemoryWorkspace

use of org.nd4j.linalg.api.memory.MemoryWorkspace in project nd4j by deeplearning4j.

the class CudaWorkspaceManager method createNewWorkspace.

@Override
public MemoryWorkspace createNewWorkspace(WorkspaceConfiguration configuration, String id) {
    ensureThreadExistense();
    MemoryWorkspace workspace = new CudaWorkspace(configuration, id);
    backingMap.get().put(id, workspace);
    pickReference(workspace);
    return workspace;
}
Also used : MemoryWorkspace(org.nd4j.linalg.api.memory.MemoryWorkspace)

Example 34 with MemoryWorkspace

use of org.nd4j.linalg.api.memory.MemoryWorkspace in project nd4j by deeplearning4j.

the class CudaWorkspaceManager method createNewWorkspace.

@Override
public MemoryWorkspace createNewWorkspace() {
    ensureThreadExistense();
    MemoryWorkspace workspace = new CudaWorkspace(defaultConfiguration);
    backingMap.get().put(workspace.getId(), workspace);
    pickReference(workspace);
    return workspace;
}
Also used : MemoryWorkspace(org.nd4j.linalg.api.memory.MemoryWorkspace)

Example 35 with MemoryWorkspace

use of org.nd4j.linalg.api.memory.MemoryWorkspace in project nd4j by deeplearning4j.

the class CudaWorkspaceManager method getWorkspaceForCurrentThread.

@Override
public MemoryWorkspace getWorkspaceForCurrentThread(@NonNull WorkspaceConfiguration configuration, @NonNull String id) {
    ensureThreadExistense();
    MemoryWorkspace workspace = backingMap.get().get(id);
    if (workspace == null) {
        workspace = new CudaWorkspace(configuration, id);
        backingMap.get().put(id, workspace);
        pickReference(workspace);
    }
    return workspace;
}
Also used : MemoryWorkspace(org.nd4j.linalg.api.memory.MemoryWorkspace)

Aggregations

MemoryWorkspace (org.nd4j.linalg.api.memory.MemoryWorkspace)62 Test (org.junit.Test)39 BaseNd4jTest (org.nd4j.linalg.BaseNd4jTest)35 INDArray (org.nd4j.linalg.api.ndarray.INDArray)35 Nd4jWorkspace (org.nd4j.linalg.memory.abstracts.Nd4jWorkspace)18 WorkspaceConfiguration (org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration)14 DataBuffer (org.nd4j.linalg.api.buffer.DataBuffer)6 AtomicLong (java.util.concurrent.atomic.AtomicLong)5 AllocationPoint (org.nd4j.jita.allocator.impl.AllocationPoint)4 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)3 CudaContext (org.nd4j.linalg.jcublas.context.CudaContext)3 ByteArrayInputStream (java.io.ByteArrayInputStream)2 ByteArrayOutputStream (java.io.ByteArrayOutputStream)2 DataInputStream (java.io.DataInputStream)2 DataOutputStream (java.io.DataOutputStream)2 File (java.io.File)2 ArrayList (java.util.ArrayList)2 Ignore (org.junit.Ignore)2 IOException (java.io.IOException)1 CopyOnWriteArrayList (java.util.concurrent.CopyOnWriteArrayList)1