use of com.google.zetasql.resolvedast.ResolvedNodes.ResolvedComputedColumn in project beam by apache.
the class AggregateScanConverter method convert.
@Override
public RelNode convert(ResolvedAggregateScan zetaNode, List<RelNode> inputs) {
LogicalProject input = convertAggregateScanInputScanToLogicalProject(zetaNode, inputs.get(0));
// Calcite LogicalAggregate's GroupSet is indexes of group fields starting from 0.
int groupFieldsListSize = zetaNode.getGroupByList().size();
ImmutableBitSet groupSet;
if (groupFieldsListSize != 0) {
groupSet = ImmutableBitSet.of(IntStream.rangeClosed(0, groupFieldsListSize - 1).boxed().collect(Collectors.toList()));
} else {
groupSet = ImmutableBitSet.of();
}
// TODO: add support for indicator
List<AggregateCall> aggregateCalls;
if (zetaNode.getAggregateList().isEmpty()) {
aggregateCalls = ImmutableList.of();
} else {
aggregateCalls = new ArrayList<>();
// For aggregate calls, their input ref follow after GROUP BY input ref.
int columnRefoff = groupFieldsListSize;
for (ResolvedComputedColumn computedColumn : zetaNode.getAggregateList()) {
AggregateCall aggCall = convertAggCall(computedColumn, columnRefoff, groupSet.size(), input);
aggregateCalls.add(aggCall);
if (!aggCall.getArgList().isEmpty()) {
// Only increment column reference offset when aggregates use them (BEAM-8042).
// Ex: COUNT(*) does not have arguments, while COUNT(`field`) does.
columnRefoff++;
}
}
}
LogicalAggregate logicalAggregate = new LogicalAggregate(getCluster(), input.getTraitSet(), input, groupSet, ImmutableList.of(groupSet), aggregateCalls);
return logicalAggregate;
}
use of com.google.zetasql.resolvedast.ResolvedNodes.ResolvedComputedColumn in project beam by apache.
the class AggregateScanConverter method convertAggCall.
private AggregateCall convertAggCall(ResolvedComputedColumn computedColumn, int columnRefOff, int groupCount, RelNode input) {
ResolvedAggregateFunctionCall aggregateFunctionCall = (ResolvedAggregateFunctionCall) computedColumn.getExpr();
// Reject AVG(INT64)
if (aggregateFunctionCall.getFunction().getName().equals("avg")) {
FunctionSignature signature = aggregateFunctionCall.getSignature();
if (signature.getFunctionArgumentList().get(0).getType().getKind().equals(TypeKind.TYPE_INT64)) {
throw new UnsupportedOperationException(AVG_ILLEGAL_LONG_INPUT_TYPE);
}
}
// Reject aggregation DISTINCT
if (aggregateFunctionCall.getDistinct()) {
throw new UnsupportedOperationException("Does not support " + aggregateFunctionCall.getFunction().getSqlName() + " DISTINCT. 'SELECT DISTINCT' syntax could be used to deduplicate before" + " aggregation.");
}
final SqlAggFunction sqlAggFunction;
if (aggregateFunctionCall.getFunction().getGroup().equals(BeamZetaSqlCatalog.USER_DEFINED_JAVA_AGGREGATE_FUNCTIONS)) {
// Create a new operator for user-defined functions.
SqlReturnTypeInference typeInference = x -> ZetaSqlCalciteTranslationUtils.toCalciteType(aggregateFunctionCall.getFunction().getSignatureList().get(0).getResultType().getType(), // TODO(BEAM-9514) set nullable=true
false, getCluster().getRexBuilder());
UdafImpl<?, ?, ?> impl = new UdafImpl<>(getExpressionConverter().userFunctionDefinitions.javaAggregateFunctions().get(aggregateFunctionCall.getFunction().getNamePath()));
sqlAggFunction = SqlOperators.createUdafOperator(aggregateFunctionCall.getFunction().getName(), typeInference, impl);
} else {
// Look up builtin functions in SqlOperatorMappingTable.
sqlAggFunction = (SqlAggFunction) SqlOperatorMappingTable.create(aggregateFunctionCall);
if (sqlAggFunction == null) {
throw new UnsupportedOperationException("Does not support ZetaSQL aggregate function: " + aggregateFunctionCall.getFunction().getName());
}
}
List<Integer> argList = new ArrayList<>();
ResolvedAggregateFunctionCall expr = ((ResolvedAggregateFunctionCall) computedColumn.getExpr());
List<ZetaSQLResolvedNodeKind.ResolvedNodeKind> resolvedNodeKinds = Arrays.asList(RESOLVED_CAST, RESOLVED_COLUMN_REF, RESOLVED_GET_STRUCT_FIELD);
for (int i = 0; i < expr.getArgumentList().size(); i++) {
// Throw an error if aggregate function's input isn't either a ColumnRef or a cast(ColumnRef).
// TODO: is there a general way to handle aggregation calls conversion?
ZetaSQLResolvedNodeKind.ResolvedNodeKind resolvedNodeKind = expr.getArgumentList().get(i).nodeKind();
if (i == 0 && resolvedNodeKinds.contains(resolvedNodeKind)) {
argList.add(columnRefOff);
} else if (i > 0 && resolvedNodeKind == RESOLVED_LITERAL) {
continue;
} else {
throw new UnsupportedOperationException("Aggregate function only accepts Column Reference or CAST(Column Reference) as the first argument and " + "Literals as subsequent arguments as its inputs");
}
}
String aggName = getTrait().resolveAlias(computedColumn.getColumn());
return AggregateCall.create(sqlAggFunction, false, false, false, argList, -1, null, RelCollations.EMPTY, groupCount, input, // When we pass null as the return type, Calcite infers it for us.
null, aggName);
}
use of com.google.zetasql.resolvedast.ResolvedNodes.ResolvedComputedColumn in project beam by apache.
the class ExpressionConverter method retrieveRexNode.
/**
* Extract expressions from a project scan node.
*/
public List<RexNode> retrieveRexNode(ResolvedProjectScan node, List<RelDataTypeField> fieldList) {
List<RexNode> ret = new ArrayList<>();
for (ResolvedColumn column : node.getColumnList()) {
int index = -1;
if ((index = indexOfResolvedColumnInExprList(node.getExprList(), column)) != -1) {
ResolvedComputedColumn computedColumn = node.getExprList().get(index);
int windowFieldIndex = -1;
if (computedColumn.getExpr().nodeKind() == RESOLVED_FUNCTION_CALL) {
String functionName = ((ResolvedFunctionCall) computedColumn.getExpr()).getFunction().getName();
if (WINDOW_START_END_FUNCTION_SET.contains(functionName)) {
ResolvedAggregateScan resolvedAggregateScan = (ResolvedAggregateScan) node.getInputScan();
windowFieldIndex = indexOfWindowField(resolvedAggregateScan.getGroupByList(), resolvedAggregateScan.getColumnList(), WINDOW_START_END_TO_WINDOW_MAP.get(functionName));
}
}
ret.add(convertRexNodeFromComputedColumnWithFieldList(computedColumn, node.getInputScan().getColumnList(), fieldList, windowFieldIndex));
} else {
// ResolvedColumn is not a expression, which means it has to be an input column reference.
index = indexOfProjectionColumnRef(column.getId(), node.getInputScan().getColumnList());
if (index < 0 || index >= node.getInputScan().getColumnList().size()) {
throw new IllegalStateException(String.format("Cannot find %s in fieldList %s", column, fieldList));
}
ret.add(rexBuilder().makeInputRef(fieldList.get(index).getType(), index));
}
}
return ret;
}
use of com.google.zetasql.resolvedast.ResolvedNodes.ResolvedComputedColumn in project beam by apache.
the class AggregateScanConverter method convertAggregateScanInputScanToLogicalProject.
private LogicalProject convertAggregateScanInputScanToLogicalProject(ResolvedAggregateScan node, RelNode input) {
// AggregateScan's input is the source of data (e.g. TableScan), which is different from the
// design of CalciteSQL, in which the LogicalAggregate's input is a LogicalProject, whose input
// is a LogicalTableScan. When AggregateScan's input is WithRefScan, the WithRefScan is
// ebullient to a LogicalTableScan. So it's still required to build another LogicalProject as
// the input of LogicalAggregate.
List<RexNode> projects = new ArrayList<>();
List<String> fieldNames = new ArrayList<>();
// LogicalAggregate.
for (ResolvedComputedColumn computedColumn : node.getGroupByList()) {
projects.add(getExpressionConverter().convertRexNodeFromResolvedExpr(computedColumn.getExpr(), node.getInputScan().getColumnList(), input.getRowType().getFieldList(), ImmutableMap.of()));
fieldNames.add(getTrait().resolveAlias(computedColumn.getColumn()));
}
// TODO: remove duplicate columns in projects.
for (ResolvedComputedColumn resolvedComputedColumn : node.getAggregateList()) {
// Should create Calcite's RexInputRef from ResolvedColumn from ResolvedComputedColumn.
// TODO: handle aggregate function with more than one argument and handle OVER
// TODO: is there is general way for column reference tracking and deduplication for
// aggregation?
ResolvedAggregateFunctionCall aggregateFunctionCall = ((ResolvedAggregateFunctionCall) resolvedComputedColumn.getExpr());
if (aggregateFunctionCall.getArgumentList() != null && aggregateFunctionCall.getArgumentList().size() >= 1) {
ResolvedExpr resolvedExpr = aggregateFunctionCall.getArgumentList().get(0);
for (int i = 0; i < aggregateFunctionCall.getArgumentList().size(); i++) {
if (i == 0) {
// TODO: assume aggregate function's input is either a ColumnRef or a cast(ColumnRef).
// TODO: user might use multiple CAST so we need to handle this rare case.
projects.add(getExpressionConverter().convertRexNodeFromResolvedExpr(resolvedExpr, node.getInputScan().getColumnList(), input.getRowType().getFieldList(), ImmutableMap.of()));
} else {
projects.add(getExpressionConverter().convertRexNodeFromResolvedExpr(aggregateFunctionCall.getArgumentList().get(i)));
}
fieldNames.add(getTrait().resolveAlias(resolvedComputedColumn.getColumn()));
}
}
}
return LogicalProject.create(input, ImmutableList.of(), projects, fieldNames);
}
Aggregations