use of org.apache.calcite.rel.core.AggregateCall in project hive by apache.
the class HiveAggregateProjectMergeRule method apply.
public static RelNode apply(HiveAggregate aggregate, HiveProject project) {
final List<Integer> newKeys = Lists.newArrayList();
final Map<Integer, Integer> map = new HashMap<>();
for (int key : aggregate.getGroupSet()) {
final RexNode rex = project.getProjects().get(key);
if (rex instanceof RexInputRef) {
final int newKey = ((RexInputRef) rex).getIndex();
newKeys.add(newKey);
map.put(key, newKey);
} else {
// Cannot handle "GROUP BY expression"
return null;
}
}
final ImmutableBitSet newGroupSet = aggregate.getGroupSet().permute(map);
ImmutableList<ImmutableBitSet> newGroupingSets = null;
if (aggregate.indicator) {
newGroupingSets = ImmutableBitSet.ORDERING.immutableSortedCopy(ImmutableBitSet.permute(aggregate.getGroupSets(), map));
}
final ImmutableList.Builder<AggregateCall> aggCalls = ImmutableList.builder();
for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
final ImmutableList.Builder<Integer> newArgs = ImmutableList.builder();
for (int arg : aggregateCall.getArgList()) {
final RexNode rex = project.getProjects().get(arg);
if (rex instanceof RexInputRef) {
newArgs.add(((RexInputRef) rex).getIndex());
} else {
// Cannot handle "AGG(expression)"
return null;
}
}
final int newFilterArg;
if (aggregateCall.filterArg >= 0) {
final RexNode rex = project.getProjects().get(aggregateCall.filterArg);
if (!(rex instanceof RexInputRef)) {
return null;
}
newFilterArg = ((RexInputRef) rex).getIndex();
} else {
newFilterArg = -1;
}
aggCalls.add(aggregateCall.copy(newArgs.build(), newFilterArg));
}
final Aggregate newAggregate = aggregate.copy(aggregate.getTraitSet(), project.getInput(), aggregate.indicator, newGroupSet, newGroupingSets, aggCalls.build());
// Add a project if the group set is not in the same order or
// contains duplicates.
RelNode rel = newAggregate;
if (!newKeys.equals(newGroupSet.asList())) {
final List<Integer> posList = Lists.newArrayList();
for (int newKey : newKeys) {
posList.add(newGroupSet.indexOf(newKey));
}
if (aggregate.indicator) {
for (int newKey : newKeys) {
posList.add(aggregate.getGroupCount() + newGroupSet.indexOf(newKey));
}
}
for (int i = newAggregate.getGroupCount() + newAggregate.getIndicatorCount(); i < newAggregate.getRowType().getFieldCount(); i++) {
posList.add(i);
}
rel = HiveRelOptUtil.createProject(HiveRelFactories.HIVE_BUILDER.create(aggregate.getCluster(), null), rel, posList);
}
return rel;
}
use of org.apache.calcite.rel.core.AggregateCall in project hive by apache.
the class HiveAggregate method deriveRowType.
public static RelDataType deriveRowType(RelDataTypeFactory typeFactory, final RelDataType inputRowType, boolean indicator, ImmutableBitSet groupSet, List<ImmutableBitSet> groupSets, final List<AggregateCall> aggCalls) {
final List<Integer> groupList = groupSet.asList();
assert groupList.size() == groupSet.cardinality();
final RelDataTypeFactory.FieldInfoBuilder builder = typeFactory.builder();
final List<RelDataTypeField> fieldList = inputRowType.getFieldList();
final Set<String> containedNames = Sets.newHashSet();
for (int groupKey : groupList) {
containedNames.add(fieldList.get(groupKey).getName());
builder.add(fieldList.get(groupKey));
}
if (indicator) {
for (int groupKey : groupList) {
final RelDataType booleanType = typeFactory.createTypeWithNullability(typeFactory.createSqlType(SqlTypeName.BOOLEAN), false);
String name = "i$" + fieldList.get(groupKey).getName();
int i = 0;
while (containedNames.contains(name)) {
name += "_" + i++;
}
containedNames.add(name);
builder.add(name, booleanType);
}
}
for (Ord<AggregateCall> aggCall : Ord.zip(aggCalls)) {
String name;
if (aggCall.e.name != null) {
name = aggCall.e.name;
} else {
name = "$f" + (groupList.size() + aggCall.i);
}
int i = 0;
while (containedNames.contains(name)) {
name += "_" + i++;
}
containedNames.add(name);
builder.add(name, aggCall.e.type);
}
return builder.build();
}
use of org.apache.calcite.rel.core.AggregateCall in project flink by apache.
the class FlinkAggregateExpandDistinctAggregatesRule method convertSingletonDistinct.
/**
* Converts an aggregate with one distinct aggregate and one or more
* non-distinct aggregates to multi-phase aggregates (see reference example
* below).
*
* @param relBuilder Contains the input relational expression
* @param aggregate Original aggregate
* @param argLists Arguments and filters to the distinct aggregate function
*
*/
private RelBuilder convertSingletonDistinct(RelBuilder relBuilder, Aggregate aggregate, Set<Pair<List<Integer>, Integer>> argLists) {
// For example,
// SELECT deptno, COUNT(*), SUM(bonus), MIN(DISTINCT sal)
// FROM emp
// GROUP BY deptno
//
// becomes
//
// SELECT deptno, SUM(cnt), SUM(bonus), MIN(sal)
// FROM (
// SELECT deptno, COUNT(*) as cnt, SUM(bonus), sal
// FROM EMP
// GROUP BY deptno, sal) // Aggregate B
// GROUP BY deptno // Aggregate A
relBuilder.push(aggregate.getInput());
final List<Pair<RexNode, String>> projects = new ArrayList<>();
final Map<Integer, Integer> sourceOf = new HashMap<>();
SortedSet<Integer> newGroupSet = new TreeSet<>();
final List<RelDataTypeField> childFields = relBuilder.peek().getRowType().getFieldList();
final boolean hasGroupBy = aggregate.getGroupSet().size() > 0;
SortedSet<Integer> groupSet = new TreeSet<>(aggregate.getGroupSet().asList());
// Add the distinct aggregate column(s) to the group-by columns,
// if not already a part of the group-by
newGroupSet.addAll(aggregate.getGroupSet().asList());
for (Pair<List<Integer>, Integer> argList : argLists) {
newGroupSet.addAll(argList.getKey());
}
// transformation.
for (int arg : newGroupSet) {
sourceOf.put(arg, projects.size());
projects.add(RexInputRef.of2(arg, childFields));
}
// Generate the intermediate aggregate B
final List<AggregateCall> aggCalls = aggregate.getAggCallList();
final List<AggregateCall> newAggCalls = new ArrayList<>();
final List<Integer> fakeArgs = new ArrayList<>();
final Map<AggregateCall, Integer> callArgMap = new HashMap<>();
// e.g. if real arguments are 0, 1, 3. Then the fake arguments will be 2, 4
for (final AggregateCall aggCall : aggCalls) {
if (!aggCall.isDistinct()) {
for (int arg : aggCall.getArgList()) {
if (!groupSet.contains(arg)) {
sourceOf.put(arg, projects.size());
}
}
}
}
int fakeArg0 = 0;
for (final AggregateCall aggCall : aggCalls) {
// We will deal with non-distinct aggregates below
if (!aggCall.isDistinct()) {
boolean isGroupKeyUsedInAgg = false;
for (int arg : aggCall.getArgList()) {
if (groupSet.contains(arg)) {
isGroupKeyUsedInAgg = true;
break;
}
}
if (aggCall.getArgList().size() == 0 || isGroupKeyUsedInAgg) {
while (sourceOf.get(fakeArg0) != null) {
++fakeArg0;
}
fakeArgs.add(fakeArg0);
++fakeArg0;
}
}
}
for (final AggregateCall aggCall : aggCalls) {
if (!aggCall.isDistinct()) {
for (int arg : aggCall.getArgList()) {
if (!groupSet.contains(arg)) {
sourceOf.remove(arg);
}
}
}
}
// Compute the remapped arguments using fake arguments for non-distinct
// aggregates with no arguments e.g. count(*).
int fakeArgIdx = 0;
for (final AggregateCall aggCall : aggCalls) {
// as-is all the non-distinct aggregates
if (!aggCall.isDistinct()) {
final AggregateCall newCall = AggregateCall.create(aggCall.getAggregation(), false, aggCall.getArgList(), -1, ImmutableBitSet.of(newGroupSet).cardinality(), relBuilder.peek(), null, aggCall.name);
newAggCalls.add(newCall);
if (newCall.getArgList().size() == 0) {
int fakeArg = fakeArgs.get(fakeArgIdx);
callArgMap.put(newCall, fakeArg);
sourceOf.put(fakeArg, projects.size());
projects.add(Pair.of((RexNode) new RexInputRef(fakeArg, newCall.getType()), newCall.getName()));
++fakeArgIdx;
} else {
for (int arg : newCall.getArgList()) {
if (groupSet.contains(arg)) {
int fakeArg = fakeArgs.get(fakeArgIdx);
callArgMap.put(newCall, fakeArg);
sourceOf.put(fakeArg, projects.size());
projects.add(Pair.of((RexNode) new RexInputRef(fakeArg, newCall.getType()), newCall.getName()));
++fakeArgIdx;
} else {
sourceOf.put(arg, projects.size());
projects.add(Pair.of((RexNode) new RexInputRef(arg, newCall.getType()), newCall.getName()));
}
}
}
}
}
// Generate the aggregate B (see the reference example above)
relBuilder.push(aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), false, ImmutableBitSet.of(newGroupSet), null, newAggCalls));
// Convert the existing aggregate to aggregate A (see the reference example above)
final List<AggregateCall> newTopAggCalls = Lists.newArrayList(aggregate.getAggCallList());
// Use the remapped arguments for the (non)distinct aggregate calls
for (int i = 0; i < newTopAggCalls.size(); i++) {
// Re-map arguments.
final AggregateCall aggCall = newTopAggCalls.get(i);
final int argCount = aggCall.getArgList().size();
final List<Integer> newArgs = new ArrayList<>(argCount);
final AggregateCall newCall;
for (int j = 0; j < argCount; j++) {
final Integer arg = aggCall.getArgList().get(j);
if (callArgMap.containsKey(aggCall)) {
newArgs.add(sourceOf.get(callArgMap.get(aggCall)));
} else {
newArgs.add(sourceOf.get(arg));
}
}
if (aggCall.isDistinct()) {
newCall = AggregateCall.create(aggCall.getAggregation(), false, newArgs, -1, aggregate.getGroupSet().cardinality(), relBuilder.peek(), aggCall.getType(), aggCall.name);
} else {
// aggregate A must be SUM. For other aggregates, it remains the same.
if (aggCall.getAggregation() instanceof SqlCountAggFunction) {
if (aggCall.getArgList().size() == 0) {
newArgs.add(sourceOf.get(callArgMap.get(aggCall)));
}
if (hasGroupBy) {
SqlSumAggFunction sumAgg = new SqlSumAggFunction(null);
newCall = AggregateCall.create(sumAgg, false, newArgs, -1, aggregate.getGroupSet().cardinality(), relBuilder.peek(), aggCall.getType(), aggCall.getName());
} else {
SqlSumEmptyIsZeroAggFunction sumAgg = new SqlSumEmptyIsZeroAggFunction();
newCall = AggregateCall.create(sumAgg, false, newArgs, -1, aggregate.getGroupSet().cardinality(), relBuilder.peek(), aggCall.getType(), aggCall.getName());
}
} else {
newCall = AggregateCall.create(aggCall.getAggregation(), false, newArgs, -1, aggregate.getGroupSet().cardinality(), relBuilder.peek(), aggCall.getType(), aggCall.name);
}
}
newTopAggCalls.set(i, newCall);
}
// Populate the group-by keys with the remapped arguments for aggregate A
newGroupSet.clear();
for (int arg : aggregate.getGroupSet()) {
newGroupSet.add(sourceOf.get(arg));
}
relBuilder.push(aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), aggregate.indicator, ImmutableBitSet.of(newGroupSet), null, newTopAggCalls));
return relBuilder;
}
use of org.apache.calcite.rel.core.AggregateCall in project flink by apache.
the class FlinkAggregateExpandDistinctAggregatesRule method doRewrite.
/**
* Converts all distinct aggregate calls to a given set of arguments.
*
* <p>This method is called several times, one for each set of arguments.
* Each time it is called, it generates a JOIN to a new SELECT DISTINCT
* relational expression, and modifies the set of top-level calls.
*
* @param aggregate Original aggregate
* @param n Ordinal of this in a join. {@code relBuilder} contains the
* input relational expression (either the original
* aggregate, the output from the previous call to this
* method. {@code n} is 0 if we're converting the
* first distinct aggregate in a query with no non-distinct
* aggregates)
* @param argList Arguments to the distinct aggregate function
* @param filterArg Argument that filters input to aggregate function, or -1
* @param refs Array of expressions which will be the projected by the
* result of this rule. Those relating to this arg list will
* be modified @return Relational expression
*/
private void doRewrite(RelBuilder relBuilder, Aggregate aggregate, int n, List<Integer> argList, int filterArg, List<RexInputRef> refs) {
final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
final List<RelDataTypeField> leftFields;
if (n == 0) {
leftFields = null;
} else {
leftFields = relBuilder.peek().getRowType().getFieldList();
}
// LogicalAggregate(
// child,
// {COUNT(DISTINCT 1), SUM(DISTINCT 1), SUM(2)})
//
// becomes
//
// LogicalAggregate(
// LogicalJoin(
// child,
// LogicalAggregate(child, < all columns > {}),
// INNER,
// <f2 = f5>))
//
// E.g.
// SELECT deptno, SUM(DISTINCT sal), COUNT(DISTINCT gender), MAX(age)
// FROM Emps
// GROUP BY deptno
//
// becomes
//
// SELECT e.deptno, adsal.sum_sal, adgender.count_gender, e.max_age
// FROM (
// SELECT deptno, MAX(age) as max_age
// FROM Emps GROUP BY deptno) AS e
// JOIN (
// SELECT deptno, COUNT(gender) AS count_gender FROM (
// SELECT DISTINCT deptno, gender FROM Emps) AS dgender
// GROUP BY deptno) AS adgender
// ON e.deptno = adgender.deptno
// JOIN (
// SELECT deptno, SUM(sal) AS sum_sal FROM (
// SELECT DISTINCT deptno, sal FROM Emps) AS dsal
// GROUP BY deptno) AS adsal
// ON e.deptno = adsal.deptno
// GROUP BY e.deptno
//
// Note that if a query contains no non-distinct aggregates, then the
// very first join/group by is omitted. In the example above, if
// MAX(age) is removed, then the sub-select of "e" is not needed, and
// instead the two other group by's are joined to one another.
// Project the columns of the GROUP BY plus the arguments
// to the agg function.
final Map<Integer, Integer> sourceOf = new HashMap<>();
createSelectDistinct(relBuilder, aggregate, argList, filterArg, sourceOf);
// Now compute the aggregate functions on top of the distinct dataset.
// Each distinct agg becomes a non-distinct call to the corresponding
// field from the right; for example,
// "COUNT(DISTINCT e.sal)"
// becomes
// "COUNT(distinct_e.sal)".
final List<AggregateCall> aggCallList = new ArrayList<>();
final List<AggregateCall> aggCalls = aggregate.getAggCallList();
final int groupAndIndicatorCount = aggregate.getGroupCount() + aggregate.getIndicatorCount();
int i = groupAndIndicatorCount - 1;
for (AggregateCall aggCall : aggCalls) {
++i;
// COUNT(DISTINCT gender) or SUM(sal).
if (!aggCall.isDistinct()) {
continue;
}
if (!aggCall.getArgList().equals(argList)) {
continue;
}
// Re-map arguments.
final int argCount = aggCall.getArgList().size();
final List<Integer> newArgs = new ArrayList<>(argCount);
for (int j = 0; j < argCount; j++) {
final Integer arg = aggCall.getArgList().get(j);
newArgs.add(sourceOf.get(arg));
}
final int newFilterArg = aggCall.filterArg >= 0 ? sourceOf.get(aggCall.filterArg) : -1;
final AggregateCall newAggCall = AggregateCall.create(aggCall.getAggregation(), false, newArgs, newFilterArg, aggCall.getType(), aggCall.getName());
assert refs.get(i) == null;
if (n == 0) {
refs.set(i, new RexInputRef(groupAndIndicatorCount + aggCallList.size(), newAggCall.getType()));
} else {
refs.set(i, new RexInputRef(leftFields.size() + groupAndIndicatorCount + aggCallList.size(), newAggCall.getType()));
}
aggCallList.add(newAggCall);
}
final Map<Integer, Integer> map = new HashMap<>();
for (Integer key : aggregate.getGroupSet()) {
map.put(key, map.size());
}
final ImmutableBitSet newGroupSet = aggregate.getGroupSet().permute(map);
assert newGroupSet.equals(ImmutableBitSet.range(aggregate.getGroupSet().cardinality()));
ImmutableList<ImmutableBitSet> newGroupingSets = null;
if (aggregate.indicator) {
newGroupingSets = ImmutableBitSet.ORDERING.immutableSortedCopy(ImmutableBitSet.permute(aggregate.getGroupSets(), map));
}
relBuilder.push(aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), aggregate.indicator, newGroupSet, newGroupingSets, aggCallList));
// If there's no left child yet, no need to create the join
if (n == 0) {
return;
}
// Create the join condition. It is of the form
// 'left.f0 = right.f0 and left.f1 = right.f1 and ...'
// where {f0, f1, ...} are the GROUP BY fields.
final List<RelDataTypeField> distinctFields = relBuilder.peek().getRowType().getFieldList();
final List<RexNode> conditions = Lists.newArrayList();
for (i = 0; i < groupAndIndicatorCount; ++i) {
// null values form its own group
// use "is not distinct from" so that the join condition
// allows null values to match.
conditions.add(rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_DISTINCT_FROM, RexInputRef.of(i, leftFields), new RexInputRef(leftFields.size() + i, distinctFields.get(i).getType())));
}
// Join in the new 'select distinct' relation.
relBuilder.join(JoinRelType.INNER, conditions);
}
use of org.apache.calcite.rel.core.AggregateCall in project flink by apache.
the class FlinkAggregateExpandDistinctAggregatesRule method rewriteAggCalls.
private static void rewriteAggCalls(List<AggregateCall> newAggCalls, List<Integer> argList, Map<Integer, Integer> sourceOf) {
// "COUNT(DISTINCT e.sal)" becomes "COUNT(distinct_e.sal)".
for (int i = 0; i < newAggCalls.size(); i++) {
final AggregateCall aggCall = newAggCalls.get(i);
// COUNT(DISTINCT gender) or SUM(sal).
if (!aggCall.isDistinct()) {
continue;
}
if (!aggCall.getArgList().equals(argList)) {
continue;
}
// Re-map arguments.
final int argCount = aggCall.getArgList().size();
final List<Integer> newArgs = new ArrayList<>(argCount);
for (int j = 0; j < argCount; j++) {
final Integer arg = aggCall.getArgList().get(j);
newArgs.add(sourceOf.get(arg));
}
final AggregateCall newAggCall = AggregateCall.create(aggCall.getAggregation(), false, newArgs, -1, aggCall.getType(), aggCall.getName());
newAggCalls.set(i, newAggCall);
}
}
Aggregations