Search in sources :

Example 31 with ND4JIllegalStateException

use of org.nd4j.linalg.exception.ND4JIllegalStateException in project nd4j by deeplearning4j.

the class Transpose method initFromTensorFlow.

@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph);
    // permute dimensions re not specified as second input
    if (nodeDef.getInputCount() < 2)
        return;
    NodeDef permuteDimsNode = null;
    for (int i = 0; i < graph.getNodeCount(); i++) {
        if (graph.getNode(i).getName().equals(nodeDef.getInput(1))) {
            permuteDimsNode = graph.getNode(i);
        }
    }
    val permuteArrayOp = TFGraphMapper.getInstance().getNDArrayFromTensor("value", permuteDimsNode, graph);
    if (permuteArrayOp != null) {
        this.permuteDims = permuteArrayOp.data().asInt();
        for (int i = 0; i < permuteDims.length; i++) {
            addIArgument(permuteDims[i]);
        }
    }
    // handle once properly mapped
    if (arg().getShape() == null) {
        return;
    }
    INDArray arr = sameDiff.getArrForVarName(arg().getVarName());
    if (arr == null) {
        val arrVar = sameDiff.getVariable(arg().getVarName());
        arr = arrVar.getWeightInitScheme().create(arrVar.getShape());
        sameDiff.putArrayForVarName(arg().getVarName(), arr);
    }
    addInputArgument(arr);
    if (arr != null && permuteDims == null) {
        this.permuteDims = ArrayUtil.reverseCopy(ArrayUtil.range(0, arr.rank()));
    }
    if (permuteDims != null && permuteDims.length < arg().getShape().length)
        throw new ND4JIllegalStateException("Illegal permute found. Not all dimensions specified");
}
Also used : lombok.val(lombok.val) NodeDef(org.tensorflow.framework.NodeDef) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException)

Example 32 with ND4JIllegalStateException

use of org.nd4j.linalg.exception.ND4JIllegalStateException in project nd4j by deeplearning4j.

the class Concat method resolvePropertiesFromSameDiffBeforeExecution.

@Override
public void resolvePropertiesFromSameDiffBeforeExecution() {
    val propertiesToResolve = sameDiff.propertiesToResolveForFunction(this);
    if (!propertiesToResolve.isEmpty()) {
        val varName = propertiesToResolve.get(0);
        val var = sameDiff.getVariable(varName);
        if (var == null) {
            throw new ND4JIllegalStateException("No variable found with name " + varName);
        } else if (var.getArr() == null) {
            throw new ND4JIllegalStateException("Array with variable name " + varName + " unset!");
        }
        concatDimension = var.getArr().getInt(0);
        addIArgument(concatDimension);
    }
    // don't pass both iArg and last axis down to libnd4j
    if (inputArguments().length == args().length) {
        val inputArgs = inputArguments();
        removeInputArgument(inputArgs[inputArguments().length - 1]);
    }
}
Also used : lombok.val(lombok.val) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException)

Example 33 with ND4JIllegalStateException

use of org.nd4j.linalg.exception.ND4JIllegalStateException in project nd4j by deeplearning4j.

the class ArrowSerde method fromTensor.

/**
 * Convert a {@link Tensor}
 * to an {@link INDArray}
 * @param tensor the input tensor
 * @return the equivalent {@link INDArray}
 */
public static INDArray fromTensor(Tensor tensor) {
    byte b = tensor.typeType();
    int[] shape = new int[tensor.shapeLength()];
    int[] stride = new int[tensor.stridesLength()];
    for (int i = 0; i < shape.length; i++) {
        shape[i] = (int) tensor.shape(i).size();
        stride[i] = (int) tensor.strides(i);
    }
    int length = ArrayUtil.prod(shape);
    Buffer buffer = tensor.data();
    if (buffer == null) {
        throw new ND4JIllegalStateException("Buffer was not serialized properly.");
    }
    // deduce element size
    int elementSize = (int) buffer.length() / length;
    // nd4j strides aren't  based on element size
    for (int i = 0; i < stride.length; i++) {
        stride[i] /= elementSize;
    }
    DataBuffer.Type type = typeFromTensorType(b, elementSize);
    DataBuffer dataBuffer = DataBufferStruct.createFromByteBuffer(tensor.getByteBuffer(), (int) tensor.data().offset(), type, length, elementSize);
    INDArray arr = Nd4j.create(dataBuffer, shape);
    arr.setShapeAndStride(shape, stride);
    return arr;
}
Also used : DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer)

Example 34 with ND4JIllegalStateException

use of org.nd4j.linalg.exception.ND4JIllegalStateException in project nd4j by deeplearning4j.

the class CudaZeroHandler method memcpySpecial.

/**
 * Special memcpy version, addressing shapeInfoDataBuffer copies
 *
 * PLEASE NOTE: Blocking H->H, Async H->D
 *
 * @param dstBuffer
 * @param srcPointer
 * @param length
 * @param dstOffset
 */
