Search in sources :

Example 1 with ResolvedComputedColumn

use of com.google.zetasql.resolvedast.ResolvedNodes.ResolvedComputedColumn in project beam by apache.

the class AggregateScanConverter method convert.

@Override
public RelNode convert(ResolvedAggregateScan zetaNode, List<RelNode> inputs) {
    LogicalProject input = convertAggregateScanInputScanToLogicalProject(zetaNode, inputs.get(0));
    // Calcite LogicalAggregate's GroupSet is indexes of group fields starting from 0.
    int groupFieldsListSize = zetaNode.getGroupByList().size();
    ImmutableBitSet groupSet;
    if (groupFieldsListSize != 0) {
        groupSet = ImmutableBitSet.of(IntStream.rangeClosed(0, groupFieldsListSize - 1).boxed().collect(Collectors.toList()));
    } else {
        groupSet = ImmutableBitSet.of();
    }
    // TODO: add support for indicator
    List<AggregateCall> aggregateCalls;
    if (zetaNode.getAggregateList().isEmpty()) {
        aggregateCalls = ImmutableList.of();
    } else {
        aggregateCalls = new ArrayList<>();
        // For aggregate calls, their input ref follow after GROUP BY input ref.
        int columnRefoff = groupFieldsListSize;
        for (ResolvedComputedColumn computedColumn : zetaNode.getAggregateList()) {
            AggregateCall aggCall = convertAggCall(computedColumn, columnRefoff, groupSet.size(), input);
            aggregateCalls.add(aggCall);
            if (!aggCall.getArgList().isEmpty()) {
                // Only increment column reference offset when aggregates use them (BEAM-8042).
                // Ex: COUNT(*) does not have arguments, while COUNT(`field`) does.
                columnRefoff++;
            }
        }
    }
    LogicalAggregate logicalAggregate = new LogicalAggregate(getCluster(), input.getTraitSet(), input, groupSet, ImmutableList.of(groupSet), aggregateCalls);
    return logicalAggregate;
}
Also used : AggregateCall(org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.AggregateCall) LogicalAggregate(org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.logical.LogicalAggregate) ImmutableBitSet(org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.util.ImmutableBitSet) LogicalProject(org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.logical.LogicalProject) ResolvedComputedColumn(com.google.zetasql.resolvedast.ResolvedNodes.ResolvedComputedColumn)

Example 2 with ResolvedComputedColumn

use of com.google.zetasql.resolvedast.ResolvedNodes.ResolvedComputedColumn in project beam by apache.

the class AggregateScanConverter method convertAggCall.

