Search in sources :

Example 96 with RexNode

use of org.apache.calcite.rex.RexNode in project flink by apache.

the class FlinkAggregateExpandDistinctAggregatesRule method convertSingletonDistinct.

/**
	 * Converts an aggregate with one distinct aggregate and one or more
	 * non-distinct aggregates to multi-phase aggregates (see reference example
	 * below).
	 *
	 * @param relBuilder Contains the input relational expression
	 * @param aggregate  Original aggregate
	 * @param argLists   Arguments and filters to the distinct aggregate function
	 *
	 */
private RelBuilder convertSingletonDistinct(RelBuilder relBuilder, Aggregate aggregate, Set<Pair<List<Integer>, Integer>> argLists) {
    // For example,
    //	SELECT deptno, COUNT(*), SUM(bonus), MIN(DISTINCT sal)
    //	FROM emp
    //	GROUP BY deptno
    //
    // becomes
    //
    //	SELECT deptno, SUM(cnt), SUM(bonus), MIN(sal)
    //	FROM (
    //		  SELECT deptno, COUNT(*) as cnt, SUM(bonus), sal
    //		  FROM EMP
    //		  GROUP BY deptno, sal)			// Aggregate B
    //	GROUP BY deptno						// Aggregate A
    relBuilder.push(aggregate.getInput());
    final List<Pair<RexNode, String>> projects = new ArrayList<>();
    final Map<Integer, Integer> sourceOf = new HashMap<>();
    SortedSet<Integer> newGroupSet = new TreeSet<>();
    final List<RelDataTypeField> childFields = relBuilder.peek().getRowType().getFieldList();
    final boolean hasGroupBy = aggregate.getGroupSet().size() > 0;
    SortedSet<Integer> groupSet = new TreeSet<>(aggregate.getGroupSet().asList());
    // Add the distinct aggregate column(s) to the group-by columns,
    // if not already a part of the group-by
    newGroupSet.addAll(aggregate.getGroupSet().asList());
    for (Pair<List<Integer>, Integer> argList : argLists) {
        newGroupSet.addAll(argList.getKey());
    }
    // transformation.
    for (int arg : newGroupSet) {
        sourceOf.put(arg, projects.size());
        projects.add(RexInputRef.of2(arg, childFields));
    }
    // Generate the intermediate aggregate B
    final List<AggregateCall> aggCalls = aggregate.getAggCallList();
    final List<AggregateCall> newAggCalls = new ArrayList<>();
    final List<Integer> fakeArgs = new ArrayList<>();
    final Map<AggregateCall, Integer> callArgMap = new HashMap<>();
    // e.g. if real arguments are 0, 1, 3. Then the fake arguments will be 2, 4
    for (final AggregateCall aggCall : aggCalls) {
        if (!aggCall.isDistinct()) {
            for (int arg : aggCall.getArgList()) {
                if (!groupSet.contains(arg)) {
                    sourceOf.put(arg, projects.size());
                }
            }
        }
    }
    int fakeArg0 = 0;
    for (final AggregateCall aggCall : aggCalls) {
        // We will deal with non-distinct aggregates below
        if (!aggCall.isDistinct()) {
            boolean isGroupKeyUsedInAgg = false;
            for (int arg : aggCall.getArgList()) {
                if (groupSet.contains(arg)) {
                    isGroupKeyUsedInAgg = true;
                    break;
                }
            }
            if (aggCall.getArgList().size() == 0 || isGroupKeyUsedInAgg) {
                while (sourceOf.get(fakeArg0) != null) {
                    ++fakeArg0;
                }
                fakeArgs.add(fakeArg0);
                ++fakeArg0;
            }
        }
    }
    for (final AggregateCall aggCall : aggCalls) {
        if (!aggCall.isDistinct()) {
            for (int arg : aggCall.getArgList()) {
                if (!groupSet.contains(arg)) {
                    sourceOf.remove(arg);
                }
            }
        }
    }
    // Compute the remapped arguments using fake arguments for non-distinct
    // aggregates with no arguments e.g. count(*).
    int fakeArgIdx = 0;
    for (final AggregateCall aggCall : aggCalls) {
        // as-is all the non-distinct aggregates
        if (!aggCall.isDistinct()) {
            final AggregateCall newCall = AggregateCall.create(aggCall.getAggregation(), false, aggCall.getArgList(), -1, ImmutableBitSet.of(newGroupSet).cardinality(), relBuilder.peek(), null, aggCall.name);
            newAggCalls.add(newCall);
            if (newCall.getArgList().size() == 0) {
                int fakeArg = fakeArgs.get(fakeArgIdx);
                callArgMap.put(newCall, fakeArg);
                sourceOf.put(fakeArg, projects.size());
                projects.add(Pair.of((RexNode) new RexInputRef(fakeArg, newCall.getType()), newCall.getName()));
                ++fakeArgIdx;
            } else {
                for (int arg : newCall.getArgList()) {
                    if (groupSet.contains(arg)) {
                        int fakeArg = fakeArgs.get(fakeArgIdx);
                        callArgMap.put(newCall, fakeArg);
                        sourceOf.put(fakeArg, projects.size());
                        projects.add(Pair.of((RexNode) new RexInputRef(fakeArg, newCall.getType()), newCall.getName()));
                        ++fakeArgIdx;
                    } else {
                        sourceOf.put(arg, projects.size());
                        projects.add(Pair.of((RexNode) new RexInputRef(arg, newCall.getType()), newCall.getName()));
                    }
                }
            }
        }
    }
    // Generate the aggregate B (see the reference example above)
    relBuilder.push(aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), false, ImmutableBitSet.of(newGroupSet), null, newAggCalls));
    // Convert the existing aggregate to aggregate A (see the reference example above)
    final List<AggregateCall> newTopAggCalls = Lists.newArrayList(aggregate.getAggCallList());
    // Use the remapped arguments for the (non)distinct aggregate calls
    for (int i = 0; i < newTopAggCalls.size(); i++) {
        // Re-map arguments.
        final AggregateCall aggCall = newTopAggCalls.get(i);
        final int argCount = aggCall.getArgList().size();
        final List<Integer> newArgs = new ArrayList<>(argCount);
        final AggregateCall newCall;
        for (int j = 0; j < argCount; j++) {
            final Integer arg = aggCall.getArgList().get(j);
            if (callArgMap.containsKey(aggCall)) {
                newArgs.add(sourceOf.get(callArgMap.get(aggCall)));
            } else {
                newArgs.add(sourceOf.get(arg));
            }
        }
        if (aggCall.isDistinct()) {
            newCall = AggregateCall.create(aggCall.getAggregation(), false, newArgs, -1, aggregate.getGroupSet().cardinality(), relBuilder.peek(), aggCall.getType(), aggCall.name);
        } else {
            // aggregate A must be SUM. For other aggregates, it remains the same.
            if (aggCall.getAggregation() instanceof SqlCountAggFunction) {
                if (aggCall.getArgList().size() == 0) {
                    newArgs.add(sourceOf.get(callArgMap.get(aggCall)));
                }
                if (hasGroupBy) {
                    SqlSumAggFunction sumAgg = new SqlSumAggFunction(null);
                    newCall = AggregateCall.create(sumAgg, false, newArgs, -1, aggregate.getGroupSet().cardinality(), relBuilder.peek(), aggCall.getType(), aggCall.getName());
                } else {
                    SqlSumEmptyIsZeroAggFunction sumAgg = new SqlSumEmptyIsZeroAggFunction();
                    newCall = AggregateCall.create(sumAgg, false, newArgs, -1, aggregate.getGroupSet().cardinality(), relBuilder.peek(), aggCall.getType(), aggCall.getName());
                }
            } else {
                newCall = AggregateCall.create(aggCall.getAggregation(), false, newArgs, -1, aggregate.getGroupSet().cardinality(), relBuilder.peek(), aggCall.getType(), aggCall.name);
            }
        }
        newTopAggCalls.set(i, newCall);
    }
    // Populate the group-by keys with the remapped arguments for aggregate A
    newGroupSet.clear();
    for (int arg : aggregate.getGroupSet()) {
        newGroupSet.add(sourceOf.get(arg));
    }
    relBuilder.push(aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), aggregate.indicator, ImmutableBitSet.of(newGroupSet), null, newTopAggCalls));
    return relBuilder;
}
Also used : HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) SqlCountAggFunction(org.apache.calcite.sql.fun.SqlCountAggFunction) AggregateCall(org.apache.calcite.rel.core.AggregateCall) SqlSumEmptyIsZeroAggFunction(org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction) RelDataTypeField(org.apache.calcite.rel.type.RelDataTypeField) TreeSet(java.util.TreeSet) SqlSumAggFunction(org.apache.calcite.sql.fun.SqlSumAggFunction) RexInputRef(org.apache.calcite.rex.RexInputRef) ArrayList(java.util.ArrayList) ImmutableList(com.google.common.collect.ImmutableList) ImmutableIntList(org.apache.calcite.util.ImmutableIntList) List(java.util.List) Pair(org.apache.calcite.util.Pair) RexNode(org.apache.calcite.rex.RexNode)

