Search in sources :

Example 41 with AggregateCall

use of org.apache.calcite.rel.core.AggregateCall in project flink by apache.

the class HiveParserCalcitePlanner method genGBRelNode.

private RelNode genGBRelNode(List<ExprNodeDesc> gbExprs, List<AggInfo> aggInfos, List<Integer> groupSets, RelNode srcRel) throws SemanticException {
    Map<String, Integer> colNameToPos = relToHiveColNameCalcitePosMap.get(srcRel);
    HiveParserRexNodeConverter converter = new HiveParserRexNodeConverter(cluster, srcRel.getRowType(), colNameToPos, 0, false, funcConverter);
    final boolean hasGroupSets = groupSets != null && !groupSets.isEmpty();
    final List<RexNode> gbInputRexNodes = new ArrayList<>();
    final HashMap<String, Integer> inputRexNodeToIndex = new HashMap<>();
    final List<Integer> gbKeyIndices = new ArrayList<>();
    int inputIndex = 0;
    for (ExprNodeDesc key : gbExprs) {
        // also convert null literal here to support grouping by NULLs
        RexNode keyRex = convertNullLiteral(converter.convert(key)).accept(funcConverter);
        gbInputRexNodes.add(keyRex);
        gbKeyIndices.add(inputIndex);
        inputRexNodeToIndex.put(keyRex.toString(), inputIndex);
        inputIndex++;
    }
    final ImmutableBitSet groupSet = ImmutableBitSet.of(gbKeyIndices);
    // Grouping sets: we need to transform them into ImmutableBitSet objects for Calcite
    List<ImmutableBitSet> transformedGroupSets = null;
    if (hasGroupSets) {
        Set<ImmutableBitSet> set = new HashSet<>(groupSets.size());
        for (int val : groupSets) {
            set.add(convert(val, groupSet.cardinality()));
        }
        // Calcite expects the grouping sets sorted and without duplicates
        transformedGroupSets = new ArrayList<>(set);
        transformedGroupSets.sort(ImmutableBitSet.COMPARATOR);
    }
    // add Agg parameters to inputs
    for (AggInfo aggInfo : aggInfos) {
        for (ExprNodeDesc expr : aggInfo.getAggParams()) {
            RexNode paramRex = converter.convert(expr).accept(funcConverter);
            Integer argIndex = inputRexNodeToIndex.get(paramRex.toString());
            if (argIndex == null) {
                argIndex = gbInputRexNodes.size();
                inputRexNodeToIndex.put(paramRex.toString(), argIndex);
                gbInputRexNodes.add(paramRex);
            }
        }
    }
    // create the actual input before creating agg calls so that the calls can properly infer
    // return type
    RelNode gbInputRel = LogicalProject.create(srcRel, Collections.emptyList(), gbInputRexNodes, (List<String>) null);
    List<AggregateCall> aggregateCalls = new ArrayList<>();
    for (AggInfo aggInfo : aggInfos) {
        aggregateCalls.add(HiveParserUtils.toAggCall(aggInfo, converter, inputRexNodeToIndex, groupSet.cardinality(), gbInputRel, cluster, funcConverter));
    }
    // GROUPING__ID is a virtual col in Hive, so we use Flink's function
    if (hasGroupSets) {
        // Create GroupingID column
        AggregateCall aggCall = AggregateCall.create(SqlStdOperatorTable.GROUPING_ID, false, false, false, gbKeyIndices, -1, RelCollations.EMPTY, groupSet.cardinality(), gbInputRel, null, null);
        aggregateCalls.add(aggCall);
    }
    if (gbInputRexNodes.isEmpty()) {
        // This will happen for count(*), in such cases we arbitrarily pick
        // first element from srcRel
        gbInputRexNodes.add(cluster.getRexBuilder().makeInputRef(srcRel, 0));
    }
    return LogicalAggregate.create(gbInputRel, groupSet, transformedGroupSets, aggregateCalls);
}
Also used : ImmutableBitSet(org.apache.calcite.util.ImmutableBitSet) LinkedHashMap(java.util.LinkedHashMap) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) HiveParserBaseSemanticAnalyzer.getHiveAggInfo(org.apache.flink.table.planner.delegation.hive.copy.HiveParserBaseSemanticAnalyzer.getHiveAggInfo) AggInfo(org.apache.flink.table.planner.delegation.hive.copy.HiveParserBaseSemanticAnalyzer.AggInfo) AggregateCall(org.apache.calcite.rel.core.AggregateCall) RelNode(org.apache.calcite.rel.RelNode) ExprNodeDesc(org.apache.hadoop.hive.ql.plan.ExprNodeDesc) RexNode(org.apache.calcite.rex.RexNode) HashSet(java.util.HashSet)

