use of org.tensorflow.framework.NodeDef in project vespa by vespa-engine.
the class TensorFlowImporter method importNode.
private static TensorFlowOperation importNode(String modelName, String nodeName, GraphDef graph, OperationIndex index) {
if (index.alreadyImported(nodeName)) {
return index.get(nodeName);
}
NodeDef node = getTensorFlowNodeFromGraph(namePartOf(nodeName), graph);
List<TensorFlowOperation> inputs = importNodeInputs(modelName, node, graph, index);
TensorFlowOperation operation = OperationMapper.get(modelName, node, inputs, portPartOf(nodeName));
index.put(nodeName, operation);
List<TensorFlowOperation> controlInputs = importControlInputs(modelName, node, graph, index);
if (controlInputs.size() > 0) {
operation.setControlInputs(controlInputs);
}
return operation;
}
use of org.tensorflow.framework.NodeDef in project nd4j by deeplearning4j.
the class Range method initFromTensorFlow.
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph);
NodeDef startNode = null, endNode = null, deltaNode = null;
for (val node : graph.getNodeList()) {
if (node.getName().equals(nodeDef.getInput(0))) {
startNode = node;
}
if (node.getName().equals(nodeDef.getInput(1))) {
endNode = node;
}
if (node.getName().equals(nodeDef.getInput(2))) {
deltaNode = node;
}
if (startNode != null && endNode != null && deltaNode != null)
break;
}
val start = TFGraphMapper.getInstance().getNDArrayFromTensor("value", startNode, graph);
val end = TFGraphMapper.getInstance().getNDArrayFromTensor("value", endNode, graph);
val delta = TFGraphMapper.getInstance().getNDArrayFromTensor("value", deltaNode, graph);
if (start != null && end != null && delta != null) {
val outputVars = outputVariables();
this.from = start.getDouble(0);
this.to = end.getDouble(0);
this.delta = delta.getDouble(0);
addTArgument(this.from, this.to, this.delta);
val outputVertexId = outputVars[0].getVarName();
if (sameDiff.getArrForVarName(outputVertexId) == null) {
if (outputVars[0].getShape() == null) {
val calcShape = calculateOutputShape();
sameDiff.putShapeForVarName(outputVars[0].getVarName(), calcShape.get(0));
}
val arr = Nd4j.create(outputVars[0].getShape());
initWith.putArrayForVarName(outputVertexId, arr);
addOutputArgument(arr);
}
}
val fromVar = initWith.getVariable(TFGraphMapper.getInstance().getNodeName(startNode.getName()));
val toVar = initWith.getVariable(TFGraphMapper.getInstance().getNodeName(endNode.getName()));
val deltaVar = initWith.getVariable(TFGraphMapper.getInstance().getNodeName(deltaNode.getName()));
this.fromVertexId = fromVar.getVarName();
this.toVertexId = toVar.getVarName();
this.deltaVertexId = deltaVar.getVarName();
}
use of org.tensorflow.framework.NodeDef in project nd4j by deeplearning4j.
the class StridedSlice method initFromTensorFlow.
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
val inputBegin = nodeDef.getInput(1);
val inputEnd = nodeDef.getInput(2);
val inputStrides = nodeDef.getInput(3);
NodeDef beginNode = null;
NodeDef endNode = null;
NodeDef strides = null;
for (int i = 0; i < graph.getNodeCount(); i++) {
if (graph.getNode(i).getName().equals(inputBegin)) {
beginNode = graph.getNode(i);
}
if (graph.getNode(i).getName().equals(inputEnd)) {
endNode = graph.getNode(i);
}
if (graph.getNode(i).getName().equals(inputStrides)) {
strides = graph.getNode(i);
}
}
// bit masks for this slice
val bm = nodeDef.getAttrOrThrow("begin_mask");
val xm = nodeDef.getAttrOrThrow("ellipsis_mask");
val em = nodeDef.getAttrOrThrow("end_mask");
val nm = nodeDef.getAttrOrThrow("new_axis_mask");
val sm = nodeDef.getAttrOrThrow("shrink_axis_mask");
addIArgument((int) bm.getI());
addIArgument((int) xm.getI());
addIArgument((int) em.getI());
addIArgument((int) nm.getI());
addIArgument((int) sm.getI());
val beginArr = TFGraphMapper.getInstance().getNDArrayFromTensor("value", beginNode, graph);
val endArr = TFGraphMapper.getInstance().getNDArrayFromTensor("value", endNode, graph);
val stridesArr = TFGraphMapper.getInstance().getNDArrayFromTensor("value", strides, graph);
if (beginArr != null && endArr != null && stridesArr != null) {
for (int e = 0; e < beginArr.length(); e++) addIArgument(beginArr.getInt(e));
for (int e = 0; e < endArr.length(); e++) addIArgument(endArr.getInt(e));
for (int e = 0; e < stridesArr.length(); e++) addIArgument(stridesArr.getInt(e));
} else {
// do nothing
}
}
use of org.tensorflow.framework.NodeDef in project nd4j by deeplearning4j.
the class Transpose method initFromTensorFlow.
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph);
// permute dimensions re not specified as second input
if (nodeDef.getInputCount() < 2)
return;
NodeDef permuteDimsNode = null;
for (int i = 0; i < graph.getNodeCount(); i++) {
if (graph.getNode(i).getName().equals(nodeDef.getInput(1))) {
permuteDimsNode = graph.getNode(i);
}
}
val permuteArrayOp = TFGraphMapper.getInstance().getNDArrayFromTensor("value", permuteDimsNode, graph);
if (permuteArrayOp != null) {
this.permuteDims = permuteArrayOp.data().asInt();
for (int i = 0; i < permuteDims.length; i++) {
addIArgument(permuteDims[i]);
}
}
// handle once properly mapped
if (arg().getShape() == null) {
return;
}
INDArray arr = sameDiff.getArrForVarName(arg().getVarName());
if (arr == null) {
val arrVar = sameDiff.getVariable(arg().getVarName());
arr = arrVar.getWeightInitScheme().create(arrVar.getShape());
sameDiff.putArrayForVarName(arg().getVarName(), arr);
}
addInputArgument(arr);
if (arr != null && permuteDims == null) {
this.permuteDims = ArrayUtil.reverseCopy(ArrayUtil.range(0, arr.rank()));
}
if (permuteDims != null && permuteDims.length < arg().getShape().length)
throw new ND4JIllegalStateException("Illegal permute found. Not all dimensions specified");
}
use of org.tensorflow.framework.NodeDef in project nd4j by deeplearning4j.
the class BaseAccumulation method initFromTensorFlow.
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
newFormat = true;
if (!attributesForNode.containsKey("axis") && !hasReductionIndices(nodeDef)) {
this.dimensions = new int[] { Integer.MAX_VALUE };
} else if (hasReductionIndices(nodeDef)) {
NodeDef reductionNode = null;
for (int i = 0; i < graph.getNodeCount(); i++) {
if (graph.getNode(i).getName().equals(nodeDef.getName() + "/reduction_indices")) {
reductionNode = graph.getNode(i);
val arr = TFGraphMapper.getInstance().getNDArrayFromTensor("value", reductionNode, graph);
boolean keepAxis = nodeDef.getAttrOrThrow("keep_dims").getB();
// keepAxis = false by default
// int[] dimensions = ArrayUtils.add(arr.data().asInt(), 0, keepAxis ? 1 : 0);
int[] dimensions = arr.data().asInt();
this.dimensions = dimensions;
break;
}
}
if (reductionNode == null)
throw new ND4JIllegalStateException("No node found!");
} else {
val dims = TFGraphMapper.getInstance().getNDArrayFromTensor("axis", nodeDef, graph).data().asInt();
this.dimensions = dims;
}
if (attributesForNode.containsKey("keep_dims")) {
val keepDims = attributesForNode.get("keep_dims").getB();
this.keepDims = keepDims;
}
}
Aggregations