use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.sql.SqlAggFunction in project flink by apache.
the class CommonPythonUtil method extractPythonAggregateFunctionInfosFromAggregateCall.
public static Tuple2<int[], PythonFunctionInfo[]> extractPythonAggregateFunctionInfosFromAggregateCall(AggregateCall[] aggCalls) {
Map<Integer, Integer> inputNodes = new LinkedHashMap<>();
List<PythonFunctionInfo> pythonFunctionInfos = new ArrayList<>();
for (AggregateCall aggregateCall : aggCalls) {
List<Integer> inputs = new ArrayList<>();
List<Integer> argList = aggregateCall.getArgList();
for (Integer arg : argList) {
if (inputNodes.containsKey(arg)) {
inputs.add(inputNodes.get(arg));
} else {
Integer inputOffset = inputNodes.size();
inputs.add(inputOffset);
inputNodes.put(arg, inputOffset);
}
}
PythonFunction pythonFunction = null;
SqlAggFunction aggregateFunction = aggregateCall.getAggregation();
if (aggregateFunction instanceof AggSqlFunction) {
pythonFunction = (PythonFunction) ((AggSqlFunction) aggregateFunction).aggregateFunction();
} else if (aggregateFunction instanceof BridgingSqlAggFunction) {
pythonFunction = (PythonFunction) ((BridgingSqlAggFunction) aggregateFunction).getDefinition();
}
PythonFunctionInfo pythonFunctionInfo = new PythonAggregateFunctionInfo(pythonFunction, inputs.toArray(), aggregateCall.filterArg, aggregateCall.isDistinct());
pythonFunctionInfos.add(pythonFunctionInfo);
}
int[] udafInputOffsets = inputNodes.keySet().stream().mapToInt(i -> i).toArray();
return Tuple2.of(udafInputOffsets, pythonFunctionInfos.toArray(new PythonFunctionInfo[0]));
}
use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.sql.SqlAggFunction in project flink by apache.
the class SqlValidatorImpl method validatePivot.
public void validatePivot(SqlPivot pivot) {
final PivotScope scope = (PivotScope) getJoinScope(pivot);
final PivotNamespace ns = getNamespace(pivot).unwrap(PivotNamespace.class);
assert ns.rowType == null;
// Given
// query PIVOT (agg1 AS a, agg2 AS b, ...
// FOR (axis1, ..., axisN)
// IN ((v11, ..., v1N) AS label1,
// (v21, ..., v2N) AS label2, ...))
// the type is
// k1, ... kN, a_label1, b_label1, ..., a_label2, b_label2, ...
// where k1, ... kN are columns that are not referenced as an argument to
// an aggregate or as an axis.
// Aggregates, e.g. "PIVOT (sum(x) AS sum_x, count(*) AS c)"
final List<Pair<String, RelDataType>> aggNames = new ArrayList<>();
pivot.forEachAgg((alias, call) -> {
call.validate(this, scope);
final RelDataType type = deriveType(scope, call);
aggNames.add(Pair.of(alias, type));
if (!(call instanceof SqlCall) || !(((SqlCall) call).getOperator() instanceof SqlAggFunction)) {
throw newValidationError(call, RESOURCE.pivotAggMalformed());
}
});
// Axes, e.g. "FOR (JOB, DEPTNO)"
final List<RelDataType> axisTypes = new ArrayList<>();
final List<SqlIdentifier> axisIdentifiers = new ArrayList<>();
for (SqlNode axis : pivot.axisList) {
SqlIdentifier identifier = (SqlIdentifier) axis;
identifier.validate(this, scope);
final RelDataType type = deriveType(scope, identifier);
axisTypes.add(type);
axisIdentifiers.add(identifier);
}
// Columns that have been seen as arguments to aggregates or as axes
// do not appear in the output.
final Set<String> columnNames = pivot.usedColumnNames();
final RelDataTypeFactory.Builder typeBuilder = typeFactory.builder();
scope.getChild().getRowType().getFieldList().forEach(field -> {
if (!columnNames.contains(field.getName())) {
typeBuilder.add(field);
}
});
// Values, e.g. "IN (('CLERK', 10) AS c10, ('MANAGER, 20) AS m20)"
pivot.forEachNameValues((alias, nodeList) -> {
if (nodeList.size() != axisTypes.size()) {
throw newValidationError(nodeList, RESOURCE.pivotValueArityMismatch(nodeList.size(), axisTypes.size()));
}
final SqlOperandTypeChecker typeChecker = OperandTypes.COMPARABLE_UNORDERED_COMPARABLE_UNORDERED;
Pair.forEach(axisIdentifiers, nodeList, (identifier, subNode) -> {
subNode.validate(this, scope);
typeChecker.checkOperandTypes(new SqlCallBinding(this, scope, SqlStdOperatorTable.EQUALS.createCall(subNode.getParserPosition(), identifier, subNode)), true);
});
Pair.forEach(aggNames, (aggAlias, aggType) -> typeBuilder.add(aggAlias == null ? alias : alias + "_" + aggAlias, aggType));
});
final RelDataType rowType = typeBuilder.build();
ns.setType(rowType);
}
use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.sql.SqlAggFunction in project flink by apache.
the class SqlValidatorImpl method validateAggregateParams.
public void validateAggregateParams(SqlCall aggCall, SqlNode filter, SqlNodeList orderList, SqlValidatorScope scope) {
// For "agg(expr)", expr cannot itself contain aggregate function
// invocations. For example, "SUM(2 * MAX(x))" is illegal; when
// we see it, we'll report the error for the SUM (not the MAX).
// For more than one level of nesting, the error which results
// depends on the traversal order for validation.
//
// For a windowed aggregate "agg(expr)", expr can contain an aggregate
// function. For example,
// SELECT AVG(2 * MAX(x)) OVER (PARTITION BY y)
// FROM t
// GROUP BY y
// is legal. Only one level of nesting is allowed since non-windowed
// aggregates cannot nest aggregates.
// Store nesting level of each aggregate. If an aggregate is found at an invalid
// nesting level, throw an assert.
final AggFinder a;
if (inWindow) {
a = overFinder;
} else {
a = aggOrOverFinder;
}
for (SqlNode param : aggCall.getOperandList()) {
if (a.findAgg(param) != null) {
throw newValidationError(aggCall, RESOURCE.nestedAggIllegal());
}
}
if (filter != null) {
if (a.findAgg(filter) != null) {
throw newValidationError(filter, RESOURCE.aggregateInFilterIllegal());
}
}
if (orderList != null) {
for (SqlNode param : orderList) {
if (a.findAgg(param) != null) {
throw newValidationError(aggCall, RESOURCE.aggregateInWithinGroupIllegal());
}
}
}
final SqlAggFunction op = (SqlAggFunction) aggCall.getOperator();
switch(op.requiresGroupOrder()) {
case MANDATORY:
if (orderList == null || orderList.size() == 0) {
throw newValidationError(aggCall, RESOURCE.aggregateMissingWithinGroupClause(op.getName()));
}
break;
case OPTIONAL:
break;
case IGNORED:
// rewrite the order list to empty
if (orderList != null) {
orderList.getList().clear();
}
break;
case FORBIDDEN:
if (orderList != null && orderList.size() != 0) {
throw newValidationError(aggCall, RESOURCE.withinGroupClauseIllegalInAggregate(op.getName()));
}
break;
default:
throw new AssertionError(op);
}
}
use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.sql.SqlAggFunction in project beam by apache.
the class AggregateScanConverter method convertAggCall.
private AggregateCall convertAggCall(ResolvedComputedColumn computedColumn, int columnRefOff, int groupCount, RelNode input) {
ResolvedAggregateFunctionCall aggregateFunctionCall = (ResolvedAggregateFunctionCall) computedColumn.getExpr();
// Reject AVG(INT64)
if (aggregateFunctionCall.getFunction().getName().equals("avg")) {
FunctionSignature signature = aggregateFunctionCall.getSignature();
if (signature.getFunctionArgumentList().get(0).getType().getKind().equals(TypeKind.TYPE_INT64)) {
throw new UnsupportedOperationException(AVG_ILLEGAL_LONG_INPUT_TYPE);
}
}
// Reject aggregation DISTINCT
if (aggregateFunctionCall.getDistinct()) {
throw new UnsupportedOperationException("Does not support " + aggregateFunctionCall.getFunction().getSqlName() + " DISTINCT. 'SELECT DISTINCT' syntax could be used to deduplicate before" + " aggregation.");
}
final SqlAggFunction sqlAggFunction;
if (aggregateFunctionCall.getFunction().getGroup().equals(BeamZetaSqlCatalog.USER_DEFINED_JAVA_AGGREGATE_FUNCTIONS)) {
// Create a new operator for user-defined functions.
SqlReturnTypeInference typeInference = x -> ZetaSqlCalciteTranslationUtils.toCalciteType(aggregateFunctionCall.getFunction().getSignatureList().get(0).getResultType().getType(), // TODO(BEAM-9514) set nullable=true
false, getCluster().getRexBuilder());
UdafImpl<?, ?, ?> impl = new UdafImpl<>(getExpressionConverter().userFunctionDefinitions.javaAggregateFunctions().get(aggregateFunctionCall.getFunction().getNamePath()));
sqlAggFunction = SqlOperators.createUdafOperator(aggregateFunctionCall.getFunction().getName(), typeInference, impl);
} else {
// Look up builtin functions in SqlOperatorMappingTable.
sqlAggFunction = (SqlAggFunction) SqlOperatorMappingTable.create(aggregateFunctionCall);
if (sqlAggFunction == null) {
throw new UnsupportedOperationException("Does not support ZetaSQL aggregate function: " + aggregateFunctionCall.getFunction().getName());
}
}
List<Integer> argList = new ArrayList<>();
ResolvedAggregateFunctionCall expr = ((ResolvedAggregateFunctionCall) computedColumn.getExpr());
List<ZetaSQLResolvedNodeKind.ResolvedNodeKind> resolvedNodeKinds = Arrays.asList(RESOLVED_CAST, RESOLVED_COLUMN_REF, RESOLVED_GET_STRUCT_FIELD);
for (int i = 0; i < expr.getArgumentList().size(); i++) {
// Throw an error if aggregate function's input isn't either a ColumnRef or a cast(ColumnRef).
// TODO: is there a general way to handle aggregation calls conversion?
ZetaSQLResolvedNodeKind.ResolvedNodeKind resolvedNodeKind = expr.getArgumentList().get(i).nodeKind();
if (i == 0 && resolvedNodeKinds.contains(resolvedNodeKind)) {
argList.add(columnRefOff);
} else if (i > 0 && resolvedNodeKind == RESOLVED_LITERAL) {
continue;
} else {
throw new UnsupportedOperationException("Aggregate function only accepts Column Reference or CAST(Column Reference) as the first argument and " + "Literals as subsequent arguments as its inputs");
}
}
String aggName = getTrait().resolveAlias(computedColumn.getColumn());
return AggregateCall.create(sqlAggFunction, false, false, false, argList, -1, null, RelCollations.EMPTY, groupCount, input, // When we pass null as the return type, Calcite infers it for us.
null, aggName);
}
use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.sql.SqlAggFunction in project hive by apache.
the class HiveCalciteUtil method createSingleArgAggCall.
public static AggregateCall createSingleArgAggCall(String funcName, RelOptCluster cluster, PrimitiveTypeInfo typeInfo, Integer pos, RelDataType aggFnRetType) {
ImmutableList.Builder<RelDataType> aggArgRelDTBldr = new ImmutableList.Builder<RelDataType>();
aggArgRelDTBldr.add(TypeConverter.convert(typeInfo, cluster.getTypeFactory()));
SqlAggFunction aggFunction = SqlFunctionConverter.getCalciteAggFn(funcName, false, aggArgRelDTBldr.build(), aggFnRetType);
List<Integer> argList = new ArrayList<Integer>();
argList.add(pos);
return AggregateCall.create(aggFunction, false, argList, -1, aggFnRetType, null);
}
Aggregations