Search in sources :

Example 1 with AttrValue

use of org.tensorflow.framework.AttrValue in project vespa by vespa-engine.

the class Squeeze method lazyGetType.

@Override
protected OrderedTensorType lazyGetType() {
    if (!allInputTypesPresent(1)) {
        return null;
    }
    OrderedTensorType inputType = inputs.get(0).type().get();
    squeezeDimensions = new ArrayList<>();
    AttrValue squeezeDimsAttr = node.getAttrMap().get("squeeze_dims");
    if (squeezeDimsAttr == null) {
        squeezeDimensions = inputType.type().dimensions().stream().filter(dim -> TensorConverter.dimensionSize(dim) == 1).map(TensorType.Dimension::name).collect(Collectors.toList());
    } else {
        squeezeDimensions = squeezeDimsAttr.getList().getIList().stream().map(i -> i < 0 ? inputType.type().dimensions().size() - i : i).map(i -> inputType.type().dimensions().get(i.intValue())).filter(dim -> TensorConverter.dimensionSize(dim) == 1).map(TensorType.Dimension::name).collect(Collectors.toList());
    }
    return squeezeDimensions.isEmpty() ? inputType : reducedType(inputType);
}
Also used : TensorConverter(com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter) List(java.util.List) NodeDef(org.tensorflow.framework.NodeDef) AttrValue(org.tensorflow.framework.AttrValue) Optional(java.util.Optional) Reduce(com.yahoo.tensor.functions.Reduce) OrderedTensorType(com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType) TensorType(com.yahoo.tensor.TensorType) DimensionRenamer(com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer) TensorFunction(com.yahoo.tensor.functions.TensorFunction) Collectors(java.util.stream.Collectors) ArrayList(java.util.ArrayList) OrderedTensorType(com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType) AttrValue(org.tensorflow.framework.AttrValue)

Example 2 with AttrValue

use of org.tensorflow.framework.AttrValue in project vespa by vespa-engine.

the class OrderedTensorType method tensorFlowShape.

private static TensorShapeProto tensorFlowShape(NodeDef node) {
    AttrValue attrValueList = node.getAttrMap().get("_output_shapes");
    if (attrValueList == null) {
        throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " + "does not exist");
    }
    if (attrValueList.getValueCase() != AttrValue.ValueCase.LIST) {
        throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " + "is not of expected type");
    }
    List<TensorShapeProto> shapeList = attrValueList.getList().getShapeList();
    // support multiple outputs?
    return shapeList.get(0);
}
Also used : TensorShapeProto(org.tensorflow.framework.TensorShapeProto) AttrValue(org.tensorflow.framework.AttrValue)

Aggregations

AttrValue (org.tensorflow.framework.AttrValue)2 DimensionRenamer (com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer)1 OrderedTensorType (com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType)1 TensorConverter (com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter)1 TensorType (com.yahoo.tensor.TensorType)1 Reduce (com.yahoo.tensor.functions.Reduce)1 TensorFunction (com.yahoo.tensor.functions.TensorFunction)1 ArrayList (java.util.ArrayList)1 List (java.util.List)1 Optional (java.util.Optional)1 Collectors (java.util.stream.Collectors)1 NodeDef (org.tensorflow.framework.NodeDef)1 TensorShapeProto (org.tensorflow.framework.TensorShapeProto)1