use of org.apache.calcite.rel.core.AggregateCall in project flink by apache.
the class HiveParserCalcitePlanner method genGBRelNode.
private RelNode genGBRelNode(List<ExprNodeDesc> gbExprs, List<AggInfo> aggInfos, List<Integer> groupSets, RelNode srcRel) throws SemanticException {
Map<String, Integer> colNameToPos = relToHiveColNameCalcitePosMap.get(srcRel);
HiveParserRexNodeConverter converter = new HiveParserRexNodeConverter(cluster, srcRel.getRowType(), colNameToPos, 0, false, funcConverter);
final boolean hasGroupSets = groupSets != null && !groupSets.isEmpty();
final List<RexNode> gbInputRexNodes = new ArrayList<>();
final HashMap<String, Integer> inputRexNodeToIndex = new HashMap<>();
final List<Integer> gbKeyIndices = new ArrayList<>();
int inputIndex = 0;
for (ExprNodeDesc key : gbExprs) {
// also convert null literal here to support grouping by NULLs
RexNode keyRex = convertNullLiteral(converter.convert(key)).accept(funcConverter);
gbInputRexNodes.add(keyRex);
gbKeyIndices.add(inputIndex);
inputRexNodeToIndex.put(keyRex.toString(), inputIndex);
inputIndex++;
}
final ImmutableBitSet groupSet = ImmutableBitSet.of(gbKeyIndices);
// Grouping sets: we need to transform them into ImmutableBitSet objects for Calcite
List<ImmutableBitSet> transformedGroupSets = null;
if (hasGroupSets) {
Set<ImmutableBitSet> set = new HashSet<>(groupSets.size());
for (int val : groupSets) {
set.add(convert(val, groupSet.cardinality()));
}
// Calcite expects the grouping sets sorted and without duplicates
transformedGroupSets = new ArrayList<>(set);
transformedGroupSets.sort(ImmutableBitSet.COMPARATOR);
}
// add Agg parameters to inputs
for (AggInfo aggInfo : aggInfos) {
for (ExprNodeDesc expr : aggInfo.getAggParams()) {
RexNode paramRex = converter.convert(expr).accept(funcConverter);
Integer argIndex = inputRexNodeToIndex.get(paramRex.toString());
if (argIndex == null) {
argIndex = gbInputRexNodes.size();
inputRexNodeToIndex.put(paramRex.toString(), argIndex);
gbInputRexNodes.add(paramRex);
}
}
}
// create the actual input before creating agg calls so that the calls can properly infer
// return type
RelNode gbInputRel = LogicalProject.create(srcRel, Collections.emptyList(), gbInputRexNodes, (List<String>) null);
List<AggregateCall> aggregateCalls = new ArrayList<>();
for (AggInfo aggInfo : aggInfos) {
aggregateCalls.add(HiveParserUtils.toAggCall(aggInfo, converter, inputRexNodeToIndex, groupSet.cardinality(), gbInputRel, cluster, funcConverter));
}
// GROUPING__ID is a virtual col in Hive, so we use Flink's function
if (hasGroupSets) {
// Create GroupingID column
AggregateCall aggCall = AggregateCall.create(SqlStdOperatorTable.GROUPING_ID, false, false, false, gbKeyIndices, -1, RelCollations.EMPTY, groupSet.cardinality(), gbInputRel, null, null);
aggregateCalls.add(aggCall);
}
if (gbInputRexNodes.isEmpty()) {
// This will happen for count(*), in such cases we arbitrarily pick
// first element from srcRel
gbInputRexNodes.add(cluster.getRexBuilder().makeInputRef(srcRel, 0));
}
return LogicalAggregate.create(gbInputRel, groupSet, transformedGroupSets, aggregateCalls);
}
use of org.apache.calcite.rel.core.AggregateCall in project flink by apache.
the class CommonPythonUtil method extractPythonAggregateFunctionInfosFromAggregateCall.
public static Tuple2<int[], PythonFunctionInfo[]> extractPythonAggregateFunctionInfosFromAggregateCall(AggregateCall[] aggCalls) {
Map<Integer, Integer> inputNodes = new LinkedHashMap<>();
List<PythonFunctionInfo> pythonFunctionInfos = new ArrayList<>();
for (AggregateCall aggregateCall : aggCalls) {
List<Integer> inputs = new ArrayList<>();
List<Integer> argList = aggregateCall.getArgList();
for (Integer arg : argList) {
if (inputNodes.containsKey(arg)) {
inputs.add(inputNodes.get(arg));
} else {
Integer inputOffset = inputNodes.size();
inputs.add(inputOffset);
inputNodes.put(arg, inputOffset);
}
}
PythonFunction pythonFunction = null;
SqlAggFunction aggregateFunction = aggregateCall.getAggregation();
if (aggregateFunction instanceof AggSqlFunction) {
pythonFunction = (PythonFunction) ((AggSqlFunction) aggregateFunction).aggregateFunction();
} else if (aggregateFunction instanceof BridgingSqlAggFunction) {
pythonFunction = (PythonFunction) ((BridgingSqlAggFunction) aggregateFunction).getDefinition();
}
PythonFunctionInfo pythonFunctionInfo = new PythonAggregateFunctionInfo(pythonFunction, inputs.toArray(), aggregateCall.filterArg, aggregateCall.isDistinct());
pythonFunctionInfos.add(pythonFunctionInfo);
}
int[] udafInputOffsets = inputNodes.keySet().stream().mapToInt(i -> i).toArray();
return Tuple2.of(udafInputOffsets, pythonFunctionInfos.toArray(new PythonFunctionInfo[0]));
}
use of org.apache.calcite.rel.core.AggregateCall in project flink by apache.
the class StreamExecPythonOverAggregate method translateToPlanInternal.
@SuppressWarnings("unchecked")
@Override
protected Transformation<RowData> translateToPlanInternal(PlannerBase planner, ExecNodeConfig config) {
if (overSpec.getGroups().size() > 1) {
throw new TableException("All aggregates must be computed on the same window.");
}
final OverSpec.GroupSpec group = overSpec.getGroups().get(0);
final int[] orderKeys = group.getSort().getFieldIndices();
final boolean[] isAscendingOrders = group.getSort().getAscendingOrders();
if (orderKeys.length != 1 || isAscendingOrders.length != 1) {
throw new TableException("The window can only be ordered by a single time column.");
}
if (!isAscendingOrders[0]) {
throw new TableException("The window can only be ordered in ASCENDING mode.");
}
final int[] partitionKeys = overSpec.getPartition().getFieldIndices();
if (partitionKeys.length > 0 && config.getStateRetentionTime() < 0) {
LOG.warn("No state retention interval configured for a query which accumulates state. " + "Please provide a query configuration with valid retention interval to prevent " + "excessive state size. You may specify a retention time of 0 to not clean up the state.");
}
final ExecEdge inputEdge = getInputEdges().get(0);
final Transformation<RowData> inputTransform = (Transformation<RowData>) inputEdge.translateToPlan(planner);
final RowType inputRowType = (RowType) inputEdge.getOutputType();
final int orderKey = orderKeys[0];
final LogicalType orderKeyType = inputRowType.getFields().get(orderKey).getType();
// check time field && identify window rowtime attribute
final int rowTimeIdx;
if (isRowtimeAttribute(orderKeyType)) {
rowTimeIdx = orderKey;
} else if (isProctimeAttribute(orderKeyType)) {
rowTimeIdx = -1;
} else {
throw new TableException("OVER windows' ordering in stream mode must be defined on a time attribute.");
}
if (group.getLowerBound().isPreceding() && group.getLowerBound().isUnbounded()) {
throw new TableException("Python UDAF is not supported to be used in UNBOUNDED PRECEDING OVER windows.");
} else if (!group.getUpperBound().isCurrentRow()) {
throw new TableException("Python UDAF is not supported to be used in UNBOUNDED FOLLOWING OVER windows.");
}
Object boundValue = OverAggregateUtil.getBoundary(overSpec, group.getLowerBound());
if (boundValue instanceof BigDecimal) {
throw new TableException("the specific value is decimal which haven not supported yet.");
}
long precedingOffset = -1 * (long) boundValue;
Configuration pythonConfig = CommonPythonUtil.getMergedConfig(planner.getExecEnv(), config.getTableConfig());
OneInputTransformation<RowData, RowData> transform = createPythonOneInputTransformation(inputTransform, inputRowType, InternalTypeInfo.of(getOutputType()).toRowType(), rowTimeIdx, group.getAggCalls().toArray(new AggregateCall[0]), precedingOffset, group.isRows(), config.getStateRetentionTime(), config.getMaxIdleStateRetentionTime(), pythonConfig, config);
if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(pythonConfig)) {
transform.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON);
}
// set KeyType and Selector for state
final RowDataKeySelector selector = KeySelectorUtil.getRowDataSelector(partitionKeys, InternalTypeInfo.of(inputRowType));
transform.setStateKeySelector(selector);
transform.setStateKeyType(selector.getProducedType());
return transform;
}
use of org.apache.calcite.rel.core.AggregateCall in project flink by apache.
the class FlinkAggregateJoinTransposeRule method toRegularAggregate.
/**
* Convert aggregate with AUXILIARY_GROUP to regular aggregate. Return original aggregate and
* null project if the given aggregate does not contain AUXILIARY_GROUP, else new aggregate
* without AUXILIARY_GROUP and a project to permute output columns if needed.
*/
private Pair<Aggregate, List<RexNode>> toRegularAggregate(Aggregate aggregate) {
Tuple2<int[], Seq<AggregateCall>> auxGroupAndRegularAggCalls = AggregateUtil.checkAndSplitAggCalls(aggregate);
final int[] auxGroup = auxGroupAndRegularAggCalls._1;
final Seq<AggregateCall> regularAggCalls = auxGroupAndRegularAggCalls._2;
if (auxGroup.length != 0) {
int[] fullGroupSet = AggregateUtil.checkAndGetFullGroupSet(aggregate);
ImmutableBitSet newGroupSet = ImmutableBitSet.of(fullGroupSet);
List<AggregateCall> aggCalls = JavaConverters.seqAsJavaListConverter(regularAggCalls).asJava();
final Aggregate newAgg = aggregate.copy(aggregate.getTraitSet(), aggregate.getInput(), aggregate.indicator, newGroupSet, com.google.common.collect.ImmutableList.of(newGroupSet), aggCalls);
final List<RelDataTypeField> aggFields = aggregate.getRowType().getFieldList();
final List<RexNode> projectAfterAgg = new ArrayList<>();
for (int i = 0; i < fullGroupSet.length; ++i) {
int group = fullGroupSet[i];
int index = newGroupSet.indexOf(group);
projectAfterAgg.add(new RexInputRef(index, aggFields.get(i).getType()));
}
int fieldCntOfAgg = aggFields.size();
for (int i = fullGroupSet.length; i < fieldCntOfAgg; ++i) {
projectAfterAgg.add(new RexInputRef(i, aggFields.get(i).getType()));
}
Preconditions.checkArgument(projectAfterAgg.size() == fieldCntOfAgg);
return new Pair<>(newAgg, projectAfterAgg);
} else {
return new Pair<>(aggregate, null);
}
}
use of org.apache.calcite.rel.core.AggregateCall in project flink by apache.
the class FlinkAggregateExpandDistinctAggregatesRule method convertSingletonDistinct.
/**
* Converts an aggregate with one distinct aggregate and one or more non-distinct aggregates to
* multi-phase aggregates (see reference example below).
*
* @param relBuilder Contains the input relational expression
* @param aggregate Original aggregate
* @param argLists Arguments and filters to the distinct aggregate function
*/
private RelBuilder convertSingletonDistinct(RelBuilder relBuilder, Aggregate aggregate, Set<Pair<List<Integer>, Integer>> argLists) {
// In this case, we are assuming that there is a single distinct function.
// So make sure that argLists is of size one.
Preconditions.checkArgument(argLists.size() == 1);
// For example,
// SELECT deptno, COUNT(*), SUM(bonus), MIN(DISTINCT sal)
// FROM emp
// GROUP BY deptno
//
// becomes
//
// SELECT deptno, SUM(cnt), SUM(bonus), MIN(sal)
// FROM (
// SELECT deptno, COUNT(*) as cnt, SUM(bonus), sal
// FROM EMP
// GROUP BY deptno, sal) // Aggregate B
// GROUP BY deptno // Aggregate A
relBuilder.push(aggregate.getInput());
final List<AggregateCall> originalAggCalls = aggregate.getAggCallList();
final ImmutableBitSet originalGroupSet = aggregate.getGroupSet();
// Add the distinct aggregate column(s) to the group-by columns,
// if not already a part of the group-by
final SortedSet<Integer> bottomGroupSet = new TreeSet<>();
bottomGroupSet.addAll(aggregate.getGroupSet().asList());
for (AggregateCall aggCall : originalAggCalls) {
if (aggCall.isDistinct()) {
bottomGroupSet.addAll(aggCall.getArgList());
// since we only have single distinct call
break;
}
}
// Generate the intermediate aggregate B, the one on the bottom that converts
// a distinct call to group by call.
// Bottom aggregate is the same as the original aggregate, except that
// the bottom aggregate has converted the DISTINCT aggregate to a group by clause.
final List<AggregateCall> bottomAggregateCalls = new ArrayList<>();
for (AggregateCall aggCall : originalAggCalls) {
// as-is all the non-distinct aggregates
if (!aggCall.isDistinct()) {
final AggregateCall newCall = AggregateCall.create(aggCall.getAggregation(), false, aggCall.isApproximate(), false, aggCall.getArgList(), -1, RelCollations.EMPTY, ImmutableBitSet.of(bottomGroupSet).cardinality(), relBuilder.peek(), null, aggCall.name);
bottomAggregateCalls.add(newCall);
}
}
// Generate the aggregate B (see the reference example above)
relBuilder.push(aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), ImmutableBitSet.of(bottomGroupSet), null, bottomAggregateCalls));
// Add aggregate A (see the reference example above), the top aggregate
// to handle the rest of the aggregation that the bottom aggregate hasn't handled
final List<AggregateCall> topAggregateCalls = com.google.common.collect.Lists.newArrayList();
// Use the remapped arguments for the (non)distinct aggregate calls
int nonDistinctAggCallProcessedSoFar = 0;
for (AggregateCall aggCall : originalAggCalls) {
final AggregateCall newCall;
if (aggCall.isDistinct()) {
List<Integer> newArgList = new ArrayList<>();
for (int arg : aggCall.getArgList()) {
newArgList.add(bottomGroupSet.headSet(arg).size());
}
newCall = AggregateCall.create(aggCall.getAggregation(), false, aggCall.isApproximate(), false, newArgList, -1, RelCollations.EMPTY, originalGroupSet.cardinality(), relBuilder.peek(), aggCall.getType(), aggCall.name);
} else {
// If aggregate B had a COUNT aggregate call the corresponding aggregate at
// aggregate A must be SUM. For other aggregates, it remains the same.
final List<Integer> newArgs = com.google.common.collect.Lists.newArrayList(bottomGroupSet.size() + nonDistinctAggCallProcessedSoFar);
if (aggCall.getAggregation().getKind() == SqlKind.COUNT) {
newCall = AggregateCall.create(new SqlSumEmptyIsZeroAggFunction(), false, aggCall.isApproximate(), false, newArgs, -1, RelCollations.EMPTY, originalGroupSet.cardinality(), relBuilder.peek(), aggCall.getType(), aggCall.getName());
} else {
newCall = AggregateCall.create(aggCall.getAggregation(), false, aggCall.isApproximate(), false, newArgs, -1, RelCollations.EMPTY, originalGroupSet.cardinality(), relBuilder.peek(), aggCall.getType(), aggCall.name);
}
nonDistinctAggCallProcessedSoFar++;
}
topAggregateCalls.add(newCall);
}
// Populate the group-by keys with the remapped arguments for aggregate A
// The top groupset is basically an identity (first X fields of aggregate B's
// output), minus the distinct aggCall's input.
final Set<Integer> topGroupSet = new HashSet<>();
int groupSetToAdd = 0;
for (int bottomGroup : bottomGroupSet) {
if (originalGroupSet.get(bottomGroup)) {
topGroupSet.add(groupSetToAdd);
}
groupSetToAdd++;
}
relBuilder.push(aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), ImmutableBitSet.of(topGroupSet), null, topAggregateCalls));
return relBuilder;
}
Aggregations