use of org.apache.calcite.sql.fun.SqlSumAggFunction in project flink by apache.
the class FlinkAggregateExpandDistinctAggregatesRule method onMatch.
//~ Methods ----------------------------------------------------------------
public void onMatch(RelOptRuleCall call) {
final Aggregate aggregate = call.rel(0);
if (!aggregate.containsDistinctCall()) {
return;
}
// Find all of the agg expressions. We use a LinkedHashSet to ensure
// determinism.
int nonDistinctCount = 0;
int distinctCount = 0;
int filterCount = 0;
int unsupportedAggCount = 0;
final Set<Pair<List<Integer>, Integer>> argLists = new LinkedHashSet<>();
for (AggregateCall aggCall : aggregate.getAggCallList()) {
if (aggCall.filterArg >= 0) {
++filterCount;
}
if (!aggCall.isDistinct()) {
++nonDistinctCount;
if (!(aggCall.getAggregation() instanceof SqlCountAggFunction || aggCall.getAggregation() instanceof SqlSumAggFunction || aggCall.getAggregation() instanceof SqlMinMaxAggFunction)) {
++unsupportedAggCount;
}
continue;
}
++distinctCount;
argLists.add(Pair.of(aggCall.getArgList(), aggCall.filterArg));
}
Preconditions.checkState(argLists.size() > 0, "containsDistinctCall lied");
// arguments then we can use a more efficient form.
if (nonDistinctCount == 0 && argLists.size() == 1) {
final Pair<List<Integer>, Integer> pair = Iterables.getOnlyElement(argLists);
final RelBuilder relBuilder = call.builder();
convertMonopole(relBuilder, aggregate, pair.left, pair.right);
call.transformTo(relBuilder.build());
return;
}
if (useGroupingSets) {
rewriteUsingGroupingSets(call, aggregate, argLists);
return;
}
// we can generate multi-phase aggregates
if (// one distinct aggregate
distinctCount == 1 && // no filter
filterCount == 0 && // sum/min/max/count in non-distinct aggregate
unsupportedAggCount == 0 && nonDistinctCount > 0) {
// one or more non-distinct aggregates
final RelBuilder relBuilder = call.builder();
convertSingletonDistinct(relBuilder, aggregate, argLists);
call.transformTo(relBuilder.build());
return;
}
// Create a list of the expressions which will yield the final result.
// Initially, the expressions point to the input field.
final List<RelDataTypeField> aggFields = aggregate.getRowType().getFieldList();
final List<RexInputRef> refs = new ArrayList<>();
final List<String> fieldNames = aggregate.getRowType().getFieldNames();
final ImmutableBitSet groupSet = aggregate.getGroupSet();
final int groupAndIndicatorCount = aggregate.getGroupCount() + aggregate.getIndicatorCount();
for (int i : Util.range(groupAndIndicatorCount)) {
refs.add(RexInputRef.of(i, aggFields));
}
// Aggregate the original relation, including any non-distinct aggregates.
final List<AggregateCall> newAggCallList = new ArrayList<>();
int i = -1;
for (AggregateCall aggCall : aggregate.getAggCallList()) {
++i;
if (aggCall.isDistinct()) {
refs.add(null);
continue;
}
refs.add(new RexInputRef(groupAndIndicatorCount + newAggCallList.size(), aggFields.get(groupAndIndicatorCount + i).getType()));
newAggCallList.add(aggCall);
}
// In the case where there are no non-distinct aggregates (regardless of
// whether there are group bys), there's no need to generate the
// extra aggregate and join.
final RelBuilder relBuilder = call.builder();
relBuilder.push(aggregate.getInput());
int n = 0;
if (!newAggCallList.isEmpty()) {
final RelBuilder.GroupKey groupKey = relBuilder.groupKey(groupSet, aggregate.indicator, aggregate.getGroupSets());
relBuilder.aggregate(groupKey, newAggCallList);
++n;
}
// set of operands.
for (Pair<List<Integer>, Integer> argList : argLists) {
doRewrite(relBuilder, aggregate, n++, argList.left, argList.right, refs);
}
relBuilder.project(refs, fieldNames);
call.transformTo(relBuilder.build());
}
use of org.apache.calcite.sql.fun.SqlSumAggFunction in project drill by apache.
the class DrillReduceAggregatesRule method reduceStddev.
private RexNode reduceStddev(Aggregate oldAggRel, AggregateCall oldCall, boolean biased, boolean sqrt, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping, List<RexNode> inputExprs) {
// stddev_pop(x) ==>
// power(
// (sum(x * x) - sum(x) * sum(x) / count(x))
// / count(x),
// .5)
//
// stddev_samp(x) ==>
// power(
// (sum(x * x) - sum(x) * sum(x) / count(x))
// / nullif(count(x) - 1, 0),
// .5)
final PlannerSettings plannerSettings = (PlannerSettings) oldAggRel.getCluster().getPlanner().getContext();
final boolean isInferenceEnabled = plannerSettings.isTypeInferenceEnabled();
final int nGroups = oldAggRel.getGroupCount();
RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory();
final RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
assert oldCall.getArgList().size() == 1 : oldCall.getArgList();
final int argOrdinal = oldCall.getArgList().get(0);
final RelDataType argType = getFieldType(oldAggRel.getInput(), argOrdinal);
// final RexNode argRef = inputExprs.get(argOrdinal);
RexNode argRef = rexBuilder.makeCall(CastHighOp, inputExprs.get(argOrdinal));
inputExprs.set(argOrdinal, argRef);
final RexNode argSquared = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, argRef, argRef);
final int argSquaredOrdinal = lookupOrAdd(inputExprs, argSquared);
final RelDataType sumType = typeFactory.createTypeWithNullability(argType, true);
final AggregateCall sumArgSquaredAggCall = AggregateCall.create(new SqlSumAggFunction(sumType), oldCall.isDistinct(), ImmutableIntList.of(argSquaredOrdinal), -1, sumType, null);
final RexNode sumArgSquared = rexBuilder.addAggCall(sumArgSquaredAggCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(argType));
final AggregateCall sumArgAggCall = AggregateCall.create(new SqlSumAggFunction(sumType), oldCall.isDistinct(), ImmutableIntList.of(argOrdinal), -1, sumType, null);
final RexNode sumArg = rexBuilder.addAggCall(sumArgAggCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(argType));
final RexNode sumSquaredArg = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, sumArg, sumArg);
final SqlCountAggFunction countAgg = (SqlCountAggFunction) SqlStdOperatorTable.COUNT;
final RelDataType countType = countAgg.getReturnType(typeFactory);
final AggregateCall countArgAggCall = AggregateCall.create(countAgg, oldCall.isDistinct(), oldCall.getArgList(), -1, countType, null);
final RexNode countArg = rexBuilder.addAggCall(countArgAggCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(argType));
final RexNode avgSumSquaredArg = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, sumSquaredArg, countArg);
final RexNode diff = rexBuilder.makeCall(SqlStdOperatorTable.MINUS, sumArgSquared, avgSumSquaredArg);
final RexNode denominator;
if (biased) {
denominator = countArg;
} else {
final RexLiteral one = rexBuilder.makeExactLiteral(BigDecimal.ONE);
final RexNode nul = rexBuilder.makeNullLiteral(countArg.getType().getSqlTypeName());
final RexNode countMinusOne = rexBuilder.makeCall(SqlStdOperatorTable.MINUS, countArg, one);
final RexNode countEqOne = rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, countArg, one);
denominator = rexBuilder.makeCall(SqlStdOperatorTable.CASE, countEqOne, nul, countMinusOne);
}
final SqlOperator divide;
if (isInferenceEnabled) {
divide = new DrillSqlOperator("divide", 2, true, oldCall.getType(), false);
} else {
divide = SqlStdOperatorTable.DIVIDE;
}
final RexNode div = rexBuilder.makeCall(divide, diff, denominator);
RexNode result = div;
if (sqrt) {
final RexNode half = rexBuilder.makeExactLiteral(new BigDecimal("0.5"));
result = rexBuilder.makeCall(SqlStdOperatorTable.POWER, div, half);
}
if (isInferenceEnabled) {
return result;
} else {
/*
* Currently calcite's strategy to infer the return type of aggregate functions
* is wrong because it uses the first known argument to determine output type. For
* instance if we are performing stddev on an integer column then it interprets the
* output type to be integer which is incorrect as it should be double. So based on
* this if we add cast after rewriting the aggregate we add an additional cast which
* would cause wrong results. So we simply add a cast to ANY.
*/
return rexBuilder.makeCast(typeFactory.createSqlType(SqlTypeName.ANY), result);
}
}
use of org.apache.calcite.sql.fun.SqlSumAggFunction in project drill by apache.
the class DrillReduceAggregatesRule method reduceAgg.
private RexNode reduceAgg(Aggregate oldAggRel, AggregateCall oldCall, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping, List<RexNode> inputExprs) {
final SqlAggFunction sqlAggFunction = DrillCalciteWrapperUtility.extractSqlOperatorFromWrapper(oldCall.getAggregation());
if (sqlAggFunction instanceof SqlSumAggFunction) {
// case COUNT(x) when 0 then null else SUM0(x) end
return reduceSum(oldAggRel, oldCall, newCalls, aggCallMapping);
}
if (sqlAggFunction instanceof SqlAvgAggFunction) {
final SqlAvgAggFunction.Subtype subtype = ((SqlAvgAggFunction) sqlAggFunction).getSubtype();
switch(subtype) {
case AVG:
// replace original AVG(x) with SUM(x) / COUNT(x)
return reduceAvg(oldAggRel, oldCall, newCalls, aggCallMapping);
case STDDEV_POP:
// / COUNT(x))
return reduceStddev(oldAggRel, oldCall, true, true, newCalls, aggCallMapping, inputExprs);
case STDDEV_SAMP:
// / CASE COUNT(x) WHEN 1 THEN NULL ELSE COUNT(x) - 1 END)
return reduceStddev(oldAggRel, oldCall, false, true, newCalls, aggCallMapping, inputExprs);
case VAR_POP:
// / COUNT(x)
return reduceStddev(oldAggRel, oldCall, true, false, newCalls, aggCallMapping, inputExprs);
case VAR_SAMP:
// / CASE COUNT(x) WHEN 1 THEN NULL ELSE COUNT(x) - 1 END
return reduceStddev(oldAggRel, oldCall, false, false, newCalls, aggCallMapping, inputExprs);
default:
throw Util.unexpected(subtype);
}
} else {
// anything else: preserve original call
RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
final int nGroups = oldAggRel.getGroupCount();
List<RelDataType> oldArgTypes = new ArrayList<>();
List<Integer> ordinals = oldCall.getArgList();
assert ordinals.size() <= inputExprs.size();
for (int ordinal : ordinals) {
oldArgTypes.add(inputExprs.get(ordinal).getType());
}
return rexBuilder.addAggCall(oldCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, oldArgTypes);
}
}
use of org.apache.calcite.sql.fun.SqlSumAggFunction in project flink by apache.
the class FlinkAggregateExpandDistinctAggregatesRule method convertSingletonDistinct.
/**
* Converts an aggregate with one distinct aggregate and one or more
* non-distinct aggregates to multi-phase aggregates (see reference example
* below).
*
* @param relBuilder Contains the input relational expression
* @param aggregate Original aggregate
* @param argLists Arguments and filters to the distinct aggregate function
*
*/
private RelBuilder convertSingletonDistinct(RelBuilder relBuilder, Aggregate aggregate, Set<Pair<List<Integer>, Integer>> argLists) {
// For example,
// SELECT deptno, COUNT(*), SUM(bonus), MIN(DISTINCT sal)
// FROM emp
// GROUP BY deptno
//
// becomes
//
// SELECT deptno, SUM(cnt), SUM(bonus), MIN(sal)
// FROM (
// SELECT deptno, COUNT(*) as cnt, SUM(bonus), sal
// FROM EMP
// GROUP BY deptno, sal) // Aggregate B
// GROUP BY deptno // Aggregate A
relBuilder.push(aggregate.getInput());
final List<Pair<RexNode, String>> projects = new ArrayList<>();
final Map<Integer, Integer> sourceOf = new HashMap<>();
SortedSet<Integer> newGroupSet = new TreeSet<>();
final List<RelDataTypeField> childFields = relBuilder.peek().getRowType().getFieldList();
final boolean hasGroupBy = aggregate.getGroupSet().size() > 0;
SortedSet<Integer> groupSet = new TreeSet<>(aggregate.getGroupSet().asList());
// Add the distinct aggregate column(s) to the group-by columns,
// if not already a part of the group-by
newGroupSet.addAll(aggregate.getGroupSet().asList());
for (Pair<List<Integer>, Integer> argList : argLists) {
newGroupSet.addAll(argList.getKey());
}
// transformation.
for (int arg : newGroupSet) {
sourceOf.put(arg, projects.size());
projects.add(RexInputRef.of2(arg, childFields));
}
// Generate the intermediate aggregate B
final List<AggregateCall> aggCalls = aggregate.getAggCallList();
final List<AggregateCall> newAggCalls = new ArrayList<>();
final List<Integer> fakeArgs = new ArrayList<>();
final Map<AggregateCall, Integer> callArgMap = new HashMap<>();
// e.g. if real arguments are 0, 1, 3. Then the fake arguments will be 2, 4
for (final AggregateCall aggCall : aggCalls) {
if (!aggCall.isDistinct()) {
for (int arg : aggCall.getArgList()) {
if (!groupSet.contains(arg)) {
sourceOf.put(arg, projects.size());
}
}
}
}
int fakeArg0 = 0;
for (final AggregateCall aggCall : aggCalls) {
// We will deal with non-distinct aggregates below
if (!aggCall.isDistinct()) {
boolean isGroupKeyUsedInAgg = false;
for (int arg : aggCall.getArgList()) {
if (groupSet.contains(arg)) {
isGroupKeyUsedInAgg = true;
break;
}
}
if (aggCall.getArgList().size() == 0 || isGroupKeyUsedInAgg) {
while (sourceOf.get(fakeArg0) != null) {
++fakeArg0;
}
fakeArgs.add(fakeArg0);
++fakeArg0;
}
}
}
for (final AggregateCall aggCall : aggCalls) {
if (!aggCall.isDistinct()) {
for (int arg : aggCall.getArgList()) {
if (!groupSet.contains(arg)) {
sourceOf.remove(arg);
}
}
}
}
// Compute the remapped arguments using fake arguments for non-distinct
// aggregates with no arguments e.g. count(*).
int fakeArgIdx = 0;
for (final AggregateCall aggCall : aggCalls) {
// as-is all the non-distinct aggregates
if (!aggCall.isDistinct()) {
final AggregateCall newCall = AggregateCall.create(aggCall.getAggregation(), false, aggCall.getArgList(), -1, ImmutableBitSet.of(newGroupSet).cardinality(), relBuilder.peek(), null, aggCall.name);
newAggCalls.add(newCall);
if (newCall.getArgList().size() == 0) {
int fakeArg = fakeArgs.get(fakeArgIdx);
callArgMap.put(newCall, fakeArg);
sourceOf.put(fakeArg, projects.size());
projects.add(Pair.of((RexNode) new RexInputRef(fakeArg, newCall.getType()), newCall.getName()));
++fakeArgIdx;
} else {
for (int arg : newCall.getArgList()) {
if (groupSet.contains(arg)) {
int fakeArg = fakeArgs.get(fakeArgIdx);
callArgMap.put(newCall, fakeArg);
sourceOf.put(fakeArg, projects.size());
projects.add(Pair.of((RexNode) new RexInputRef(fakeArg, newCall.getType()), newCall.getName()));
++fakeArgIdx;
} else {
sourceOf.put(arg, projects.size());
projects.add(Pair.of((RexNode) new RexInputRef(arg, newCall.getType()), newCall.getName()));
}
}
}
}
}
// Generate the aggregate B (see the reference example above)
relBuilder.push(aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), false, ImmutableBitSet.of(newGroupSet), null, newAggCalls));
// Convert the existing aggregate to aggregate A (see the reference example above)
final List<AggregateCall> newTopAggCalls = Lists.newArrayList(aggregate.getAggCallList());
// Use the remapped arguments for the (non)distinct aggregate calls
for (int i = 0; i < newTopAggCalls.size(); i++) {
// Re-map arguments.
final AggregateCall aggCall = newTopAggCalls.get(i);
final int argCount = aggCall.getArgList().size();
final List<Integer> newArgs = new ArrayList<>(argCount);
final AggregateCall newCall;
for (int j = 0; j < argCount; j++) {
final Integer arg = aggCall.getArgList().get(j);
if (callArgMap.containsKey(aggCall)) {
newArgs.add(sourceOf.get(callArgMap.get(aggCall)));
} else {
newArgs.add(sourceOf.get(arg));
}
}
if (aggCall.isDistinct()) {
newCall = AggregateCall.create(aggCall.getAggregation(), false, newArgs, -1, aggregate.getGroupSet().cardinality(), relBuilder.peek(), aggCall.getType(), aggCall.name);
} else {
// aggregate A must be SUM. For other aggregates, it remains the same.
if (aggCall.getAggregation() instanceof SqlCountAggFunction) {
if (aggCall.getArgList().size() == 0) {
newArgs.add(sourceOf.get(callArgMap.get(aggCall)));
}
if (hasGroupBy) {
SqlSumAggFunction sumAgg = new SqlSumAggFunction(null);
newCall = AggregateCall.create(sumAgg, false, newArgs, -1, aggregate.getGroupSet().cardinality(), relBuilder.peek(), aggCall.getType(), aggCall.getName());
} else {
SqlSumEmptyIsZeroAggFunction sumAgg = new SqlSumEmptyIsZeroAggFunction();
newCall = AggregateCall.create(sumAgg, false, newArgs, -1, aggregate.getGroupSet().cardinality(), relBuilder.peek(), aggCall.getType(), aggCall.getName());
}
} else {
newCall = AggregateCall.create(aggCall.getAggregation(), false, newArgs, -1, aggregate.getGroupSet().cardinality(), relBuilder.peek(), aggCall.getType(), aggCall.name);
}
}
newTopAggCalls.set(i, newCall);
}
// Populate the group-by keys with the remapped arguments for aggregate A
newGroupSet.clear();
for (int arg : aggregate.getGroupSet()) {
newGroupSet.add(sourceOf.get(arg));
}
relBuilder.push(aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), aggregate.indicator, ImmutableBitSet.of(newGroupSet), null, newTopAggCalls));
return relBuilder;
}
Aggregations