Example 97 with RexNode

use of org.apache.calcite.rex.RexNode in project flink by apache.

the class FlinkAggregateExpandDistinctAggregatesRule method doRewrite.

/**
	 * Converts all distinct aggregate calls to a given set of arguments.
	 *
	 * <p>This method is called several times, one for each set of arguments.
	 * Each time it is called, it generates a JOIN to a new SELECT DISTINCT
	 * relational expression, and modifies the set of top-level calls.
	 *
	 * @param aggregate Original aggregate
	 * @param n		 Ordinal of this in a join. {@code relBuilder} contains the
	 *				  input relational expression (either the original
	 *				  aggregate, the output from the previous call to this
	 *				  method. {@code n} is 0 if we're converting the
	 *				  first distinct aggregate in a query with no non-distinct
	 *				  aggregates)
	 * @param argList   Arguments to the distinct aggregate function
	 * @param filterArg Argument that filters input to aggregate function, or -1
	 * @param refs	  Array of expressions which will be the projected by the
	 *				  result of this rule. Those relating to this arg list will
	 *				  be modified  @return Relational expression
	 */
private void doRewrite(RelBuilder relBuilder, Aggregate aggregate, int n, List<Integer> argList, int filterArg, List<RexInputRef> refs) {
    final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
    final List<RelDataTypeField> leftFields;
    if (n == 0) {
        leftFields = null;
    } else {
        leftFields = relBuilder.peek().getRowType().getFieldList();
    }
    // LogicalAggregate(
    //	 child,
    //	 {COUNT(DISTINCT 1), SUM(DISTINCT 1), SUM(2)})
    //
    // becomes
    //
    // LogicalAggregate(
    //	 LogicalJoin(
    //		 child,
    //		 LogicalAggregate(child, < all columns > {}),
    //		 INNER,
    //		 <f2 = f5>))
    //
    // E.g.
    //   SELECT deptno, SUM(DISTINCT sal), COUNT(DISTINCT gender), MAX(age)
    //   FROM Emps
    //   GROUP BY deptno
    //
    // becomes
    //
    //   SELECT e.deptno, adsal.sum_sal, adgender.count_gender, e.max_age
    //   FROM (
    //	 SELECT deptno, MAX(age) as max_age
    //	 FROM Emps GROUP BY deptno) AS e
    //   JOIN (
    //	 SELECT deptno, COUNT(gender) AS count_gender FROM (
    //	   SELECT DISTINCT deptno, gender FROM Emps) AS dgender
    //	 GROUP BY deptno) AS adgender
    //	 ON e.deptno = adgender.deptno
    //   JOIN (
    //	 SELECT deptno, SUM(sal) AS sum_sal FROM (
    //	   SELECT DISTINCT deptno, sal FROM Emps) AS dsal
    //	 GROUP BY deptno) AS adsal
    //   ON e.deptno = adsal.deptno
    //   GROUP BY e.deptno
    //
    // Note that if a query contains no non-distinct aggregates, then the
    // very first join/group by is omitted.  In the example above, if
    // MAX(age) is removed, then the sub-select of "e" is not needed, and
    // instead the two other group by's are joined to one another.
    // Project the columns of the GROUP BY plus the arguments
    // to the agg function.
    final Map<Integer, Integer> sourceOf = new HashMap<>();
    createSelectDistinct(relBuilder, aggregate, argList, filterArg, sourceOf);
    // Now compute the aggregate functions on top of the distinct dataset.
    // Each distinct agg becomes a non-distinct call to the corresponding
    // field from the right; for example,
    //   "COUNT(DISTINCT e.sal)"
    // becomes
    //   "COUNT(distinct_e.sal)".
    final List<AggregateCall> aggCallList = new ArrayList<>();
    final List<AggregateCall> aggCalls = aggregate.getAggCallList();
    final int groupAndIndicatorCount = aggregate.getGroupCount() + aggregate.getIndicatorCount();
    int i = groupAndIndicatorCount - 1;
    for (AggregateCall aggCall : aggCalls) {
        ++i;
        // COUNT(DISTINCT gender) or SUM(sal).
        if (!aggCall.isDistinct()) {
            continue;
        }
        if (!aggCall.getArgList().equals(argList)) {
            continue;
        }
        // Re-map arguments.
        final int argCount = aggCall.getArgList().size();
        final List<Integer> newArgs = new ArrayList<>(argCount);
        for (int j = 0; j < argCount; j++) {
            final Integer arg = aggCall.getArgList().get(j);
            newArgs.add(sourceOf.get(arg));
        }
        final int newFilterArg = aggCall.filterArg >= 0 ? sourceOf.get(aggCall.filterArg) : -1;
        final AggregateCall newAggCall = AggregateCall.create(aggCall.getAggregation(), false, newArgs, newFilterArg, aggCall.getType(), aggCall.getName());
        assert refs.get(i) == null;
        if (n == 0) {
            refs.set(i, new RexInputRef(groupAndIndicatorCount + aggCallList.size(), newAggCall.getType()));
        } else {
            refs.set(i, new RexInputRef(leftFields.size() + groupAndIndicatorCount + aggCallList.size(), newAggCall.getType()));
        }
        aggCallList.add(newAggCall);
    }
    final Map<Integer, Integer> map = new HashMap<>();
    for (Integer key : aggregate.getGroupSet()) {
        map.put(key, map.size());
    }
    final ImmutableBitSet newGroupSet = aggregate.getGroupSet().permute(map);
    assert newGroupSet.equals(ImmutableBitSet.range(aggregate.getGroupSet().cardinality()));
    ImmutableList<ImmutableBitSet> newGroupingSets = null;
    if (aggregate.indicator) {
        newGroupingSets = ImmutableBitSet.ORDERING.immutableSortedCopy(ImmutableBitSet.permute(aggregate.getGroupSets(), map));
    }
    relBuilder.push(aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), aggregate.indicator, newGroupSet, newGroupingSets, aggCallList));
    // If there's no left child yet, no need to create the join
    if (n == 0) {
        return;
    }
    // Create the join condition. It is of the form
    //  'left.f0 = right.f0 and left.f1 = right.f1 and ...'
    // where {f0, f1, ...} are the GROUP BY fields.
    final List<RelDataTypeField> distinctFields = relBuilder.peek().getRowType().getFieldList();
    final List<RexNode> conditions = Lists.newArrayList();
    for (i = 0; i < groupAndIndicatorCount; ++i) {
        // null values form its own group
        // use "is not distinct from" so that the join condition
        // allows null values to match.
        conditions.add(rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_DISTINCT_FROM, RexInputRef.of(i, leftFields), new RexInputRef(leftFields.size() + i, distinctFields.get(i).getType())));
    }
    // Join in the new 'select distinct' relation.
    relBuilder.join(JoinRelType.INNER, conditions);
}
Also used : ImmutableBitSet(org.apache.calcite.util.ImmutableBitSet) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) AggregateCall(org.apache.calcite.rel.core.AggregateCall) RelDataTypeField(org.apache.calcite.rel.type.RelDataTypeField) RexBuilder(org.apache.calcite.rex.RexBuilder) RexInputRef(org.apache.calcite.rex.RexInputRef) RexNode(org.apache.calcite.rex.RexNode)