private AggregateCall convertAggCall(ResolvedComputedColumn computedColumn, int columnRefOff, int groupCount, RelNode input) {
    ResolvedAggregateFunctionCall aggregateFunctionCall = (ResolvedAggregateFunctionCall) computedColumn.getExpr();
    // Reject AVG(INT64)
    if (aggregateFunctionCall.getFunction().getName().equals("avg")) {
        FunctionSignature signature = aggregateFunctionCall.getSignature();
        if (signature.getFunctionArgumentList().get(0).getType().getKind().equals(TypeKind.TYPE_INT64)) {
            throw new UnsupportedOperationException(AVG_ILLEGAL_LONG_INPUT_TYPE);
        }
    }
    // Reject aggregation DISTINCT
    if (aggregateFunctionCall.getDistinct()) {
        throw new UnsupportedOperationException("Does not support " + aggregateFunctionCall.getFunction().getSqlName() + " DISTINCT. 'SELECT DISTINCT' syntax could be used to deduplicate before" + " aggregation.");
    }
    final SqlAggFunction sqlAggFunction;
    if (aggregateFunctionCall.getFunction().getGroup().equals(BeamZetaSqlCatalog.USER_DEFINED_JAVA_AGGREGATE_FUNCTIONS)) {
        // Create a new operator for user-defined functions.
        SqlReturnTypeInference typeInference = x -> ZetaSqlCalciteTranslationUtils.toCalciteType(aggregateFunctionCall.getFunction().getSignatureList().get(0).getResultType().getType(), // TODO(BEAM-9514) set nullable=true
        false, getCluster().getRexBuilder());
        UdafImpl<?, ?, ?> impl = new UdafImpl<>(getExpressionConverter().userFunctionDefinitions.javaAggregateFunctions().get(aggregateFunctionCall.getFunction().getNamePath()));
        sqlAggFunction = SqlOperators.createUdafOperator(aggregateFunctionCall.getFunction().getName(), typeInference, impl);
    } else {
        // Look up builtin functions in SqlOperatorMappingTable.
        sqlAggFunction = (SqlAggFunction) SqlOperatorMappingTable.create(aggregateFunctionCall);
        if (sqlAggFunction == null) {
            throw new UnsupportedOperationException("Does not support ZetaSQL aggregate function: " + aggregateFunctionCall.getFunction().getName());
        }
    }
    List<Integer> argList = new ArrayList<>();
    ResolvedAggregateFunctionCall expr = ((ResolvedAggregateFunctionCall) computedColumn.getExpr());
    List<ZetaSQLResolvedNodeKind.ResolvedNodeKind> resolvedNodeKinds = Arrays.asList(RESOLVED_CAST, RESOLVED_COLUMN_REF, RESOLVED_GET_STRUCT_FIELD);
    for (int i = 0; i < expr.getArgumentList().size(); i++) {
        // Throw an error if aggregate function's input isn't either a ColumnRef or a cast(ColumnRef).
        // TODO: is there a general way to handle aggregation calls conversion?
        ZetaSQLResolvedNodeKind.ResolvedNodeKind resolvedNodeKind = expr.getArgumentList().get(i).nodeKind();
        if (i == 0 && resolvedNodeKinds.contains(resolvedNodeKind)) {
            argList.add(columnRefOff);
        } else if (i > 0 && resolvedNodeKind == RESOLVED_LITERAL) {
            continue;
        } else {
            throw new UnsupportedOperationException("Aggregate function only accepts Column Reference or CAST(Column Reference) as the first argument and " + "Literals as subsequent arguments as its inputs");
        }
    }
    String aggName = getTrait().resolveAlias(computedColumn.getColumn());
    return AggregateCall.create(sqlAggFunction, false, false, false, argList, -1, null, RelCollations.EMPTY, groupCount, input, // When we pass null as the return type, Calcite infers it for us.
    null, aggName);
}
Also used : ZetaSQLResolvedNodeKind(com.google.zetasql.ZetaSQLResolvedNodeKind) IntStream(java.util.stream.IntStream) AggregateCall(org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.AggregateCall) ZetaSqlCalciteTranslationUtils(org.apache.beam.sdk.extensions.sql.zetasql.ZetaSqlCalciteTranslationUtils) RelNode(org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.RelNode) Arrays(java.util.Arrays) ResolvedNode(com.google.zetasql.resolvedast.ResolvedNode) RexNode(org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rex.RexNode) ImmutableMap(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap) ResolvedAggregateFunctionCall(com.google.zetasql.resolvedast.ResolvedNodes.ResolvedAggregateFunctionCall) ResolvedAggregateScan(com.google.zetasql.resolvedast.ResolvedNodes.ResolvedAggregateScan) RelCollations(org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.RelCollations) ArrayList(java.util.ArrayList) ImmutableBitSet(org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.util.ImmutableBitSet) SqlAggFunction(org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.sql.SqlAggFunction) RESOLVED_COLUMN_REF(com.google.zetasql.ZetaSQLResolvedNodeKind.ResolvedNodeKind.RESOLVED_COLUMN_REF) LogicalProject(org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.logical.LogicalProject) FunctionSignature(com.google.zetasql.FunctionSignature) RESOLVED_CAST(com.google.zetasql.ZetaSQLResolvedNodeKind.ResolvedNodeKind.RESOLVED_CAST) RESOLVED_GET_STRUCT_FIELD(com.google.zetasql.ZetaSQLResolvedNodeKind.ResolvedNodeKind.RESOLVED_GET_STRUCT_FIELD) RESOLVED_LITERAL(com.google.zetasql.ZetaSQLResolvedNodeKind.ResolvedNodeKind.RESOLVED_LITERAL) Collectors(java.util.stream.Collectors) LogicalAggregate(org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.logical.LogicalAggregate) List(java.util.List) UdafImpl(org.apache.beam.sdk.extensions.sql.impl.UdafImpl) ResolvedComputedColumn(com.google.zetasql.resolvedast.ResolvedNodes.ResolvedComputedColumn) ResolvedExpr(com.google.zetasql.resolvedast.ResolvedNodes.ResolvedExpr) TypeKind(com.google.zetasql.ZetaSQLType.TypeKind) SqlReturnTypeInference(org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.sql.type.SqlReturnTypeInference) ImmutableList(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList) BeamZetaSqlCatalog(org.apache.beam.sdk.extensions.sql.zetasql.BeamZetaSqlCatalog) Collections(java.util.Collections) ArrayList(java.util.ArrayList) SqlAggFunction(org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.sql.SqlAggFunction) FunctionSignature(com.google.zetasql.FunctionSignature) UdafImpl(org.apache.beam.sdk.extensions.sql.impl.UdafImpl) ResolvedAggregateFunctionCall(com.google.zetasql.resolvedast.ResolvedNodes.ResolvedAggregateFunctionCall) SqlReturnTypeInference(org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.sql.type.SqlReturnTypeInference) ZetaSQLResolvedNodeKind(com.google.zetasql.ZetaSQLResolvedNodeKind) ZetaSQLResolvedNodeKind(com.google.zetasql.ZetaSQLResolvedNodeKind)

