use of org.apache.calcite.sql.SqlAggFunction in project flink by apache.
the class FlinkAggregateExpandDistinctAggregatesRule method rewriteUsingGroupingSets.
private void rewriteUsingGroupingSets(RelOptRuleCall call, Aggregate aggregate) {
final Set<ImmutableBitSet> groupSetTreeSet = new TreeSet<>(ImmutableBitSet.ORDERING);
final Map<ImmutableBitSet, Integer> groupSetToDistinctAggCallFilterArg = new HashMap<>();
for (AggregateCall aggCall : aggregate.getAggCallList()) {
if (!aggCall.isDistinct()) {
groupSetTreeSet.add(aggregate.getGroupSet());
} else {
ImmutableBitSet groupSet = ImmutableBitSet.of(aggCall.getArgList()).setIf(aggCall.filterArg, aggCall.filterArg >= 0).union(aggregate.getGroupSet());
groupSetToDistinctAggCallFilterArg.put(groupSet, aggCall.filterArg);
groupSetTreeSet.add(groupSet);
}
}
final com.google.common.collect.ImmutableList<ImmutableBitSet> groupSets = com.google.common.collect.ImmutableList.copyOf(groupSetTreeSet);
final ImmutableBitSet fullGroupSet = ImmutableBitSet.union(groupSets);
final List<AggregateCall> distinctAggCalls = new ArrayList<>();
for (Pair<AggregateCall, String> aggCall : aggregate.getNamedAggCalls()) {
if (!aggCall.left.isDistinct()) {
AggregateCall newAggCall = aggCall.left.adaptTo(aggregate.getInput(), aggCall.left.getArgList(), aggCall.left.filterArg, aggregate.getGroupCount(), fullGroupSet.cardinality());
distinctAggCalls.add(newAggCall.rename(aggCall.right));
}
}
final RelBuilder relBuilder = call.builder();
relBuilder.push(aggregate.getInput());
final int groupCount = fullGroupSet.cardinality();
final Map<ImmutableBitSet, Integer> filters = new LinkedHashMap<>();
final int z = groupCount + distinctAggCalls.size();
distinctAggCalls.add(AggregateCall.create(SqlStdOperatorTable.GROUPING, false, false, false, ImmutableIntList.copyOf(fullGroupSet), -1, RelCollations.EMPTY, groupSets.size(), relBuilder.peek(), null, "$g"));
for (Ord<ImmutableBitSet> groupSet : Ord.zip(groupSets)) {
filters.put(groupSet.e, z + groupSet.i);
}
relBuilder.aggregate(relBuilder.groupKey(fullGroupSet, groupSets), distinctAggCalls);
final RelNode distinct = relBuilder.peek();
// values to BOOLEAN.
if (!filters.isEmpty()) {
final List<RexNode> nodes = new ArrayList<>(relBuilder.fields());
final RexNode nodeZ = nodes.remove(nodes.size() - 1);
for (Map.Entry<ImmutableBitSet, Integer> entry : filters.entrySet()) {
final long v = groupValue(fullGroupSet, entry.getKey());
// Get and remap the filterArg of the distinct aggregate call.
int distinctAggCallFilterArg = remap(fullGroupSet, groupSetToDistinctAggCallFilterArg.getOrDefault(entry.getKey(), -1));
RexNode expr;
if (distinctAggCallFilterArg < 0) {
expr = relBuilder.equals(nodeZ, relBuilder.literal(v));
} else {
RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
// merge the filter of the distinct aggregate call itself.
expr = relBuilder.and(relBuilder.equals(nodeZ, relBuilder.literal(v)), rexBuilder.makeCall(SqlStdOperatorTable.IS_TRUE, relBuilder.field(distinctAggCallFilterArg)));
}
nodes.add(relBuilder.alias(expr, "$g_" + v));
}
relBuilder.project(nodes);
}
int aggCallIdx = 0;
int x = groupCount;
final List<AggregateCall> newCalls = new ArrayList<>();
// TODO supports more aggCalls (currently only supports COUNT)
// Some aggregate functions (e.g. COUNT) have the special property that they can return a
// non-null result without any input. We need to make sure we return a result in this case.
final List<Integer> needDefaultValueAggCalls = new ArrayList<>();
for (AggregateCall aggCall : aggregate.getAggCallList()) {
final int newFilterArg;
final List<Integer> newArgList;
final SqlAggFunction aggregation;
if (!aggCall.isDistinct()) {
aggregation = SqlStdOperatorTable.MIN;
newArgList = ImmutableIntList.of(x++);
newFilterArg = filters.get(aggregate.getGroupSet());
switch(aggCall.getAggregation().getKind()) {
case COUNT:
needDefaultValueAggCalls.add(aggCallIdx);
break;
default:
}
} else {
aggregation = aggCall.getAggregation();
newArgList = remap(fullGroupSet, aggCall.getArgList());
newFilterArg = filters.get(ImmutableBitSet.of(aggCall.getArgList()).setIf(aggCall.filterArg, aggCall.filterArg >= 0).union(aggregate.getGroupSet()));
}
final AggregateCall newCall = AggregateCall.create(aggregation, false, aggCall.isApproximate(), false, newArgList, newFilterArg, RelCollations.EMPTY, aggregate.getGroupCount(), distinct, null, aggCall.name);
newCalls.add(newCall);
aggCallIdx++;
}
relBuilder.aggregate(relBuilder.groupKey(remap(fullGroupSet, aggregate.getGroupSet()), remap(fullGroupSet, aggregate.getGroupSets())), newCalls);
if (!needDefaultValueAggCalls.isEmpty() && aggregate.getGroupCount() == 0) {
final Aggregate newAgg = (Aggregate) relBuilder.peek();
final List<RexNode> nodes = new ArrayList<>();
for (int i = 0; i < newAgg.getGroupCount(); ++i) {
nodes.add(RexInputRef.of(i, newAgg.getRowType()));
}
for (int i = 0; i < newAgg.getAggCallList().size(); ++i) {
final RexNode inputRef = RexInputRef.of(newAgg.getGroupCount() + i, newAgg.getRowType());
RexNode newNode = inputRef;
if (needDefaultValueAggCalls.contains(i)) {
SqlKind originalFunKind = aggregate.getAggCallList().get(i).getAggregation().getKind();
switch(originalFunKind) {
case COUNT:
newNode = relBuilder.call(SqlStdOperatorTable.CASE, relBuilder.isNotNull(inputRef), inputRef, relBuilder.literal(BigDecimal.ZERO));
break;
default:
}
}
nodes.add(newNode);
}
relBuilder.project(nodes);
}
relBuilder.convert(aggregate.getRowType(), true);
call.transformTo(relBuilder.build());
}
use of org.apache.calcite.sql.SqlAggFunction in project drill by apache.
the class DrillReduceAggregatesRule method reduceAvg.
private RexNode reduceAvg(Aggregate oldAggRel, AggregateCall oldCall, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping) {
final PlannerSettings plannerSettings = (PlannerSettings) oldAggRel.getCluster().getPlanner().getContext();
final boolean isInferenceEnabled = plannerSettings.isTypeInferenceEnabled();
final int nGroups = oldAggRel.getGroupCount();
RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory();
RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
int iAvgInput = oldCall.getArgList().get(0);
RelDataType avgInputType = getFieldType(oldAggRel.getInput(), iAvgInput);
RelDataType sumType = TypeInferenceUtils.getDrillSqlReturnTypeInference(SqlKind.SUM.name(), ImmutableList.of()).inferReturnType(oldCall.createBinding(oldAggRel));
sumType = typeFactory.createTypeWithNullability(sumType, sumType.isNullable() || nGroups == 0);
SqlAggFunction sumAgg = new DrillCalciteSqlAggFunctionWrapper(new SqlSumEmptyIsZeroAggFunction(), sumType);
AggregateCall sumCall = AggregateCall.create(sumAgg, oldCall.isDistinct(), oldCall.isApproximate(), oldCall.getArgList(), -1, sumType, null);
final SqlCountAggFunction countAgg = (SqlCountAggFunction) SqlStdOperatorTable.COUNT;
final RelDataType countType = countAgg.getReturnType(typeFactory);
AggregateCall countCall = AggregateCall.create(countAgg, oldCall.isDistinct(), oldCall.isApproximate(), oldCall.getArgList(), -1, countType, null);
RexNode tmpsumRef = rexBuilder.addAggCall(sumCall, nGroups, newCalls, aggCallMapping, ImmutableList.of(avgInputType));
RexNode tmpcountRef = rexBuilder.addAggCall(countCall, nGroups, newCalls, aggCallMapping, ImmutableList.of(avgInputType));
RexNode n = rexBuilder.makeCall(SqlStdOperatorTable.CASE, rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, tmpcountRef, rexBuilder.makeExactLiteral(BigDecimal.ZERO)), rexBuilder.constantNull(), tmpsumRef);
// NOTE: these references are with respect to the output
// of newAggRel
/*
RexNode numeratorRef =
rexBuilder.makeCall(CastHighOp,
rexBuilder.addAggCall(
sumCall,
nGroups,
newCalls,
aggCallMapping,
ImmutableList.of(avgInputType))
);
*/
RexNode numeratorRef = rexBuilder.makeCall(CastHighOp, n);
RexNode denominatorRef = rexBuilder.addAggCall(countCall, nGroups, newCalls, aggCallMapping, ImmutableList.of(avgInputType));
if (isInferenceEnabled) {
return rexBuilder.makeCall(new DrillSqlOperator("divide", 2, true, oldCall.getType(), false), numeratorRef, denominatorRef);
} else {
final RexNode divideRef = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, numeratorRef, denominatorRef);
return rexBuilder.makeCast(typeFactory.createSqlType(SqlTypeName.ANY), divideRef);
}
}
use of org.apache.calcite.sql.SqlAggFunction in project drill by apache.
the class DrillOperatorTable method populateWrappedCalciteOperators.
private void populateWrappedCalciteOperators() {
for (SqlOperator calciteOperator : inner.getOperatorList()) {
final SqlOperator wrapper;
if (calciteOperator instanceof SqlAggFunction) {
wrapper = new DrillCalciteSqlAggFunctionWrapper((SqlAggFunction) calciteOperator, getFunctionListWithInference(calciteOperator.getName()));
} else if (calciteOperator instanceof SqlFunction) {
wrapper = new DrillCalciteSqlFunctionWrapper((SqlFunction) calciteOperator, getFunctionListWithInference(calciteOperator.getName()));
} else if (calciteOperator instanceof SqlBetweenOperator) {
// During the procedure of converting to RexNode,
// StandardConvertletTable.convertBetween expects the SqlOperator to be a subclass of SqlBetweenOperator
final SqlBetweenOperator sqlBetweenOperator = (SqlBetweenOperator) calciteOperator;
wrapper = new DrillCalciteSqlBetweenOperatorWrapper(sqlBetweenOperator);
} else {
final String drillOpName;
// Otherwise, Calcite will mix them up with binary operator subtract (-) or add (+)
if (calciteOperator == SqlStdOperatorTable.UNARY_MINUS || calciteOperator == SqlStdOperatorTable.UNARY_PLUS) {
drillOpName = calciteOperator.getName();
} else {
drillOpName = FunctionCallFactory.convertToDrillFunctionName(calciteOperator.getName());
}
final List<DrillFuncHolder> drillFuncHolders = getFunctionListWithInference(drillOpName);
if (drillFuncHolders.isEmpty()) {
continue;
}
wrapper = new DrillCalciteSqlOperatorWrapper(calciteOperator, drillOpName, drillFuncHolders);
}
calciteToWrapper.put(calciteOperator, wrapper);
}
}
Aggregations