Search in sources :

Example 1 with NodeDef

use of org.tensorflow.framework.NodeDef in project vespa by vespa-engine.

the class TensorFlowImporter method importNode.

private static TensorFlowOperation importNode(String modelName, String nodeName, GraphDef graph, OperationIndex index) {
    if (index.alreadyImported(nodeName)) {
        return index.get(nodeName);
    }
    NodeDef node = getTensorFlowNodeFromGraph(namePartOf(nodeName), graph);
    List<TensorFlowOperation> inputs = importNodeInputs(modelName, node, graph, index);
    TensorFlowOperation operation = OperationMapper.get(modelName, node, inputs, portPartOf(nodeName));
    index.put(nodeName, operation);
    List<TensorFlowOperation> controlInputs = importControlInputs(modelName, node, graph, index);
    if (controlInputs.size() > 0) {
        operation.setControlInputs(controlInputs);
    }
    return operation;
}
Also used : TensorFlowOperation(com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.TensorFlowOperation) NodeDef(org.tensorflow.framework.NodeDef)

Example 2 with NodeDef

use of org.tensorflow.framework.NodeDef in project nd4j by deeplearning4j.

the class Range method initFromTensorFlow.

@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph);
    NodeDef startNode = null, endNode = null, deltaNode = null;
    for (val node : graph.getNodeList()) {
        if (node.getName().equals(nodeDef.getInput(0))) {
            startNode = node;
        }
        if (node.getName().equals(nodeDef.getInput(1))) {
            endNode = node;
        }
        if (node.getName().equals(nodeDef.getInput(2))) {
            deltaNode = node;
        }
        if (startNode != null && endNode != null && deltaNode != null)
            break;
    }
    val start = TFGraphMapper.getInstance().getNDArrayFromTensor("value", startNode, graph);
    val end = TFGraphMapper.getInstance().getNDArrayFromTensor("value", endNode, graph);
    val delta = TFGraphMapper.getInstance().getNDArrayFromTensor("value", deltaNode, graph);
    if (start != null && end != null && delta != null) {
        val outputVars = outputVariables();
        this.from = start.getDouble(0);
        this.to = end.getDouble(0);
        this.delta = delta.getDouble(0);
        addTArgument(this.from, this.to, this.delta);
        val outputVertexId = outputVars[0].getVarName();
        if (sameDiff.getArrForVarName(outputVertexId) == null) {
            if (outputVars[0].getShape() == null) {
                val calcShape = calculateOutputShape();
                sameDiff.putShapeForVarName(outputVars[0].getVarName(), calcShape.get(0));
            }
            val arr = Nd4j.create(outputVars[0].getShape());
            initWith.putArrayForVarName(outputVertexId, arr);
            addOutputArgument(arr);
        }
    }
    val fromVar = initWith.getVariable(TFGraphMapper.getInstance().getNodeName(startNode.getName()));
    val toVar = initWith.getVariable(TFGraphMapper.getInstance().getNodeName(endNode.getName()));
    val deltaVar = initWith.getVariable(TFGraphMapper.getInstance().getNodeName(deltaNode.getName()));
    this.fromVertexId = fromVar.getVarName();
    this.toVertexId = toVar.getVarName();
    this.deltaVertexId = deltaVar.getVarName();
}
Also used : lombok.val(lombok.val) NodeDef(org.tensorflow.framework.NodeDef)

Example 3 with NodeDef

use of org.tensorflow.framework.NodeDef in project nd4j by deeplearning4j.

the class StridedSlice method initFromTensorFlow.

@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    val inputBegin = nodeDef.getInput(1);
    val inputEnd = nodeDef.getInput(2);
    val inputStrides = nodeDef.getInput(3);
    NodeDef beginNode = null;
    NodeDef endNode = null;
    NodeDef strides = null;
    for (int i = 0; i < graph.getNodeCount(); i++) {
        if (graph.getNode(i).getName().equals(inputBegin)) {
            beginNode = graph.getNode(i);
        }
        if (graph.getNode(i).getName().equals(inputEnd)) {
            endNode = graph.getNode(i);
        }
        if (graph.getNode(i).getName().equals(inputStrides)) {
            strides = graph.getNode(i);
        }
    }
    // bit masks for this slice
    val bm = nodeDef.getAttrOrThrow("begin_mask");
    val xm = nodeDef.getAttrOrThrow("ellipsis_mask");
    val em = nodeDef.getAttrOrThrow("end_mask");
    val nm = nodeDef.getAttrOrThrow("new_axis_mask");
    val sm = nodeDef.getAttrOrThrow("shrink_axis_mask");
    addIArgument((int) bm.getI());
    addIArgument((int) xm.getI());
    addIArgument((int) em.getI());
    addIArgument((int) nm.getI());
    addIArgument((int) sm.getI());
    val beginArr = TFGraphMapper.getInstance().getNDArrayFromTensor("value", beginNode, graph);
    val endArr = TFGraphMapper.getInstance().getNDArrayFromTensor("value", endNode, graph);
    val stridesArr = TFGraphMapper.getInstance().getNDArrayFromTensor("value", strides, graph);
    if (beginArr != null && endArr != null && stridesArr != null) {
        for (int e = 0; e < beginArr.length(); e++) addIArgument(beginArr.getInt(e));
        for (int e = 0; e < endArr.length(); e++) addIArgument(endArr.getInt(e));
        for (int e = 0; e < stridesArr.length(); e++) addIArgument(stridesArr.getInt(e));
    } else {
    // do nothing
    }
}
Also used : lombok.val(lombok.val) NodeDef(org.tensorflow.framework.NodeDef)