Example 42 with AggregateCall

use of org.apache.calcite.rel.core.AggregateCall in project flink by apache.

the class CommonPythonUtil method extractPythonAggregateFunctionInfosFromAggregateCall.

public static Tuple2<int[], PythonFunctionInfo[]> extractPythonAggregateFunctionInfosFromAggregateCall(AggregateCall[] aggCalls) {
    Map<Integer, Integer> inputNodes = new LinkedHashMap<>();
    List<PythonFunctionInfo> pythonFunctionInfos = new ArrayList<>();
    for (AggregateCall aggregateCall : aggCalls) {
        List<Integer> inputs = new ArrayList<>();
        List<Integer> argList = aggregateCall.getArgList();
        for (Integer arg : argList) {
            if (inputNodes.containsKey(arg)) {
                inputs.add(inputNodes.get(arg));
            } else {
                Integer inputOffset = inputNodes.size();
                inputs.add(inputOffset);
                inputNodes.put(arg, inputOffset);
            }
        }
        PythonFunction pythonFunction = null;
        SqlAggFunction aggregateFunction = aggregateCall.getAggregation();
        if (aggregateFunction instanceof AggSqlFunction) {
            pythonFunction = (PythonFunction) ((AggSqlFunction) aggregateFunction).aggregateFunction();
        } else if (aggregateFunction instanceof BridgingSqlAggFunction) {
            pythonFunction = (PythonFunction) ((BridgingSqlAggFunction) aggregateFunction).getDefinition();
        }
        PythonFunctionInfo pythonFunctionInfo = new PythonAggregateFunctionInfo(pythonFunction, inputs.toArray(), aggregateCall.filterArg, aggregateCall.isDistinct());
        pythonFunctionInfos.add(pythonFunctionInfo);
    }
    int[] udafInputOffsets = inputNodes.keySet().stream().mapToInt(i -> i).toArray();
    return Tuple2.of(udafInputOffsets, pythonFunctionInfos.toArray(new PythonFunctionInfo[0]));
}
Also used : TypeInference(org.apache.flink.table.types.inference.TypeInference) MapView(org.apache.flink.table.api.dataview.MapView) SumAggFunction(org.apache.flink.table.planner.functions.aggfunctions.SumAggFunction) DataType(org.apache.flink.table.types.DataType) Sum0AggFunction(org.apache.flink.table.planner.functions.aggfunctions.Sum0AggFunction) Arrays(java.util.Arrays) Tuple2(org.apache.flink.api.java.tuple.Tuple2) DataViewSpec(org.apache.flink.table.runtime.dataview.DataViewSpec) StructuredType(org.apache.flink.table.types.logical.StructuredType) LastValueWithRetractAggFunction(org.apache.flink.table.runtime.functions.aggregate.LastValueWithRetractAggFunction) ListViewSpec(org.apache.flink.table.runtime.dataview.ListViewSpec) CountAggFunction(org.apache.flink.table.planner.functions.aggfunctions.CountAggFunction) BigDecimal(java.math.BigDecimal) PythonFunction(org.apache.flink.table.functions.python.PythonFunction) RexNode(org.apache.calcite.rex.RexNode) MaxAggFunction(org.apache.flink.table.planner.functions.aggfunctions.MaxAggFunction) Map(java.util.Map) FirstValueAggFunction(org.apache.flink.table.runtime.functions.aggregate.FirstValueAggFunction) Method(java.lang.reflect.Method) TableSqlFunction(org.apache.flink.table.planner.functions.utils.TableSqlFunction) BuiltInPythonAggregateFunction(org.apache.flink.table.functions.python.BuiltInPythonAggregateFunction) TableConfig(org.apache.flink.table.api.TableConfig) ListAggWsWithRetractAggFunction(org.apache.flink.table.runtime.functions.aggregate.ListAggWsWithRetractAggFunction) RexLiteral(org.apache.calcite.rex.RexLiteral) AggregateInfoList(org.apache.flink.table.planner.plan.utils.AggregateInfoList) MapViewSpec(org.apache.flink.table.runtime.dataview.MapViewSpec) UserDefinedFunction(org.apache.flink.table.functions.UserDefinedFunction) Count1AggFunction(org.apache.flink.table.planner.functions.aggfunctions.Count1AggFunction) FirstValueWithRetractAggFunction(org.apache.flink.table.runtime.functions.aggregate.FirstValueWithRetractAggFunction) InvocationTargetException(java.lang.reflect.InvocationTargetException) Objects(java.util.Objects) AggSqlFunction(org.apache.flink.table.planner.functions.utils.AggSqlFunction) List(java.util.List) ListAggFunction(org.apache.flink.table.planner.functions.aggfunctions.ListAggFunction) BridgingSqlAggFunction(org.apache.flink.table.planner.functions.bridging.BridgingSqlAggFunction) LogicalType(org.apache.flink.table.types.logical.LogicalType) PythonAggregateFunctionInfo(org.apache.flink.table.functions.python.PythonAggregateFunctionInfo) StreamExecutionEnvironment(org.apache.flink.streaming.api.environment.StreamExecutionEnvironment) RexCall(org.apache.calcite.rex.RexCall) IntStream(java.util.stream.IntStream) MinAggFunction(org.apache.flink.table.planner.functions.aggfunctions.MinAggFunction) ListAggWithRetractAggFunction(org.apache.flink.table.runtime.functions.aggregate.ListAggWithRetractAggFunction) DummyStreamExecutionEnvironment(org.apache.flink.table.planner.utils.DummyStreamExecutionEnvironment) MinWithRetractAggFunction(org.apache.flink.table.runtime.functions.aggregate.MinWithRetractAggFunction) RowType(org.apache.flink.table.types.logical.RowType) ArrayList(java.util.ArrayList) LinkedHashMap(java.util.LinkedHashMap) DataView(org.apache.flink.table.api.dataview.DataView) FieldsDataType(org.apache.flink.table.types.FieldsDataType) ConfigOption(org.apache.flink.configuration.ConfigOption) SqlOperator(org.apache.calcite.sql.SqlOperator) ListView(org.apache.flink.table.api.dataview.ListView) AggregateInfo(org.apache.flink.table.planner.plan.utils.AggregateInfo) SqlTypeName(org.apache.calcite.sql.type.SqlTypeName) FunctionDefinition(org.apache.flink.table.functions.FunctionDefinition) AvgAggFunction(org.apache.flink.table.planner.functions.aggfunctions.AvgAggFunction) MaxWithRetractAggFunction(org.apache.flink.table.runtime.functions.aggregate.MaxWithRetractAggFunction) Configuration(org.apache.flink.configuration.Configuration) TableException(org.apache.flink.table.api.TableException) SumWithRetractAggFunction(org.apache.flink.table.planner.functions.aggfunctions.SumWithRetractAggFunction) PythonFunctionInfo(org.apache.flink.table.functions.python.PythonFunctionInfo) Field(java.lang.reflect.Field) LastValueAggFunction(org.apache.flink.table.runtime.functions.aggregate.LastValueAggFunction) BridgingSqlFunction(org.apache.flink.table.planner.functions.bridging.BridgingSqlFunction) ScalarSqlFunction(org.apache.flink.table.planner.functions.utils.ScalarSqlFunction) AggregateCall(org.apache.calcite.rel.core.AggregateCall) SqlAggFunction(org.apache.calcite.sql.SqlAggFunction) PythonFunctionInfo(org.apache.flink.table.functions.python.PythonFunctionInfo) ArrayList(java.util.ArrayList) BridgingSqlAggFunction(org.apache.flink.table.planner.functions.bridging.BridgingSqlAggFunction) SqlAggFunction(org.apache.calcite.sql.SqlAggFunction) BridgingSqlAggFunction(org.apache.flink.table.planner.functions.bridging.BridgingSqlAggFunction) LinkedHashMap(java.util.LinkedHashMap) AggregateCall(org.apache.calcite.rel.core.AggregateCall) PythonAggregateFunctionInfo(org.apache.flink.table.functions.python.PythonAggregateFunctionInfo) PythonFunction(org.apache.flink.table.functions.python.PythonFunction) AggSqlFunction(org.apache.flink.table.planner.functions.utils.AggSqlFunction)