Example 3 with ResolvedComputedColumn

use of com.google.zetasql.resolvedast.ResolvedNodes.ResolvedComputedColumn in project beam by apache.

the class ExpressionConverter method retrieveRexNode.

/**
 * Extract expressions from a project scan node.
 */
public List<RexNode> retrieveRexNode(ResolvedProjectScan node, List<RelDataTypeField> fieldList) {
    List<RexNode> ret = new ArrayList<>();
    for (ResolvedColumn column : node.getColumnList()) {
        int index = -1;
        if ((index = indexOfResolvedColumnInExprList(node.getExprList(), column)) != -1) {
            ResolvedComputedColumn computedColumn = node.getExprList().get(index);
            int windowFieldIndex = -1;
            if (computedColumn.getExpr().nodeKind() == RESOLVED_FUNCTION_CALL) {
                String functionName = ((ResolvedFunctionCall) computedColumn.getExpr()).getFunction().getName();
                if (WINDOW_START_END_FUNCTION_SET.contains(functionName)) {
                    ResolvedAggregateScan resolvedAggregateScan = (ResolvedAggregateScan) node.getInputScan();
                    windowFieldIndex = indexOfWindowField(resolvedAggregateScan.getGroupByList(), resolvedAggregateScan.getColumnList(), WINDOW_START_END_TO_WINDOW_MAP.get(functionName));
                }
            }
            ret.add(convertRexNodeFromComputedColumnWithFieldList(computedColumn, node.getInputScan().getColumnList(), fieldList, windowFieldIndex));
        } else {
            // ResolvedColumn is not a expression, which means it has to be an input column reference.
            index = indexOfProjectionColumnRef(column.getId(), node.getInputScan().getColumnList());
            if (index < 0 || index >= node.getInputScan().getColumnList().size()) {
                throw new IllegalStateException(String.format("Cannot find %s in fieldList %s", column, fieldList));
            }
            ret.add(rexBuilder().makeInputRef(fieldList.get(index).getType(), index));
        }
    }
    return ret;
}
Also used : ResolvedAggregateScan(com.google.zetasql.resolvedast.ResolvedNodes.ResolvedAggregateScan) ArrayList(java.util.ArrayList) ResolvedColumn(com.google.zetasql.resolvedast.ResolvedColumn) ResolvedComputedColumn(com.google.zetasql.resolvedast.ResolvedNodes.ResolvedComputedColumn) RexNode(org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rex.RexNode)

Example 4 with ResolvedComputedColumn

use of com.google.zetasql.resolvedast.ResolvedNodes.ResolvedComputedColumn in project beam by apache.

the class AggregateScanConverter method convertAggregateScanInputScanToLogicalProject.

