Search in sources :

Example 26 with MemoryWorkspace

use of org.nd4j.linalg.api.memory.MemoryWorkspace in project nd4j by deeplearning4j.

the class JCublasNDArrayFactory method createUninitializedDetached.

@Override
public INDArray createUninitializedDetached(int[] shape, char ordering) {
    MemoryWorkspace workspace = Nd4j.getMemoryManager().getCurrentWorkspace();
    Nd4j.getMemoryManager().setCurrentWorkspace(null);
    INDArray ret = new JCublasNDArray(shape, Nd4j.getStrides(shape, ordering), 0, ordering, false);
    Nd4j.getMemoryManager().setCurrentWorkspace(workspace);
    return ret;
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) MemoryWorkspace(org.nd4j.linalg.api.memory.MemoryWorkspace)

Example 27 with MemoryWorkspace

use of org.nd4j.linalg.api.memory.MemoryWorkspace in project nd4j by deeplearning4j.

the class JCublasNDArray method leverageTo.

@Override
public INDArray leverageTo(String id) {
    if (!isAttached()) {
        // log.info("Skipping detached");
        return this;
    }
    if (!Nd4j.getWorkspaceManager().checkIfWorkspaceExists(id)) {
        // log.info("Skipping non-existent");
        return this;
    }
    MemoryWorkspace current = Nd4j.getMemoryManager().getCurrentWorkspace();
    MemoryWorkspace target = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(id);
    if (current == target) {
        // log.info("Skipping equals A");
        return this;
    }
    if (this.data.getParentWorkspace() == target) {
        // log.info("Skipping equals B");
        return this;
    }
    Nd4j.getMemoryManager().setCurrentWorkspace(target);
    // log.info("Leveraging...");
    INDArray copy = null;
    if (!this.isView()) {
        // if (1 < 0) {
        Nd4j.getExecutioner().commit();
        DataBuffer buffer = Nd4j.createBuffer(this.lengthLong(), false);
        AllocationPoint pointDst = AtomicAllocator.getInstance().getAllocationPoint(buffer);
        AllocationPoint pointSrc = AtomicAllocator.getInstance().getAllocationPoint(this.data);
        CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(pointDst, pointSrc);
        /*
            if (NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(pointDst.getDevicePointer(), 0, 1, 0, context.getOldStream()) == 0)
                throw new ND4JIllegalStateException("memsetAsync 1 failed");

            context.syncOldStream();

            if (NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(pointSrc.getDevicePointer(), 0, 1, 0, context.getOldStream()) == 0)
                throw new ND4JIllegalStateException("memsetAsync 2 failed");

            context.syncOldStream();
*/
        if (pointSrc.isActualOnDeviceSide()) {
            if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(pointDst.getDevicePointer(), pointSrc.getDevicePointer(), this.lengthLong() * Nd4j.sizeOfDataType(buffer.dataType()), CudaConstants.cudaMemcpyDeviceToDevice, context.getOldStream()) == 0)
                throw new ND4JIllegalStateException("memcpyAsync failed");
        } else {
            if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(pointDst.getDevicePointer(), pointSrc.getHostPointer(), this.lengthLong() * Nd4j.sizeOfDataType(buffer.dataType()), CudaConstants.cudaMemcpyHostToDevice, context.getOldStream()) == 0)
                throw new ND4JIllegalStateException("memcpyAsync failed");
        }
        context.syncOldStream();
        copy = Nd4j.createArrayFromShapeBuffer(buffer, this.shapeInfoDataBuffer());
        // tag buffer as valid on device side
        pointDst.tickHostRead();
        pointDst.tickDeviceWrite();
        AtomicAllocator.getInstance().getFlowController().registerAction(context, pointDst, pointSrc);
    } else {
        copy = this.dup(this.ordering());
        Nd4j.getExecutioner().commit();
    }
    Nd4j.getMemoryManager().setCurrentWorkspace(current);
    return copy;
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) CudaContext(org.nd4j.linalg.jcublas.context.CudaContext) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) AllocationPoint(org.nd4j.jita.allocator.impl.AllocationPoint) MemoryWorkspace(org.nd4j.linalg.api.memory.MemoryWorkspace) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer)

Example 28 with MemoryWorkspace

