Search in sources :

Example 1 with TensorFunctionNode

use of com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode in project vespa by vespa-engine.

the class TensorFlowFeatureConverter method expandBatchDimensionsAtOutput.

/**
 * If batch dimensions have been reduced away above, bring them back here
 * for any following computation of the tensor.
 * Todo: determine when this is not necessary!
 */
private ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) {
    if (after.equals(before)) {
        return node;
    }
    TensorType.Builder typeBuilder = new TensorType.Builder();
    for (TensorType.Dimension dimension : before.dimensions()) {
        if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) {
            typeBuilder.indexed(dimension.name(), 1);
        }
    }
    TensorType expandDimensionsType = typeBuilder.build();
    if (expandDimensionsType.dimensions().size() > 0) {
        ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1.0));
        Generate generatedFunction = new Generate(expandDimensionsType, new GeneratorLambdaFunctionNode(expandDimensionsType, generatedExpression).asLongListToDoubleOperator());
        Join expand = new Join(TensorFunctionNode.wrapArgument(node), generatedFunction, ScalarFunctions.multiply());
        return new TensorFunctionNode(expand);
    }
    return node;
}
Also used : GeneratorLambdaFunctionNode(com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode) ConstantNode(com.yahoo.searchlib.rankingexpression.rule.ConstantNode) DoubleValue(com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue) TensorFunctionNode(com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode) ExpressionNode(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode) Generate(com.yahoo.tensor.functions.Generate) Join(com.yahoo.tensor.functions.Join) TensorType(com.yahoo.tensor.TensorType)

Example 2 with TensorFunctionNode

use of com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode in project vespa by vespa-engine.

the class TensorFlowFeatureConverter method reduceBatchDimensionExpression.

private ExpressionNode reduceBatchDimensionExpression(TensorFunction function, TypeContext<Reference> context) {
    TensorFunction result = function;
    TensorType type = function.type(context);
    if (type.dimensions().size() > 1) {
        List<String> reduceDimensions = new ArrayList<>();
        for (TensorType.Dimension dimension : type.dimensions()) {
            if (dimension.size().orElse(-1L) == 1) {
                reduceDimensions.add(dimension.name());
            }
        }
        if (reduceDimensions.size() > 0) {
            result = new Reduce(function, Reduce.Aggregator.sum, reduceDimensions);
        }
    }
    return new TensorFunctionNode(result);
}
Also used : TensorFunction(com.yahoo.tensor.functions.TensorFunction) TensorFunctionNode(com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode) ArrayList(java.util.ArrayList) TensorType(com.yahoo.tensor.TensorType) Reduce(com.yahoo.tensor.functions.Reduce)

Example 3 with TensorFunctionNode

use of com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode in project vespa by vespa-engine.

the class TensorFlowFeatureConverter method reduceBatchDimensionsAtInput.

private ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, TensorFlowModel model, TypeContext<Reference> typeContext) {
    if (node instanceof TensorFunctionNode) {
        TensorFunction tensorFunction = ((TensorFunctionNode) node).function();
        if (tensorFunction instanceof Rename) {
            List<ExpressionNode> children = ((TensorFunctionNode) node).children();
            if (children.size() == 1 && children.get(0) instanceof ReferenceNode) {
                ReferenceNode referenceNode = (ReferenceNode) children.get(0);
                if (model.requiredMacros().containsKey(referenceNode.getName())) {
                    return reduceBatchDimensionExpression(tensorFunction, typeContext);
                }
            }
        }
    }
    if (node instanceof ReferenceNode) {
        ReferenceNode referenceNode = (ReferenceNode) node;
        if (model.requiredMacros().containsKey(referenceNode.getName())) {
            return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), typeContext);
        }
    }
    if (node instanceof CompositeNode) {
        List<ExpressionNode> children = ((CompositeNode) node).children();
        List<ExpressionNode> transformedChildren = new ArrayList<>(children.size());
        for (ExpressionNode child : children) {
            transformedChildren.add(reduceBatchDimensionsAtInput(child, model, typeContext));
        }
        return ((CompositeNode) node).setChildren(transformedChildren);
    }
    return node;
}
Also used : CompositeNode(com.yahoo.searchlib.rankingexpression.rule.CompositeNode) TensorFunctionNode(com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode) TensorFunction(com.yahoo.tensor.functions.TensorFunction) ReferenceNode(com.yahoo.searchlib.rankingexpression.rule.ReferenceNode) ExpressionNode(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode) ArrayList(java.util.ArrayList) Rename(com.yahoo.tensor.functions.Rename)

Example 4 with TensorFunctionNode

use of com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode in project vespa by vespa-engine.

the class TensorTransformer method replaceMaxAndMinFunction.

private ExpressionNode replaceMaxAndMinFunction(FunctionNode node) {
    ExpressionNode arg1 = node.children().get(0);
    ExpressionNode arg2 = node.children().get(1);
    TensorFunctionNode.TensorFunctionExpressionNode expression = TensorFunctionNode.wrapArgument(arg1);
    Reduce.Aggregator aggregator = Reduce.Aggregator.valueOf(node.getFunction().name());
    String dimension = ((ReferenceNode) arg2).getName();
    return new TensorFunctionNode(new Reduce(expression, aggregator, dimension));
}
Also used : TensorFunctionNode(com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode) ReferenceNode(com.yahoo.searchlib.rankingexpression.rule.ReferenceNode) ExpressionNode(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode) Reduce(com.yahoo.tensor.functions.Reduce)

Example 5 with TensorFunctionNode

use of com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode in project vespa by vespa-engine.

the class RankingExpressionTestCase method testProgrammaticBuilding.

@Test
public void testProgrammaticBuilding() throws ParseException {
    ReferenceNode input = new ReferenceNode("input");
    ReferenceNode constant = new ReferenceNode("constant");
    ArithmeticNode product = new ArithmeticNode(input, ArithmeticOperator.MULTIPLY, constant);
    Reduce sum = new Reduce(new TensorFunctionNode.TensorFunctionExpressionNode(product), Reduce.Aggregator.sum);
    RankingExpression expression = new RankingExpression(new TensorFunctionNode(sum));
    RankingExpression expected = new RankingExpression("sum(input * constant)");
    assertEquals(expected.toString(), expression.toString());
}
Also used : ArithmeticNode(com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode) ReferenceNode(com.yahoo.searchlib.rankingexpression.rule.ReferenceNode) TensorFunctionNode(com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode) Reduce(com.yahoo.tensor.functions.Reduce) Test(org.junit.Test)

Aggregations

TensorFunctionNode (com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode)5 ExpressionNode (com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)3 ReferenceNode (com.yahoo.searchlib.rankingexpression.rule.ReferenceNode)3 Reduce (com.yahoo.tensor.functions.Reduce)3 TensorType (com.yahoo.tensor.TensorType)2 TensorFunction (com.yahoo.tensor.functions.TensorFunction)2 ArrayList (java.util.ArrayList)2 DoubleValue (com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue)1 ArithmeticNode (com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode)1 CompositeNode (com.yahoo.searchlib.rankingexpression.rule.CompositeNode)1 ConstantNode (com.yahoo.searchlib.rankingexpression.rule.ConstantNode)1 GeneratorLambdaFunctionNode (com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode)1 Generate (com.yahoo.tensor.functions.Generate)1 Join (com.yahoo.tensor.functions.Join)1 Rename (com.yahoo.tensor.functions.Rename)1 Test (org.junit.Test)1