use of com.yahoo.tensor.functions.Rename in project vespa by vespa-engine.
the class TensorFlowImporter method importRankingExpression.
private static void importRankingExpression(TensorFlowModel model, TensorFlowOperation operation) {
if (operation.function().isPresent()) {
String name = operation.node().getName();
if (!model.expressions().containsKey(operation.node().getName())) {
TensorFunction function = operation.function().get();
// Make sure output adheres to standard naming convention
if (isSignatureOutput(model, operation)) {
OrderedTensorType operationType = operation.type().get();
OrderedTensorType standardNamingType = OrderedTensorType.fromTensorFlowType(operation.node());
if (!operationType.equals(standardNamingType)) {
List<String> renameFrom = operationType.dimensionNames();
List<String> renameTo = standardNamingType.dimensionNames();
function = new Rename(function, renameFrom, renameTo);
}
}
try {
// We add all intermediate nodes imported as separate expressions. Only
// those referenced in a signature output will be used. We parse the
// TensorFunction here to convert it to a RankingExpression tree.
model.expression(name, new RankingExpression(name, function.toString()));
} catch (ParseException e) {
throw new RuntimeException("Tensorflow function " + function + " cannot be parsed as a ranking expression", e);
}
}
}
}
use of com.yahoo.tensor.functions.Rename in project vespa by vespa-engine.
the class TensorFlowFeatureConverter method reduceBatchDimensionsAtInput.
private ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, TensorFlowModel model, TypeContext<Reference> typeContext) {
if (node instanceof TensorFunctionNode) {
TensorFunction tensorFunction = ((TensorFunctionNode) node).function();
if (tensorFunction instanceof Rename) {
List<ExpressionNode> children = ((TensorFunctionNode) node).children();
if (children.size() == 1 && children.get(0) instanceof ReferenceNode) {
ReferenceNode referenceNode = (ReferenceNode) children.get(0);
if (model.requiredMacros().containsKey(referenceNode.getName())) {
return reduceBatchDimensionExpression(tensorFunction, typeContext);
}
}
}
}
if (node instanceof ReferenceNode) {
ReferenceNode referenceNode = (ReferenceNode) node;
if (model.requiredMacros().containsKey(referenceNode.getName())) {
return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), typeContext);
}
}
if (node instanceof CompositeNode) {
List<ExpressionNode> children = ((CompositeNode) node).children();
List<ExpressionNode> transformedChildren = new ArrayList<>(children.size());
for (ExpressionNode child : children) {
transformedChildren.add(reduceBatchDimensionsAtInput(child, model, typeContext));
}
return ((CompositeNode) node).setChildren(transformedChildren);
}
return node;
}
use of com.yahoo.tensor.functions.Rename in project vespa by vespa-engine.
the class Placeholder method lazyGetFunction.
@Override
protected TensorFunction lazyGetFunction() {
TensorFunction output = new VariableTensor(vespaName(), standardNamingType.type());
if (!standardNamingType.equals(type)) {
List<String> renameFrom = standardNamingType.dimensionNames();
List<String> renameTo = type.dimensionNames();
output = new Rename(output, renameFrom, renameTo);
}
return output;
}
Aggregations