Search in sources :

Example 26 with ND4JIllegalStateException

use of org.nd4j.linalg.exception.ND4JIllegalStateException in project nd4j by deeplearning4j.

the class SameDiff method addOutgoingFor.

/**
 * Adds outgoing arguments to the graph.
 * Also checks for input arguments
 * and updates the graph adding an appropriate edge
 * when the full graph is declared.
 *
 * @param varNames
 * @param function
 */
public void addOutgoingFor(String[] varNames, DifferentialFunction function) {
    if (function.getOwnName() == null)
        throw new ND4JIllegalStateException("Instance id can not be null. Function not initialized properly");
    if (outgoingArgsReverse.containsKey(function.getOwnName())) {
        throw new ND4JIllegalStateException("Outgoing arguments already declared for " + function);
    }
    if (varNames == null)
        throw new ND4JIllegalStateException("Var names can not be null!");
    for (int i = 0; i < varNames.length; i++) {
        if (varNames[i] == null)
            throw new ND4JIllegalStateException("Variable name elements can not be null!");
    }
    outgoingArgsReverse.put(function.getOwnName(), varNames);
    outgoingArgs.put(varNames, function);
    for (val resultName : varNames) {
        List<DifferentialFunction> funcs = functionOutputFor.get(resultName);
        if (funcs == null) {
            funcs = new ArrayList<>();
            functionOutputFor.put(resultName, funcs);
        }
        funcs.add(function);
    }
}
Also used : DifferentialFunction(org.nd4j.autodiff.functions.DifferentialFunction) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException)

Example 27 with ND4JIllegalStateException

use of org.nd4j.linalg.exception.ND4JIllegalStateException in project nd4j by deeplearning4j.

the class OnnxGraphMapper method getNDArrayFromTensor.

@Override
public INDArray getNDArrayFromTensor(String tensorName, OnnxProto3.TypeProto.Tensor tensorProto, OnnxProto3.GraphProto graph) {
    DataBuffer.Type type = dataTypeForTensor(tensorProto);
    if (!tensorProto.isInitialized()) {
        throw new ND4JIllegalStateException("Unable to retrieve ndarray. Tensor was not initialized");
    }
    OnnxProto3.TensorProto tensor = null;
    for (int i = 0; i < graph.getInitializerCount(); i++) {
        val initializer = graph.getInitializer(i);
        if (initializer.getName().equals(tensorName)) {
            tensor = initializer;
            break;
        }
    }
    if (tensor == null)
        return null;
    ByteString bytes = tensor.getRawData();
    ByteBuffer byteBuffer = bytes.asReadOnlyByteBuffer().order(ByteOrder.nativeOrder());
    ByteBuffer directAlloc = ByteBuffer.allocateDirect(byteBuffer.capacity()).order(ByteOrder.nativeOrder());
    directAlloc.put(byteBuffer);
    directAlloc.rewind();
    int[] shape = getShapeFromTensor(tensorProto);
    DataBuffer buffer = Nd4j.createBuffer(directAlloc, type, ArrayUtil.prod(shape));
    INDArray arr = Nd4j.create(buffer).reshape(shape);
    return arr;
}
Also used : lombok.val(lombok.val) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ByteString(com.github.os72.protobuf351.ByteString) OnnxProto3(onnx.OnnxProto3) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) ByteBuffer(java.nio.ByteBuffer) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer)

Example 28 with ND4JIllegalStateException

use of org.nd4j.linalg.exception.ND4JIllegalStateException in project nd4j by deeplearning4j.

the class TFGraphMapper method mapProperty.

