Search in sources :

Example 6 with ConstantNode

use of com.yahoo.searchlib.rankingexpression.rule.ConstantNode in project vespa by vespa-engine.

the class ExpandDims method lazyGetFunction.

@Override
protected TensorFunction lazyGetFunction() {
    if (!allInputFunctionsPresent(2)) {
        return null;
    }
    // multiply with a generated tensor created from the reduced dimensions
    TensorType.Builder typeBuilder = new TensorType.Builder();
    for (String name : expandDimensions) {
        typeBuilder.indexed(name, 1);
    }
    TensorType generatedType = typeBuilder.build();
    ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1));
    Generate generatedFunction = new Generate(generatedType, new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator());
    return new com.yahoo.tensor.functions.Join(inputs().get(0).function().get(), generatedFunction, ScalarFunctions.multiply());
}
Also used : GeneratorLambdaFunctionNode(com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode) ConstantNode(com.yahoo.searchlib.rankingexpression.rule.ConstantNode) DoubleValue(com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue) ExpressionNode(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode) Generate(com.yahoo.tensor.functions.Generate) OrderedTensorType(com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType) TensorType(com.yahoo.tensor.TensorType)

Aggregations

DoubleValue (com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue)6 ConstantNode (com.yahoo.searchlib.rankingexpression.rule.ConstantNode)6 ExpressionNode (com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)6 TensorType (com.yahoo.tensor.TensorType)4 OrderedTensorType (com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType)3 ArithmeticNode (com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode)3 GeneratorLambdaFunctionNode (com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode)3 Generate (com.yahoo.tensor.functions.Generate)3 ArithmeticOperator (com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator)2 ArrayList (java.util.ArrayList)2 Value (com.yahoo.searchlib.rankingexpression.evaluation.Value)1 ReferenceNode (com.yahoo.searchlib.rankingexpression.rule.ReferenceNode)1 TensorFunctionNode (com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode)1 Join (com.yahoo.tensor.functions.Join)1 Reduce (com.yahoo.tensor.functions.Reduce)1 TensorFunction (com.yahoo.tensor.functions.TensorFunction)1