Search in sources :

Example 76 with ND4JIllegalStateException

use of org.nd4j.linalg.exception.ND4JIllegalStateException in project nd4j by deeplearning4j.

the class BaseAccumulation method initFromTensorFlow.

@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    newFormat = true;
    if (!attributesForNode.containsKey("axis") && !hasReductionIndices(nodeDef)) {
        this.dimensions = new int[] { Integer.MAX_VALUE };
    } else if (hasReductionIndices(nodeDef)) {
        NodeDef reductionNode = null;
        for (int i = 0; i < graph.getNodeCount(); i++) {
            if (graph.getNode(i).getName().equals(nodeDef.getName() + "/reduction_indices")) {
                reductionNode = graph.getNode(i);
                val arr = TFGraphMapper.getInstance().getNDArrayFromTensor("value", reductionNode, graph);
                boolean keepAxis = nodeDef.getAttrOrThrow("keep_dims").getB();
                // keepAxis = false by default
                // int[] dimensions = ArrayUtils.add(arr.data().asInt(), 0, keepAxis ? 1 : 0);
                int[] dimensions = arr.data().asInt();
                this.dimensions = dimensions;
                break;
            }
        }
        if (reductionNode == null)
            throw new ND4JIllegalStateException("No node found!");
    } else {
        val dims = TFGraphMapper.getInstance().getNDArrayFromTensor("axis", nodeDef, graph).data().asInt();
        this.dimensions = dims;
    }
    if (attributesForNode.containsKey("keep_dims")) {
        val keepDims = attributesForNode.get("keep_dims").getB();
        this.keepDims = keepDims;
    }
}
Also used : lombok.val(lombok.val) NodeDef(org.tensorflow.framework.NodeDef) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException)

Example 77 with ND4JIllegalStateException

use of org.nd4j.linalg.exception.ND4JIllegalStateException in project nd4j by deeplearning4j.

the class BaseAccumulation method calculateOutputShape.

@Override
public List<int[]> calculateOutputShape() {
    if (args().length < 1) {
        throw new ND4JIllegalStateException("Unable to compute input shape. No arguments found.");
    }
    if (arg().getShape() == null)
        return Collections.emptyList();
    List<int[]> ret = new ArrayList<>(1);
    val reducedShape = Shape.getReducedShape(arg().getShape(), dimensions, isKeepDims(), newFormat);
    ret.add(reducedShape);
    return ret;
}
Also used : lombok.val(lombok.val) ArrayList(java.util.ArrayList) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException)

Example 78 with ND4JIllegalStateException

use of org.nd4j.linalg.exception.ND4JIllegalStateException in project nd4j by deeplearning4j.

the class Fill method initFromTensorFlow.

@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    if (nodeDef.getInputCount() == 2) {
        val targetNode = TFGraphMapper.getInstance().getNodeWithNameFromGraph(graph, nodeDef.getInput(1));
        val mapper = TFGraphMapper.getInstance();
        val secondInputAsScalar = mapper.getNDArrayFromTensor("value", targetNode, graph);
        // must be scalar
        if (secondInputAsScalar.length() == 1) {
            addTArgument(secondInputAsScalar.getDouble(0));
        } else {
            throw new ND4JIllegalStateException("Second input to node " + nodeDef + " should be scalar!");
        }
    }
}
Also used : lombok.val(lombok.val) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException)

Example 79 with ND4JIllegalStateException

use of org.nd4j.linalg.exception.ND4JIllegalStateException in project nd4j by deeplearning4j.

the class Shape method normalizeAxis.

public static int[] normalizeAxis(int rank, int... axis) {
    // first we should get rid of all negative axis
    int[] tmp = new int[axis.length];
    int cnt = 0;
    for (val v : axis) {
        val t = v < 0 ? v + rank : v;
        if ((t >= rank && t != Integer.MAX_VALUE) || t < 0)
            throw new ND4JIllegalStateException("Axis array " + Arrays.toString(axis) + " contains values above rank " + rank);
        tmp[cnt++] = t;
    }
    // now we're sorting array
    Arrays.sort(tmp);
    // and getting rid of possible duplicates
    return uniquify(tmp);
}
Also used : lombok.val(lombok.val) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException)

Example 80 with ND4JIllegalStateException

use of org.nd4j.linalg.exception.ND4JIllegalStateException in project nd4j by deeplearning4j.

the class ProtectedCudaConstantHandler method moveToConstantSpace.

/**
 * This method moves specified dataBuffer to CUDA constant memory space.
 *
 * PLEASE NOTE: CUDA constant memory is limited to 48KB per device.
 *
 * @param dataBuffer
 * @return
 */
