Search in sources :

Example 6 with NodeDef

use of org.tensorflow.framework.NodeDef in project nd4j by deeplearning4j.

the class Slice method initFromTensorFlow.

@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    /*
            strided slice typically takes 4 tensor arguments:
            0) input, it's shape determines number of elements in other arguments
            1) begin indices
            2) end indices
            3) strides
         */
    val inputBegin = nodeDef.getInput(1);
    val inputEnd = nodeDef.getInput(2);
    NodeDef beginNode = null;
    NodeDef endNode = null;
    for (int i = 0; i < graph.getNodeCount(); i++) {
        if (graph.getNode(i).getName().equals(inputBegin)) {
            beginNode = graph.getNode(i);
        }
        if (graph.getNode(i).getName().equals(inputEnd)) {
            endNode = graph.getNode(i);
        }
    }
    val beginArr = TFGraphMapper.getInstance().getNDArrayFromTensor("value", beginNode, graph);
    val endArr = TFGraphMapper.getInstance().getNDArrayFromTensor("value", endNode, graph);
    if (beginArr != null && endArr != null) {
        for (int e = 0; e < beginArr.length(); e++) addIArgument(beginArr.getInt(e));
        for (int e = 0; e < endArr.length(); e++) addIArgument(endArr.getInt(e));
    } else {
    // do nothing
    }
}
Also used : lombok.val(lombok.val) NodeDef(org.tensorflow.framework.NodeDef)

Example 7 with NodeDef

use of org.tensorflow.framework.NodeDef in project nd4j by deeplearning4j.

the class TensorArrayV3 method initFromTensorFlow.

@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    val idd = nodeDef.getInput(nodeDef.getInputCount() - 1);
    NodeDef iddNode = null;
    for (int i = 0; i < graph.getNodeCount(); i++) {
        if (graph.getNode(i).getName().equals(idd)) {
            iddNode = graph.getNode(i);
        }
    }
    val arr = TFGraphMapper.getInstance().getNDArrayFromTensor("value", iddNode, graph);
    if (arr != null) {
        int idx = arr.getInt(0);
        addIArgument(idx);
    }
}
Also used : lombok.val(lombok.val) NodeDef(org.tensorflow.framework.NodeDef)

Aggregations

NodeDef (org.tensorflow.framework.NodeDef)7 lombok.val (lombok.val)6 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)2 TensorFlowOperation (com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.TensorFlowOperation)1 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1