Example 98 with RexNode

use of org.apache.calcite.rex.RexNode in project flink by apache.

the class FlinkAggregateExpandDistinctAggregatesRule method rewriteUsingGroupingSets.

/*
	public RelBuilder convertSingletonDistinct(RelBuilder relBuilder,
											   Aggregate aggregate, Set<Pair<List<Integer>, Integer>> argLists) {
		// For example,
		//	SELECT deptno, COUNT(*), SUM(bonus), MIN(DISTINCT sal)
		//	FROM emp
		//	GROUP BY deptno
		//
		// becomes
		//
		//	SELECT deptno, SUM(cnt), SUM(bonus), MIN(sal)
		//	FROM (
		//		  SELECT deptno, COUNT(*) as cnt, SUM(bonus), sal
		//		  FROM EMP
		//		  GROUP BY deptno, sal)			// Aggregate B
		//	GROUP BY deptno						// Aggregate A
		relBuilder.push(aggregate.getInput());
		final List<Pair<RexNode, String>> projects = new ArrayList<>();
		final Map<Integer, Integer> sourceOf = new HashMap<>();
		SortedSet<Integer> newGroupSet = new TreeSet<>();
		final List<RelDataTypeField> childFields =
				relBuilder.peek().getRowType().getFieldList();
		final boolean hasGroupBy = aggregate.getGroupSet().size() > 0;

		// Add the distinct aggregate column(s) to the group-by columns,
		// if not already a part of the group-by
		newGroupSet.addAll(aggregate.getGroupSet().asList());
		for (Pair<List<Integer>, Integer> argList : argLists) {
			newGroupSet.addAll(argList.getKey());
		}

		// Re-map the arguments to the aggregate A. These arguments will get
		// remapped because of the intermediate aggregate B generated as part of the
		// transformation.
		for (int arg : newGroupSet) {
			sourceOf.put(arg, projects.size());
			projects.add(RexInputRef.of2(arg, childFields));
		}
		// Generate the intermediate aggregate B
		final List<AggregateCall> aggCalls = aggregate.getAggCallList();
		final List<AggregateCall> newAggCalls = new ArrayList<>();
		final List<Integer> fakeArgs = new ArrayList<>();
		final Map<AggregateCall, Integer> callArgMap = new HashMap<>();
		// First identify the real arguments, then use the rest for fake arguments
		// e.g. if real arguments are 0, 1, 3. Then the fake arguments will be 2, 4
		for (final AggregateCall aggCall : aggCalls) {
			if (!aggCall.isDistinct()) {
				for (int arg : aggCall.getArgList()) {
					if (!sourceOf.containsKey(arg)) {
						sourceOf.put(arg, projects.size());
					}
				}
			}
		}
		int fakeArg0 = 0;
		for (final AggregateCall aggCall : aggCalls) {
			// We will deal with non-distinct aggregates below
			if (!aggCall.isDistinct()) {
				boolean isGroupKeyUsedInAgg = false;
				for (int arg : aggCall.getArgList()) {
					if (sourceOf.containsKey(arg)) {
						isGroupKeyUsedInAgg = true;
						break;
					}
				}
				if (aggCall.getArgList().size() == 0 || isGroupKeyUsedInAgg) {
					while (sourceOf.get(fakeArg0) != null) {
						++fakeArg0;
					}
					fakeArgs.add(fakeArg0);
				}
			}
		}
		for (final AggregateCall aggCall : aggCalls) {
			if (!aggCall.isDistinct()) {
				for (int arg : aggCall.getArgList()) {
					if (!sourceOf.containsKey(arg)) {
						sourceOf.remove(arg);
					}
				}
			}
		}
		// Compute the remapped arguments using fake arguments for non-distinct
		// aggregates with no arguments e.g. count(*).
		int fakeArgIdx = 0;
		for (final AggregateCall aggCall : aggCalls) {
			// Project the column corresponding to the distinct aggregate. Project
			// as-is all the non-distinct aggregates
			if (!aggCall.isDistinct()) {
				final AggregateCall newCall =
						AggregateCall.create(aggCall.getAggregation(), false,
								aggCall.getArgList(), -1,
								ImmutableBitSet.of(newGroupSet).cardinality(),
								relBuilder.peek(), null, aggCall.name);
				newAggCalls.add(newCall);
				if (newCall.getArgList().size() == 0) {
					int fakeArg = fakeArgs.get(fakeArgIdx);
					callArgMap.put(newCall, fakeArg);
					sourceOf.put(fakeArg, projects.size());
					projects.add(
							Pair.of((RexNode) new RexInputRef(fakeArg, newCall.getType()),
									newCall.getName()));
					++fakeArgIdx;
				} else {
					for (int arg : newCall.getArgList()) {
						if (sourceOf.containsKey(arg)) {
							int fakeArg = fakeArgs.get(fakeArgIdx);
							callArgMap.put(newCall, fakeArg);
							sourceOf.put(fakeArg, projects.size());
							projects.add(
									Pair.of((RexNode) new RexInputRef(fakeArg, newCall.getType()),
											newCall.getName()));
							++fakeArgIdx;
						} else {
							sourceOf.put(arg, projects.size());
							projects.add(
									Pair.of((RexNode) new RexInputRef(arg, newCall.getType()),
											newCall.getName()));
						}
					}
				}
			}
		}
		// Generate the aggregate B (see the reference example above)
		relBuilder.push(
				aggregate.copy(
						aggregate.getTraitSet(), relBuilder.build(),
						false, ImmutableBitSet.of(newGroupSet), null, newAggCalls));
		// Convert the existing aggregate to aggregate A (see the reference example above)
		final List<AggregateCall> newTopAggCalls =
				Lists.newArrayList(aggregate.getAggCallList());
		// Use the remapped arguments for the (non)distinct aggregate calls
		for (int i = 0; i < newTopAggCalls.size(); i++) {
			// Re-map arguments.
			final AggregateCall aggCall = newTopAggCalls.get(i);
			final int argCount = aggCall.getArgList().size();
			final List<Integer> newArgs = new ArrayList<>(argCount);
			final AggregateCall newCall;


			for (int j = 0; j < argCount; j++) {
				final Integer arg = aggCall.getArgList().get(j);
				if (callArgMap.containsKey(aggCall)) {
					newArgs.add(sourceOf.get(callArgMap.get(aggCall)));
				}
				else {
					newArgs.add(sourceOf.get(arg));
				}
			}
			if (aggCall.isDistinct()) {
				newCall =
						AggregateCall.create(aggCall.getAggregation(), false, newArgs,
								-1, aggregate.getGroupSet().cardinality(), relBuilder.peek(),
								aggCall.getType(), aggCall.name);
			} else {
				// If aggregate B had a COUNT aggregate call the corresponding aggregate at
				// aggregate A must be SUM. For other aggregates, it remains the same.
				if (aggCall.getAggregation() instanceof SqlCountAggFunction) {
					if (aggCall.getArgList().size() == 0) {
						newArgs.add(sourceOf.get(callArgMap.get(aggCall)));
					}
					if (hasGroupBy) {
						SqlSumAggFunction sumAgg = new SqlSumAggFunction(null);
						newCall =
								AggregateCall.create(sumAgg, false, newArgs, -1,
										aggregate.getGroupSet().cardinality(), relBuilder.peek(),
										aggCall.getType(), aggCall.getName());
					} else {
						SqlSumEmptyIsZeroAggFunction sumAgg = new SqlSumEmptyIsZeroAggFunction();
						newCall =
								AggregateCall.create(sumAgg, false, newArgs, -1,
										aggregate.getGroupSet().cardinality(), relBuilder.peek(),
										aggCall.getType(), aggCall.getName());
					}
				} else {
					newCall =
							AggregateCall.create(aggCall.getAggregation(), false, newArgs, -1,
									aggregate.getGroupSet().cardinality(),
									relBuilder.peek(), aggCall.getType(), aggCall.name);
				}
			}
			newTopAggCalls.set(i, newCall);
		}
		// Populate the group-by keys with the remapped arguments for aggregate A
		newGroupSet.clear();
		for (int arg : aggregate.getGroupSet()) {
			newGroupSet.add(sourceOf.get(arg));
		}
		relBuilder.push(
				aggregate.copy(aggregate.getTraitSet(),
						relBuilder.build(), aggregate.indicator,
						ImmutableBitSet.of(newGroupSet), null, newTopAggCalls));
		return relBuilder;
	}
	*/