@Override
public synchronized long moveToConstantSpace(DataBuffer dataBuffer) {
    // now, we move things to constant memory
    Integer deviceId = AtomicAllocator.getInstance().getDeviceId();
    ensureMaps(deviceId);
    AllocationPoint point = AtomicAllocator.getInstance().getAllocationPoint(dataBuffer);
    long requiredMemoryBytes = AllocationUtils.getRequiredMemory(point.getShape());
    // logger.info("shape: " + point.getShape());
    // and release device memory :)
    long currentOffset = constantOffsets.get(deviceId).get();
    CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext();
    if (currentOffset + requiredMemoryBytes >= MAX_CONSTANT_LENGTH || requiredMemoryBytes > MAX_BUFFER_LENGTH) {
        if (point.getAllocationStatus() == AllocationStatus.HOST && CudaEnvironment.getInstance().getConfiguration().getMemoryModel() == Configuration.MemoryModel.DELAYED) {
            AtomicAllocator.getInstance().getMemoryHandler().alloc(AllocationStatus.DEVICE, point, point.getShape(), false);
        }
        if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(point.getPointers().getDevicePointer(), point.getPointers().getHostPointer(), requiredMemoryBytes, 1, context.getSpecialStream()) == 0) {
            throw new ND4JIllegalStateException("memcpyAsync failed");
        }
        flowController.commitTransfer(context.getSpecialStream());
        point.setConstant(true);
        point.tickDeviceWrite();
        point.tickHostRead();
        point.setDeviceId(deviceId);
        protector.persistDataBuffer(dataBuffer);
        return 0;
    }
    long bytes = requiredMemoryBytes;
    // hack for misalignment avoidance for 16bit data opType
    if (dataBuffer.dataType() == DataBuffer.Type.HALF) {
        if (bytes % 4 != 0) {
            bytes += 2;
        }
    } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE || dataBuffer.dataType() == DataBuffer.Type.LONG) {
        // for double data opType, we must be assured, that all DOUBLE pointers are starting from even addresses, to avoid banks spills
        long div = bytes / 4;
        if (div % 2 != 0)
            bytes += 4;
        // for possible changes of dtype in the same jvm, we skip few bytes in constant memory
        div = currentOffset / 4;
        while (div % 2 != 0) {
            currentOffset = constantOffsets.get(deviceId).addAndGet(4);
            div = currentOffset / 4;
            // just break out, if we're stepped beyond constant memory space
            if (currentOffset > MAX_CONSTANT_LENGTH)
                break;
        }
    }
    currentOffset = constantOffsets.get(deviceId).getAndAdd(bytes);
    if (currentOffset >= MAX_CONSTANT_LENGTH) {
        if (point.getAllocationStatus() == AllocationStatus.HOST && CudaEnvironment.getInstance().getConfiguration().getMemoryModel() == Configuration.MemoryModel.DELAYED) {
            AtomicAllocator.getInstance().getMemoryHandler().alloc(AllocationStatus.DEVICE, point, point.getShape(), false);
        }
        if (NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(point.getPointers().getDevicePointer(), point.getPointers().getHostPointer(), requiredMemoryBytes, 1, context.getSpecialStream()) == 0) {
            throw new ND4JIllegalStateException("memcpyAsync failed");
        }
        flowController.commitTransfer(context.getSpecialStream());
        point.setConstant(true);
        point.tickDeviceWrite();
        point.tickHostRead();
        point.setDeviceId(deviceId);
        protector.persistDataBuffer(dataBuffer);
        return 0;
    }
    NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyConstantAsync(currentOffset, point.getPointers().getHostPointer(), requiredMemoryBytes, 1, context.getSpecialStream());
    flowController.commitTransfer(context.getSpecialStream());
    long cAddr = deviceAddresses.get(deviceId).address() + currentOffset;
    // if (resetHappened)
    // logger.info("copying to constant: {}, bufferLength: {}, bufferDtype: {}, currentOffset: {}, currentAddres: {}", requiredMemoryBytes, dataBuffer.length(), dataBuffer.dataType(), currentOffset, cAddr);
    point.setAllocationStatus(AllocationStatus.CONSTANT);
    point.getPointers().setDevicePointer(new CudaPointer(cAddr));
    point.setConstant(true);
    point.tickDeviceWrite();
    point.setDeviceId(deviceId);
    point.tickHostRead();
    protector.persistDataBuffer(dataBuffer);
    return cAddr;
}
Also used : CudaContext(org.nd4j.linalg.jcublas.context.CudaContext) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) AllocationPoint(org.nd4j.jita.allocator.impl.AllocationPoint) CudaPointer(org.nd4j.jita.allocator.pointers.CudaPointer)

Aggregations

ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)116 lombok.val (lombok.val)26 INDArray (org.nd4j.linalg.api.ndarray.INDArray)23 CudaContext (org.nd4j.linalg.jcublas.context.CudaContext)21 AllocationPoint (org.nd4j.jita.allocator.impl.AllocationPoint)19 DataBuffer (org.nd4j.linalg.api.buffer.DataBuffer)17 CudaPointer (org.nd4j.jita.allocator.pointers.CudaPointer)15 PagedPointer (org.nd4j.linalg.api.memory.pointers.PagedPointer)12 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)8 BaseDataBuffer (org.nd4j.linalg.api.buffer.BaseDataBuffer)7 IComplexNDArray (org.nd4j.linalg.api.complex.IComplexNDArray)6 Pointer (org.bytedeco.javacpp.Pointer)5 ArrayList (java.util.ArrayList)4 DifferentialFunction (org.nd4j.autodiff.functions.DifferentialFunction)4 Aeron (io.aeron.Aeron)3 FragmentAssembler (io.aeron.FragmentAssembler)3 MediaDriver (io.aeron.driver.MediaDriver)3 AtomicBoolean (java.util.concurrent.atomic.AtomicBoolean)3 Slf4j (lombok.extern.slf4j.Slf4j)3 CloseHelper (org.agrona.CloseHelper)3