use of org.apache.calcite.rel.core.AggregateCall in project beam by apache.
the class AggregateScanConverter method convertAggCall.
private AggregateCall convertAggCall(ResolvedComputedColumn computedColumn, int columnRefOff, int groupCount, RelNode input) {
ResolvedAggregateFunctionCall aggregateFunctionCall = (ResolvedAggregateFunctionCall) computedColumn.getExpr();
// Reject AVG(INT64)
if (aggregateFunctionCall.getFunction().getName().equals("avg")) {
FunctionSignature signature = aggregateFunctionCall.getSignature();
if (signature.getFunctionArgumentList().get(0).getType().getKind().equals(TypeKind.TYPE_INT64)) {
throw new UnsupportedOperationException(AVG_ILLEGAL_LONG_INPUT_TYPE);
}
}
// Reject aggregation DISTINCT
if (aggregateFunctionCall.getDistinct()) {
throw new UnsupportedOperationException("Does not support " + aggregateFunctionCall.getFunction().getSqlName() + " DISTINCT. 'SELECT DISTINCT' syntax could be used to deduplicate before" + " aggregation.");
}
final SqlAggFunction sqlAggFunction;
if (aggregateFunctionCall.getFunction().getGroup().equals(BeamZetaSqlCatalog.USER_DEFINED_JAVA_AGGREGATE_FUNCTIONS)) {
// Create a new operator for user-defined functions.
SqlReturnTypeInference typeInference = x -> ZetaSqlCalciteTranslationUtils.toCalciteType(aggregateFunctionCall.getFunction().getSignatureList().get(0).getResultType().getType(), // TODO(BEAM-9514) set nullable=true
false, getCluster().getRexBuilder());
UdafImpl<?, ?, ?> impl = new UdafImpl<>(getExpressionConverter().userFunctionDefinitions.javaAggregateFunctions().get(aggregateFunctionCall.getFunction().getNamePath()));
sqlAggFunction = SqlOperators.createUdafOperator(aggregateFunctionCall.getFunction().getName(), typeInference, impl);
} else {
// Look up builtin functions in SqlOperatorMappingTable.
sqlAggFunction = (SqlAggFunction) SqlOperatorMappingTable.create(aggregateFunctionCall);
if (sqlAggFunction == null) {
throw new UnsupportedOperationException("Does not support ZetaSQL aggregate function: " + aggregateFunctionCall.getFunction().getName());
}
}
List<Integer> argList = new ArrayList<>();
ResolvedAggregateFunctionCall expr = ((ResolvedAggregateFunctionCall) computedColumn.getExpr());
List<ZetaSQLResolvedNodeKind.ResolvedNodeKind> resolvedNodeKinds = Arrays.asList(RESOLVED_CAST, RESOLVED_COLUMN_REF, RESOLVED_GET_STRUCT_FIELD);
for (int i = 0; i < expr.getArgumentList().size(); i++) {
// Throw an error if aggregate function's input isn't either a ColumnRef or a cast(ColumnRef).
// TODO: is there a general way to handle aggregation calls conversion?
ZetaSQLResolvedNodeKind.ResolvedNodeKind resolvedNodeKind = expr.getArgumentList().get(i).nodeKind();
if (i == 0 && resolvedNodeKinds.contains(resolvedNodeKind)) {
argList.add(columnRefOff);
} else if (i > 0 && resolvedNodeKind == RESOLVED_LITERAL) {
continue;
} else {
throw new UnsupportedOperationException("Aggregate function only accepts Column Reference or CAST(Column Reference) as the first argument and " + "Literals as subsequent arguments as its inputs");
}
}
String aggName = getTrait().resolveAlias(computedColumn.getColumn());
return AggregateCall.create(sqlAggFunction, false, false, false, argList, -1, null, RelCollations.EMPTY, groupCount, input, // When we pass null as the return type, Calcite infers it for us.
null, aggName);
}
use of org.apache.calcite.rel.core.AggregateCall in project druid by druid-io.
the class ArrayConcatSqlAggregator method toDruidAggregation.
@Nullable
@Override
public Aggregation toDruidAggregation(PlannerContext plannerContext, RowSignature rowSignature, VirtualColumnRegistry virtualColumnRegistry, RexBuilder rexBuilder, String name, AggregateCall aggregateCall, Project project, List<Aggregation> existingAggregations, boolean finalizeAggregations) {
final List<RexNode> arguments = aggregateCall.getArgList().stream().map(i -> Expressions.fromFieldAccess(rowSignature, project, i)).collect(Collectors.toList());
Integer maxSizeBytes = null;
if (arguments.size() > 1) {
RexNode maxBytes = arguments.get(1);
if (!maxBytes.isA(SqlKind.LITERAL)) {
// maxBytes must be a literal
return null;
}
maxSizeBytes = ((Number) RexLiteral.value(maxBytes)).intValue();
}
final DruidExpression arg = Expressions.toDruidExpression(plannerContext, rowSignature, arguments.get(0));
final ExprMacroTable macroTable = plannerContext.getExprMacroTable();
final String fieldName;
final ColumnType druidType = Calcites.getValueTypeForRelDataTypeFull(aggregateCall.getType());
if (druidType == null || !druidType.isArray()) {
// must be an array
return null;
}
final String initialvalue = ExpressionType.fromColumnTypeStrict(druidType).asTypeString() + "[]";
if (arg.isDirectColumnAccess()) {
fieldName = arg.getDirectColumn();
} else {
VirtualColumn vc = virtualColumnRegistry.getOrCreateVirtualColumnForExpression(plannerContext, arg, druidType);
fieldName = vc.getOutputName();
}
if (aggregateCall.isDistinct()) {
return Aggregation.create(new ExpressionLambdaAggregatorFactory(name, ImmutableSet.of(fieldName), null, initialvalue, null, true, false, false, StringUtils.format("array_set_add_all(\"__acc\", \"%s\")", fieldName), StringUtils.format("array_set_add_all(\"__acc\", \"%s\")", name), null, null, maxSizeBytes != null ? new HumanReadableBytes(maxSizeBytes) : null, macroTable));
} else {
return Aggregation.create(new ExpressionLambdaAggregatorFactory(name, ImmutableSet.of(fieldName), null, initialvalue, null, true, false, false, StringUtils.format("array_concat(\"__acc\", \"%s\")", fieldName), StringUtils.format("array_concat(\"__acc\", \"%s\")", name), null, null, maxSizeBytes != null ? new HumanReadableBytes(maxSizeBytes) : null, macroTable));
}
}
use of org.apache.calcite.rel.core.AggregateCall in project hazelcast by hazelcast.
the class AggregateAbstractPhysicalRule method aggregateOperation.
protected static AggregateOperation<?, JetSqlRow> aggregateOperation(RelDataType inputType, ImmutableBitSet groupSet, List<AggregateCall> aggregateCalls) {
List<QueryDataType> operandTypes = OptUtils.schema(inputType).getTypes();
List<SupplierEx<SqlAggregation>> aggregationProviders = new ArrayList<>();
List<FunctionEx<JetSqlRow, Object>> valueProviders = new ArrayList<>();
for (Integer groupIndex : groupSet.toList()) {
aggregationProviders.add(ValueSqlAggregation::new);
// getMaybeSerialized is safe for ValueAggr because it only passes the value on
valueProviders.add(new RowGetMaybeSerializedFn(groupIndex));
}
for (AggregateCall aggregateCall : aggregateCalls) {
boolean distinct = aggregateCall.isDistinct();
List<Integer> aggregateCallArguments = aggregateCall.getArgList();
SqlKind kind = aggregateCall.getAggregation().getKind();
switch(kind) {
case COUNT:
if (distinct) {
int countIndex = aggregateCallArguments.get(0);
aggregationProviders.add(new AggregateCountSupplier(true, true));
// getMaybeSerialized is safe for COUNT because the aggregation only looks whether it is null or not
valueProviders.add(new RowGetMaybeSerializedFn(countIndex));
} else if (aggregateCallArguments.size() == 1) {
int countIndex = aggregateCallArguments.get(0);
aggregationProviders.add(new AggregateCountSupplier(true, false));
valueProviders.add(new RowGetMaybeSerializedFn(countIndex));
} else {
aggregationProviders.add(new AggregateCountSupplier(false, false));
valueProviders.add(NullFunction.INSTANCE);
}
break;
case MIN:
int minIndex = aggregateCallArguments.get(0);
aggregationProviders.add(MinSqlAggregation::new);
valueProviders.add(new RowGetFn(minIndex));
break;
case MAX:
int maxIndex = aggregateCallArguments.get(0);
aggregationProviders.add(MaxSqlAggregation::new);
valueProviders.add(new RowGetFn(maxIndex));
break;
case SUM:
int sumIndex = aggregateCallArguments.get(0);
QueryDataType sumOperandType = operandTypes.get(sumIndex);
aggregationProviders.add(new AggregateSumSupplier(distinct, sumOperandType));
valueProviders.add(new RowGetFn(sumIndex));
break;
case AVG:
int avgIndex = aggregateCallArguments.get(0);
QueryDataType avgOperandType = operandTypes.get(avgIndex);
aggregationProviders.add(new AggregateAvgSupplier(distinct, avgOperandType));
valueProviders.add(new RowGetFn(avgIndex));
break;
default:
throw QueryException.error("Unsupported aggregation function: " + kind);
}
}
return AggregateOperation.withCreate(new AggregateCreateSupplier(aggregationProviders)).andAccumulate(new AggregateAccumulateFunction(valueProviders)).andCombine(AggregateCombineFunction.INSTANCE).andExportFinish(AggregateExportFinishFunction.INSTANCE);
}
use of org.apache.calcite.rel.core.AggregateCall in project druid by druid-io.
the class QuantileSqlAggregator method toDruidAggregation.
@Nullable
@Override
public Aggregation toDruidAggregation(final PlannerContext plannerContext, final RowSignature rowSignature, final VirtualColumnRegistry virtualColumnRegistry, final RexBuilder rexBuilder, final String name, final AggregateCall aggregateCall, final Project project, final List<Aggregation> existingAggregations, final boolean finalizeAggregations) {
final DruidExpression input = Aggregations.toDruidExpressionForNumericAggregator(plannerContext, rowSignature, Expressions.fromFieldAccess(rowSignature, project, aggregateCall.getArgList().get(0)));
if (input == null) {
return null;
}
final AggregatorFactory aggregatorFactory;
final String histogramName = StringUtils.format("%s:agg", name);
final RexNode probabilityArg = Expressions.fromFieldAccess(rowSignature, project, aggregateCall.getArgList().get(1));
if (!probabilityArg.isA(SqlKind.LITERAL)) {
// Probability must be a literal in order to plan.
return null;
}
final float probability = ((Number) RexLiteral.value(probabilityArg)).floatValue();
final int resolution;
if (aggregateCall.getArgList().size() >= 3) {
final RexNode resolutionArg = Expressions.fromFieldAccess(rowSignature, project, aggregateCall.getArgList().get(2));
if (!resolutionArg.isA(SqlKind.LITERAL)) {
// Resolution must be a literal in order to plan.
return null;
}
resolution = ((Number) RexLiteral.value(resolutionArg)).intValue();
} else {
resolution = ApproximateHistogram.DEFAULT_HISTOGRAM_SIZE;
}
final int numBuckets = ApproximateHistogram.DEFAULT_BUCKET_SIZE;
final float lowerLimit = Float.NEGATIVE_INFINITY;
final float upperLimit = Float.POSITIVE_INFINITY;
// Look for existing matching aggregatorFactory.
for (final Aggregation existing : existingAggregations) {
for (AggregatorFactory factory : existing.getAggregatorFactories()) {
if (factory instanceof ApproximateHistogramAggregatorFactory) {
final ApproximateHistogramAggregatorFactory theFactory = (ApproximateHistogramAggregatorFactory) factory;
// Check input for equivalence.
final boolean inputMatches;
final DruidExpression virtualInput = virtualColumnRegistry.findVirtualColumnExpressions(theFactory.requiredFields()).stream().findFirst().orElse(null);
if (virtualInput == null) {
inputMatches = input.isDirectColumnAccess() && input.getDirectColumn().equals(theFactory.getFieldName());
} else {
inputMatches = virtualInput.equals(input);
}
final boolean matches = inputMatches && theFactory.getResolution() == resolution && theFactory.getNumBuckets() == numBuckets && theFactory.getLowerLimit() == lowerLimit && theFactory.getUpperLimit() == upperLimit;
if (matches) {
// Found existing one. Use this.
return Aggregation.create(ImmutableList.of(), new QuantilePostAggregator(name, factory.getName(), probability));
}
}
}
}
// No existing match found. Create a new one.
if (input.isDirectColumnAccess()) {
if (rowSignature.getColumnType(input.getDirectColumn()).map(type -> type.is(ValueType.COMPLEX)).orElse(false)) {
aggregatorFactory = new ApproximateHistogramFoldingAggregatorFactory(histogramName, input.getDirectColumn(), resolution, numBuckets, lowerLimit, upperLimit, false);
} else {
aggregatorFactory = new ApproximateHistogramAggregatorFactory(histogramName, input.getDirectColumn(), resolution, numBuckets, lowerLimit, upperLimit, false);
}
} else {
final String virtualColumnName = virtualColumnRegistry.getOrCreateVirtualColumnForExpression(input, ColumnType.FLOAT);
aggregatorFactory = new ApproximateHistogramAggregatorFactory(histogramName, virtualColumnName, resolution, numBuckets, lowerLimit, upperLimit, false);
}
return Aggregation.create(ImmutableList.of(aggregatorFactory), new QuantilePostAggregator(name, histogramName, probability));
}
use of org.apache.calcite.rel.core.AggregateCall in project hive by apache.
the class HiveRelOptUtil method createSingleValueAggRel.
/**
* Creates a LogicalAggregate that removes all duplicates from the result of
* an underlying relational expression.
*
* @param rel underlying rel
* @return rel implementing SingleValueAgg
*/
public static RelNode createSingleValueAggRel(RelOptCluster cluster, RelNode rel, RelFactories.AggregateFactory aggregateFactory) {
// assert (rel.getRowType().getFieldCount() == 1);
final int aggCallCnt = rel.getRowType().getFieldCount();
final List<AggregateCall> aggCalls = new ArrayList<>();
for (int i = 0; i < aggCallCnt; i++) {
aggCalls.add(AggregateCall.create(SqlStdOperatorTable.SINGLE_VALUE, false, false, ImmutableList.of(i), -1, 0, rel, null, null));
}
return aggregateFactory.createAggregate(rel, Collections.emptyList(), ImmutableBitSet.of(), null, aggCalls);
}
Aggregations