use of org.nd4j.linalg.api.memory.MemoryWorkspace in project nd4j by deeplearning4j.

the class JCublasNDArray method migrate.

/**
 * This method pulls this INDArray into current Workspace.
 *
 * PLEASE NOTE: If there's no current Workspace - INDArray returned as is
 *
 * @return
 */
@Override
public INDArray migrate() {
    MemoryWorkspace current = Nd4j.getMemoryManager().getCurrentWorkspace();
    if (current == null)
        return this;
    INDArray copy = null;
    if (!this.isView()) {
        Nd4j.getExecutioner().commit();
        DataBuffer buffer = Nd4j.createBuffer(this.lengthLong(), false);
        AllocationPoint pointDst = AtomicAllocator.getInstance().getAllocationPoint(buffer);
        AllocationPoint pointSrc = AtomicAllocator.getInstance().getAllocationPoint(this.data);
        // CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext();
        CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(pointDst, pointSrc);
        if (pointSrc.isActualOnDeviceSide()) {
            if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(pointDst.getDevicePointer(), pointSrc.getDevicePointer(), this.lengthLong() * Nd4j.sizeOfDataType(buffer.dataType()), CudaConstants.cudaMemcpyDeviceToDevice, context.getOldStream()) == 0)
                throw new ND4JIllegalStateException("memcpyAsync failed");
        } else {
            if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(pointDst.getDevicePointer(), pointSrc.getHostPointer(), this.lengthLong() * Nd4j.sizeOfDataType(buffer.dataType()), CudaConstants.cudaMemcpyHostToDevice, context.getOldStream()) == 0)
                throw new ND4JIllegalStateException("memcpyAsync failed");
        }
        context.syncOldStream();
        if (pointDst.getDeviceId() != Nd4j.getMemoryManager().getCurrentWorkspace().getDeviceId()) {
            // log.info("Swapping [{}] -> [{}]", pointDst.getDeviceId(), Nd4j.getMemoryManager().getCurrentWorkspace().getDeviceId());
            pointDst.setDeviceId(Nd4j.getMemoryManager().getCurrentWorkspace().getDeviceId());
        }
        copy = Nd4j.createArrayFromShapeBuffer(buffer, this.shapeInfoDataBuffer());
        // tag buffer as valid on device side
        pointDst.tickHostRead();
        pointDst.tickDeviceWrite();
        AtomicAllocator.getInstance().getFlowController().registerAction(context, pointDst, pointSrc);
    } else {
        copy = this.dup(this.ordering());
    }
    return copy;
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) CudaContext(org.nd4j.linalg.jcublas.context.CudaContext) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) AllocationPoint(org.nd4j.jita.allocator.impl.AllocationPoint) MemoryWorkspace(org.nd4j.linalg.api.memory.MemoryWorkspace) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer)

Example 29 with MemoryWorkspace

use of org.nd4j.linalg.api.memory.MemoryWorkspace in project nd4j by deeplearning4j.

the class BaseNDArray method leverageTo.

/**
 * This method detaches INDArray from current Workspace, and attaches it to Workspace with a given Id.
 * If enforceExistence == true, and no workspace with the specified ID exists, then an {@link Nd4jNoSuchWorkspaceException}
 * is thrown. Otherwise, if enforceExistance == false and no workspace with the specified ID exists, then the current
 * INDArray is returned unmodified (same as {@link #leverage()}
 *
 * @param id ID of the workspace to leverage to
 * @param enforceExistence If true, and the specified workspace does not exist: an {@link Nd4jNoSuchWorkspaceException}
 *                         will be thrown.
 * @return The INDArray, leveraged to the specified workspace
 * @see #leverageTo(String)
 */
