Search in sources :

Example 6 with SqlSumEmptyIsZeroAggFunction

use of org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction in project calcite by apache.

the class AggregateExpandDistinctAggregatesRule method convertSingletonDistinct.

/**
 * Converts an aggregate with one distinct aggregate and one or more
 * non-distinct aggregates to multi-phase aggregates (see reference example
 * below).
 *
 * @param relBuilder Contains the input relational expression
 * @param aggregate  Original aggregate
 * @param argLists   Arguments and filters to the distinct aggregate function
 */
private RelBuilder convertSingletonDistinct(RelBuilder relBuilder, Aggregate aggregate, Set<Pair<List<Integer>, Integer>> argLists) {
    // In this case, we are assuming that there is a single distinct function.
    // So make sure that argLists is of size one.
    Preconditions.checkArgument(argLists.size() == 1);
    // For example,
    // SELECT deptno, COUNT(*), SUM(bonus), MIN(DISTINCT sal)
    // FROM emp
    // GROUP BY deptno
    // 
    // becomes
    // 
    // SELECT deptno, SUM(cnt), SUM(bonus), MIN(sal)
    // FROM (
    // SELECT deptno, COUNT(*) as cnt, SUM(bonus), sal
    // FROM EMP
    // GROUP BY deptno, sal)            // Aggregate B
    // GROUP BY deptno                        // Aggregate A
    relBuilder.push(aggregate.getInput());
    final List<AggregateCall> originalAggCalls = aggregate.getAggCallList();
    final ImmutableBitSet originalGroupSet = aggregate.getGroupSet();
    // Add the distinct aggregate column(s) to the group-by columns,
    // if not already a part of the group-by
    final SortedSet<Integer> bottomGroupSet = new TreeSet<>();
    bottomGroupSet.addAll(aggregate.getGroupSet().asList());
    for (AggregateCall aggCall : originalAggCalls) {
        if (aggCall.isDistinct()) {
            bottomGroupSet.addAll(aggCall.getArgList());
            // since we only have single distinct call
            break;
        }
    }
    // Generate the intermediate aggregate B, the one on the bottom that converts
    // a distinct call to group by call.
    // Bottom aggregate is the same as the original aggregate, except that
    // the bottom aggregate has converted the DISTINCT aggregate to a group by clause.
    final List<AggregateCall> bottomAggregateCalls = new ArrayList<>();
    for (AggregateCall aggCall : originalAggCalls) {
        // as-is all the non-distinct aggregates
        if (!aggCall.isDistinct()) {
            final AggregateCall newCall = AggregateCall.create(aggCall.getAggregation(), false, aggCall.isApproximate(), aggCall.getArgList(), -1, ImmutableBitSet.of(bottomGroupSet).cardinality(), relBuilder.peek(), null, aggCall.name);
            bottomAggregateCalls.add(newCall);
        }
    }
    // Generate the aggregate B (see the reference example above)
    relBuilder.push(aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), false, ImmutableBitSet.of(bottomGroupSet), null, bottomAggregateCalls));
    // Add aggregate A (see the reference example above), the top aggregate
    // to handle the rest of the aggregation that the bottom aggregate hasn't handled
    final List<AggregateCall> topAggregateCalls = Lists.newArrayList();
    // Use the remapped arguments for the (non)distinct aggregate calls
    int nonDistinctAggCallProcessedSoFar = 0;
    for (AggregateCall aggCall : originalAggCalls) {
        final AggregateCall newCall;
        if (aggCall.isDistinct()) {
            List<Integer> newArgList = new ArrayList<>();
            for (int arg : aggCall.getArgList()) {
                newArgList.add(bottomGroupSet.headSet(arg).size());
            }
            newCall = AggregateCall.create(aggCall.getAggregation(), false, aggCall.isApproximate(), newArgList, -1, originalGroupSet.cardinality(), relBuilder.peek(), aggCall.getType(), aggCall.name);
        } else {
            // If aggregate B had a COUNT aggregate call the corresponding aggregate at
            // aggregate A must be SUM. For other aggregates, it remains the same.
            final List<Integer> newArgs = Lists.newArrayList(bottomGroupSet.size() + nonDistinctAggCallProcessedSoFar);
            if (aggCall.getAggregation().getKind() == SqlKind.COUNT) {
                newCall = AggregateCall.create(new SqlSumEmptyIsZeroAggFunction(), false, aggCall.isApproximate(), newArgs, -1, originalGroupSet.cardinality(), relBuilder.peek(), aggCall.getType(), aggCall.getName());
            } else {
                newCall = AggregateCall.create(aggCall.getAggregation(), false, aggCall.isApproximate(), newArgs, -1, originalGroupSet.cardinality(), relBuilder.peek(), aggCall.getType(), aggCall.name);
            }
            nonDistinctAggCallProcessedSoFar++;
        }
        topAggregateCalls.add(newCall);
    }
    // Populate the group-by keys with the remapped arguments for aggregate A
    // The top groupset is basically an identity (first X fields of aggregate B's
    // output), minus the distinct aggCall's input.
    final Set<Integer> topGroupSet = new HashSet<>();
    int groupSetToAdd = 0;
    for (int bottomGroup : bottomGroupSet) {
        if (originalGroupSet.get(bottomGroup)) {
            topGroupSet.add(groupSetToAdd);
        }
        groupSetToAdd++;
    }
    relBuilder.push(aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), aggregate.indicator, ImmutableBitSet.of(topGroupSet), null, topAggregateCalls));
    return relBuilder;
}
Also used : AggregateCall(org.apache.calcite.rel.core.AggregateCall) SqlSumEmptyIsZeroAggFunction(org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction) ImmutableBitSet(org.apache.calcite.util.ImmutableBitSet) TreeSet(java.util.TreeSet) ArrayList(java.util.ArrayList) HashSet(java.util.HashSet) LinkedHashSet(java.util.LinkedHashSet)

