use of org.tensorflow.framework.AttrValue in project vespa by vespa-engine.
the class Squeeze method lazyGetType.
@Override
protected OrderedTensorType lazyGetType() {
if (!allInputTypesPresent(1)) {
return null;
}
OrderedTensorType inputType = inputs.get(0).type().get();
squeezeDimensions = new ArrayList<>();
AttrValue squeezeDimsAttr = node.getAttrMap().get("squeeze_dims");
if (squeezeDimsAttr == null) {
squeezeDimensions = inputType.type().dimensions().stream().filter(dim -> TensorConverter.dimensionSize(dim) == 1).map(TensorType.Dimension::name).collect(Collectors.toList());
} else {
squeezeDimensions = squeezeDimsAttr.getList().getIList().stream().map(i -> i < 0 ? inputType.type().dimensions().size() - i : i).map(i -> inputType.type().dimensions().get(i.intValue())).filter(dim -> TensorConverter.dimensionSize(dim) == 1).map(TensorType.Dimension::name).collect(Collectors.toList());
}
return squeezeDimensions.isEmpty() ? inputType : reducedType(inputType);
}
use of org.tensorflow.framework.AttrValue in project vespa by vespa-engine.
the class OrderedTensorType method tensorFlowShape.
private static TensorShapeProto tensorFlowShape(NodeDef node) {
AttrValue attrValueList = node.getAttrMap().get("_output_shapes");
if (attrValueList == null) {
throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " + "does not exist");
}
if (attrValueList.getValueCase() != AttrValue.ValueCase.LIST) {
throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " + "is not of expected type");
}
List<TensorShapeProto> shapeList = attrValueList.getList().getShapeList();
// support multiple outputs?
return shapeList.get(0);
}
Aggregations