//~ Methods ----------------------------------------------------------------
public void onMatch(RelOptRuleCall call) {
    final Aggregate aggregate = call.rel(0);
    if (!aggregate.containsDistinctCall()) {
    // Find all of the agg expressions. We use a LinkedHashSet to ensure
    // determinism.
    int nonDistinctCount = 0;
    int distinctCount = 0;
    int filterCount = 0;
    int unsupportedAggCount = 0;
    final Set<Pair<List<Integer>, Integer>> argLists = new LinkedHashSet<>();
    for (AggregateCall aggCall : aggregate.getAggCallList()) {
        if (aggCall.filterArg >= 0) {
        if (!aggCall.isDistinct()) {
            if (!(aggCall.getAggregation() instanceof SqlCountAggFunction || aggCall.getAggregation() instanceof SqlSumAggFunction || aggCall.getAggregation() instanceof SqlMinMaxAggFunction)) {
        argLists.add(Pair.of(aggCall.getArgList(), aggCall.filterArg));
    Preconditions.checkState(argLists.size() > 0, "containsDistinctCall lied");
    // arguments then we can use a more efficient form.
    if (nonDistinctCount == 0 && argLists.size() == 1) {
        final Pair<List<Integer>, Integer> pair = Iterables.getOnlyElement(argLists);
        final RelBuilder relBuilder = call.builder();
        convertMonopole(relBuilder, aggregate, pair.left, pair.right);
    if (useGroupingSets) {
        rewriteUsingGroupingSets(call, aggregate, argLists);
    // we can generate multi-phase aggregates
    if (// one distinct aggregate
    distinctCount == 1 && // no filter
    filterCount == 0 && // sum/min/max/count in non-distinct aggregate
    unsupportedAggCount == 0 && nonDistinctCount > 0) {
        // one or more non-distinct aggregates
        final RelBuilder relBuilder = call.builder();
        convertSingletonDistinct(relBuilder, aggregate, argLists);
    // Create a list of the expressions which will yield the final result.
    // Initially, the expressions point to the input field.
    final List<RelDataTypeField> aggFields = aggregate.getRowType().getFieldList();
    final List<RexInputRef> refs = new ArrayList<>();
    final List<String> fieldNames = aggregate.getRowType().getFieldNames();
    final ImmutableBitSet groupSet = aggregate.getGroupSet();
    final int groupAndIndicatorCount = aggregate.getGroupCount() + aggregate.getIndicatorCount();
    for (int i : Util.range(groupAndIndicatorCount)) {
        refs.add(RexInputRef.of(i, aggFields));
    // Aggregate the original relation, including any non-distinct aggregates.
    final List<AggregateCall> newAggCallList = new ArrayList<>();
    int i = -1;
    for (AggregateCall aggCall : aggregate.getAggCallList()) {
        if (aggCall.isDistinct()) {
        refs.add(new RexInputRef(groupAndIndicatorCount + newAggCallList.size(), aggFields.get(groupAndIndicatorCount + i).getType()));
    // In the case where there are no non-distinct aggregates (regardless of
    // whether there are group bys), there's no need to generate the
    // extra aggregate and join.
    final RelBuilder relBuilder = call.builder();
    int n = 0;
    if (!newAggCallList.isEmpty()) {
        final RelBuilder.GroupKey groupKey = relBuilder.groupKey(groupSet, aggregate.indicator, aggregate.getGroupSets());
        relBuilder.aggregate(groupKey, newAggCallList);
    // set of operands.
    for (Pair<List<Integer>, Integer> argList : argLists) {
        doRewrite(relBuilder, aggregate, n++, argList.left, argList.right, refs);
    relBuilder.project(refs, fieldNames);
Also used : LinkedHashSet(java.util.LinkedHashSet) RelBuilder( SqlMinMaxAggFunction( ImmutableBitSet(org.apache.calcite.util.ImmutableBitSet) ArrayList(java.util.ArrayList) SqlCountAggFunction( AggregateCall(org.apache.calcite.rel.core.AggregateCall) RelDataTypeField(org.apache.calcite.rel.type.RelDataTypeField) SqlSumAggFunction( RexInputRef(org.apache.calcite.rex.RexInputRef) ArrayList(java.util.ArrayList) ImmutableList( ImmutableIntList(org.apache.calcite.util.ImmutableIntList) List(java.util.List) Aggregate(org.apache.calcite.rel.core.Aggregate) LogicalAggregate(org.apache.calcite.rel.logical.LogicalAggregate) Pair(org.apache.calcite.util.Pair)

private RexNode reduceStddev(Aggregate oldAggRel, AggregateCall oldCall, boolean biased, boolean sqrt, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping, List<RexNode> inputExprs) {
    // stddev_pop(x) ==>
    //   power(
    //     (sum(x * x) - sum(x) * sum(x) / count(x))
    //     / count(x),
    //     .5)
    // stddev_samp(x) ==>
    //   power(
    //     (sum(x * x) - sum(x) * sum(x) / count(x))
    //     / nullif(count(x) - 1, 0),
    //     .5)
    final PlannerSettings plannerSettings = (PlannerSettings) oldAggRel.getCluster().getPlanner().getContext();
    final boolean isInferenceEnabled = plannerSettings.isTypeInferenceEnabled();
    final int nGroups = oldAggRel.getGroupCount();
    RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory();
    final RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
    assert oldCall.getArgList().size() == 1 : oldCall.getArgList();
    final int argOrdinal = oldCall.getArgList().get(0);
    final RelDataType argType = getFieldType(oldAggRel.getInput(), argOrdinal);
    // final RexNode argRef = inputExprs.get(argOrdinal);
    RexNode argRef = rexBuilder.makeCall(CastHighOp, inputExprs.get(argOrdinal));
    inputExprs.set(argOrdinal, argRef);
    final RexNode argSquared = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, argRef, argRef);
    final int argSquaredOrdinal = lookupOrAdd(inputExprs, argSquared);
    final RelDataType sumType = typeFactory.createTypeWithNullability(argType, true);
    final AggregateCall sumArgSquaredAggCall = AggregateCall.create(new SqlSumAggFunction(sumType), oldCall.isDistinct(), ImmutableIntList.of(argSquaredOrdinal), -1, sumType, null);
    final RexNode sumArgSquared = rexBuilder.addAggCall(sumArgSquaredAggCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(argType));
    final AggregateCall sumArgAggCall = AggregateCall.create(new SqlSumAggFunction(sumType), oldCall.isDistinct(), ImmutableIntList.of(argOrdinal), -1, sumType, null);
    final RexNode sumArg = rexBuilder.addAggCall(sumArgAggCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(argType));
    final RexNode sumSquaredArg = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, sumArg, sumArg);
    final SqlCountAggFunction countAgg = (SqlCountAggFunction) SqlStdOperatorTable.COUNT;
    final RelDataType countType = countAgg.getReturnType(typeFactory);
    final AggregateCall countArgAggCall = AggregateCall.create(countAgg, oldCall.isDistinct(), oldCall.getArgList(), -1, countType, null);
    final RexNode countArg = rexBuilder.addAggCall(countArgAggCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(argType));
    final RexNode avgSumSquaredArg = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, sumSquaredArg, countArg);
    final RexNode diff = rexBuilder.makeCall(SqlStdOperatorTable.MINUS, sumArgSquared, avgSumSquaredArg);
    final RexNode denominator;
    if (biased) {
        denominator = countArg;
    } else {
        final RexLiteral one = rexBuilder.makeExactLiteral(BigDecimal.ONE);
        final RexNode nul = rexBuilder.makeNullLiteral(countArg.getType().getSqlTypeName());
        final RexNode countMinusOne = rexBuilder.makeCall(SqlStdOperatorTable.MINUS, countArg, one);
        final RexNode countEqOne = rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, countArg, one);
        denominator = rexBuilder.makeCall(SqlStdOperatorTable.CASE, countEqOne, nul, countMinusOne);
    final SqlOperator divide;
    if (isInferenceEnabled) {
        divide = new DrillSqlOperator("divide", 2, true, oldCall.getType(), false);
    } else {
        divide = SqlStdOperatorTable.DIVIDE;
    final RexNode div = rexBuilder.makeCall(divide, diff, denominator);
    RexNode result = div;
    if (sqrt) {
        final RexNode half = rexBuilder.makeExactLiteral(new BigDecimal("0.5"));
        result = rexBuilder.makeCall(SqlStdOperatorTable.POWER, div, half);
    if (isInferenceEnabled) {
        return result;
    } else {
      * Currently calcite's strategy to infer the return type of aggregate functions
      * is wrong because it uses the first known argument to determine output type. For
      * instance if we are performing stddev on an integer column then it interprets the
      * output type to be integer which is incorrect as it should be double. So based on
      * this if we add cast after rewriting the aggregate we add an additional cast which
      * would cause wrong results. So we simply add a cast to ANY.
        return rexBuilder.makeCast(typeFactory.createSqlType(SqlTypeName.ANY), result);
Also used : RexLiteral(org.apache.calcite.rex.RexLiteral) PlannerSettings(org.apache.drill.exec.planner.physical.PlannerSettings) DrillSqlOperator(org.apache.drill.exec.planner.sql.DrillSqlOperator) SqlOperator(org.apache.calcite.sql.SqlOperator) RelDataType(org.apache.calcite.rel.type.RelDataType) SqlCountAggFunction( BigDecimal(java.math.BigDecimal) DrillSqlOperator(org.apache.drill.exec.planner.sql.DrillSqlOperator) AggregateCall(org.apache.calcite.rel.core.AggregateCall) RelDataTypeFactory(org.apache.calcite.rel.type.RelDataTypeFactory) SqlSumAggFunction( RexBuilder(org.apache.calcite.rex.RexBuilder) RexNode(org.apache.calcite.rex.RexNode)

private RexNode reduceAgg(Aggregate oldAggRel, AggregateCall oldCall, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping, List<RexNode> inputExprs) {
    final SqlAggFunction sqlAggFunction = DrillCalciteWrapperUtility.extractSqlOperatorFromWrapper(oldCall.getAggregation());
    if (sqlAggFunction instanceof SqlSumAggFunction) {
        // case COUNT(x) when 0 then null else SUM0(x) end
        return reduceSum(oldAggRel, oldCall, newCalls, aggCallMapping);
    if (sqlAggFunction instanceof SqlAvgAggFunction) {
        final SqlAvgAggFunction.Subtype subtype = ((SqlAvgAggFunction) sqlAggFunction).getSubtype();
        switch(subtype) {
            case AVG:
                // replace original AVG(x) with SUM(x) / COUNT(x)
                return reduceAvg(oldAggRel, oldCall, newCalls, aggCallMapping);
            case STDDEV_POP:
                //     / COUNT(x))
                return reduceStddev(oldAggRel, oldCall, true, true, newCalls, aggCallMapping, inputExprs);
            case STDDEV_SAMP:
                //     / CASE COUNT(x) WHEN 1 THEN NULL ELSE COUNT(x) - 1 END)
                return reduceStddev(oldAggRel, oldCall, false, true, newCalls, aggCallMapping, inputExprs);
            case VAR_POP:
                //     / COUNT(x)
                return reduceStddev(oldAggRel, oldCall, true, false, newCalls, aggCallMapping, inputExprs);
            case VAR_SAMP:
                //     / CASE COUNT(x) WHEN 1 THEN NULL ELSE COUNT(x) - 1 END
                return reduceStddev(oldAggRel, oldCall, false, false, newCalls, aggCallMapping, inputExprs);
                throw Util.unexpected(subtype);
    } else {
        // anything else:  preserve original call
        RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
        final int nGroups = oldAggRel.getGroupCount();
        List<RelDataType> oldArgTypes = new ArrayList<>();
        List<Integer> ordinals = oldCall.getArgList();
        assert ordinals.size() <= inputExprs.size();
        for (int ordinal : ordinals) {
        return rexBuilder.addAggCall(oldCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, oldArgTypes);
Also used : SqlAvgAggFunction( SqlSumAggFunction( ArrayList(java.util.ArrayList) RexBuilder(org.apache.calcite.rex.RexBuilder) RelDataType(org.apache.calcite.rel.type.RelDataType) SqlAggFunction(org.apache.calcite.sql.SqlAggFunction)

	 * Converts an aggregate with one distinct aggregate and one or more
	 * non-distinct aggregates to multi-phase aggregates (see reference example
	 * below).
	 * @param relBuilder Contains the input relational expression
	 * @param aggregate  Original aggregate
	 * @param argLists   Arguments and filters to the distinct aggregate function
private RelBuilder convertSingletonDistinct(RelBuilder relBuilder, Aggregate aggregate, Set<Pair<List<Integer>, Integer>> argLists) {
    // For example,
    //	SELECT deptno, COUNT(*), SUM(bonus), MIN(DISTINCT sal)
    //	FROM emp
    //	GROUP BY deptno
    // becomes
    //	SELECT deptno, SUM(cnt), SUM(bonus), MIN(sal)
    //	FROM (
    //		  SELECT deptno, COUNT(*) as cnt, SUM(bonus), sal
    //		  FROM EMP
    //		  GROUP BY deptno, sal)			// Aggregate B
    //	GROUP BY deptno						// Aggregate A
    final List<Pair<RexNode, String>> projects = new ArrayList<>();
    final Map<Integer, Integer> sourceOf = new HashMap<>();
    SortedSet<Integer> newGroupSet = new TreeSet<>();
    final List<RelDataTypeField> childFields = relBuilder.peek().getRowType().getFieldList();
    final boolean hasGroupBy = aggregate.getGroupSet().size() > 0;
    SortedSet<Integer> groupSet = new TreeSet<>(aggregate.getGroupSet().asList());
    // Add the distinct aggregate column(s) to the group-by columns,
    // if not already a part of the group-by
    for (Pair<List<Integer>, Integer> argList : argLists) {
    // transformation.
    for (int arg : newGroupSet) {
        sourceOf.put(arg, projects.size());
        projects.add(RexInputRef.of2(arg, childFields));
    // Generate the intermediate aggregate B
    final List<AggregateCall> aggCalls = aggregate.getAggCallList();
    final List<AggregateCall> newAggCalls = new ArrayList<>();
    final List<Integer> fakeArgs = new ArrayList<>();
    final Map<AggregateCall, Integer> callArgMap = new HashMap<>();
    // e.g. if real arguments are 0, 1, 3. Then the fake arguments will be 2, 4
    for (final AggregateCall aggCall : aggCalls) {
        if (!aggCall.isDistinct()) {
            for (int arg : aggCall.getArgList()) {
                if (!groupSet.contains(arg)) {
                    sourceOf.put(arg, projects.size());
    int fakeArg0 = 0;
    for (final AggregateCall aggCall : aggCalls) {
        // We will deal with non-distinct aggregates below
        if (!aggCall.isDistinct()) {
            boolean isGroupKeyUsedInAgg = false;
            for (int arg : aggCall.getArgList()) {
                if (groupSet.contains(arg)) {
                    isGroupKeyUsedInAgg = true;
            if (aggCall.getArgList().size() == 0 || isGroupKeyUsedInAgg) {
                while (sourceOf.get(fakeArg0) != null) {
    for (final AggregateCall aggCall : aggCalls) {
        if (!aggCall.isDistinct()) {
            for (int arg : aggCall.getArgList()) {
                if (!groupSet.contains(arg)) {
    // Compute the remapped arguments using fake arguments for non-distinct
    // aggregates with no arguments e.g. count(*).
    int fakeArgIdx = 0;
    for (final AggregateCall aggCall : aggCalls) {
        // as-is all the non-distinct aggregates
        if (!aggCall.isDistinct()) {
            final AggregateCall newCall = AggregateCall.create(aggCall.getAggregation(), false, aggCall.getArgList(), -1, ImmutableBitSet.of(newGroupSet).cardinality(), relBuilder.peek(), null,;
            if (newCall.getArgList().size() == 0) {
                int fakeArg = fakeArgs.get(fakeArgIdx);
                callArgMap.put(newCall, fakeArg);
                sourceOf.put(fakeArg, projects.size());
                projects.add(Pair.of((RexNode) new RexInputRef(fakeArg, newCall.getType()), newCall.getName()));
            } else {
                for (int arg : newCall.getArgList()) {
                    if (groupSet.contains(arg)) {
                        int fakeArg = fakeArgs.get(fakeArgIdx);
                        callArgMap.put(newCall, fakeArg);
                        sourceOf.put(fakeArg, projects.size());
                        projects.add(Pair.of((RexNode) new RexInputRef(fakeArg, newCall.getType()), newCall.getName()));
                    } else {
                        sourceOf.put(arg, projects.size());
                        projects.add(Pair.of((RexNode) new RexInputRef(arg, newCall.getType()), newCall.getName()));
    // Generate the aggregate B (see the reference example above)
    relBuilder.push(aggregate.copy(aggregate.getTraitSet(),, false, ImmutableBitSet.of(newGroupSet), null, newAggCalls));
    // Convert the existing aggregate to aggregate A (see the reference example above)
    final List<AggregateCall> newTopAggCalls = Lists.newArrayList(aggregate.getAggCallList());
    // Use the remapped arguments for the (non)distinct aggregate calls
    for (int i = 0; i < newTopAggCalls.size(); i++) {
        // Re-map arguments.
        final AggregateCall aggCall = newTopAggCalls.get(i);
        final int argCount = aggCall.getArgList().size();
        final List<Integer> newArgs = new ArrayList<>(argCount);
        final AggregateCall newCall;
        for (int j = 0; j < argCount; j++) {
            final Integer arg = aggCall.getArgList().get(j);
            if (callArgMap.containsKey(aggCall)) {
            } else {
        if (aggCall.isDistinct()) {
            newCall = AggregateCall.create(aggCall.getAggregation(), false, newArgs, -1, aggregate.getGroupSet().cardinality(), relBuilder.peek(), aggCall.getType(),;
        } else {
            // aggregate A must be SUM. For other aggregates, it remains the same.
            if (aggCall.getAggregation() instanceof SqlCountAggFunction) {
                if (aggCall.getArgList().size() == 0) {
                if (hasGroupBy) {
                    SqlSumAggFunction sumAgg = new SqlSumAggFunction(null);
                    newCall = AggregateCall.create(sumAgg, false, newArgs, -1, aggregate.getGroupSet().cardinality(), relBuilder.peek(), aggCall.getType(), aggCall.getName());
                } else {
                    SqlSumEmptyIsZeroAggFunction sumAgg = new SqlSumEmptyIsZeroAggFunction();
                    newCall = AggregateCall.create(sumAgg, false, newArgs, -1, aggregate.getGroupSet().cardinality(), relBuilder.peek(), aggCall.getType(), aggCall.getName());
            } else {
                newCall = AggregateCall.create(aggCall.getAggregation(), false, newArgs, -1, aggregate.getGroupSet().cardinality(), relBuilder.peek(), aggCall.getType(),;
        newTopAggCalls.set(i, newCall);
    // Populate the group-by keys with the remapped arguments for aggregate A
    for (int arg : aggregate.getGroupSet()) {
    relBuilder.push(aggregate.copy(aggregate.getTraitSet(),, aggregate.indicator, ImmutableBitSet.of(newGroupSet), null, newTopAggCalls));
    return relBuilder;
Also used : HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) SqlCountAggFunction( AggregateCall(org.apache.calcite.rel.core.AggregateCall) SqlSumEmptyIsZeroAggFunction( RelDataTypeField(org.apache.calcite.rel.type.RelDataTypeField) TreeSet(java.util.TreeSet) SqlSumAggFunction( RexInputRef(org.apache.calcite.rex.RexInputRef) ArrayList(java.util.ArrayList) ImmutableList( ImmutableIntList(org.apache.calcite.util.ImmutableIntList) List(java.util.List) Pair(org.apache.calcite.util.Pair) RexNode(org.apache.calcite.rex.RexNode)


