use of org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction in project drill by axbaretto.
the class DrillReduceAggregatesRule method reduceAvg.
private RexNode reduceAvg(Aggregate oldAggRel, AggregateCall oldCall, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping) {
final PlannerSettings plannerSettings = (PlannerSettings) oldAggRel.getCluster().getPlanner().getContext();
final boolean isInferenceEnabled = plannerSettings.isTypeInferenceEnabled();
final int nGroups = oldAggRel.getGroupCount();
RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory();
RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
int iAvgInput = oldCall.getArgList().get(0);
RelDataType avgInputType = getFieldType(oldAggRel.getInput(), iAvgInput);
RelDataType sumType = TypeInferenceUtils.getDrillSqlReturnTypeInference(SqlKind.SUM.name(), ImmutableList.<DrillFuncHolder>of()).inferReturnType(oldCall.createBinding(oldAggRel));
sumType = typeFactory.createTypeWithNullability(sumType, sumType.isNullable() || nGroups == 0);
SqlAggFunction sumAgg = new DrillCalciteSqlAggFunctionWrapper(new SqlSumEmptyIsZeroAggFunction(), sumType);
AggregateCall sumCall = AggregateCall.create(sumAgg, oldCall.isDistinct(), oldCall.isApproximate(), oldCall.getArgList(), -1, sumType, null);
final SqlCountAggFunction countAgg = (SqlCountAggFunction) SqlStdOperatorTable.COUNT;
final RelDataType countType = countAgg.getReturnType(typeFactory);
AggregateCall countCall = AggregateCall.create(countAgg, oldCall.isDistinct(), oldCall.isApproximate(), oldCall.getArgList(), -1, countType, null);
RexNode tmpsumRef = rexBuilder.addAggCall(sumCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(avgInputType));
RexNode tmpcountRef = rexBuilder.addAggCall(countCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(avgInputType));
RexNode n = rexBuilder.makeCall(SqlStdOperatorTable.CASE, rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, tmpcountRef, rexBuilder.makeExactLiteral(BigDecimal.ZERO)), rexBuilder.constantNull(), tmpsumRef);
// NOTE: these references are with respect to the output
// of newAggRel
/*
RexNode numeratorRef =
rexBuilder.makeCall(CastHighOp,
rexBuilder.addAggCall(
sumCall,
nGroups,
newCalls,
aggCallMapping,
ImmutableList.of(avgInputType))
);
*/
RexNode numeratorRef = rexBuilder.makeCall(CastHighOp, n);
RexNode denominatorRef = rexBuilder.addAggCall(countCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(avgInputType));
if (isInferenceEnabled) {
return rexBuilder.makeCall(new DrillSqlOperator("divide", 2, true, oldCall.getType(), false), numeratorRef, denominatorRef);
} else {
final RexNode divideRef = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, numeratorRef, denominatorRef);
return rexBuilder.makeCast(typeFactory.createSqlType(SqlTypeName.ANY), divideRef);
}
}
use of org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction in project drill by apache.
the class DrillReduceAggregatesRule method reduceSum.
private RexNode reduceSum(Aggregate oldAggRel, AggregateCall oldCall, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping) {
final PlannerSettings plannerSettings = (PlannerSettings) oldAggRel.getCluster().getPlanner().getContext();
final boolean isInferenceEnabled = plannerSettings.isTypeInferenceEnabled();
final int nGroups = oldAggRel.getGroupCount();
RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory();
RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
int arg = oldCall.getArgList().get(0);
RelDataType argType = getFieldType(oldAggRel.getInput(), arg);
final RelDataType sumType;
final SqlAggFunction sumZeroAgg;
if (isInferenceEnabled) {
sumType = oldCall.getType();
} else {
sumType = typeFactory.createTypeWithNullability(oldCall.getType(), argType.isNullable());
}
sumZeroAgg = new DrillCalciteSqlAggFunctionWrapper(new SqlSumEmptyIsZeroAggFunction(), sumType);
AggregateCall sumZeroCall = AggregateCall.create(sumZeroAgg, oldCall.isDistinct(), oldCall.isApproximate(), oldCall.getArgList(), -1, sumType, null);
final SqlCountAggFunction countAgg = (SqlCountAggFunction) SqlStdOperatorTable.COUNT;
final RelDataType countType = countAgg.getReturnType(typeFactory);
AggregateCall countCall = AggregateCall.create(countAgg, oldCall.isDistinct(), oldCall.isApproximate(), oldCall.getArgList(), -1, countType, null);
// NOTE: these references are with respect to the output
// of newAggRel
RexNode sumZeroRef = rexBuilder.addAggCall(sumZeroCall, nGroups, newCalls, aggCallMapping, ImmutableList.of(argType));
if (!oldCall.getType().isNullable()) {
// null). Therefore we translate to SUM0(x).
return sumZeroRef;
}
RexNode countRef = rexBuilder.addAggCall(countCall, nGroups, newCalls, aggCallMapping, ImmutableList.of(argType));
return rexBuilder.makeCall(SqlStdOperatorTable.CASE, rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, countRef, rexBuilder.makeExactLiteral(BigDecimal.ZERO)), rexBuilder.constantNull(), sumZeroRef);
}
use of org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction in project flink by apache.
the class FlinkAggregateExpandDistinctAggregatesRule method convertSingletonDistinct.
/**
* Converts an aggregate with one distinct aggregate and one or more non-distinct aggregates to
* multi-phase aggregates (see reference example below).
*
* @param relBuilder Contains the input relational expression
* @param aggregate Original aggregate
* @param argLists Arguments and filters to the distinct aggregate function
*/
private RelBuilder convertSingletonDistinct(RelBuilder relBuilder, Aggregate aggregate, Set<Pair<List<Integer>, Integer>> argLists) {
// In this case, we are assuming that there is a single distinct function.
// So make sure that argLists is of size one.
Preconditions.checkArgument(argLists.size() == 1);
// For example,
// SELECT deptno, COUNT(*), SUM(bonus), MIN(DISTINCT sal)
// FROM emp
// GROUP BY deptno
//
// becomes
//
// SELECT deptno, SUM(cnt), SUM(bonus), MIN(sal)
// FROM (
// SELECT deptno, COUNT(*) as cnt, SUM(bonus), sal
// FROM EMP
// GROUP BY deptno, sal) // Aggregate B
// GROUP BY deptno // Aggregate A
relBuilder.push(aggregate.getInput());
final List<AggregateCall> originalAggCalls = aggregate.getAggCallList();
final ImmutableBitSet originalGroupSet = aggregate.getGroupSet();
// Add the distinct aggregate column(s) to the group-by columns,
// if not already a part of the group-by
final SortedSet<Integer> bottomGroupSet = new TreeSet<>();
bottomGroupSet.addAll(aggregate.getGroupSet().asList());
for (AggregateCall aggCall : originalAggCalls) {
if (aggCall.isDistinct()) {
bottomGroupSet.addAll(aggCall.getArgList());
// since we only have single distinct call
break;
}
}
// Generate the intermediate aggregate B, the one on the bottom that converts
// a distinct call to group by call.
// Bottom aggregate is the same as the original aggregate, except that
// the bottom aggregate has converted the DISTINCT aggregate to a group by clause.
final List<AggregateCall> bottomAggregateCalls = new ArrayList<>();
for (AggregateCall aggCall : originalAggCalls) {
// as-is all the non-distinct aggregates
if (!aggCall.isDistinct()) {
final AggregateCall newCall = AggregateCall.create(aggCall.getAggregation(), false, aggCall.isApproximate(), false, aggCall.getArgList(), -1, RelCollations.EMPTY, ImmutableBitSet.of(bottomGroupSet).cardinality(), relBuilder.peek(), null, aggCall.name);
bottomAggregateCalls.add(newCall);
}
}
// Generate the aggregate B (see the reference example above)
relBuilder.push(aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), ImmutableBitSet.of(bottomGroupSet), null, bottomAggregateCalls));
// Add aggregate A (see the reference example above), the top aggregate
// to handle the rest of the aggregation that the bottom aggregate hasn't handled
final List<AggregateCall> topAggregateCalls = com.google.common.collect.Lists.newArrayList();
// Use the remapped arguments for the (non)distinct aggregate calls
int nonDistinctAggCallProcessedSoFar = 0;
for (AggregateCall aggCall : originalAggCalls) {
final AggregateCall newCall;
if (aggCall.isDistinct()) {
List<Integer> newArgList = new ArrayList<>();
for (int arg : aggCall.getArgList()) {
newArgList.add(bottomGroupSet.headSet(arg).size());
}
newCall = AggregateCall.create(aggCall.getAggregation(), false, aggCall.isApproximate(), false, newArgList, -1, RelCollations.EMPTY, originalGroupSet.cardinality(), relBuilder.peek(), aggCall.getType(), aggCall.name);
} else {
// If aggregate B had a COUNT aggregate call the corresponding aggregate at
// aggregate A must be SUM. For other aggregates, it remains the same.
final List<Integer> newArgs = com.google.common.collect.Lists.newArrayList(bottomGroupSet.size() + nonDistinctAggCallProcessedSoFar);
if (aggCall.getAggregation().getKind() == SqlKind.COUNT) {
newCall = AggregateCall.create(new SqlSumEmptyIsZeroAggFunction(), false, aggCall.isApproximate(), false, newArgs, -1, RelCollations.EMPTY, originalGroupSet.cardinality(), relBuilder.peek(), aggCall.getType(), aggCall.getName());
} else {
newCall = AggregateCall.create(aggCall.getAggregation(), false, aggCall.isApproximate(), false, newArgs, -1, RelCollations.EMPTY, originalGroupSet.cardinality(), relBuilder.peek(), aggCall.getType(), aggCall.name);
}
nonDistinctAggCallProcessedSoFar++;
}
topAggregateCalls.add(newCall);
}
// Populate the group-by keys with the remapped arguments for aggregate A
// The top groupset is basically an identity (first X fields of aggregate B's
// output), minus the distinct aggCall's input.
final Set<Integer> topGroupSet = new HashSet<>();
int groupSetToAdd = 0;
for (int bottomGroup : bottomGroupSet) {
if (originalGroupSet.get(bottomGroup)) {
topGroupSet.add(groupSetToAdd);
}
groupSetToAdd++;
}
relBuilder.push(aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), ImmutableBitSet.of(topGroupSet), null, topAggregateCalls));
return relBuilder;
}
use of org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction in project flink by apache.
the class FlinkAggregateExpandDistinctAggregatesRule method convertSingletonDistinct.
/**
* Converts an aggregate with one distinct aggregate and one or more
* non-distinct aggregates to multi-phase aggregates (see reference example
* below).
*
* @param relBuilder Contains the input relational expression
* @param aggregate Original aggregate
* @param argLists Arguments and filters to the distinct aggregate function
*
*/
private RelBuilder convertSingletonDistinct(RelBuilder relBuilder, Aggregate aggregate, Set<Pair<List<Integer>, Integer>> argLists) {
// For example,
// SELECT deptno, COUNT(*), SUM(bonus), MIN(DISTINCT sal)
// FROM emp
// GROUP BY deptno
//
// becomes
//
// SELECT deptno, SUM(cnt), SUM(bonus), MIN(sal)
// FROM (
// SELECT deptno, COUNT(*) as cnt, SUM(bonus), sal
// FROM EMP
// GROUP BY deptno, sal) // Aggregate B
// GROUP BY deptno // Aggregate A
relBuilder.push(aggregate.getInput());
final List<Pair<RexNode, String>> projects = new ArrayList<>();
final Map<Integer, Integer> sourceOf = new HashMap<>();
SortedSet<Integer> newGroupSet = new TreeSet<>();
final List<RelDataTypeField> childFields = relBuilder.peek().getRowType().getFieldList();
final boolean hasGroupBy = aggregate.getGroupSet().size() > 0;
SortedSet<Integer> groupSet = new TreeSet<>(aggregate.getGroupSet().asList());
// Add the distinct aggregate column(s) to the group-by columns,
// if not already a part of the group-by
newGroupSet.addAll(aggregate.getGroupSet().asList());
for (Pair<List<Integer>, Integer> argList : argLists) {
newGroupSet.addAll(argList.getKey());
}
// transformation.
for (int arg : newGroupSet) {
sourceOf.put(arg, projects.size());
projects.add(RexInputRef.of2(arg, childFields));
}
// Generate the intermediate aggregate B
final List<AggregateCall> aggCalls = aggregate.getAggCallList();
final List<AggregateCall> newAggCalls = new ArrayList<>();
final List<Integer> fakeArgs = new ArrayList<>();
final Map<AggregateCall, Integer> callArgMap = new HashMap<>();
// e.g. if real arguments are 0, 1, 3. Then the fake arguments will be 2, 4
for (final AggregateCall aggCall : aggCalls) {
if (!aggCall.isDistinct()) {
for (int arg : aggCall.getArgList()) {
if (!groupSet.contains(arg)) {
sourceOf.put(arg, projects.size());
}
}
}
}
int fakeArg0 = 0;
for (final AggregateCall aggCall : aggCalls) {
// We will deal with non-distinct aggregates below
if (!aggCall.isDistinct()) {
boolean isGroupKeyUsedInAgg = false;
for (int arg : aggCall.getArgList()) {
if (groupSet.contains(arg)) {
isGroupKeyUsedInAgg = true;
break;
}
}
if (aggCall.getArgList().size() == 0 || isGroupKeyUsedInAgg) {
while (sourceOf.get(fakeArg0) != null) {
++fakeArg0;
}
fakeArgs.add(fakeArg0);
++fakeArg0;
}
}
}
for (final AggregateCall aggCall : aggCalls) {
if (!aggCall.isDistinct()) {
for (int arg : aggCall.getArgList()) {
if (!groupSet.contains(arg)) {
sourceOf.remove(arg);
}
}
}
}
// Compute the remapped arguments using fake arguments for non-distinct
// aggregates with no arguments e.g. count(*).
int fakeArgIdx = 0;
for (final AggregateCall aggCall : aggCalls) {
// as-is all the non-distinct aggregates
if (!aggCall.isDistinct()) {
final AggregateCall newCall = AggregateCall.create(aggCall.getAggregation(), false, aggCall.getArgList(), -1, ImmutableBitSet.of(newGroupSet).cardinality(), relBuilder.peek(), null, aggCall.name);
newAggCalls.add(newCall);
if (newCall.getArgList().size() == 0) {
int fakeArg = fakeArgs.get(fakeArgIdx);
callArgMap.put(newCall, fakeArg);
sourceOf.put(fakeArg, projects.size());
projects.add(Pair.of((RexNode) new RexInputRef(fakeArg, newCall.getType()), newCall.getName()));
++fakeArgIdx;
} else {
for (int arg : newCall.getArgList()) {
if (groupSet.contains(arg)) {
int fakeArg = fakeArgs.get(fakeArgIdx);
callArgMap.put(newCall, fakeArg);
sourceOf.put(fakeArg, projects.size());
projects.add(Pair.of((RexNode) new RexInputRef(fakeArg, newCall.getType()), newCall.getName()));
++fakeArgIdx;
} else {
sourceOf.put(arg, projects.size());
projects.add(Pair.of((RexNode) new RexInputRef(arg, newCall.getType()), newCall.getName()));
}
}
}
}
}
// Generate the aggregate B (see the reference example above)
relBuilder.push(aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), false, ImmutableBitSet.of(newGroupSet), null, newAggCalls));
// Convert the existing aggregate to aggregate A (see the reference example above)
final List<AggregateCall> newTopAggCalls = Lists.newArrayList(aggregate.getAggCallList());
// Use the remapped arguments for the (non)distinct aggregate calls
for (int i = 0; i < newTopAggCalls.size(); i++) {
// Re-map arguments.
final AggregateCall aggCall = newTopAggCalls.get(i);
final int argCount = aggCall.getArgList().size();
final List<Integer> newArgs = new ArrayList<>(argCount);
final AggregateCall newCall;
for (int j = 0; j < argCount; j++) {
final Integer arg = aggCall.getArgList().get(j);
if (callArgMap.containsKey(aggCall)) {
newArgs.add(sourceOf.get(callArgMap.get(aggCall)));
} else {
newArgs.add(sourceOf.get(arg));
}
}
if (aggCall.isDistinct()) {
newCall = AggregateCall.create(aggCall.getAggregation(), false, newArgs, -1, aggregate.getGroupSet().cardinality(), relBuilder.peek(), aggCall.getType(), aggCall.name);
} else {
// aggregate A must be SUM. For other aggregates, it remains the same.
if (aggCall.getAggregation() instanceof SqlCountAggFunction) {
if (aggCall.getArgList().size() == 0) {
newArgs.add(sourceOf.get(callArgMap.get(aggCall)));
}
if (hasGroupBy) {
SqlSumAggFunction sumAgg = new SqlSumAggFunction(null);
newCall = AggregateCall.create(sumAgg, false, newArgs, -1, aggregate.getGroupSet().cardinality(), relBuilder.peek(), aggCall.getType(), aggCall.getName());
} else {
SqlSumEmptyIsZeroAggFunction sumAgg = new SqlSumEmptyIsZeroAggFunction();
newCall = AggregateCall.create(sumAgg, false, newArgs, -1, aggregate.getGroupSet().cardinality(), relBuilder.peek(), aggCall.getType(), aggCall.getName());
}
} else {
newCall = AggregateCall.create(aggCall.getAggregation(), false, newArgs, -1, aggregate.getGroupSet().cardinality(), relBuilder.peek(), aggCall.getType(), aggCall.name);
}
}
newTopAggCalls.set(i, newCall);
}
// Populate the group-by keys with the remapped arguments for aggregate A
newGroupSet.clear();
for (int arg : aggregate.getGroupSet()) {
newGroupSet.add(sourceOf.get(arg));
}
relBuilder.push(aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), aggregate.indicator, ImmutableBitSet.of(newGroupSet), null, newTopAggCalls));
return relBuilder;
}
use of org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction in project drill by axbaretto.
the class DrillReduceAggregatesRule method reduceSum.
private RexNode reduceSum(Aggregate oldAggRel, AggregateCall oldCall, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping) {
final PlannerSettings plannerSettings = (PlannerSettings) oldAggRel.getCluster().getPlanner().getContext();
final boolean isInferenceEnabled = plannerSettings.isTypeInferenceEnabled();
final int nGroups = oldAggRel.getGroupCount();
RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory();
RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
int arg = oldCall.getArgList().get(0);
RelDataType argType = getFieldType(oldAggRel.getInput(), arg);
final RelDataType sumType;
final SqlAggFunction sumZeroAgg;
if (isInferenceEnabled) {
sumType = oldCall.getType();
} else {
sumType = typeFactory.createTypeWithNullability(oldCall.getType(), argType.isNullable());
}
sumZeroAgg = new DrillCalciteSqlAggFunctionWrapper(new SqlSumEmptyIsZeroAggFunction(), sumType);
AggregateCall sumZeroCall = AggregateCall.create(sumZeroAgg, oldCall.isDistinct(), oldCall.isApproximate(), oldCall.getArgList(), -1, sumType, null);
final SqlCountAggFunction countAgg = (SqlCountAggFunction) SqlStdOperatorTable.COUNT;
final RelDataType countType = countAgg.getReturnType(typeFactory);
AggregateCall countCall = AggregateCall.create(countAgg, oldCall.isDistinct(), oldCall.isApproximate(), oldCall.getArgList(), -1, countType, null);
// NOTE: these references are with respect to the output
// of newAggRel
RexNode sumZeroRef = rexBuilder.addAggCall(sumZeroCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(argType));
if (!oldCall.getType().isNullable()) {
// null). Therefore we translate to SUM0(x).
return sumZeroRef;
}
RexNode countRef = rexBuilder.addAggCall(countCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(argType));
return rexBuilder.makeCall(SqlStdOperatorTable.CASE, rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, countRef, rexBuilder.makeExactLiteral(BigDecimal.ZERO)), rexBuilder.constantNull(), sumZeroRef);
}
Aggregations