Example 4 with NodeDef

use of org.tensorflow.framework.NodeDef in project nd4j by deeplearning4j.

the class Transpose method initFromTensorFlow.

@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph);
    // permute dimensions re not specified as second input
    if (nodeDef.getInputCount() < 2)
        return;
    NodeDef permuteDimsNode = null;
    for (int i = 0; i < graph.getNodeCount(); i++) {
        if (graph.getNode(i).getName().equals(nodeDef.getInput(1))) {
            permuteDimsNode = graph.getNode(i);
        }
    }
    val permuteArrayOp = TFGraphMapper.getInstance().getNDArrayFromTensor("value", permuteDimsNode, graph);
    if (permuteArrayOp != null) {
        this.permuteDims = permuteArrayOp.data().asInt();
        for (int i = 0; i < permuteDims.length; i++) {
            addIArgument(permuteDims[i]);
        }
    }
    // handle once properly mapped
    if (arg().getShape() == null) {
        return;
    }
    INDArray arr = sameDiff.getArrForVarName(arg().getVarName());
    if (arr == null) {
        val arrVar = sameDiff.getVariable(arg().getVarName());
        arr = arrVar.getWeightInitScheme().create(arrVar.getShape());
        sameDiff.putArrayForVarName(arg().getVarName(), arr);
    }
    addInputArgument(arr);
    if (arr != null && permuteDims == null) {
        this.permuteDims = ArrayUtil.reverseCopy(ArrayUtil.range(0, arr.rank()));
    }
    if (permuteDims != null && permuteDims.length < arg().getShape().length)
        throw new ND4JIllegalStateException("Illegal permute found. Not all dimensions specified");
}
Also used : lombok.val(lombok.val) NodeDef(org.tensorflow.framework.NodeDef) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException)

Example 5 with NodeDef

use of org.tensorflow.framework.NodeDef in project nd4j by deeplearning4j.

the class BaseAccumulation method initFromTensorFlow.

@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    newFormat = true;
    if (!attributesForNode.containsKey("axis") && !hasReductionIndices(nodeDef)) {
        this.dimensions = new int[] { Integer.MAX_VALUE };
    } else if (hasReductionIndices(nodeDef)) {
        NodeDef reductionNode = null;
        for (int i = 0; i < graph.getNodeCount(); i++) {
            if (graph.getNode(i).getName().equals(nodeDef.getName() + "/reduction_indices")) {
                reductionNode = graph.getNode(i);
                val arr = TFGraphMapper.getInstance().getNDArrayFromTensor("value", reductionNode, graph);
                boolean keepAxis = nodeDef.getAttrOrThrow("keep_dims").getB();
                // keepAxis = false by default
                // int[] dimensions = ArrayUtils.add(arr.data().asInt(), 0, keepAxis ? 1 : 0);
                int[] dimensions = arr.data().asInt();
                this.dimensions = dimensions;
                break;
            }
        }
        if (reductionNode == null)
            throw new ND4JIllegalStateException("No node found!");
    } else {
        val dims = TFGraphMapper.getInstance().getNDArrayFromTensor("axis", nodeDef, graph).data().asInt();
        this.dimensions = dims;
    }
    if (attributesForNode.containsKey("keep_dims")) {
        val keepDims = attributesForNode.get("keep_dims").getB();
        this.keepDims = keepDims;
    }
}
Also used : lombok.val(lombok.val) NodeDef(org.tensorflow.framework.NodeDef) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException)

Aggregations

NodeDef (org.tensorflow.framework.NodeDef)7 lombok.val (lombok.val)6 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)2 TensorFlowOperation (com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.TensorFlowOperation)1 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1