Search in sources :

Example 1 with ComparisonNode

use of com.yahoo.searchlib.rankingexpression.rule.ComparisonNode in project vespa by vespa-engine.

the class Reshape method reshape.

public static TensorFunction reshape(TensorFunction inputFunction, TensorType inputType, TensorType outputType) {
    if (!tensorSize(inputType).equals(tensorSize(outputType))) {
        throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping");
    }
    // Conceptually, reshaping consists on unrolling a tensor to an array using the dimension order,
    // then use the dimension order of the new shape to roll back into a tensor.
    // Here we create a transformation tensor that is multiplied with the from tensor to map into
    // the new shape. We have to introduce temporary dimension names and rename back if dimension names
    // in the new and old tensor type overlap.
    ExpressionNode unrollFrom = unrollTensorExpression(inputType);
    ExpressionNode unrollTo = unrollTensorExpression(outputType);
    ExpressionNode transformExpression = new ComparisonNode(unrollFrom, TruthOperator.EQUAL, unrollTo);
    TensorType transformationType = new TensorType.Builder(inputType, outputType).build();
    Generate transformTensor = new Generate(transformationType, new GeneratorLambdaFunctionNode(transformationType, transformExpression).asLongListToDoubleOperator());
    TensorFunction outputFunction = new Reduce(new com.yahoo.tensor.functions.Join(inputFunction, transformTensor, ScalarFunctions.multiply()), Reduce.Aggregator.sum, inputType.dimensions().stream().map(TensorType.Dimension::name).collect(Collectors.toList()));
    return outputFunction;
}
Also used : GeneratorLambdaFunctionNode(com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode) OrderedTensorType(com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType) TensorType(com.yahoo.tensor.TensorType) Reduce(com.yahoo.tensor.functions.Reduce) ComparisonNode(com.yahoo.searchlib.rankingexpression.rule.ComparisonNode) TensorFunction(com.yahoo.tensor.functions.TensorFunction) ExpressionNode(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode) Generate(com.yahoo.tensor.functions.Generate)

Aggregations

OrderedTensorType (com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType)1 ComparisonNode (com.yahoo.searchlib.rankingexpression.rule.ComparisonNode)1 ExpressionNode (com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)1 GeneratorLambdaFunctionNode (com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode)1 TensorType (com.yahoo.tensor.TensorType)1 Generate (com.yahoo.tensor.functions.Generate)1 Reduce (com.yahoo.tensor.functions.Reduce)1 TensorFunction (com.yahoo.tensor.functions.TensorFunction)1