use of com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode in project vespa by vespa-engine.
the class TensorFlowFeatureConverter method expandBatchDimensionsAtOutput.
/**
* If batch dimensions have been reduced away above, bring them back here
* for any following computation of the tensor.
* Todo: determine when this is not necessary!
*/
private ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) {
if (after.equals(before)) {
return node;
}
TensorType.Builder typeBuilder = new TensorType.Builder();
for (TensorType.Dimension dimension : before.dimensions()) {
if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) {
typeBuilder.indexed(dimension.name(), 1);
}
}
TensorType expandDimensionsType = typeBuilder.build();
if (expandDimensionsType.dimensions().size() > 0) {
ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1.0));
Generate generatedFunction = new Generate(expandDimensionsType, new GeneratorLambdaFunctionNode(expandDimensionsType, generatedExpression).asLongListToDoubleOperator());
Join expand = new Join(TensorFunctionNode.wrapArgument(node), generatedFunction, ScalarFunctions.multiply());
return new TensorFunctionNode(expand);
}
return node;
}
use of com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode in project vespa by vespa-engine.
the class TensorFlowFeatureConverter method reduceBatchDimensionExpression.
private ExpressionNode reduceBatchDimensionExpression(TensorFunction function, TypeContext<Reference> context) {
TensorFunction result = function;
TensorType type = function.type(context);
if (type.dimensions().size() > 1) {
List<String> reduceDimensions = new ArrayList<>();
for (TensorType.Dimension dimension : type.dimensions()) {
if (dimension.size().orElse(-1L) == 1) {
reduceDimensions.add(dimension.name());
}
}
if (reduceDimensions.size() > 0) {
result = new Reduce(function, Reduce.Aggregator.sum, reduceDimensions);
}
}
return new TensorFunctionNode(result);
}
use of com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode in project vespa by vespa-engine.
the class TensorFlowFeatureConverter method reduceBatchDimensionsAtInput.
private ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, TensorFlowModel model, TypeContext<Reference> typeContext) {
if (node instanceof TensorFunctionNode) {
TensorFunction tensorFunction = ((TensorFunctionNode) node).function();
if (tensorFunction instanceof Rename) {
List<ExpressionNode> children = ((TensorFunctionNode) node).children();
if (children.size() == 1 && children.get(0) instanceof ReferenceNode) {
ReferenceNode referenceNode = (ReferenceNode) children.get(0);
if (model.requiredMacros().containsKey(referenceNode.getName())) {
return reduceBatchDimensionExpression(tensorFunction, typeContext);
}
}
}
}
if (node instanceof ReferenceNode) {
ReferenceNode referenceNode = (ReferenceNode) node;
if (model.requiredMacros().containsKey(referenceNode.getName())) {
return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), typeContext);
}
}
if (node instanceof CompositeNode) {
List<ExpressionNode> children = ((CompositeNode) node).children();
List<ExpressionNode> transformedChildren = new ArrayList<>(children.size());
for (ExpressionNode child : children) {
transformedChildren.add(reduceBatchDimensionsAtInput(child, model, typeContext));
}
return ((CompositeNode) node).setChildren(transformedChildren);
}
return node;
}
use of com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode in project vespa by vespa-engine.
the class TensorTransformer method replaceMaxAndMinFunction.
private ExpressionNode replaceMaxAndMinFunction(FunctionNode node) {
ExpressionNode arg1 = node.children().get(0);
ExpressionNode arg2 = node.children().get(1);
TensorFunctionNode.TensorFunctionExpressionNode expression = TensorFunctionNode.wrapArgument(arg1);
Reduce.Aggregator aggregator = Reduce.Aggregator.valueOf(node.getFunction().name());
String dimension = ((ReferenceNode) arg2).getName();
return new TensorFunctionNode(new Reduce(expression, aggregator, dimension));
}
use of com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode in project vespa by vespa-engine.
the class RankingExpressionTestCase method testProgrammaticBuilding.
@Test
public void testProgrammaticBuilding() throws ParseException {
ReferenceNode input = new ReferenceNode("input");
ReferenceNode constant = new ReferenceNode("constant");
ArithmeticNode product = new ArithmeticNode(input, ArithmeticOperator.MULTIPLY, constant);
Reduce sum = new Reduce(new TensorFunctionNode.TensorFunctionExpressionNode(product), Reduce.Aggregator.sum);
RankingExpression expression = new RankingExpression(new TensorFunctionNode(sum));
RankingExpression expected = new RankingExpression("sum(input * constant)");
assertEquals(expected.toString(), expression.toString());
}
Aggregations