Search in sources :

Example 1 with ConstantNode

use of com.yahoo.searchlib.rankingexpression.rule.ConstantNode in project vespa by vespa-engine.

the class TensorFlowFeatureConverter method expandBatchDimensionsAtOutput.

/**
 * If batch dimensions have been reduced away above, bring them back here
 * for any following computation of the tensor.
 * Todo: determine when this is not necessary!
 */
private ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) {
    if (after.equals(before)) {
        return node;
    }
    TensorType.Builder typeBuilder = new TensorType.Builder();
    for (TensorType.Dimension dimension : before.dimensions()) {
        if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) {
            typeBuilder.indexed(dimension.name(), 1);
        }
    }
    TensorType expandDimensionsType = typeBuilder.build();
    if (expandDimensionsType.dimensions().size() > 0) {
        ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1.0));
        Generate generatedFunction = new Generate(expandDimensionsType, new GeneratorLambdaFunctionNode(expandDimensionsType, generatedExpression).asLongListToDoubleOperator());
        Join expand = new Join(TensorFunctionNode.wrapArgument(node), generatedFunction, ScalarFunctions.multiply());
        return new TensorFunctionNode(expand);
    }
    return node;
}
Also used : GeneratorLambdaFunctionNode(com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode) ConstantNode(com.yahoo.searchlib.rankingexpression.rule.ConstantNode) DoubleValue(com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue) TensorFunctionNode(com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode) ExpressionNode(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode) Generate(com.yahoo.tensor.functions.Generate) Join(com.yahoo.tensor.functions.Join) TensorType(com.yahoo.tensor.TensorType)

Example 2 with ConstantNode

use of com.yahoo.searchlib.rankingexpression.rule.ConstantNode in project vespa by vespa-engine.

the class Simplifier method transformArithmetic.

private ExpressionNode transformArithmetic(ArithmeticNode node) {
    if (node.children().size() > 1) {
        List<ExpressionNode> children = new ArrayList<>(node.children());
        List<ArithmeticOperator> operators = new ArrayList<>(node.operators());
        for (ArithmeticOperator operator : ArithmeticOperator.operatorsByPrecedence) transform(operator, children, operators);
        node = new ArithmeticNode(children, operators);
    }
    if (isConstant(node))
        return new ConstantNode(node.evaluate(null));
    else if (// disregarding the /0 case
    allMultiplicationOrDivision(node) && hasZero(node))
        return new ConstantNode(new DoubleValue(0));
    else
        return node;
}
Also used : ArithmeticNode(com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode) ConstantNode(com.yahoo.searchlib.rankingexpression.rule.ConstantNode) DoubleValue(com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue) ExpressionNode(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode) ArrayList(java.util.ArrayList) ArithmeticOperator(com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator)

Example 3 with ConstantNode

use of com.yahoo.searchlib.rankingexpression.rule.ConstantNode in project vespa by vespa-engine.

the class Mean method lazyGetFunction.

// todo: optimization: if keepDims and one reduce dimension that has size 1: same as identity.
@Override
protected TensorFunction lazyGetFunction() {
    if (!allInputTypesPresent(2)) {
        return null;
    }
    TensorFunction inputFunction = inputs.get(0).function().get();
    TensorFunction output = new Reduce(inputFunction, Reduce.Aggregator.avg, reduceDimensions);
    if (shouldKeepDimensions()) {
        // multiply with a generated tensor created from the reduced dimensions
        TensorType.Builder typeBuilder = new TensorType.Builder();
        for (String name : reduceDimensions) {
            typeBuilder.indexed(name, 1);
        }
        TensorType generatedType = typeBuilder.build();
        ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1));
        Generate generatedFunction = new Generate(generatedType, new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator());
        output = new com.yahoo.tensor.functions.Join(output, generatedFunction, ScalarFunctions.multiply());
    }
    return output;
}
Also used : GeneratorLambdaFunctionNode(com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode) OrderedTensorType(com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType) TensorType(com.yahoo.tensor.TensorType) Reduce(com.yahoo.tensor.functions.Reduce) ConstantNode(com.yahoo.searchlib.rankingexpression.rule.ConstantNode) TensorFunction(com.yahoo.tensor.functions.TensorFunction) DoubleValue(com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue) ExpressionNode(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode) Generate(com.yahoo.tensor.functions.Generate)