Example 43 with AggregateCall

use of org.apache.calcite.rel.core.AggregateCall in project flink by apache.

the class StreamExecPythonOverAggregate method translateToPlanInternal.

@SuppressWarnings("unchecked")
@Override
protected Transformation<RowData> translateToPlanInternal(PlannerBase planner, ExecNodeConfig config) {
    if (overSpec.getGroups().size() > 1) {
        throw new TableException("All aggregates must be computed on the same window.");
    }
    final OverSpec.GroupSpec group = overSpec.getGroups().get(0);
    final int[] orderKeys = group.getSort().getFieldIndices();
    final boolean[] isAscendingOrders = group.getSort().getAscendingOrders();
    if (orderKeys.length != 1 || isAscendingOrders.length != 1) {
        throw new TableException("The window can only be ordered by a single time column.");
    }
    if (!isAscendingOrders[0]) {
        throw new TableException("The window can only be ordered in ASCENDING mode.");
    }
    final int[] partitionKeys = overSpec.getPartition().getFieldIndices();
    if (partitionKeys.length > 0 && config.getStateRetentionTime() < 0) {
        LOG.warn("No state retention interval configured for a query which accumulates state. " + "Please provide a query configuration with valid retention interval to prevent " + "excessive state size. You may specify a retention time of 0 to not clean up the state.");
    }
    final ExecEdge inputEdge = getInputEdges().get(0);
    final Transformation<RowData> inputTransform = (Transformation<RowData>) inputEdge.translateToPlan(planner);
    final RowType inputRowType = (RowType) inputEdge.getOutputType();
    final int orderKey = orderKeys[0];
    final LogicalType orderKeyType = inputRowType.getFields().get(orderKey).getType();
    // check time field && identify window rowtime attribute
    final int rowTimeIdx;
    if (isRowtimeAttribute(orderKeyType)) {
        rowTimeIdx = orderKey;
    } else if (isProctimeAttribute(orderKeyType)) {
        rowTimeIdx = -1;
    } else {
        throw new TableException("OVER windows' ordering in stream mode must be defined on a time attribute.");
    }
    if (group.getLowerBound().isPreceding() && group.getLowerBound().isUnbounded()) {
        throw new TableException("Python UDAF is not supported to be used in UNBOUNDED PRECEDING OVER windows.");
    } else if (!group.getUpperBound().isCurrentRow()) {
        throw new TableException("Python UDAF is not supported to be used in UNBOUNDED FOLLOWING OVER windows.");
    }
    Object boundValue = OverAggregateUtil.getBoundary(overSpec, group.getLowerBound());
    if (boundValue instanceof BigDecimal) {
        throw new TableException("the specific value is decimal which haven not supported yet.");
    }
    long precedingOffset = -1 * (long) boundValue;
    Configuration pythonConfig = CommonPythonUtil.getMergedConfig(planner.getExecEnv(), config.getTableConfig());
    OneInputTransformation<RowData, RowData> transform = createPythonOneInputTransformation(inputTransform, inputRowType, InternalTypeInfo.of(getOutputType()).toRowType(), rowTimeIdx, group.getAggCalls().toArray(new AggregateCall[0]), precedingOffset, group.isRows(), config.getStateRetentionTime(), config.getMaxIdleStateRetentionTime(), pythonConfig, config);
    if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(pythonConfig)) {
        transform.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON);
    }
    // set KeyType and Selector for state
    final RowDataKeySelector selector = KeySelectorUtil.getRowDataSelector(partitionKeys, InternalTypeInfo.of(inputRowType));
    transform.setStateKeySelector(selector);
    transform.setStateKeyType(selector.getProducedType());
    return transform;
}
Also used : TableException(org.apache.flink.table.api.TableException) OneInputTransformation(org.apache.flink.streaming.api.transformations.OneInputTransformation) Transformation(org.apache.flink.api.dag.Transformation) ExecEdge(org.apache.flink.table.planner.plan.nodes.exec.ExecEdge) Configuration(org.apache.flink.configuration.Configuration) RowType(org.apache.flink.table.types.logical.RowType) LogicalType(org.apache.flink.table.types.logical.LogicalType) OverSpec(org.apache.flink.table.planner.plan.nodes.exec.spec.OverSpec) BigDecimal(java.math.BigDecimal) AggregateCall(org.apache.calcite.rel.core.AggregateCall) RowData(org.apache.flink.table.data.RowData) RowDataKeySelector(org.apache.flink.table.runtime.keyselector.RowDataKeySelector)