@Override
public INDArray leverageTo(String id, boolean enforceExistence) throws Nd4jNoSuchWorkspaceException {
    if (!isAttached())
        return this;
    if (!Nd4j.getWorkspaceManager().checkIfWorkspaceExists(id)) {
        if (enforceExistence) {
            throw new Nd4jNoSuchWorkspaceException(id);
        } else {
            return this;
        }
    }
    MemoryWorkspace current = Nd4j.getMemoryManager().getCurrentWorkspace();
    MemoryWorkspace target = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(id);
    if (current == target)
        return this;
    if (this.data.getParentWorkspace() == target)
        return this;
    Nd4j.getMemoryManager().setCurrentWorkspace(target);
    INDArray copy = null;
    if (!this.isView()) {
        Nd4j.getExecutioner().commit();
        DataBuffer buffer = Nd4j.createBuffer(this.lengthLong(), false);
        Nd4j.getMemoryManager().memcpy(buffer, this.data());
        copy = Nd4j.createArrayFromShapeBuffer(buffer, this.shapeInfoDataBuffer());
    } else {
        copy = this.dup(this.ordering());
        Nd4j.getExecutioner().commit();
    }
    Nd4j.getMemoryManager().setCurrentWorkspace(current);
    return copy;
}
Also used : MemoryWorkspace(org.nd4j.linalg.api.memory.MemoryWorkspace) Nd4jNoSuchWorkspaceException(org.nd4j.linalg.exception.Nd4jNoSuchWorkspaceException) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer)

Example 30 with MemoryWorkspace

use of org.nd4j.linalg.api.memory.MemoryWorkspace in project nd4j by deeplearning4j.

the class BaseNDArray method leverage.

/**
 * This method detaches INDArray from current Workspace, and attaches it to Workspace above, if any.
 * <p>
 * PLEASE NOTE: If this INDArray instance is NOT attached - it will be returned unmodified.
 * PLEASE NOTE: If current Workspace is the top-tier one, effect will be equal to detach() call - detached copy will be returned
 *
 * @return
 */
@Override
public INDArray leverage() {
    if (!isAttached())
        return this;
    MemoryWorkspace workspace = Nd4j.getMemoryManager().getCurrentWorkspace();
    if (workspace == null) {
        return this.detach();
    }
    MemoryWorkspace parentWorkspace = workspace.getParentWorkspace();
    if (this.data.getParentWorkspace() == parentWorkspace)
        return this;
    // if there's no parent ws - just detach
    if (parentWorkspace == null)
        return this.detach();
    else {
        Nd4j.getExecutioner().commit();
        // temporary set parent ws as current ws
        Nd4j.getMemoryManager().setCurrentWorkspace(parentWorkspace);
        INDArray copy = null;
        if (!this.isView()) {
            Nd4j.getExecutioner().commit();
            DataBuffer buffer = Nd4j.createBuffer(this.lengthLong(), false);
            Nd4j.getMemoryManager().memcpy(buffer, this.data());
            copy = Nd4j.createArrayFromShapeBuffer(buffer, this.shapeInfoDataBuffer());
        } else {
            copy = this.dup(this.ordering());
            Nd4j.getExecutioner().commit();
        }
        // restore current ws
        Nd4j.getMemoryManager().setCurrentWorkspace(workspace);
        return copy;
    }
}
Also used : MemoryWorkspace(org.nd4j.linalg.api.memory.MemoryWorkspace) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer)

Aggregations

MemoryWorkspace (org.nd4j.linalg.api.memory.MemoryWorkspace)62 Test (org.junit.Test)39 BaseNd4jTest (org.nd4j.linalg.BaseNd4jTest)35 INDArray (org.nd4j.linalg.api.ndarray.INDArray)35 Nd4jWorkspace (org.nd4j.linalg.memory.abstracts.Nd4jWorkspace)18 WorkspaceConfiguration (org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration)14 DataBuffer (org.nd4j.linalg.api.buffer.DataBuffer)6 AtomicLong (java.util.concurrent.atomic.AtomicLong)5 AllocationPoint (org.nd4j.jita.allocator.impl.AllocationPoint)4 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)3 CudaContext (org.nd4j.linalg.jcublas.context.CudaContext)3 ByteArrayInputStream (java.io.ByteArrayInputStream)2 ByteArrayOutputStream (java.io.ByteArrayOutputStream)2 DataInputStream (java.io.DataInputStream)2 DataOutputStream (java.io.DataOutputStream)2 File (java.io.File)2 ArrayList (java.util.ArrayList)2 Ignore (org.junit.Ignore)2 IOException (java.io.IOException)1 CopyOnWriteArrayList (java.util.concurrent.CopyOnWriteArrayList)1