@Override
public void mapProperty(String name, DifferentialFunction on, NodeDef node, GraphDef graph, SameDiff sameDiff, Map<String, Map<String, PropertyMapping>> propertyMappingsForFunction) {
    if (node == null) {
        throw new ND4JIllegalStateException("No node found for name " + name);
    }
    val mapping = propertyMappingsForFunction.get(getOpType(node)).get(name);
    val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(on);
    if (mapping.getTfInputPosition() != null && mapping.getTfInputPosition() < node.getInputCount()) {
        int tfMappingIdx = mapping.getTfInputPosition();
        if (tfMappingIdx < 0)
            tfMappingIdx += node.getInputCount();
        val input = node.getInput(tfMappingIdx);
        val inputNode = TFGraphMapper.getInstance().getNodeWithNameFromGraph(graph, input);
        INDArray arr = getArrayFrom(inputNode, graph);
        if (arr == null) {
            arr = sameDiff.getArrForVarName(input);
        }
        if (arr == null && inputNode != null) {
            sameDiff.addPropertyToResolve(on, name);
            sameDiff.addVariableMappingForField(on, name, inputNode.getName());
            return;
        } else if (inputNode == null) {
            sameDiff.addAsPlaceHolder(input);
            return;
        }
        val field = fields.get(name);
        val type = field.getType();
        if (type.equals(int[].class)) {
            on.setValueFor(field, arr.data().asInt());
        } else if (type.equals(int.class) || type.equals(long.class) || type.equals(Long.class) || type.equals(Integer.class)) {
            if (mapping.getShapePosition() != null) {
                on.setValueFor(field, arr.size(mapping.getShapePosition()));
            } else
                on.setValueFor(field, arr.getInt(0));
        } else if (type.equals(float.class) || type.equals(double.class) || type.equals(Float.class) || type.equals(Double.class)) {
            on.setValueFor(field, arr.getDouble(0));
        }
    } else {
        val tfMappingAttrName = mapping.getTfAttrName();
        if (tfMappingAttrName == null) {
            return;
        }
        if (!node.containsAttr(tfMappingAttrName)) {
            return;
        }
        val attr = node.getAttrOrThrow(tfMappingAttrName);
        val type = attr.getType();
        if (fields == null) {
            throw new ND4JIllegalStateException("No fields found for op " + mapping);
        }
        if (mapping.getPropertyNames() == null) {
            throw new ND4JIllegalStateException("no property found for " + name + " and op " + on.opName());
        }
        val field = fields.get(mapping.getPropertyNames()[0]);
        Object valueToSet = null;
        switch(type) {
            case DT_BOOL:
                valueToSet = attr.getB();
                break;
            case DT_INT8:
                valueToSet = attr.getI();
                break;
            case DT_INT16:
                valueToSet = attr.getI();
                break;
            case DT_INT32:
                valueToSet = attr.getI();
                break;
            case DT_FLOAT:
                valueToSet = attr.getF();
                break;
            case DT_DOUBLE:
                valueToSet = attr.getF();
                break;
            case DT_STRING:
                valueToSet = attr.getS();
                break;
            case DT_INT64:
                valueToSet = attr.getI();
                break;
        }
        if (field != null && valueToSet != null)
            on.setValueFor(field, valueToSet);
    }
}
Also used : lombok.val(lombok.val) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException)

Example 29 with ND4JIllegalStateException

use of org.nd4j.linalg.exception.ND4JIllegalStateException in project nd4j by deeplearning4j.

the class TFGraphMapper method mapNodeType.