Example 7 with SqlSumEmptyIsZeroAggFunction

use of org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction in project drill by apache.

the class DrillReduceAggregatesRule method reduceAvg.

private RexNode reduceAvg(Aggregate oldAggRel, AggregateCall oldCall, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping) {
    final PlannerSettings plannerSettings = (PlannerSettings) oldAggRel.getCluster().getPlanner().getContext();
    final boolean isInferenceEnabled = plannerSettings.isTypeInferenceEnabled();
    final int nGroups = oldAggRel.getGroupCount();
    RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory();
    RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
    int iAvgInput = oldCall.getArgList().get(0);
    RelDataType avgInputType = getFieldType(oldAggRel.getInput(), iAvgInput);
    RelDataType sumType = TypeInferenceUtils.getDrillSqlReturnTypeInference(SqlKind.SUM.name(), ImmutableList.of()).inferReturnType(oldCall.createBinding(oldAggRel));
    sumType = typeFactory.createTypeWithNullability(sumType, sumType.isNullable() || nGroups == 0);
    SqlAggFunction sumAgg = new DrillCalciteSqlAggFunctionWrapper(new SqlSumEmptyIsZeroAggFunction(), sumType);
    AggregateCall sumCall = AggregateCall.create(sumAgg, oldCall.isDistinct(), oldCall.isApproximate(), oldCall.getArgList(), -1, sumType, null);
    final SqlCountAggFunction countAgg = (SqlCountAggFunction) SqlStdOperatorTable.COUNT;
    final RelDataType countType = countAgg.getReturnType(typeFactory);
    AggregateCall countCall = AggregateCall.create(countAgg, oldCall.isDistinct(), oldCall.isApproximate(), oldCall.getArgList(), -1, countType, null);
    RexNode tmpsumRef = rexBuilder.addAggCall(sumCall, nGroups, newCalls, aggCallMapping, ImmutableList.of(avgInputType));
    RexNode tmpcountRef = rexBuilder.addAggCall(countCall, nGroups, newCalls, aggCallMapping, ImmutableList.of(avgInputType));
    RexNode n = rexBuilder.makeCall(SqlStdOperatorTable.CASE, rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, tmpcountRef, rexBuilder.makeExactLiteral(BigDecimal.ZERO)), rexBuilder.constantNull(), tmpsumRef);
    // NOTE:  these references are with respect to the output
    // of newAggRel
    /*
    RexNode numeratorRef =
        rexBuilder.makeCall(CastHighOp,
          rexBuilder.addAggCall(
              sumCall,
              nGroups,
              newCalls,
              aggCallMapping,
              ImmutableList.of(avgInputType))
        );
    */
    RexNode numeratorRef = rexBuilder.makeCall(CastHighOp, n);
    RexNode denominatorRef = rexBuilder.addAggCall(countCall, nGroups, newCalls, aggCallMapping, ImmutableList.of(avgInputType));
    if (isInferenceEnabled) {
        return rexBuilder.makeCall(new DrillSqlOperator("divide", 2, true, oldCall.getType(), false), numeratorRef, denominatorRef);
    } else {
        final RexNode divideRef = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, numeratorRef, denominatorRef);
        return rexBuilder.makeCast(typeFactory.createSqlType(SqlTypeName.ANY), divideRef);
    }
}
Also used : PlannerSettings(org.apache.drill.exec.planner.physical.PlannerSettings) RelDataType(org.apache.calcite.rel.type.RelDataType) SqlAggFunction(org.apache.calcite.sql.SqlAggFunction) SqlCountAggFunction(org.apache.calcite.sql.fun.SqlCountAggFunction) DrillCalciteSqlAggFunctionWrapper(org.apache.drill.exec.planner.sql.DrillCalciteSqlAggFunctionWrapper) DrillSqlOperator(org.apache.drill.exec.planner.sql.DrillSqlOperator) SqlSumEmptyIsZeroAggFunction(org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction) AggregateCall(org.apache.calcite.rel.core.AggregateCall) RelDataTypeFactory(org.apache.calcite.rel.type.RelDataTypeFactory) RexBuilder(org.apache.calcite.rex.RexBuilder) RexNode(org.apache.calcite.rex.RexNode)

Aggregations

AggregateCall (org.apache.calcite.rel.core.AggregateCall)7 SqlSumEmptyIsZeroAggFunction (org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction)7 RexNode (org.apache.calcite.rex.RexNode)5 SqlCountAggFunction (org.apache.calcite.sql.fun.SqlCountAggFunction)5 RelDataType (org.apache.calcite.rel.type.RelDataType)4 RelDataTypeFactory (org.apache.calcite.rel.type.RelDataTypeFactory)4 RexBuilder (org.apache.calcite.rex.RexBuilder)4 SqlAggFunction (org.apache.calcite.sql.SqlAggFunction)4 PlannerSettings (org.apache.drill.exec.planner.physical.PlannerSettings)4 DrillCalciteSqlAggFunctionWrapper (org.apache.drill.exec.planner.sql.DrillCalciteSqlAggFunctionWrapper)4 ArrayList (java.util.ArrayList)3 TreeSet (java.util.TreeSet)3 HashSet (java.util.HashSet)2 LinkedHashSet (java.util.LinkedHashSet)2 ImmutableBitSet (org.apache.calcite.util.ImmutableBitSet)2 DrillSqlOperator (org.apache.drill.exec.planner.sql.DrillSqlOperator)2 ImmutableList (com.google.common.collect.ImmutableList)1 HashMap (java.util.HashMap)1 List (java.util.List)1 RelDataTypeField (org.apache.calcite.rel.type.RelDataTypeField)1