use of org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction in project calcite by apache.
the class AggregateExpandDistinctAggregatesRule method convertSingletonDistinct.
/**
* Converts an aggregate with one distinct aggregate and one or more
* non-distinct aggregates to multi-phase aggregates (see reference example
* below).
*
* @param relBuilder Contains the input relational expression
* @param aggregate Original aggregate
* @param argLists Arguments and filters to the distinct aggregate function
*/
private RelBuilder convertSingletonDistinct(RelBuilder relBuilder, Aggregate aggregate, Set<Pair<List<Integer>, Integer>> argLists) {
// In this case, we are assuming that there is a single distinct function.
// So make sure that argLists is of size one.
Preconditions.checkArgument(argLists.size() == 1);
// For example,
// SELECT deptno, COUNT(*), SUM(bonus), MIN(DISTINCT sal)
// FROM emp
// GROUP BY deptno
//
// becomes
//
// SELECT deptno, SUM(cnt), SUM(bonus), MIN(sal)
// FROM (
// SELECT deptno, COUNT(*) as cnt, SUM(bonus), sal
// FROM EMP
// GROUP BY deptno, sal) // Aggregate B
// GROUP BY deptno // Aggregate A
relBuilder.push(aggregate.getInput());
final List<AggregateCall> originalAggCalls = aggregate.getAggCallList();
final ImmutableBitSet originalGroupSet = aggregate.getGroupSet();
// Add the distinct aggregate column(s) to the group-by columns,
// if not already a part of the group-by
final SortedSet<Integer> bottomGroupSet = new TreeSet<>();
bottomGroupSet.addAll(aggregate.getGroupSet().asList());
for (AggregateCall aggCall : originalAggCalls) {
if (aggCall.isDistinct()) {
bottomGroupSet.addAll(aggCall.getArgList());
// since we only have single distinct call
break;
}
}
// Generate the intermediate aggregate B, the one on the bottom that converts
// a distinct call to group by call.
// Bottom aggregate is the same as the original aggregate, except that
// the bottom aggregate has converted the DISTINCT aggregate to a group by clause.
final List<AggregateCall> bottomAggregateCalls = new ArrayList<>();
for (AggregateCall aggCall : originalAggCalls) {
// as-is all the non-distinct aggregates
if (!aggCall.isDistinct()) {
final AggregateCall newCall = AggregateCall.create(aggCall.getAggregation(), false, aggCall.isApproximate(), aggCall.getArgList(), -1, ImmutableBitSet.of(bottomGroupSet).cardinality(), relBuilder.peek(), null, aggCall.name);
bottomAggregateCalls.add(newCall);
}
}
// Generate the aggregate B (see the reference example above)
relBuilder.push(aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), false, ImmutableBitSet.of(bottomGroupSet), null, bottomAggregateCalls));
// Add aggregate A (see the reference example above), the top aggregate
// to handle the rest of the aggregation that the bottom aggregate hasn't handled
final List<AggregateCall> topAggregateCalls = Lists.newArrayList();
// Use the remapped arguments for the (non)distinct aggregate calls
int nonDistinctAggCallProcessedSoFar = 0;
for (AggregateCall aggCall : originalAggCalls) {
final AggregateCall newCall;
if (aggCall.isDistinct()) {
List<Integer> newArgList = new ArrayList<>();
for (int arg : aggCall.getArgList()) {
newArgList.add(bottomGroupSet.headSet(arg).size());
}
newCall = AggregateCall.create(aggCall.getAggregation(), false, aggCall.isApproximate(), newArgList, -1, originalGroupSet.cardinality(), relBuilder.peek(), aggCall.getType(), aggCall.name);
} else {
// If aggregate B had a COUNT aggregate call the corresponding aggregate at
// aggregate A must be SUM. For other aggregates, it remains the same.
final List<Integer> newArgs = Lists.newArrayList(bottomGroupSet.size() + nonDistinctAggCallProcessedSoFar);
if (aggCall.getAggregation().getKind() == SqlKind.COUNT) {
newCall = AggregateCall.create(new SqlSumEmptyIsZeroAggFunction(), false, aggCall.isApproximate(), newArgs, -1, originalGroupSet.cardinality(), relBuilder.peek(), aggCall.getType(), aggCall.getName());
} else {
newCall = AggregateCall.create(aggCall.getAggregation(), false, aggCall.isApproximate(), newArgs, -1, originalGroupSet.cardinality(), relBuilder.peek(), aggCall.getType(), aggCall.name);
}
nonDistinctAggCallProcessedSoFar++;
}
topAggregateCalls.add(newCall);
}
// Populate the group-by keys with the remapped arguments for aggregate A
// The top groupset is basically an identity (first X fields of aggregate B's
// output), minus the distinct aggCall's input.
final Set<Integer> topGroupSet = new HashSet<>();
int groupSetToAdd = 0;
for (int bottomGroup : bottomGroupSet) {
if (originalGroupSet.get(bottomGroup)) {
topGroupSet.add(groupSetToAdd);
}
groupSetToAdd++;
}
relBuilder.push(aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), aggregate.indicator, ImmutableBitSet.of(topGroupSet), null, topAggregateCalls));
return relBuilder;
}
use of org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction in project drill by apache.
the class DrillReduceAggregatesRule method reduceAvg.
private RexNode reduceAvg(Aggregate oldAggRel, AggregateCall oldCall, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping) {
final PlannerSettings plannerSettings = (PlannerSettings) oldAggRel.getCluster().getPlanner().getContext();
final boolean isInferenceEnabled = plannerSettings.isTypeInferenceEnabled();
final int nGroups = oldAggRel.getGroupCount();
RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory();
RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
int iAvgInput = oldCall.getArgList().get(0);
RelDataType avgInputType = getFieldType(oldAggRel.getInput(), iAvgInput);
RelDataType sumType = TypeInferenceUtils.getDrillSqlReturnTypeInference(SqlKind.SUM.name(), ImmutableList.of()).inferReturnType(oldCall.createBinding(oldAggRel));
sumType = typeFactory.createTypeWithNullability(sumType, sumType.isNullable() || nGroups == 0);
SqlAggFunction sumAgg = new DrillCalciteSqlAggFunctionWrapper(new SqlSumEmptyIsZeroAggFunction(), sumType);
AggregateCall sumCall = AggregateCall.create(sumAgg, oldCall.isDistinct(), oldCall.isApproximate(), oldCall.getArgList(), -1, sumType, null);
final SqlCountAggFunction countAgg = (SqlCountAggFunction) SqlStdOperatorTable.COUNT;
final RelDataType countType = countAgg.getReturnType(typeFactory);
AggregateCall countCall = AggregateCall.create(countAgg, oldCall.isDistinct(), oldCall.isApproximate(), oldCall.getArgList(), -1, countType, null);
RexNode tmpsumRef = rexBuilder.addAggCall(sumCall, nGroups, newCalls, aggCallMapping, ImmutableList.of(avgInputType));
RexNode tmpcountRef = rexBuilder.addAggCall(countCall, nGroups, newCalls, aggCallMapping, ImmutableList.of(avgInputType));
RexNode n = rexBuilder.makeCall(SqlStdOperatorTable.CASE, rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, tmpcountRef, rexBuilder.makeExactLiteral(BigDecimal.ZERO)), rexBuilder.constantNull(), tmpsumRef);
// NOTE: these references are with respect to the output
// of newAggRel
/*
RexNode numeratorRef =
rexBuilder.makeCall(CastHighOp,
rexBuilder.addAggCall(
sumCall,
nGroups,
newCalls,
aggCallMapping,
ImmutableList.of(avgInputType))
);
*/
RexNode numeratorRef = rexBuilder.makeCall(CastHighOp, n);
RexNode denominatorRef = rexBuilder.addAggCall(countCall, nGroups, newCalls, aggCallMapping, ImmutableList.of(avgInputType));
if (isInferenceEnabled) {
return rexBuilder.makeCall(new DrillSqlOperator("divide", 2, true, oldCall.getType(), false), numeratorRef, denominatorRef);
} else {
final RexNode divideRef = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, numeratorRef, denominatorRef);
return rexBuilder.makeCast(typeFactory.createSqlType(SqlTypeName.ANY), divideRef);
}
}
Aggregations