use of org.tensorflow.framework.NodeDef in project nd4j by deeplearning4j.
the class Slice method initFromTensorFlow.
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
/*
strided slice typically takes 4 tensor arguments:
0) input, it's shape determines number of elements in other arguments
1) begin indices
2) end indices
3) strides
*/
val inputBegin = nodeDef.getInput(1);
val inputEnd = nodeDef.getInput(2);
NodeDef beginNode = null;
NodeDef endNode = null;
for (int i = 0; i < graph.getNodeCount(); i++) {
if (graph.getNode(i).getName().equals(inputBegin)) {
beginNode = graph.getNode(i);
}
if (graph.getNode(i).getName().equals(inputEnd)) {
endNode = graph.getNode(i);
}
}
val beginArr = TFGraphMapper.getInstance().getNDArrayFromTensor("value", beginNode, graph);
val endArr = TFGraphMapper.getInstance().getNDArrayFromTensor("value", endNode, graph);
if (beginArr != null && endArr != null) {
for (int e = 0; e < beginArr.length(); e++) addIArgument(beginArr.getInt(e));
for (int e = 0; e < endArr.length(); e++) addIArgument(endArr.getInt(e));
} else {
// do nothing
}
}
use of org.tensorflow.framework.NodeDef in project nd4j by deeplearning4j.
the class TensorArrayV3 method initFromTensorFlow.
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
val idd = nodeDef.getInput(nodeDef.getInputCount() - 1);
NodeDef iddNode = null;
for (int i = 0; i < graph.getNodeCount(); i++) {
if (graph.getNode(i).getName().equals(idd)) {
iddNode = graph.getNode(i);
}
}
val arr = TFGraphMapper.getInstance().getNDArrayFromTensor("value", iddNode, graph);
if (arr != null) {
int idx = arr.getInt(0);
addIArgument(idx);
}
}
Aggregations