@SuppressWarnings("DanglingJavadoc")
private void rewriteUsingGroupingSets(RelOptRuleCall call, Aggregate aggregate, Set<Pair<List<Integer>, Integer>> argLists) {
    final Set<ImmutableBitSet> groupSetTreeSet = new TreeSet<>(ImmutableBitSet.ORDERING);
    groupSetTreeSet.add(aggregate.getGroupSet());
    for (Pair<List<Integer>, Integer> argList : argLists) {
        groupSetTreeSet.add(ImmutableBitSet.of(argList.left).setIf(argList.right, argList.right >= 0).union(aggregate.getGroupSet()));
    }
    final ImmutableList<ImmutableBitSet> groupSets = ImmutableList.copyOf(groupSetTreeSet);
    final ImmutableBitSet fullGroupSet = ImmutableBitSet.union(groupSets);
    final List<AggregateCall> distinctAggCalls = new ArrayList<>();
    for (Pair<AggregateCall, String> aggCall : aggregate.getNamedAggCalls()) {
        if (!aggCall.left.isDistinct()) {
            distinctAggCalls.add(aggCall.left.rename(aggCall.right));
        }
    }
    final RelBuilder relBuilder = call.builder();
    relBuilder.push(aggregate.getInput());
    relBuilder.aggregate(relBuilder.groupKey(fullGroupSet, groupSets.size() > 1, groupSets), distinctAggCalls);
    final RelNode distinct = relBuilder.peek();
    final int groupCount = fullGroupSet.cardinality();
    final int indicatorCount = groupSets.size() > 1 ? groupCount : 0;
    final RelOptCluster cluster = aggregate.getCluster();
    final RexBuilder rexBuilder = cluster.getRexBuilder();
    final RelDataTypeFactory typeFactory = cluster.getTypeFactory();
    final RelDataType booleanType = typeFactory.createTypeWithNullability(typeFactory.createSqlType(SqlTypeName.BOOLEAN), false);
    final List<Pair<RexNode, String>> predicates = new ArrayList<>();
    final Map<ImmutableBitSet, Integer> filters = new HashMap<>();
    /** Function to register a filter for a group set. */
    class Registrar {

        RexNode group = null;

        private int register(ImmutableBitSet groupSet) {
            if (group == null) {
                group = makeGroup(groupCount - 1);
            }
            final RexNode node = rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, group, rexBuilder.makeExactLiteral(toNumber(remap(fullGroupSet, groupSet))));
            predicates.add(Pair.of(node, toString(groupSet)));
            return groupCount + indicatorCount + distinctAggCalls.size() + predicates.size() - 1;
        }

        private RexNode makeGroup(int i) {
            final RexInputRef ref = rexBuilder.makeInputRef(booleanType, groupCount + i);
            final RexNode kase = rexBuilder.makeCall(SqlStdOperatorTable.CASE, ref, rexBuilder.makeExactLiteral(BigDecimal.ZERO), rexBuilder.makeExactLiteral(TWO.pow(i)));
            if (i == 0) {
                return kase;
            } else {
                return rexBuilder.makeCall(SqlStdOperatorTable.PLUS, makeGroup(i - 1), kase);
            }
        }

        private BigDecimal toNumber(ImmutableBitSet bitSet) {
            BigDecimal n = BigDecimal.ZERO;
            for (int key : bitSet) {
                n = n.add(TWO.pow(key));
            }
            return n;
        }

        private String toString(ImmutableBitSet bitSet) {
            final StringBuilder buf = new StringBuilder("$i");
            for (int key : bitSet) {
                buf.append(key).append('_');
            }
            return buf.substring(0, buf.length() - 1);
        }
    }
    final Registrar registrar = new Registrar();
    for (ImmutableBitSet groupSet : groupSets) {
        filters.put(groupSet, registrar.register(groupSet));
    }
    if (!predicates.isEmpty()) {
        List<Pair<RexNode, String>> nodes = new ArrayList<>();
        for (RelDataTypeField f : relBuilder.peek().getRowType().getFieldList()) {
            final RexNode node = rexBuilder.makeInputRef(f.getType(), f.getIndex());
            nodes.add(Pair.of(node, f.getName()));
        }
        nodes.addAll(predicates);
        relBuilder.project(Pair.left(nodes), Pair.right(nodes));
    }
    int x = groupCount + indicatorCount;
    final List<AggregateCall> newCalls = new ArrayList<>();
    for (AggregateCall aggCall : aggregate.getAggCallList()) {
        final int newFilterArg;
        final List<Integer> newArgList;
        final SqlAggFunction aggregation;
        if (!aggCall.isDistinct()) {
            aggregation = SqlStdOperatorTable.MIN;
            newArgList = ImmutableIntList.of(x++);
            newFilterArg = filters.get(aggregate.getGroupSet());
        } else {
            aggregation = aggCall.getAggregation();
            newArgList = remap(fullGroupSet, aggCall.getArgList());
            newFilterArg = filters.get(ImmutableBitSet.of(aggCall.getArgList()).setIf(aggCall.filterArg, aggCall.filterArg >= 0).union(aggregate.getGroupSet()));
        }
        final AggregateCall newCall = AggregateCall.create(aggregation, false, newArgList, newFilterArg, aggregate.getGroupCount(), distinct, null, aggCall.name);
        newCalls.add(newCall);
    }
    relBuilder.aggregate(relBuilder.groupKey(remap(fullGroupSet, aggregate.getGroupSet()), aggregate.indicator, remap(fullGroupSet, aggregate.getGroupSets())), newCalls);
    relBuilder.convert(aggregate.getRowType(), true);
    call.transformTo(relBuilder.build());
}
Also used : RelOptCluster(org.apache.calcite.plan.RelOptCluster) ImmutableBitSet(org.apache.calcite.util.ImmutableBitSet) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) RelDataType(org.apache.calcite.rel.type.RelDataType) TreeSet(java.util.TreeSet) RelDataTypeFactory(org.apache.calcite.rel.type.RelDataTypeFactory) RexBuilder(org.apache.calcite.rex.RexBuilder) ArrayList(java.util.ArrayList) ImmutableList(com.google.common.collect.ImmutableList) ImmutableIntList(org.apache.calcite.util.ImmutableIntList) List(java.util.List) Pair(org.apache.calcite.util.Pair) RelBuilder(org.apache.calcite.tools.RelBuilder) SqlAggFunction(org.apache.calcite.sql.SqlAggFunction) BigDecimal(java.math.BigDecimal) AggregateCall(org.apache.calcite.rel.core.AggregateCall) RelDataTypeField(org.apache.calcite.rel.type.RelDataTypeField) RelNode(org.apache.calcite.rel.RelNode) RexInputRef(org.apache.calcite.rex.RexInputRef) RexNode(org.apache.calcite.rex.RexNode)