Example 44 with AggregateCall

use of org.apache.calcite.rel.core.AggregateCall in project flink by apache.

the class FlinkAggregateJoinTransposeRule method toRegularAggregate.

/**
 * Convert aggregate with AUXILIARY_GROUP to regular aggregate. Return original aggregate and
 * null project if the given aggregate does not contain AUXILIARY_GROUP, else new aggregate
 * without AUXILIARY_GROUP and a project to permute output columns if needed.
 */
private Pair<Aggregate, List<RexNode>> toRegularAggregate(Aggregate aggregate) {
    Tuple2<int[], Seq<AggregateCall>> auxGroupAndRegularAggCalls = AggregateUtil.checkAndSplitAggCalls(aggregate);
    final int[] auxGroup = auxGroupAndRegularAggCalls._1;
    final Seq<AggregateCall> regularAggCalls = auxGroupAndRegularAggCalls._2;
    if (auxGroup.length != 0) {
        int[] fullGroupSet = AggregateUtil.checkAndGetFullGroupSet(aggregate);
        ImmutableBitSet newGroupSet = ImmutableBitSet.of(fullGroupSet);
        List<AggregateCall> aggCalls = JavaConverters.seqAsJavaListConverter(regularAggCalls).asJava();
        final Aggregate newAgg = aggregate.copy(aggregate.getTraitSet(), aggregate.getInput(), aggregate.indicator, newGroupSet, com.google.common.collect.ImmutableList.of(newGroupSet), aggCalls);
        final List<RelDataTypeField> aggFields = aggregate.getRowType().getFieldList();
        final List<RexNode> projectAfterAgg = new ArrayList<>();
        for (int i = 0; i < fullGroupSet.length; ++i) {
            int group = fullGroupSet[i];
            int index = newGroupSet.indexOf(group);
            projectAfterAgg.add(new RexInputRef(index, aggFields.get(i).getType()));
        }
        int fieldCntOfAgg = aggFields.size();
        for (int i = fullGroupSet.length; i < fieldCntOfAgg; ++i) {
            projectAfterAgg.add(new RexInputRef(i, aggFields.get(i).getType()));
        }
        Preconditions.checkArgument(projectAfterAgg.size() == fieldCntOfAgg);
        return new Pair<>(newAgg, projectAfterAgg);
    } else {
        return new Pair<>(aggregate, null);
    }
}
Also used : ImmutableBitSet(org.apache.calcite.util.ImmutableBitSet) ArrayList(java.util.ArrayList) AggregateCall(org.apache.calcite.rel.core.AggregateCall) RelDataTypeField(org.apache.calcite.rel.type.RelDataTypeField) RexInputRef(org.apache.calcite.rex.RexInputRef) Aggregate(org.apache.calcite.rel.core.Aggregate) LogicalAggregate(org.apache.calcite.rel.logical.LogicalAggregate) Seq(scala.collection.Seq) RexNode(org.apache.calcite.rex.RexNode) Pair(org.apache.calcite.util.Pair)