private LogicalProject convertAggregateScanInputScanToLogicalProject(ResolvedAggregateScan node, RelNode input) {
    // AggregateScan's input is the source of data (e.g. TableScan), which is different from the
    // design of CalciteSQL, in which the LogicalAggregate's input is a LogicalProject, whose input
    // is a LogicalTableScan. When AggregateScan's input is WithRefScan, the WithRefScan is
    // ebullient to a LogicalTableScan. So it's still required to build another LogicalProject as
    // the input of LogicalAggregate.
    List<RexNode> projects = new ArrayList<>();
    List<String> fieldNames = new ArrayList<>();
    // LogicalAggregate.
    for (ResolvedComputedColumn computedColumn : node.getGroupByList()) {
        projects.add(getExpressionConverter().convertRexNodeFromResolvedExpr(computedColumn.getExpr(), node.getInputScan().getColumnList(), input.getRowType().getFieldList(), ImmutableMap.of()));
        fieldNames.add(getTrait().resolveAlias(computedColumn.getColumn()));
    }
    // TODO: remove duplicate columns in projects.
    for (ResolvedComputedColumn resolvedComputedColumn : node.getAggregateList()) {
        // Should create Calcite's RexInputRef from ResolvedColumn from ResolvedComputedColumn.
        // TODO: handle aggregate function with more than one argument and handle OVER
        // TODO: is there is general way for column reference tracking and deduplication for
        // aggregation?
        ResolvedAggregateFunctionCall aggregateFunctionCall = ((ResolvedAggregateFunctionCall) resolvedComputedColumn.getExpr());
        if (aggregateFunctionCall.getArgumentList() != null && aggregateFunctionCall.getArgumentList().size() >= 1) {
            ResolvedExpr resolvedExpr = aggregateFunctionCall.getArgumentList().get(0);
            for (int i = 0; i < aggregateFunctionCall.getArgumentList().size(); i++) {
                if (i == 0) {
                    // TODO: assume aggregate function's input is either a ColumnRef or a cast(ColumnRef).
                    // TODO: user might use multiple CAST so we need to handle this rare case.
                    projects.add(getExpressionConverter().convertRexNodeFromResolvedExpr(resolvedExpr, node.getInputScan().getColumnList(), input.getRowType().getFieldList(), ImmutableMap.of()));
                } else {
                    projects.add(getExpressionConverter().convertRexNodeFromResolvedExpr(aggregateFunctionCall.getArgumentList().get(i)));
                }
                fieldNames.add(getTrait().resolveAlias(resolvedComputedColumn.getColumn()));
            }
        }
    }
    return LogicalProject.create(input, ImmutableList.of(), projects, fieldNames);
}
Also used : ResolvedAggregateFunctionCall(com.google.zetasql.resolvedast.ResolvedNodes.ResolvedAggregateFunctionCall) ArrayList(java.util.ArrayList) ResolvedExpr(com.google.zetasql.resolvedast.ResolvedNodes.ResolvedExpr) ResolvedComputedColumn(com.google.zetasql.resolvedast.ResolvedNodes.ResolvedComputedColumn) RexNode(org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rex.RexNode)

Aggregations

ResolvedComputedColumn (com.google.zetasql.resolvedast.ResolvedNodes.ResolvedComputedColumn)4 ArrayList (java.util.ArrayList)3 RexNode (org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rex.RexNode)3 ResolvedAggregateFunctionCall (com.google.zetasql.resolvedast.ResolvedNodes.ResolvedAggregateFunctionCall)2 ResolvedAggregateScan (com.google.zetasql.resolvedast.ResolvedNodes.ResolvedAggregateScan)2 ResolvedExpr (com.google.zetasql.resolvedast.ResolvedNodes.ResolvedExpr)2 AggregateCall (org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.AggregateCall)2 LogicalAggregate (org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.logical.LogicalAggregate)2 LogicalProject (org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.logical.LogicalProject)2 ImmutableBitSet (org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.util.ImmutableBitSet)2 FunctionSignature (com.google.zetasql.FunctionSignature)1 ZetaSQLResolvedNodeKind (com.google.zetasql.ZetaSQLResolvedNodeKind)1 RESOLVED_CAST (com.google.zetasql.ZetaSQLResolvedNodeKind.ResolvedNodeKind.RESOLVED_CAST)1 RESOLVED_COLUMN_REF (com.google.zetasql.ZetaSQLResolvedNodeKind.ResolvedNodeKind.RESOLVED_COLUMN_REF)1 RESOLVED_GET_STRUCT_FIELD (com.google.zetasql.ZetaSQLResolvedNodeKind.ResolvedNodeKind.RESOLVED_GET_STRUCT_FIELD)1 RESOLVED_LITERAL (com.google.zetasql.ZetaSQLResolvedNodeKind.ResolvedNodeKind.RESOLVED_LITERAL)1 TypeKind (com.google.zetasql.ZetaSQLType.TypeKind)1 ResolvedColumn (com.google.zetasql.resolvedast.ResolvedColumn)1 ResolvedNode (com.google.zetasql.resolvedast.ResolvedNode)1 Arrays (java.util.Arrays)1