Search in sources :

Example 1 with TensorShapeProto

use of org.tensorflow.framework.TensorShapeProto in project vespa by vespa-engine.

the class OrderedTensorType method tensorFlowShape.

private static TensorShapeProto tensorFlowShape(NodeDef node) {
    AttrValue attrValueList = node.getAttrMap().get("_output_shapes");
    if (attrValueList == null) {
        throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " + "does not exist");
    }
    if (attrValueList.getValueCase() != AttrValue.ValueCase.LIST) {
        throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " + "is not of expected type");
    }
    List<TensorShapeProto> shapeList = attrValueList.getList().getShapeList();
    // support multiple outputs?
    return shapeList.get(0);
}
Also used : TensorShapeProto(org.tensorflow.framework.TensorShapeProto) AttrValue(org.tensorflow.framework.AttrValue)

Example 2 with TensorShapeProto

use of org.tensorflow.framework.TensorShapeProto in project vespa by vespa-engine.

the class OrderedTensorType method verifyType.

public void verifyType(NodeDef node) {
    TensorShapeProto shape = tensorFlowShape(node);
    if (shape != null) {
        if (shape.getDimCount() != type.rank()) {
            throw new IllegalArgumentException("TensorFlow shape of '" + node.getName() + "' " + "does not match Vespa shape");
        }
        for (int tensorFlowIndex = 0; tensorFlowIndex < dimensions.size(); ++tensorFlowIndex) {
            int vespaIndex = dimensionMap[tensorFlowIndex];
            TensorShapeProto.Dim tensorFlowDimension = shape.getDim(tensorFlowIndex);
            TensorType.Dimension vespaDimension = type().dimensions().get(vespaIndex);
            if (tensorFlowDimension.getSize() != vespaDimension.size().orElse(-1L)) {
                throw new IllegalArgumentException("TensorFlow dimensions of '" + node.getName() + "' " + "does not match Vespa dimensions");
            }
        }
    }
}
Also used : TensorShapeProto(org.tensorflow.framework.TensorShapeProto) TensorType(com.yahoo.tensor.TensorType)

Example 3 with TensorShapeProto

use of org.tensorflow.framework.TensorShapeProto in project vespa by vespa-engine.

the class OrderedTensorType method fromTensorFlowType.

public static OrderedTensorType fromTensorFlowType(NodeDef node, String dimensionPrefix) {
    Builder builder = new Builder(node);
    TensorShapeProto shape = tensorFlowShape(node);
    for (int i = 0; i < shape.getDimCount(); ++i) {
        String dimensionName = dimensionPrefix + i;
        TensorShapeProto.Dim tensorFlowDimension = shape.getDim(i);
        if (tensorFlowDimension.getSize() >= 0) {
            builder.add(TensorType.Dimension.indexed(dimensionName, tensorFlowDimension.getSize()));
        } else {
            builder.add(TensorType.Dimension.indexed(dimensionName));
        }
    }
    return builder.build();
}
Also used : TensorShapeProto(org.tensorflow.framework.TensorShapeProto)

Aggregations

TensorShapeProto (org.tensorflow.framework.TensorShapeProto)3 TensorType (com.yahoo.tensor.TensorType)1 AttrValue (org.tensorflow.framework.AttrValue)1