Example 99 with RexNode

use of org.apache.calcite.rex.RexNode in project flink by apache.

the class FlinkAggregateJoinTransposeRule method onMatch.

public void onMatch(RelOptRuleCall call) {
    final Aggregate aggregate = call.rel(0);
    final Join join = call.rel(1);
    final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
    final RelBuilder relBuilder = call.builder();
    // If any aggregate call has a filter, bail out
    for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
        if (aggregateCall.getAggregation().unwrap(SqlSplittableAggFunction.class) == null) {
            return;
        }
        if (aggregateCall.filterArg >= 0) {
            return;
        }
    }
    // aggregate operator
    if (join.getJoinType() != JoinRelType.INNER) {
        return;
    }
    if (!allowFunctions && !aggregate.getAggCallList().isEmpty()) {
        return;
    }
    // Do the columns used by the join appear in the output of the aggregate?
    final ImmutableBitSet aggregateColumns = aggregate.getGroupSet();
    final RelMetadataQuery mq = RelMetadataQuery.instance();
    final ImmutableBitSet keyColumns = keyColumns(aggregateColumns, mq.getPulledUpPredicates(join).pulledUpPredicates);
    final ImmutableBitSet joinColumns = RelOptUtil.InputFinder.bits(join.getCondition());
    final boolean allColumnsInAggregate = keyColumns.contains(joinColumns);
    final ImmutableBitSet belowAggregateColumns = aggregateColumns.union(joinColumns);
    // Split join condition
    final List<Integer> leftKeys = Lists.newArrayList();
    final List<Integer> rightKeys = Lists.newArrayList();
    final List<Boolean> filterNulls = Lists.newArrayList();
    RexNode nonEquiConj = RelOptUtil.splitJoinCondition(join.getLeft(), join.getRight(), join.getCondition(), leftKeys, rightKeys, filterNulls);
    // If it contains non-equi join conditions, we bail out
    if (!nonEquiConj.isAlwaysTrue()) {
        return;
    }
    // Push each aggregate function down to each side that contains all of its
    // arguments. Note that COUNT(*), because it has no arguments, can go to
    // both sides.
    final Map<Integer, Integer> map = new HashMap<>();
    final List<Side> sides = new ArrayList<>();
    int uniqueCount = 0;
    int offset = 0;
    int belowOffset = 0;
    for (int s = 0; s < 2; s++) {
        final Side side = new Side();
        final RelNode joinInput = join.getInput(s);
        int fieldCount = joinInput.getRowType().getFieldCount();
        final ImmutableBitSet fieldSet = ImmutableBitSet.range(offset, offset + fieldCount);
        final ImmutableBitSet belowAggregateKeyNotShifted = belowAggregateColumns.intersect(fieldSet);
        for (Ord<Integer> c : Ord.zip(belowAggregateKeyNotShifted)) {
            map.put(c.e, belowOffset + c.i);
        }
        final ImmutableBitSet belowAggregateKey = belowAggregateKeyNotShifted.shift(-offset);
        final boolean unique;
        if (!allowFunctions) {
            assert aggregate.getAggCallList().isEmpty();
            // If there are no functions, it doesn't matter as much whether we
            // aggregate the inputs before the join, because there will not be
            // any functions experiencing a cartesian product effect.
            //
            // But finding out whether the input is already unique requires a call
            // to areColumnsUnique that currently (until [CALCITE-1048] "Make
            // metadata more robust" is fixed) places a heavy load on
            // the metadata system.
            //
            // So we choose to imagine the the input is already unique, which is
            // untrue but harmless.
            //
            Util.discard(Bug.CALCITE_1048_FIXED);
            unique = true;
        } else {
            final Boolean unique0 = mq.areColumnsUnique(joinInput, belowAggregateKey);
            unique = unique0 != null && unique0;
        }
        if (unique) {
            ++uniqueCount;
            side.aggregate = false;
            side.newInput = joinInput;
        } else {
            side.aggregate = true;
            List<AggregateCall> belowAggCalls = new ArrayList<>();
            final SqlSplittableAggFunction.Registry<AggregateCall> belowAggCallRegistry = registry(belowAggCalls);
            final Mappings.TargetMapping mapping = s == 0 ? Mappings.createIdentity(fieldCount) : Mappings.createShiftMapping(fieldCount + offset, 0, offset, fieldCount);
            for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) {
                final SqlAggFunction aggregation = aggCall.e.getAggregation();
                final SqlSplittableAggFunction splitter = Preconditions.checkNotNull(aggregation.unwrap(SqlSplittableAggFunction.class));
                final AggregateCall call1;
                if (fieldSet.contains(ImmutableBitSet.of(aggCall.e.getArgList()))) {
                    call1 = splitter.split(aggCall.e, mapping);
                } else {
                    call1 = splitter.other(rexBuilder.getTypeFactory(), aggCall.e);
                }
                if (call1 != null) {
                    side.split.put(aggCall.i, belowAggregateKey.cardinality() + belowAggCallRegistry.register(call1));
                }
            }
            side.newInput = relBuilder.push(joinInput).aggregate(relBuilder.groupKey(belowAggregateKey, false, null), belowAggCalls).build();
        }
        offset += fieldCount;
        belowOffset += side.newInput.getRowType().getFieldCount();
        sides.add(side);
    }
    if (uniqueCount == 2) {
        // invocation of this rule; if we continue we might loop forever.
        return;
    }
    // Update condition
    final Mapping mapping = (Mapping) Mappings.target(new Function<Integer, Integer>() {

        public Integer apply(Integer a0) {
            return map.get(a0);
        }
    }, join.getRowType().getFieldCount(), belowOffset);
    final RexNode newCondition = RexUtil.apply(mapping, join.getCondition());
    // Create new join
    relBuilder.push(sides.get(0).newInput).push(sides.get(1).newInput).join(join.getJoinType(), newCondition);
    // Aggregate above to sum up the sub-totals
    final List<AggregateCall> newAggCalls = new ArrayList<>();
    final int groupIndicatorCount = aggregate.getGroupCount() + aggregate.getIndicatorCount();
    final int newLeftWidth = sides.get(0).newInput.getRowType().getFieldCount();
    final List<RexNode> projects = new ArrayList<>(rexBuilder.identityProjects(relBuilder.peek().getRowType()));
    for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) {
        final SqlAggFunction aggregation = aggCall.e.getAggregation();
        final SqlSplittableAggFunction splitter = Preconditions.checkNotNull(aggregation.unwrap(SqlSplittableAggFunction.class));
        final Integer leftSubTotal = sides.get(0).split.get(aggCall.i);
        final Integer rightSubTotal = sides.get(1).split.get(aggCall.i);
        newAggCalls.add(splitter.topSplit(rexBuilder, registry(projects), groupIndicatorCount, relBuilder.peek().getRowType(), aggCall.e, leftSubTotal == null ? -1 : leftSubTotal, rightSubTotal == null ? -1 : rightSubTotal + newLeftWidth));
    }
    relBuilder.project(projects);
    boolean aggConvertedToProjects = false;
    if (allColumnsInAggregate) {
        // let's see if we can convert aggregate into projects
        List<RexNode> projects2 = new ArrayList<>();
        for (int key : Mappings.apply(mapping, aggregate.getGroupSet())) {
            projects2.add(relBuilder.field(key));
        }
        for (AggregateCall newAggCall : newAggCalls) {
            final SqlSplittableAggFunction splitter = newAggCall.getAggregation().unwrap(SqlSplittableAggFunction.class);
            if (splitter != null) {
                projects2.add(splitter.singleton(rexBuilder, relBuilder.peek().getRowType(), newAggCall));
            }
        }
        if (projects2.size() == aggregate.getGroupSet().cardinality() + newAggCalls.size()) {
            // We successfully converted agg calls into projects.
            relBuilder.project(projects2);
            aggConvertedToProjects = true;
        }
    }
    if (!aggConvertedToProjects) {
        relBuilder.aggregate(relBuilder.groupKey(Mappings.apply(mapping, aggregate.getGroupSet()), aggregate.indicator, Mappings.apply2(mapping, aggregate.getGroupSets())), newAggCalls);
    }
    call.transformTo(relBuilder.build());
}
Also used : RelMetadataQuery(org.apache.calcite.rel.metadata.RelMetadataQuery) ImmutableBitSet(org.apache.calcite.util.ImmutableBitSet) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) Mapping(org.apache.calcite.util.mapping.Mapping) Function(com.google.common.base.Function) SqlSplittableAggFunction(org.apache.calcite.sql.SqlSplittableAggFunction) SqlAggFunction(org.apache.calcite.sql.SqlAggFunction) RexBuilder(org.apache.calcite.rex.RexBuilder) SqlSplittableAggFunction(org.apache.calcite.sql.SqlSplittableAggFunction) RelBuilder(org.apache.calcite.tools.RelBuilder) Join(org.apache.calcite.rel.core.Join) LogicalJoin(org.apache.calcite.rel.logical.LogicalJoin) SqlAggFunction(org.apache.calcite.sql.SqlAggFunction) AggregateCall(org.apache.calcite.rel.core.AggregateCall) RelNode(org.apache.calcite.rel.RelNode) Mappings(org.apache.calcite.util.mapping.Mappings) Aggregate(org.apache.calcite.rel.core.Aggregate) LogicalAggregate(org.apache.calcite.rel.logical.LogicalAggregate) RexNode(org.apache.calcite.rex.RexNode)