@Override
public void mapNodeType(NodeDef tfNode, ImportState<GraphDef, NodeDef> importState) {
    if (shouldSkip(tfNode) || alreadySeen(tfNode) || isVariableNode(tfNode)) {
        return;
    }
    val diff = importState.getSameDiff();
    if (isVariableNode(tfNode)) {
        List<Integer> dimensions = new ArrayList<>();
        Map<String, AttrValue> attributes = getAttrMap(tfNode);
        if (attributes.containsKey(VALUE_ATTR_KEY)) {
            diff.var(getName(tfNode), getArrayFrom(tfNode, importState.getGraph()));
        } else if (attributes.containsKey(SHAPE_KEY)) {
            AttrValue shape = attributes.get(SHAPE_KEY);
            int[] shapeArr = getShapeFromAttr(shape);
            int dims = shapeArr.length;
            if (dims > 0) {
                // even vector is 2d in nd4j
                if (dims == 1)
                    dimensions.add(1);
                for (int e = 0; e < dims; e++) {
                    // TODO: eventually we want long shapes :(
                    dimensions.add(getShapeFromAttr(shape)[e]);
                }
            }
        }
    } else if (isPlaceHolder(tfNode)) {
        val vertexId = diff.getVariable(getName(tfNode));
        diff.addAsPlaceHolder(vertexId.getVarName());
    } else {
        val opName = tfNode.getOp();
        val nodeName = tfNode.getName();
        // FIXME: early draft
        // conditional import
        /*
            if (nodeName.startsWith("cond") && nodeName.contains("/")) {
                val str = nodeName.replaceAll("/.*$","");
                importCondition(str, tfNode, importState);

                seenNodes.add(nodeName);
                return;
            } else if (nodeName.startsWith("while")) {
                // while loop import

                return;
            }
            */
        val differentialFunction = DifferentialFunctionClassHolder.getInstance().getOpWithTensorflowName(opName);
        if (differentialFunction == null) {
            throw new ND4JIllegalStateException("No tensorflow op found for " + opName + " possibly missing operation class?");
        }
        try {
            val newInstance = differentialFunction.getClass().newInstance();
            val args = new SDVariable[tfNode.getInputCount()];
            newInstance.setOwnName(tfNode.getName());
            for (int i = 0; i < tfNode.getInputCount(); i++) {
                val name = getNodeName(tfNode.getInput(i));
                args[i] = diff.getVariable(name);
                if (args[i] == null) {
                    args[i] = diff.var(name, null, new ZeroInitScheme('f'));
                    diff.addAsPlaceHolder(args[i].getVarName());
                }
                /**
                 * Note here that we are associating
                 * the output/result variable
                 * with its inputs and notifying
                 * the variable that it has a place holder argument
                 * it should resolve before trying to execute
                 * anything.
                 */
                if (diff.isPlaceHolder(args[i].getVarName())) {
                    diff.putPlaceHolderForVariable(args[i].getVarName(), name);
                }
            }
            diff.addArgsFor(args, newInstance);
            newInstance.setSameDiff(importState.getSameDiff());
            newInstance.initFromTensorFlow(tfNode, diff, getAttrMap(tfNode), importState.getGraph());
            mapProperties(newInstance, tfNode, importState.getGraph(), importState.getSameDiff(), newInstance.mappingsForFunction());
            importState.getSameDiff().putFunctionForId(newInstance.getOwnName(), newInstance);
            // ensure we can track node name to function instance later.
            diff.setBaseNameForFunctionInstanceId(tfNode.getName(), newInstance);
            diff.addVarNameForImport(tfNode.getName());
        } catch (Exception e) {
            log.error("Failed with [{}]", opName);
            throw new RuntimeException(e);
        }
    }
}
Also used : lombok.val(lombok.val) ZeroInitScheme(org.nd4j.weightinit.impl.ZeroInitScheme) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException)

Example 30 with ND4JIllegalStateException

use of org.nd4j.linalg.exception.ND4JIllegalStateException in project nd4j by deeplearning4j.

the class SDVariable method storeAndAllocateNewArray.

/**
 * Allocate and return a  new array
 * based on the vertex id and weight initialization.
 * @return the allocated array
 */
public INDArray storeAndAllocateNewArray() {
    val shape = sameDiff.getShapeForVarName(getVarName());
    if (getArr() != null && Arrays.equals(getArr().shape(), shape))
        return getArr();
    if (varName == null)
        throw new ND4JIllegalStateException("Unable to store array for null variable name!");
    if (shape == null) {
        throw new ND4JIllegalStateException("Unable to allocate new array. No shape found for variable " + varName);
    }
    val arr = getWeightInitScheme().create(shape);
    sameDiff.putArrayForVarName(getVarName(), arr);
    return arr;
}
Also used : ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException)

Aggregations

ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)116 lombok.val (lombok.val)26 INDArray (org.nd4j.linalg.api.ndarray.INDArray)23 CudaContext (org.nd4j.linalg.jcublas.context.CudaContext)21 AllocationPoint (org.nd4j.jita.allocator.impl.AllocationPoint)19 DataBuffer (org.nd4j.linalg.api.buffer.DataBuffer)17 CudaPointer (org.nd4j.jita.allocator.pointers.CudaPointer)15 PagedPointer (org.nd4j.linalg.api.memory.pointers.PagedPointer)12 AtomicInteger (java.util.concurrent.atomic.AtomicInteger)8 BaseDataBuffer (org.nd4j.linalg.api.buffer.BaseDataBuffer)7 IComplexNDArray (org.nd4j.linalg.api.complex.IComplexNDArray)6 Pointer (org.bytedeco.javacpp.Pointer)5 ArrayList (java.util.ArrayList)4 DifferentialFunction (org.nd4j.autodiff.functions.DifferentialFunction)4 Aeron (io.aeron.Aeron)3 FragmentAssembler (io.aeron.FragmentAssembler)3 MediaDriver (io.aeron.driver.MediaDriver)3 AtomicBoolean (java.util.concurrent.atomic.AtomicBoolean)3 Slf4j (lombok.extern.slf4j.Slf4j)3 CloseHelper (org.agrona.CloseHelper)3