Search in sources :

Example 26 with ND4JIllegalStateException

use of org.nd4j.linalg.exception.ND4JIllegalStateException in project nd4j by deeplearning4j.

the class SameDiff method addOutgoingFor.

 * Adds outgoing arguments to the graph.
 * Also checks for input arguments
 * and updates the graph adding an appropriate edge
 * when the full graph is declared.
 * @param varNames
 * @param function
public void addOutgoingFor(String[] varNames, DifferentialFunction function) {
    if (function.getOwnName() == null)
        throw new ND4JIllegalStateException("Instance id can not be null. Function not initialized properly");
    if (outgoingArgsReverse.containsKey(function.getOwnName())) {
        throw new ND4JIllegalStateException("Outgoing arguments already declared for " + function);
    if (varNames == null)
        throw new ND4JIllegalStateException("Var names can not be null!");
    for (int i = 0; i < varNames.length; i++) {
        if (varNames[i] == null)
            throw new ND4JIllegalStateException("Variable name elements can not be null!");
    outgoingArgsReverse.put(function.getOwnName(), varNames);
    outgoingArgs.put(varNames, function);
    for (val resultName : varNames) {
        List<DifferentialFunction> funcs = functionOutputFor.get(resultName);
        if (funcs == null) {
            funcs = new ArrayList<>();
            functionOutputFor.put(resultName, funcs);
Also used : DifferentialFunction(org.nd4j.autodiff.functions.DifferentialFunction) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException)

Example 27 with ND4JIllegalStateException

use of org.nd4j.linalg.exception.ND4JIllegalStateException in project nd4j by deeplearning4j.

the class OnnxGraphMapper method getNDArrayFromTensor.

public INDArray getNDArrayFromTensor(String tensorName, OnnxProto3.TypeProto.Tensor tensorProto, OnnxProto3.GraphProto graph) {
    DataBuffer.Type type = dataTypeForTensor(tensorProto);
    if (!tensorProto.isInitialized()) {
        throw new ND4JIllegalStateException("Unable to retrieve ndarray. Tensor was not initialized");
    OnnxProto3.TensorProto tensor = null;
    for (int i = 0; i < graph.getInitializerCount(); i++) {
        val initializer = graph.getInitializer(i);
        if (initializer.getName().equals(tensorName)) {
            tensor = initializer;
    if (tensor == null)
        return null;
    ByteString bytes = tensor.getRawData();
    ByteBuffer byteBuffer = bytes.asReadOnlyByteBuffer().order(ByteOrder.nativeOrder());
    ByteBuffer directAlloc = ByteBuffer.allocateDirect(byteBuffer.capacity()).order(ByteOrder.nativeOrder());
    int[] shape = getShapeFromTensor(tensorProto);
    DataBuffer buffer = Nd4j.createBuffer(directAlloc, type,;
    INDArray arr = Nd4j.create(buffer).reshape(shape);
    return arr;
Also used : lombok.val(lombok.val) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ByteString(com.github.os72.protobuf351.ByteString) OnnxProto3(onnx.OnnxProto3) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) ByteBuffer(java.nio.ByteBuffer) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer)

Example 28 with ND4JIllegalStateException

use of org.nd4j.linalg.exception.ND4JIllegalStateException in project nd4j by deeplearning4j.

the class TFGraphMapper method mapProperty.

public void mapProperty(String name, DifferentialFunction on, NodeDef node, GraphDef graph, SameDiff sameDiff, Map<String, Map<String, PropertyMapping>> propertyMappingsForFunction) {
    if (node == null) {
        throw new ND4JIllegalStateException("No node found for name " + name);
    val mapping = propertyMappingsForFunction.get(getOpType(node)).get(name);
    val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(on);
    if (mapping.getTfInputPosition() != null && mapping.getTfInputPosition() < node.getInputCount()) {
        int tfMappingIdx = mapping.getTfInputPosition();
        if (tfMappingIdx < 0)
            tfMappingIdx += node.getInputCount();
        val input = node.getInput(tfMappingIdx);
        val inputNode = TFGraphMapper.getInstance().getNodeWithNameFromGraph(graph, input);
        INDArray arr = getArrayFrom(inputNode, graph);
        if (arr == null) {
            arr = sameDiff.getArrForVarName(input);
        if (arr == null && inputNode != null) {
            sameDiff.addPropertyToResolve(on, name);
            sameDiff.addVariableMappingForField(on, name, inputNode.getName());
        } else if (inputNode == null) {
        val field = fields.get(name);
        val type = field.getType();
        if (type.equals(int[].class)) {
        } else if (type.equals(int.class) || type.equals(long.class) || type.equals(Long.class) || type.equals(Integer.class)) {
            if (mapping.getShapePosition() != null) {
                on.setValueFor(field, arr.size(mapping.getShapePosition()));
            } else
                on.setValueFor(field, arr.getInt(0));
        } else if (type.equals(float.class) || type.equals(double.class) || type.equals(Float.class) || type.equals(Double.class)) {
            on.setValueFor(field, arr.getDouble(0));
    } else {
        val tfMappingAttrName = mapping.getTfAttrName();
        if (tfMappingAttrName == null) {
        if (!node.containsAttr(tfMappingAttrName)) {
        val attr = node.getAttrOrThrow(tfMappingAttrName);
        val type = attr.getType();
        if (fields == null) {
            throw new ND4JIllegalStateException("No fields found for op " + mapping);
        if (mapping.getPropertyNames() == null) {
            throw new ND4JIllegalStateException("no property found for " + name + " and op " + on.opName());
        val field = fields.get(mapping.getPropertyNames()[0]);
        Object valueToSet = null;
        switch(type) {
            case DT_BOOL:
                valueToSet = attr.getB();
            case DT_INT8:
                valueToSet = attr.getI();
            case DT_INT16:
                valueToSet = attr.getI();
            case DT_INT32:
                valueToSet = attr.getI();
            case DT_FLOAT:
                valueToSet = attr.getF();
            case DT_DOUBLE:
                valueToSet = attr.getF();
            case DT_STRING:
                valueToSet = attr.getS();
            case DT_INT64:
                valueToSet = attr.getI();
        if (field != null && valueToSet != null)
            on.setValueFor(field, valueToSet);
Also used : lombok.val(lombok.val) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException)

Example 29 with ND4JIllegalStateException

use of org.nd4j.linalg.exception.ND4JIllegalStateException in project nd4j by deeplearning4j.

the class TFGraphMapper method mapNodeType.

public void mapNodeType(NodeDef tfNode, ImportState<GraphDef, NodeDef> importState) {
    if (shouldSkip(tfNode) || alreadySeen(tfNode) || isVariableNode(tfNode)) {
    val diff = importState.getSameDiff();
    if (isVariableNode(tfNode)) {
        List<Integer> dimensions = new ArrayList<>();
        Map<String, AttrValue> attributes = getAttrMap(tfNode);
        if (attributes.containsKey(VALUE_ATTR_KEY)) {
            diff.var(getName(tfNode), getArrayFrom(tfNode, importState.getGraph()));
        } else if (attributes.containsKey(SHAPE_KEY)) {
            AttrValue shape = attributes.get(SHAPE_KEY);
            int[] shapeArr = getShapeFromAttr(shape);
            int dims = shapeArr.length;
            if (dims > 0) {
                // even vector is 2d in nd4j
                if (dims == 1)
                for (int e = 0; e < dims; e++) {
                    // TODO: eventually we want long shapes :(
    } else if (isPlaceHolder(tfNode)) {
        val vertexId = diff.getVariable(getName(tfNode));
    } else {
        val opName = tfNode.getOp();
        val nodeName = tfNode.getName();
        // FIXME: early draft
        // conditional import
            if (nodeName.startsWith("cond") && nodeName.contains("/")) {
                val str = nodeName.replaceAll("/.*$","");
                importCondition(str, tfNode, importState);

            } else if (nodeName.startsWith("while")) {
                // while loop import

        val differentialFunction = DifferentialFunctionClassHolder.getInstance().getOpWithTensorflowName(opName);
        if (differentialFunction == null) {
            throw new ND4JIllegalStateException("No tensorflow op found for " + opName + " possibly missing operation class?");
        try {
            val newInstance = differentialFunction.getClass().newInstance();
            val args = new SDVariable[tfNode.getInputCount()];
            for (int i = 0; i < tfNode.getInputCount(); i++) {
                val name = getNodeName(tfNode.getInput(i));
                args[i] = diff.getVariable(name);
                if (args[i] == null) {
                    args[i] = diff.var(name, null, new ZeroInitScheme('f'));
                 * Note here that we are associating
                 * the output/result variable
                 * with its inputs and notifying
                 * the variable that it has a place holder argument
                 * it should resolve before trying to execute
                 * anything.
                if (diff.isPlaceHolder(args[i].getVarName())) {
                    diff.putPlaceHolderForVariable(args[i].getVarName(), name);
            diff.addArgsFor(args, newInstance);
            newInstance.initFromTensorFlow(tfNode, diff, getAttrMap(tfNode), importState.getGraph());
            mapProperties(newInstance, tfNode, importState.getGraph(), importState.getSameDiff(), newInstance.mappingsForFunction());
            importState.getSameDiff().putFunctionForId(newInstance.getOwnName(), newInstance);
            // ensure we can track node name to function instance later.
            diff.setBaseNameForFunctionInstanceId(tfNode.getName(), newInstance);
        } catch (Exception e) {
            log.error("Failed with [{}]", opName);
            throw new RuntimeException(e);
Also used : lombok.val(lombok.val) ZeroInitScheme(org.nd4j.weightinit.impl.ZeroInitScheme) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException)

Example 30 with ND4JIllegalStateException

use of org.nd4j.linalg.exception.ND4JIllegalStateException in project nd4j by deeplearning4j.

the class SDVariable method storeAndAllocateNewArray.

 * Allocate and return a  new array
 * based on the vertex id and weight initialization.
 * @return the allocated array
public INDArray storeAndAllocateNewArray() {
    val shape = sameDiff.getShapeForVarName(getVarName());
    if (getArr() != null && Arrays.equals(getArr().shape(), shape))
        return getArr();
    if (varName == null)
        throw new ND4JIllegalStateException("Unable to store array for null variable name!");
    if (shape == null) {
        throw new ND4JIllegalStateException("Unable to allocate new array. No shape found for variable " + varName);
    val arr = getWeightInitScheme().create(shape);
    sameDiff.putArrayForVarName(getVarName(), arr);
    return arr;
Also used : ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException)


ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)116 lombok.val (lombok.val)26 INDArray (org.nd4j.linalg.api.ndarray.INDArray)23 CudaContext (org.nd4j.linalg.jcublas.context.CudaContext)21 AllocationPoint (org.nd4j.jita.allocator.impl.AllocationPoint)19 DataBuffer (org.nd4j.linalg.api.buffer.DataBuffer)17 CudaPointer (org.nd4j.jita.allocator.pointers.CudaPointer)15 PagedPointer (org.nd4j.linalg.api.memory.pointers.PagedPointer)12 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)8 BaseDataBuffer (org.nd4j.linalg.api.buffer.BaseDataBuffer)7 IComplexNDArray (org.nd4j.linalg.api.complex.IComplexNDArray)6 Pointer (org.bytedeco.javacpp.Pointer)5 ArrayList (java.util.ArrayList)4 DifferentialFunction (org.nd4j.autodiff.functions.DifferentialFunction)4 Aeron (io.aeron.Aeron)3 FragmentAssembler (io.aeron.FragmentAssembler)3 MediaDriver (io.aeron.driver.MediaDriver)3 AtomicBoolean (java.util.concurrent.atomic.AtomicBoolean)3 Slf4j (lombok.extern.slf4j.Slf4j)3 CloseHelper (org.agrona.CloseHelper)3