@Override
public void memcpySpecial(DataBuffer dstBuffer, Pointer srcPointer, long length, long dstOffset) {
    // log.info("Memcpy special: {} bytes ", length);
    CudaContext context = getCudaContext();
    AllocationPoint point = ((BaseCudaDataBuffer) dstBuffer).getAllocationPoint();
    // context.syncOldStream();
    Pointer dP = new CudaPointer((point.getPointers().getHostPointer().address()) + dstOffset);
    if (nativeOps.memcpyAsync(dP, srcPointer, length, CudaConstants.cudaMemcpyHostToHost, context.getOldStream()) == 0)
        throw new ND4JIllegalStateException("memcpyAsync failed");
    if (point.getAllocationStatus() == AllocationStatus.DEVICE) {
        Pointer rDP = new CudaPointer(point.getPointers().getDevicePointer().address() + dstOffset);
        if (nativeOps.memcpyAsync(rDP, dP, length, CudaConstants.cudaMemcpyHostToDevice, context.getOldStream()) == 0)
            throw new ND4JIllegalStateException("memcpyAsync failed");
        context.syncOldStream();
    }
    context.syncOldStream();
    point.tickDeviceWrite();
// point.tickHostRead();
}
Also used : CudaContext(org.nd4j.linalg.jcublas.context.CudaContext) BaseCudaDataBuffer(org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer) CudaPointer(org.nd4j.jita.allocator.pointers.CudaPointer) Pointer(org.bytedeco.javacpp.Pointer) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) AllocationPoint(org.nd4j.jita.allocator.impl.AllocationPoint) CudaPointer(org.nd4j.jita.allocator.pointers.CudaPointer)

Example 35 with ND4JIllegalStateException

use of org.nd4j.linalg.exception.ND4JIllegalStateException in project nd4j by deeplearning4j.

the class CudaZeroHandler method memcpy.

/**
 *  Synchronous version of memcpy.
 *
 * @param dstBuffer
 * @param srcBuffer
 */
@Override
public void memcpy(DataBuffer dstBuffer, DataBuffer srcBuffer) {
    // log.info("Buffer MemCpy called");
    // log.info("Memcpy buffer: {} bytes ", dstBuffer.length() * dstBuffer.getElementSize());
    CudaContext context = getCudaContext();
    AllocationPoint dstPoint = ((BaseCudaDataBuffer) dstBuffer).getAllocationPoint();
    AllocationPoint srcPoint = ((BaseCudaDataBuffer) srcBuffer).getAllocationPoint();
    Pointer dP = new CudaPointer(dstPoint.getPointers().getHostPointer().address());
    Pointer sP = null;
    if (srcPoint.getAllocationStatus() == AllocationStatus.DEVICE) {
        sP = new CudaPointer(srcPoint.getPointers().getDevicePointer().address());
        /*
            JCuda.cudaMemcpyAsync(
                    dP,
                    sP,
                    srcBuffer.length(),
                    cudaMemcpyKind.cudaMemcpyHostToDevice,
                    context.getOldStream()
            );*/
        if (nativeOps.memcpyAsync(dP, sP, srcBuffer.length() * srcBuffer.getElementSize(), CudaConstants.cudaMemcpyHostToDevice, context.getOldStream()) == 0) {
            throw new ND4JIllegalStateException("memcpyAsync failed");
        }
    } else {
        sP = new CudaPointer(srcPoint.getPointers().getHostPointer().address());
        /*
            JCuda.cudaMemcpyAsync(
                    dP,
                    sP,
                    srcBuffer.length(),
                    cudaMemcpyKind.cudaMemcpyHostToDevice,
                    context.getOldStream()
            );*/
        if (nativeOps.memcpyAsync(dP, sP, srcBuffer.length() * srcBuffer.getElementSize(), CudaConstants.cudaMemcpyHostToDevice, context.getOldStream()) == 0) {
            throw new ND4JIllegalStateException("memcpyAsync failed");
        }
    }
    if (dstPoint.getAllocationStatus() == AllocationStatus.DEVICE) {
        Pointer rDP = new CudaPointer(dstPoint.getPointers().getDevicePointer().address());
        /*
            JCuda.cudaMemcpyAsync(
                    rDP,
                    dP,
                    srcBuffer.length(),
                    cudaMemcpyKind.cudaMemcpyHostToDevice,
                    context.getOldStream()
            );*/
        if (nativeOps.memcpyAsync(rDP, dP, srcBuffer.length() * srcBuffer.getElementSize(), CudaConstants.cudaMemcpyHostToDevice, context.getOldStream()) == 0) {
            throw new ND4JIllegalStateException("memcpyAsync failed");
        }
    }
    dstPoint.tickDeviceWrite();
    // it has to be blocking call
    context.syncOldStream();
}
Also used : CudaContext(org.nd4j.linalg.jcublas.context.CudaContext) BaseCudaDataBuffer(org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer) CudaPointer(org.nd4j.jita.allocator.pointers.CudaPointer) Pointer(org.bytedeco.javacpp.Pointer) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) AllocationPoint(org.nd4j.jita.allocator.impl.AllocationPoint) CudaPointer(org.nd4j.jita.allocator.pointers.CudaPointer)

Aggregations

ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)116 lombok.val (lombok.val)26 INDArray (org.nd4j.linalg.api.ndarray.INDArray)23 CudaContext (org.nd4j.linalg.jcublas.context.CudaContext)21 AllocationPoint (org.nd4j.jita.allocator.impl.AllocationPoint)19 DataBuffer (org.nd4j.linalg.api.buffer.DataBuffer)17 CudaPointer (org.nd4j.jita.allocator.pointers.CudaPointer)15 PagedPointer (org.nd4j.linalg.api.memory.pointers.PagedPointer)12 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)8 BaseDataBuffer (org.nd4j.linalg.api.buffer.BaseDataBuffer)7 IComplexNDArray (org.nd4j.linalg.api.complex.IComplexNDArray)6 Pointer (org.bytedeco.javacpp.Pointer)5 ArrayList (java.util.ArrayList)4 DifferentialFunction (org.nd4j.autodiff.functions.DifferentialFunction)4 Aeron (io.aeron.Aeron)3 FragmentAssembler (io.aeron.FragmentAssembler)3 MediaDriver (io.aeron.driver.MediaDriver)3 AtomicBoolean (java.util.concurrent.atomic.AtomicBoolean)3 Slf4j (lombok.extern.slf4j.Slf4j)3 CloseHelper (org.agrona.CloseHelper)3