use of org.tensorflow.framework.TensorShapeProto in project vespa by vespa-engine.
the class OrderedTensorType method tensorFlowShape.
private static TensorShapeProto tensorFlowShape(NodeDef node) {
AttrValue attrValueList = node.getAttrMap().get("_output_shapes");
if (attrValueList == null) {
throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " + "does not exist");
}
if (attrValueList.getValueCase() != AttrValue.ValueCase.LIST) {
throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " + "is not of expected type");
}
List<TensorShapeProto> shapeList = attrValueList.getList().getShapeList();
// support multiple outputs?
return shapeList.get(0);
}
use of org.tensorflow.framework.TensorShapeProto in project vespa by vespa-engine.
the class OrderedTensorType method verifyType.
public void verifyType(NodeDef node) {
TensorShapeProto shape = tensorFlowShape(node);
if (shape != null) {
if (shape.getDimCount() != type.rank()) {
throw new IllegalArgumentException("TensorFlow shape of '" + node.getName() + "' " + "does not match Vespa shape");
}
for (int tensorFlowIndex = 0; tensorFlowIndex < dimensions.size(); ++tensorFlowIndex) {
int vespaIndex = dimensionMap[tensorFlowIndex];
TensorShapeProto.Dim tensorFlowDimension = shape.getDim(tensorFlowIndex);
TensorType.Dimension vespaDimension = type().dimensions().get(vespaIndex);
if (tensorFlowDimension.getSize() != vespaDimension.size().orElse(-1L)) {
throw new IllegalArgumentException("TensorFlow dimensions of '" + node.getName() + "' " + "does not match Vespa dimensions");
}
}
}
}
use of org.tensorflow.framework.TensorShapeProto in project vespa by vespa-engine.
the class OrderedTensorType method fromTensorFlowType.
public static OrderedTensorType fromTensorFlowType(NodeDef node, String dimensionPrefix) {
Builder builder = new Builder(node);
TensorShapeProto shape = tensorFlowShape(node);
for (int i = 0; i < shape.getDimCount(); ++i) {
String dimensionName = dimensionPrefix + i;
TensorShapeProto.Dim tensorFlowDimension = shape.getDim(i);
if (tensorFlowDimension.getSize() >= 0) {
builder.add(TensorType.Dimension.indexed(dimensionName, tensorFlowDimension.getSize()));
} else {
builder.add(TensorType.Dimension.indexed(dimensionName));
}
}
return builder.build();
}
Aggregations