Example 4 with ConstantNode

use of com.yahoo.searchlib.rankingexpression.rule.ConstantNode in project vespa by vespa-engine.

the class Reshape method unrollTensorExpression.

private static ExpressionNode unrollTensorExpression(TensorType type) {
    if (type.rank() == 0) {
        return new ConstantNode(DoubleValue.zero);
    }
    List<ExpressionNode> children = new ArrayList<>();
    List<ArithmeticOperator> operators = new ArrayList<>();
    int size = 1;
    for (int i = type.dimensions().size() - 1; i >= 0; --i) {
        TensorType.Dimension dimension = type.dimensions().get(i);
        children.add(0, new ReferenceNode(dimension.name()));
        if (size > 1) {
            operators.add(0, ArithmeticOperator.MULTIPLY);
            children.add(0, new ConstantNode(new DoubleValue(size)));
        }
        size *= TensorConverter.dimensionSize(dimension);
        if (i > 0) {
            operators.add(0, ArithmeticOperator.PLUS);
        }
    }
    return new ArithmeticNode(children, operators);
}
Also used : ArithmeticNode(com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode) ConstantNode(com.yahoo.searchlib.rankingexpression.rule.ConstantNode) ReferenceNode(com.yahoo.searchlib.rankingexpression.rule.ReferenceNode) DoubleValue(com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue) ExpressionNode(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode) ArrayList(java.util.ArrayList) ArithmeticOperator(com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator) OrderedTensorType(com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType) TensorType(com.yahoo.tensor.TensorType)

Example 5 with ConstantNode

use of com.yahoo.searchlib.rankingexpression.rule.ConstantNode in project vespa by vespa-engine.

the class Simplifier method transform.

private void transform(ArithmeticOperator operator, List<ExpressionNode> children, List<ArithmeticOperator> operators) {
    int i = 0;
    while (i < children.size() - 1) {
        if (!operators.get(i).equals(operator)) {
            i++;
            continue;
        }
        ExpressionNode child1 = children.get(i);
        ExpressionNode child2 = children.get(i + 1);
        if (isConstant(child1) && isConstant(child2) && hasPrecedence(operators, i)) {
            Value evaluated = new ArithmeticNode(child1, operators.remove(i), child2).evaluate(null);
            children.set(i, new ConstantNode(evaluated.freeze()));
            children.remove(i + 1);
        } else {
            // try the next index
            i++;
        }
    }
}
Also used : ArithmeticNode(com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode) ConstantNode(com.yahoo.searchlib.rankingexpression.rule.ConstantNode) ExpressionNode(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode) DoubleValue(com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue) Value(com.yahoo.searchlib.rankingexpression.evaluation.Value)

Aggregations

DoubleValue (com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue)6 ConstantNode (com.yahoo.searchlib.rankingexpression.rule.ConstantNode)6 ExpressionNode (com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)6 TensorType (com.yahoo.tensor.TensorType)4 OrderedTensorType (com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType)3 ArithmeticNode (com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode)3 GeneratorLambdaFunctionNode (com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode)3 Generate (com.yahoo.tensor.functions.Generate)3 ArithmeticOperator (com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator)2 ArrayList (java.util.ArrayList)2 Value (com.yahoo.searchlib.rankingexpression.evaluation.Value)1 ReferenceNode (com.yahoo.searchlib.rankingexpression.rule.ReferenceNode)1 TensorFunctionNode (com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode)1 Join (com.yahoo.tensor.functions.Join)1 Reduce (com.yahoo.tensor.functions.Reduce)1 TensorFunction (com.yahoo.tensor.functions.TensorFunction)1