Example 45 with AggregateCall

use of org.apache.calcite.rel.core.AggregateCall in project flink by apache.

the class FlinkAggregateExpandDistinctAggregatesRule method convertSingletonDistinct.

/**
 * Converts an aggregate with one distinct aggregate and one or more non-distinct aggregates to
 * multi-phase aggregates (see reference example below).
 *
 * @param relBuilder Contains the input relational expression
 * @param aggregate Original aggregate
 * @param argLists Arguments and filters to the distinct aggregate function
 */
private RelBuilder convertSingletonDistinct(RelBuilder relBuilder, Aggregate aggregate, Set<Pair<List<Integer>, Integer>> argLists) {
    // In this case, we are assuming that there is a single distinct function.
    // So make sure that argLists is of size one.
    Preconditions.checkArgument(argLists.size() == 1);
    // For example,
    // SELECT deptno, COUNT(*), SUM(bonus), MIN(DISTINCT sal)
    // FROM emp
    // GROUP BY deptno
    // 
    // becomes
    // 
    // SELECT deptno, SUM(cnt), SUM(bonus), MIN(sal)
    // FROM (
    // SELECT deptno, COUNT(*) as cnt, SUM(bonus), sal
    // FROM EMP
    // GROUP BY deptno, sal)            // Aggregate B
    // GROUP BY deptno                        // Aggregate A
    relBuilder.push(aggregate.getInput());
    final List<AggregateCall> originalAggCalls = aggregate.getAggCallList();
    final ImmutableBitSet originalGroupSet = aggregate.getGroupSet();
    // Add the distinct aggregate column(s) to the group-by columns,
    // if not already a part of the group-by
    final SortedSet<Integer> bottomGroupSet = new TreeSet<>();
    bottomGroupSet.addAll(aggregate.getGroupSet().asList());
    for (AggregateCall aggCall : originalAggCalls) {
        if (aggCall.isDistinct()) {
            bottomGroupSet.addAll(aggCall.getArgList());
            // since we only have single distinct call
            break;
        }
    }
    // Generate the intermediate aggregate B, the one on the bottom that converts
    // a distinct call to group by call.
    // Bottom aggregate is the same as the original aggregate, except that
    // the bottom aggregate has converted the DISTINCT aggregate to a group by clause.
    final List<AggregateCall> bottomAggregateCalls = new ArrayList<>();
    for (AggregateCall aggCall : originalAggCalls) {
        // as-is all the non-distinct aggregates
        if (!aggCall.isDistinct()) {
            final AggregateCall newCall = AggregateCall.create(aggCall.getAggregation(), false, aggCall.isApproximate(), false, aggCall.getArgList(), -1, RelCollations.EMPTY, ImmutableBitSet.of(bottomGroupSet).cardinality(), relBuilder.peek(), null, aggCall.name);
            bottomAggregateCalls.add(newCall);
        }
    }
    // Generate the aggregate B (see the reference example above)
    relBuilder.push(aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), ImmutableBitSet.of(bottomGroupSet), null, bottomAggregateCalls));
    // Add aggregate A (see the reference example above), the top aggregate
    // to handle the rest of the aggregation that the bottom aggregate hasn't handled
    final List<AggregateCall> topAggregateCalls = com.google.common.collect.Lists.newArrayList();
    // Use the remapped arguments for the (non)distinct aggregate calls
    int nonDistinctAggCallProcessedSoFar = 0;
    for (AggregateCall aggCall : originalAggCalls) {
        final AggregateCall newCall;
        if (aggCall.isDistinct()) {
            List<Integer> newArgList = new ArrayList<>();
            for (int arg : aggCall.getArgList()) {
                newArgList.add(bottomGroupSet.headSet(arg).size());
            }
            newCall = AggregateCall.create(aggCall.getAggregation(), false, aggCall.isApproximate(), false, newArgList, -1, RelCollations.EMPTY, originalGroupSet.cardinality(), relBuilder.peek(), aggCall.getType(), aggCall.name);
        } else {
            // If aggregate B had a COUNT aggregate call the corresponding aggregate at
            // aggregate A must be SUM. For other aggregates, it remains the same.
            final List<Integer> newArgs = com.google.common.collect.Lists.newArrayList(bottomGroupSet.size() + nonDistinctAggCallProcessedSoFar);
            if (aggCall.getAggregation().getKind() == SqlKind.COUNT) {
                newCall = AggregateCall.create(new SqlSumEmptyIsZeroAggFunction(), false, aggCall.isApproximate(), false, newArgs, -1, RelCollations.EMPTY, originalGroupSet.cardinality(), relBuilder.peek(), aggCall.getType(), aggCall.getName());
            } else {
                newCall = AggregateCall.create(aggCall.getAggregation(), false, aggCall.isApproximate(), false, newArgs, -1, RelCollations.EMPTY, originalGroupSet.cardinality(), relBuilder.peek(), aggCall.getType(), aggCall.name);
            }
            nonDistinctAggCallProcessedSoFar++;
        }
        topAggregateCalls.add(newCall);
    }
    // Populate the group-by keys with the remapped arguments for aggregate A
    // The top groupset is basically an identity (first X fields of aggregate B's
    // output), minus the distinct aggCall's input.
    final Set<Integer> topGroupSet = new HashSet<>();
    int groupSetToAdd = 0;
    for (int bottomGroup : bottomGroupSet) {
        if (originalGroupSet.get(bottomGroup)) {
            topGroupSet.add(groupSetToAdd);
        }
        groupSetToAdd++;
    }
    relBuilder.push(aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), ImmutableBitSet.of(topGroupSet), null, topAggregateCalls));
    return relBuilder;
}
Also used : AggregateCall(org.apache.calcite.rel.core.AggregateCall) SqlSumEmptyIsZeroAggFunction(org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction) ImmutableBitSet(org.apache.calcite.util.ImmutableBitSet) TreeSet(java.util.TreeSet) ArrayList(java.util.ArrayList) HashSet(java.util.HashSet) LinkedHashSet(java.util.LinkedHashSet)