Example 100 with RexNode

use of org.apache.calcite.rex.RexNode in project flink by apache.

the class FlinkAggregateJoinTransposeRule method keyColumns.

/**
	 * Computes the closure of a set of columns according to a given list of
	 * constraints. Each 'x = y' constraint causes bit y to be set if bit x is
	 * set, and vice versa.
	 */
private static ImmutableBitSet keyColumns(ImmutableBitSet aggregateColumns, ImmutableList<RexNode> predicates) {
    SortedMap<Integer, BitSet> equivalence = new TreeMap<>();
    for (RexNode pred : predicates) {
        populateEquivalences(equivalence, pred);
    }
    ImmutableBitSet keyColumns = aggregateColumns;
    for (Integer aggregateColumn : aggregateColumns) {
        final BitSet bitSet = equivalence.get(aggregateColumn);
        if (bitSet != null) {
            keyColumns = keyColumns.union(bitSet);
        }
    }
    return keyColumns;
}
Also used : ImmutableBitSet(org.apache.calcite.util.ImmutableBitSet) ImmutableBitSet(org.apache.calcite.util.ImmutableBitSet) BitSet(java.util.BitSet) TreeMap(java.util.TreeMap) RexNode(org.apache.calcite.rex.RexNode)

Aggregations

RexNode (org.apache.calcite.rex.RexNode)204 RelNode (org.apache.calcite.rel.RelNode)75 ArrayList (java.util.ArrayList)68 RexBuilder (org.apache.calcite.rex.RexBuilder)52 RelDataType (org.apache.calcite.rel.type.RelDataType)48 RexInputRef (org.apache.calcite.rex.RexInputRef)43 RelDataTypeField (org.apache.calcite.rel.type.RelDataTypeField)41 RexCall (org.apache.calcite.rex.RexCall)32 ImmutableBitSet (org.apache.calcite.util.ImmutableBitSet)28 Pair (org.apache.calcite.util.Pair)22 HashMap (java.util.HashMap)21 AggregateCall (org.apache.calcite.rel.core.AggregateCall)21 RexLiteral (org.apache.calcite.rex.RexLiteral)19 ImmutableList (com.google.common.collect.ImmutableList)18 Project (org.apache.calcite.rel.core.Project)17 RelDataTypeFactory (org.apache.calcite.rel.type.RelDataTypeFactory)14 Filter (org.apache.calcite.rel.core.Filter)12 HashSet (java.util.HashSet)11 CalciteSemanticException (org.apache.hadoop.hive.ql.optimizer.calcite.CalciteSemanticException)11 BigDecimal (java.math.BigDecimal)10