Aggregations

AggregateCall (org.apache.calcite.rel.core.AggregateCall)158 ArrayList (java.util.ArrayList)82 RexNode (org.apache.calcite.rex.RexNode)78 ImmutableBitSet (org.apache.calcite.util.ImmutableBitSet)57 RelNode (org.apache.calcite.rel.RelNode)54 RexBuilder (org.apache.calcite.rex.RexBuilder)52 RelDataType (org.apache.calcite.rel.type.RelDataType)42 Aggregate (org.apache.calcite.rel.core.Aggregate)37 RelDataTypeField (org.apache.calcite.rel.type.RelDataTypeField)36 RexInputRef (org.apache.calcite.rex.RexInputRef)33 RelBuilder (org.apache.calcite.tools.RelBuilder)29 HashMap (java.util.HashMap)28 SqlAggFunction (org.apache.calcite.sql.SqlAggFunction)28 List (java.util.List)27 RexLiteral (org.apache.calcite.rex.RexLiteral)23 Pair (org.apache.calcite.util.Pair)20 ImmutableList (com.google.common.collect.ImmutableList)19 Project (org.apache.calcite.rel.core.Project)17 RelDataTypeFactory (org.apache.calcite.rel.type.RelDataTypeFactory)17 LogicalAggregate (org.apache.calcite.rel.logical.LogicalAggregate)16