Search in sources :

Example 31 with RexBuilder

use of org.apache.calcite.rex.RexBuilder in project drill by apache.

the class DrillReduceAggregatesRule method reduceAgg.

private RexNode reduceAgg(Aggregate oldAggRel, AggregateCall oldCall, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping, List<RexNode> inputExprs) {
    final SqlAggFunction sqlAggFunction = DrillCalciteWrapperUtility.extractSqlOperatorFromWrapper(oldCall.getAggregation());
    if (sqlAggFunction instanceof SqlSumAggFunction) {
        // case COUNT(x) when 0 then null else SUM0(x) end
        return reduceSum(oldAggRel, oldCall, newCalls, aggCallMapping);
    }
    if (sqlAggFunction instanceof SqlAvgAggFunction) {
        final SqlAvgAggFunction.Subtype subtype = ((SqlAvgAggFunction) sqlAggFunction).getSubtype();
        switch(subtype) {
            case AVG:
                // replace original AVG(x) with SUM(x) / COUNT(x)
                return reduceAvg(oldAggRel, oldCall, newCalls, aggCallMapping);
            case STDDEV_POP:
                //     / COUNT(x))
                return reduceStddev(oldAggRel, oldCall, true, true, newCalls, aggCallMapping, inputExprs);
            case STDDEV_SAMP:
                //     / CASE COUNT(x) WHEN 1 THEN NULL ELSE COUNT(x) - 1 END)
                return reduceStddev(oldAggRel, oldCall, false, true, newCalls, aggCallMapping, inputExprs);
            case VAR_POP:
                //     / COUNT(x)
                return reduceStddev(oldAggRel, oldCall, true, false, newCalls, aggCallMapping, inputExprs);
            case VAR_SAMP:
                //     / CASE COUNT(x) WHEN 1 THEN NULL ELSE COUNT(x) - 1 END
                return reduceStddev(oldAggRel, oldCall, false, false, newCalls, aggCallMapping, inputExprs);
            default:
                throw Util.unexpected(subtype);
        }
    } else {
        // anything else:  preserve original call
        RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
        final int nGroups = oldAggRel.getGroupCount();
        List<RelDataType> oldArgTypes = new ArrayList<>();
        List<Integer> ordinals = oldCall.getArgList();
        assert ordinals.size() <= inputExprs.size();
        for (int ordinal : ordinals) {
            oldArgTypes.add(inputExprs.get(ordinal).getType());
        }
        return rexBuilder.addAggCall(oldCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, oldArgTypes);
    }
}
Also used : SqlAvgAggFunction(org.apache.calcite.sql.fun.SqlAvgAggFunction) SqlSumAggFunction(org.apache.calcite.sql.fun.SqlSumAggFunction) ArrayList(java.util.ArrayList) RexBuilder(org.apache.calcite.rex.RexBuilder) RelDataType(org.apache.calcite.rel.type.RelDataType) SqlAggFunction(org.apache.calcite.sql.SqlAggFunction)

Example 32 with RexBuilder

use of org.apache.calcite.rex.RexBuilder in project flink by apache.

the class FlinkAggregateExpandDistinctAggregatesRule method doRewrite.

/**
	 * Converts all distinct aggregate calls to a given set of arguments.
	 *
	 * <p>This method is called several times, one for each set of arguments.
	 * Each time it is called, it generates a JOIN to a new SELECT DISTINCT
	 * relational expression, and modifies the set of top-level calls.
	 *
	 * @param aggregate Original aggregate
	 * @param n		 Ordinal of this in a join. {@code relBuilder} contains the
	 *				  input relational expression (either the original
	 *				  aggregate, the output from the previous call to this
	 *				  method. {@code n} is 0 if we're converting the
	 *				  first distinct aggregate in a query with no non-distinct
	 *				  aggregates)
	 * @param argList   Arguments to the distinct aggregate function
	 * @param filterArg Argument that filters input to aggregate function, or -1
	 * @param refs	  Array of expressions which will be the projected by the
	 *				  result of this rule. Those relating to this arg list will
	 *				  be modified  @return Relational expression
	 */
private void doRewrite(RelBuilder relBuilder, Aggregate aggregate, int n, List<Integer> argList, int filterArg, List<RexInputRef> refs) {
    final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
    final List<RelDataTypeField> leftFields;
    if (n == 0) {
        leftFields = null;
    } else {
        leftFields = relBuilder.peek().getRowType().getFieldList();
    }
    // LogicalAggregate(
    //	 child,
    //	 {COUNT(DISTINCT 1), SUM(DISTINCT 1), SUM(2)})
    //
    // becomes
    //
    // LogicalAggregate(
    //	 LogicalJoin(
    //		 child,
    //		 LogicalAggregate(child, < all columns > {}),
    //		 INNER,
    //		 <f2 = f5>))
    //
    // E.g.
    //   SELECT deptno, SUM(DISTINCT sal), COUNT(DISTINCT gender), MAX(age)
    //   FROM Emps
    //   GROUP BY deptno
    //
    // becomes
    //
    //   SELECT e.deptno, adsal.sum_sal, adgender.count_gender, e.max_age
    //   FROM (
    //	 SELECT deptno, MAX(age) as max_age
    //	 FROM Emps GROUP BY deptno) AS e
    //   JOIN (
    //	 SELECT deptno, COUNT(gender) AS count_gender FROM (
    //	   SELECT DISTINCT deptno, gender FROM Emps) AS dgender
    //	 GROUP BY deptno) AS adgender
    //	 ON e.deptno = adgender.deptno
    //   JOIN (
    //	 SELECT deptno, SUM(sal) AS sum_sal FROM (
    //	   SELECT DISTINCT deptno, sal FROM Emps) AS dsal
    //	 GROUP BY deptno) AS adsal
    //   ON e.deptno = adsal.deptno
    //   GROUP BY e.deptno
    //
    // Note that if a query contains no non-distinct aggregates, then the
    // very first join/group by is omitted.  In the example above, if
    // MAX(age) is removed, then the sub-select of "e" is not needed, and
    // instead the two other group by's are joined to one another.
    // Project the columns of the GROUP BY plus the arguments
    // to the agg function.
    final Map<Integer, Integer> sourceOf = new HashMap<>();
    createSelectDistinct(relBuilder, aggregate, argList, filterArg, sourceOf);
    // Now compute the aggregate functions on top of the distinct dataset.
    // Each distinct agg becomes a non-distinct call to the corresponding
    // field from the right; for example,
    //   "COUNT(DISTINCT e.sal)"
    // becomes
    //   "COUNT(distinct_e.sal)".
    final List<AggregateCall> aggCallList = new ArrayList<>();
    final List<AggregateCall> aggCalls = aggregate.getAggCallList();
    final int groupAndIndicatorCount = aggregate.getGroupCount() + aggregate.getIndicatorCount();
    int i = groupAndIndicatorCount - 1;
    for (AggregateCall aggCall : aggCalls) {
        ++i;
        // COUNT(DISTINCT gender) or SUM(sal).
        if (!aggCall.isDistinct()) {
            continue;
        }
        if (!aggCall.getArgList().equals(argList)) {
            continue;
        }
        // Re-map arguments.
        final int argCount = aggCall.getArgList().size();
        final List<Integer> newArgs = new ArrayList<>(argCount);
        for (int j = 0; j < argCount; j++) {
            final Integer arg = aggCall.getArgList().get(j);
            newArgs.add(sourceOf.get(arg));
        }
        final int newFilterArg = aggCall.filterArg >= 0 ? sourceOf.get(aggCall.filterArg) : -1;
        final AggregateCall newAggCall = AggregateCall.create(aggCall.getAggregation(), false, newArgs, newFilterArg, aggCall.getType(), aggCall.getName());
        assert refs.get(i) == null;
        if (n == 0) {
            refs.set(i, new RexInputRef(groupAndIndicatorCount + aggCallList.size(), newAggCall.getType()));
        } else {
            refs.set(i, new RexInputRef(leftFields.size() + groupAndIndicatorCount + aggCallList.size(), newAggCall.getType()));
        }
        aggCallList.add(newAggCall);
    }
    final Map<Integer, Integer> map = new HashMap<>();
    for (Integer key : aggregate.getGroupSet()) {
        map.put(key, map.size());
    }
    final ImmutableBitSet newGroupSet = aggregate.getGroupSet().permute(map);
    assert newGroupSet.equals(ImmutableBitSet.range(aggregate.getGroupSet().cardinality()));
    ImmutableList<ImmutableBitSet> newGroupingSets = null;
    if (aggregate.indicator) {
        newGroupingSets = ImmutableBitSet.ORDERING.immutableSortedCopy(ImmutableBitSet.permute(aggregate.getGroupSets(), map));
    }
    relBuilder.push(aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), aggregate.indicator, newGroupSet, newGroupingSets, aggCallList));
    // If there's no left child yet, no need to create the join
    if (n == 0) {
        return;
    }
    // Create the join condition. It is of the form
    //  'left.f0 = right.f0 and left.f1 = right.f1 and ...'
    // where {f0, f1, ...} are the GROUP BY fields.
    final List<RelDataTypeField> distinctFields = relBuilder.peek().getRowType().getFieldList();
    final List<RexNode> conditions = Lists.newArrayList();
    for (i = 0; i < groupAndIndicatorCount; ++i) {
        // null values form its own group
        // use "is not distinct from" so that the join condition
        // allows null values to match.
        conditions.add(rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_DISTINCT_FROM, RexInputRef.of(i, leftFields), new RexInputRef(leftFields.size() + i, distinctFields.get(i).getType())));
    }
    // Join in the new 'select distinct' relation.
    relBuilder.join(JoinRelType.INNER, conditions);
}
Also used : ImmutableBitSet(org.apache.calcite.util.ImmutableBitSet) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) AggregateCall(org.apache.calcite.rel.core.AggregateCall) RelDataTypeField(org.apache.calcite.rel.type.RelDataTypeField) RexBuilder(org.apache.calcite.rex.RexBuilder) RexInputRef(org.apache.calcite.rex.RexInputRef) RexNode(org.apache.calcite.rex.RexNode)

Example 33 with RexBuilder

use of org.apache.calcite.rex.RexBuilder in project flink by apache.

the class FlinkAggregateExpandDistinctAggregatesRule method rewriteUsingGroupingSets.

/*
	public RelBuilder convertSingletonDistinct(RelBuilder relBuilder,
											   Aggregate aggregate, Set<Pair<List<Integer>, Integer>> argLists) {
		// For example,
		//	SELECT deptno, COUNT(*), SUM(bonus), MIN(DISTINCT sal)
		//	FROM emp
		//	GROUP BY deptno
		//
		// becomes
		//
		//	SELECT deptno, SUM(cnt), SUM(bonus), MIN(sal)
		//	FROM (
		//		  SELECT deptno, COUNT(*) as cnt, SUM(bonus), sal
		//		  FROM EMP
		//		  GROUP BY deptno, sal)			// Aggregate B
		//	GROUP BY deptno						// Aggregate A
		relBuilder.push(aggregate.getInput());
		final List<Pair<RexNode, String>> projects = new ArrayList<>();
		final Map<Integer, Integer> sourceOf = new HashMap<>();
		SortedSet<Integer> newGroupSet = new TreeSet<>();
		final List<RelDataTypeField> childFields =
				relBuilder.peek().getRowType().getFieldList();
		final boolean hasGroupBy = aggregate.getGroupSet().size() > 0;

		// Add the distinct aggregate column(s) to the group-by columns,
		// if not already a part of the group-by
		newGroupSet.addAll(aggregate.getGroupSet().asList());
		for (Pair<List<Integer>, Integer> argList : argLists) {
			newGroupSet.addAll(argList.getKey());
		}

		// Re-map the arguments to the aggregate A. These arguments will get
		// remapped because of the intermediate aggregate B generated as part of the
		// transformation.
		for (int arg : newGroupSet) {
			sourceOf.put(arg, projects.size());
			projects.add(RexInputRef.of2(arg, childFields));
		}
		// Generate the intermediate aggregate B
		final List<AggregateCall> aggCalls = aggregate.getAggCallList();
		final List<AggregateCall> newAggCalls = new ArrayList<>();
		final List<Integer> fakeArgs = new ArrayList<>();
		final Map<AggregateCall, Integer> callArgMap = new HashMap<>();
		// First identify the real arguments, then use the rest for fake arguments
		// e.g. if real arguments are 0, 1, 3. Then the fake arguments will be 2, 4
		for (final AggregateCall aggCall : aggCalls) {
			if (!aggCall.isDistinct()) {
				for (int arg : aggCall.getArgList()) {
					if (!sourceOf.containsKey(arg)) {
						sourceOf.put(arg, projects.size());
					}
				}
			}
		}
		int fakeArg0 = 0;
		for (final AggregateCall aggCall : aggCalls) {
			// We will deal with non-distinct aggregates below
			if (!aggCall.isDistinct()) {
				boolean isGroupKeyUsedInAgg = false;
				for (int arg : aggCall.getArgList()) {
					if (sourceOf.containsKey(arg)) {
						isGroupKeyUsedInAgg = true;
						break;
					}
				}
				if (aggCall.getArgList().size() == 0 || isGroupKeyUsedInAgg) {
					while (sourceOf.get(fakeArg0) != null) {
						++fakeArg0;
					}
					fakeArgs.add(fakeArg0);
				}
			}
		}
		for (final AggregateCall aggCall : aggCalls) {
			if (!aggCall.isDistinct()) {
				for (int arg : aggCall.getArgList()) {
					if (!sourceOf.containsKey(arg)) {
						sourceOf.remove(arg);
					}
				}
			}
		}
		// Compute the remapped arguments using fake arguments for non-distinct
		// aggregates with no arguments e.g. count(*).
		int fakeArgIdx = 0;
		for (final AggregateCall aggCall : aggCalls) {
			// Project the column corresponding to the distinct aggregate. Project
			// as-is all the non-distinct aggregates
			if (!aggCall.isDistinct()) {
				final AggregateCall newCall =
						AggregateCall.create(aggCall.getAggregation(), false,
								aggCall.getArgList(), -1,
								ImmutableBitSet.of(newGroupSet).cardinality(),
								relBuilder.peek(), null, aggCall.name);
				newAggCalls.add(newCall);
				if (newCall.getArgList().size() == 0) {
					int fakeArg = fakeArgs.get(fakeArgIdx);
					callArgMap.put(newCall, fakeArg);
					sourceOf.put(fakeArg, projects.size());
					projects.add(
							Pair.of((RexNode) new RexInputRef(fakeArg, newCall.getType()),
									newCall.getName()));
					++fakeArgIdx;
				} else {
					for (int arg : newCall.getArgList()) {
						if (sourceOf.containsKey(arg)) {
							int fakeArg = fakeArgs.get(fakeArgIdx);
							callArgMap.put(newCall, fakeArg);
							sourceOf.put(fakeArg, projects.size());
							projects.add(
									Pair.of((RexNode) new RexInputRef(fakeArg, newCall.getType()),
											newCall.getName()));
							++fakeArgIdx;
						} else {
							sourceOf.put(arg, projects.size());
							projects.add(
									Pair.of((RexNode) new RexInputRef(arg, newCall.getType()),
											newCall.getName()));
						}
					}
				}
			}
		}
		// Generate the aggregate B (see the reference example above)
		relBuilder.push(
				aggregate.copy(
						aggregate.getTraitSet(), relBuilder.build(),
						false, ImmutableBitSet.of(newGroupSet), null, newAggCalls));
		// Convert the existing aggregate to aggregate A (see the reference example above)
		final List<AggregateCall> newTopAggCalls =
				Lists.newArrayList(aggregate.getAggCallList());
		// Use the remapped arguments for the (non)distinct aggregate calls
		for (int i = 0; i < newTopAggCalls.size(); i++) {
			// Re-map arguments.
			final AggregateCall aggCall = newTopAggCalls.get(i);
			final int argCount = aggCall.getArgList().size();
			final List<Integer> newArgs = new ArrayList<>(argCount);
			final AggregateCall newCall;


			for (int j = 0; j < argCount; j++) {
				final Integer arg = aggCall.getArgList().get(j);
				if (callArgMap.containsKey(aggCall)) {
					newArgs.add(sourceOf.get(callArgMap.get(aggCall)));
				}
				else {
					newArgs.add(sourceOf.get(arg));
				}
			}
			if (aggCall.isDistinct()) {
				newCall =
						AggregateCall.create(aggCall.getAggregation(), false, newArgs,
								-1, aggregate.getGroupSet().cardinality(), relBuilder.peek(),
								aggCall.getType(), aggCall.name);
			} else {
				// If aggregate B had a COUNT aggregate call the corresponding aggregate at
				// aggregate A must be SUM. For other aggregates, it remains the same.
				if (aggCall.getAggregation() instanceof SqlCountAggFunction) {
					if (aggCall.getArgList().size() == 0) {
						newArgs.add(sourceOf.get(callArgMap.get(aggCall)));
					}
					if (hasGroupBy) {
						SqlSumAggFunction sumAgg = new SqlSumAggFunction(null);
						newCall =
								AggregateCall.create(sumAgg, false, newArgs, -1,
										aggregate.getGroupSet().cardinality(), relBuilder.peek(),
										aggCall.getType(), aggCall.getName());
					} else {
						SqlSumEmptyIsZeroAggFunction sumAgg = new SqlSumEmptyIsZeroAggFunction();
						newCall =
								AggregateCall.create(sumAgg, false, newArgs, -1,
										aggregate.getGroupSet().cardinality(), relBuilder.peek(),
										aggCall.getType(), aggCall.getName());
					}
				} else {
					newCall =
							AggregateCall.create(aggCall.getAggregation(), false, newArgs, -1,
									aggregate.getGroupSet().cardinality(),
									relBuilder.peek(), aggCall.getType(), aggCall.name);
				}
			}
			newTopAggCalls.set(i, newCall);
		}
		// Populate the group-by keys with the remapped arguments for aggregate A
		newGroupSet.clear();
		for (int arg : aggregate.getGroupSet()) {
			newGroupSet.add(sourceOf.get(arg));
		}
		relBuilder.push(
				aggregate.copy(aggregate.getTraitSet(),
						relBuilder.build(), aggregate.indicator,
						ImmutableBitSet.of(newGroupSet), null, newTopAggCalls));
		return relBuilder;
	}
	*/
@SuppressWarnings("DanglingJavadoc")
private void rewriteUsingGroupingSets(RelOptRuleCall call, Aggregate aggregate, Set<Pair<List<Integer>, Integer>> argLists) {
    final Set<ImmutableBitSet> groupSetTreeSet = new TreeSet<>(ImmutableBitSet.ORDERING);
    groupSetTreeSet.add(aggregate.getGroupSet());
    for (Pair<List<Integer>, Integer> argList : argLists) {
        groupSetTreeSet.add(ImmutableBitSet.of(argList.left).setIf(argList.right, argList.right >= 0).union(aggregate.getGroupSet()));
    }
    final ImmutableList<ImmutableBitSet> groupSets = ImmutableList.copyOf(groupSetTreeSet);
    final ImmutableBitSet fullGroupSet = ImmutableBitSet.union(groupSets);
    final List<AggregateCall> distinctAggCalls = new ArrayList<>();
    for (Pair<AggregateCall, String> aggCall : aggregate.getNamedAggCalls()) {
        if (!aggCall.left.isDistinct()) {
            distinctAggCalls.add(aggCall.left.rename(aggCall.right));
        }
    }
    final RelBuilder relBuilder = call.builder();
    relBuilder.push(aggregate.getInput());
    relBuilder.aggregate(relBuilder.groupKey(fullGroupSet, groupSets.size() > 1, groupSets), distinctAggCalls);
    final RelNode distinct = relBuilder.peek();
    final int groupCount = fullGroupSet.cardinality();
    final int indicatorCount = groupSets.size() > 1 ? groupCount : 0;
    final RelOptCluster cluster = aggregate.getCluster();
    final RexBuilder rexBuilder = cluster.getRexBuilder();
    final RelDataTypeFactory typeFactory = cluster.getTypeFactory();
    final RelDataType booleanType = typeFactory.createTypeWithNullability(typeFactory.createSqlType(SqlTypeName.BOOLEAN), false);
    final List<Pair<RexNode, String>> predicates = new ArrayList<>();
    final Map<ImmutableBitSet, Integer> filters = new HashMap<>();
    /** Function to register a filter for a group set. */
    class Registrar {

        RexNode group = null;

        private int register(ImmutableBitSet groupSet) {
            if (group == null) {
                group = makeGroup(groupCount - 1);
            }
            final RexNode node = rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, group, rexBuilder.makeExactLiteral(toNumber(remap(fullGroupSet, groupSet))));
            predicates.add(Pair.of(node, toString(groupSet)));
            return groupCount + indicatorCount + distinctAggCalls.size() + predicates.size() - 1;
        }

        private RexNode makeGroup(int i) {
            final RexInputRef ref = rexBuilder.makeInputRef(booleanType, groupCount + i);
            final RexNode kase = rexBuilder.makeCall(SqlStdOperatorTable.CASE, ref, rexBuilder.makeExactLiteral(BigDecimal.ZERO), rexBuilder.makeExactLiteral(TWO.pow(i)));
            if (i == 0) {
                return kase;
            } else {
                return rexBuilder.makeCall(SqlStdOperatorTable.PLUS, makeGroup(i - 1), kase);
            }
        }

        private BigDecimal toNumber(ImmutableBitSet bitSet) {
            BigDecimal n = BigDecimal.ZERO;
            for (int key : bitSet) {
                n = n.add(TWO.pow(key));
            }
            return n;
        }

        private String toString(ImmutableBitSet bitSet) {
            final StringBuilder buf = new StringBuilder("$i");
            for (int key : bitSet) {
                buf.append(key).append('_');
            }
            return buf.substring(0, buf.length() - 1);
        }
    }
    final Registrar registrar = new Registrar();
    for (ImmutableBitSet groupSet : groupSets) {
        filters.put(groupSet, registrar.register(groupSet));
    }
    if (!predicates.isEmpty()) {
        List<Pair<RexNode, String>> nodes = new ArrayList<>();
        for (RelDataTypeField f : relBuilder.peek().getRowType().getFieldList()) {
            final RexNode node = rexBuilder.makeInputRef(f.getType(), f.getIndex());
            nodes.add(Pair.of(node, f.getName()));
        }
        nodes.addAll(predicates);
        relBuilder.project(Pair.left(nodes), Pair.right(nodes));
    }
    int x = groupCount + indicatorCount;
    final List<AggregateCall> newCalls = new ArrayList<>();
    for (AggregateCall aggCall : aggregate.getAggCallList()) {
        final int newFilterArg;
        final List<Integer> newArgList;
        final SqlAggFunction aggregation;
        if (!aggCall.isDistinct()) {
            aggregation = SqlStdOperatorTable.MIN;
            newArgList = ImmutableIntList.of(x++);
            newFilterArg = filters.get(aggregate.getGroupSet());
        } else {
            aggregation = aggCall.getAggregation();
            newArgList = remap(fullGroupSet, aggCall.getArgList());
            newFilterArg = filters.get(ImmutableBitSet.of(aggCall.getArgList()).setIf(aggCall.filterArg, aggCall.filterArg >= 0).union(aggregate.getGroupSet()));
        }
        final AggregateCall newCall = AggregateCall.create(aggregation, false, newArgList, newFilterArg, aggregate.getGroupCount(), distinct, null, aggCall.name);
        newCalls.add(newCall);
    }
    relBuilder.aggregate(relBuilder.groupKey(remap(fullGroupSet, aggregate.getGroupSet()), aggregate.indicator, remap(fullGroupSet, aggregate.getGroupSets())), newCalls);
    relBuilder.convert(aggregate.getRowType(), true);
    call.transformTo(relBuilder.build());
}
Also used : RelOptCluster(org.apache.calcite.plan.RelOptCluster) ImmutableBitSet(org.apache.calcite.util.ImmutableBitSet) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) RelDataType(org.apache.calcite.rel.type.RelDataType) TreeSet(java.util.TreeSet) RelDataTypeFactory(org.apache.calcite.rel.type.RelDataTypeFactory) RexBuilder(org.apache.calcite.rex.RexBuilder) ArrayList(java.util.ArrayList) ImmutableList(com.google.common.collect.ImmutableList) ImmutableIntList(org.apache.calcite.util.ImmutableIntList) List(java.util.List) Pair(org.apache.calcite.util.Pair) RelBuilder(org.apache.calcite.tools.RelBuilder) SqlAggFunction(org.apache.calcite.sql.SqlAggFunction) BigDecimal(java.math.BigDecimal) AggregateCall(org.apache.calcite.rel.core.AggregateCall) RelDataTypeField(org.apache.calcite.rel.type.RelDataTypeField) RelNode(org.apache.calcite.rel.RelNode) RexInputRef(org.apache.calcite.rex.RexInputRef) RexNode(org.apache.calcite.rex.RexNode)

Example 34 with RexBuilder

use of org.apache.calcite.rex.RexBuilder in project flink by apache.

the class FlinkAggregateJoinTransposeRule method onMatch.

public void onMatch(RelOptRuleCall call) {
    final Aggregate aggregate = call.rel(0);
    final Join join = call.rel(1);
    final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
    final RelBuilder relBuilder = call.builder();
    // If any aggregate call has a filter, bail out
    for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
        if (aggregateCall.getAggregation().unwrap(SqlSplittableAggFunction.class) == null) {
            return;
        }
        if (aggregateCall.filterArg >= 0) {
            return;
        }
    }
    // aggregate operator
    if (join.getJoinType() != JoinRelType.INNER) {
        return;
    }
    if (!allowFunctions && !aggregate.getAggCallList().isEmpty()) {
        return;
    }
    // Do the columns used by the join appear in the output of the aggregate?
    final ImmutableBitSet aggregateColumns = aggregate.getGroupSet();
    final RelMetadataQuery mq = RelMetadataQuery.instance();
    final ImmutableBitSet keyColumns = keyColumns(aggregateColumns, mq.getPulledUpPredicates(join).pulledUpPredicates);
    final ImmutableBitSet joinColumns = RelOptUtil.InputFinder.bits(join.getCondition());
    final boolean allColumnsInAggregate = keyColumns.contains(joinColumns);
    final ImmutableBitSet belowAggregateColumns = aggregateColumns.union(joinColumns);
    // Split join condition
    final List<Integer> leftKeys = Lists.newArrayList();
    final List<Integer> rightKeys = Lists.newArrayList();
    final List<Boolean> filterNulls = Lists.newArrayList();
    RexNode nonEquiConj = RelOptUtil.splitJoinCondition(join.getLeft(), join.getRight(), join.getCondition(), leftKeys, rightKeys, filterNulls);
    // If it contains non-equi join conditions, we bail out
    if (!nonEquiConj.isAlwaysTrue()) {
        return;
    }
    // Push each aggregate function down to each side that contains all of its
    // arguments. Note that COUNT(*), because it has no arguments, can go to
    // both sides.
    final Map<Integer, Integer> map = new HashMap<>();
    final List<Side> sides = new ArrayList<>();
    int uniqueCount = 0;
    int offset = 0;
    int belowOffset = 0;
    for (int s = 0; s < 2; s++) {
        final Side side = new Side();
        final RelNode joinInput = join.getInput(s);
        int fieldCount = joinInput.getRowType().getFieldCount();
        final ImmutableBitSet fieldSet = ImmutableBitSet.range(offset, offset + fieldCount);
        final ImmutableBitSet belowAggregateKeyNotShifted = belowAggregateColumns.intersect(fieldSet);
        for (Ord<Integer> c : Ord.zip(belowAggregateKeyNotShifted)) {
            map.put(c.e, belowOffset + c.i);
        }
        final ImmutableBitSet belowAggregateKey = belowAggregateKeyNotShifted.shift(-offset);
        final boolean unique;
        if (!allowFunctions) {
            assert aggregate.getAggCallList().isEmpty();
            // If there are no functions, it doesn't matter as much whether we
            // aggregate the inputs before the join, because there will not be
            // any functions experiencing a cartesian product effect.
            //
            // But finding out whether the input is already unique requires a call
            // to areColumnsUnique that currently (until [CALCITE-1048] "Make
            // metadata more robust" is fixed) places a heavy load on
            // the metadata system.
            //
            // So we choose to imagine the the input is already unique, which is
            // untrue but harmless.
            //
            Util.discard(Bug.CALCITE_1048_FIXED);
            unique = true;
        } else {
            final Boolean unique0 = mq.areColumnsUnique(joinInput, belowAggregateKey);
            unique = unique0 != null && unique0;
        }
        if (unique) {
            ++uniqueCount;
            side.aggregate = false;
            side.newInput = joinInput;
        } else {
            side.aggregate = true;
            List<AggregateCall> belowAggCalls = new ArrayList<>();
            final SqlSplittableAggFunction.Registry<AggregateCall> belowAggCallRegistry = registry(belowAggCalls);
            final Mappings.TargetMapping mapping = s == 0 ? Mappings.createIdentity(fieldCount) : Mappings.createShiftMapping(fieldCount + offset, 0, offset, fieldCount);
            for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) {
                final SqlAggFunction aggregation = aggCall.e.getAggregation();
                final SqlSplittableAggFunction splitter = Preconditions.checkNotNull(aggregation.unwrap(SqlSplittableAggFunction.class));
                final AggregateCall call1;
                if (fieldSet.contains(ImmutableBitSet.of(aggCall.e.getArgList()))) {
                    call1 = splitter.split(aggCall.e, mapping);
                } else {
                    call1 = splitter.other(rexBuilder.getTypeFactory(), aggCall.e);
                }
                if (call1 != null) {
                    side.split.put(aggCall.i, belowAggregateKey.cardinality() + belowAggCallRegistry.register(call1));
                }
            }
            side.newInput = relBuilder.push(joinInput).aggregate(relBuilder.groupKey(belowAggregateKey, false, null), belowAggCalls).build();
        }
        offset += fieldCount;
        belowOffset += side.newInput.getRowType().getFieldCount();
        sides.add(side);
    }
    if (uniqueCount == 2) {
        // invocation of this rule; if we continue we might loop forever.
        return;
    }
    // Update condition
    final Mapping mapping = (Mapping) Mappings.target(new Function<Integer, Integer>() {

        public Integer apply(Integer a0) {
            return map.get(a0);
        }
    }, join.getRowType().getFieldCount(), belowOffset);
    final RexNode newCondition = RexUtil.apply(mapping, join.getCondition());
    // Create new join
    relBuilder.push(sides.get(0).newInput).push(sides.get(1).newInput).join(join.getJoinType(), newCondition);
    // Aggregate above to sum up the sub-totals
    final List<AggregateCall> newAggCalls = new ArrayList<>();
    final int groupIndicatorCount = aggregate.getGroupCount() + aggregate.getIndicatorCount();
    final int newLeftWidth = sides.get(0).newInput.getRowType().getFieldCount();
    final List<RexNode> projects = new ArrayList<>(rexBuilder.identityProjects(relBuilder.peek().getRowType()));
    for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) {
        final SqlAggFunction aggregation = aggCall.e.getAggregation();
        final SqlSplittableAggFunction splitter = Preconditions.checkNotNull(aggregation.unwrap(SqlSplittableAggFunction.class));
        final Integer leftSubTotal = sides.get(0).split.get(aggCall.i);
        final Integer rightSubTotal = sides.get(1).split.get(aggCall.i);
        newAggCalls.add(splitter.topSplit(rexBuilder, registry(projects), groupIndicatorCount, relBuilder.peek().getRowType(), aggCall.e, leftSubTotal == null ? -1 : leftSubTotal, rightSubTotal == null ? -1 : rightSubTotal + newLeftWidth));
    }
    relBuilder.project(projects);
    boolean aggConvertedToProjects = false;
    if (allColumnsInAggregate) {
        // let's see if we can convert aggregate into projects
        List<RexNode> projects2 = new ArrayList<>();
        for (int key : Mappings.apply(mapping, aggregate.getGroupSet())) {
            projects2.add(relBuilder.field(key));
        }
        for (AggregateCall newAggCall : newAggCalls) {
            final SqlSplittableAggFunction splitter = newAggCall.getAggregation().unwrap(SqlSplittableAggFunction.class);
            if (splitter != null) {
                projects2.add(splitter.singleton(rexBuilder, relBuilder.peek().getRowType(), newAggCall));
            }
        }
        if (projects2.size() == aggregate.getGroupSet().cardinality() + newAggCalls.size()) {
            // We successfully converted agg calls into projects.
            relBuilder.project(projects2);
            aggConvertedToProjects = true;
        }
    }
    if (!aggConvertedToProjects) {
        relBuilder.aggregate(relBuilder.groupKey(Mappings.apply(mapping, aggregate.getGroupSet()), aggregate.indicator, Mappings.apply2(mapping, aggregate.getGroupSets())), newAggCalls);
    }
    call.transformTo(relBuilder.build());
}
Also used : RelMetadataQuery(org.apache.calcite.rel.metadata.RelMetadataQuery) ImmutableBitSet(org.apache.calcite.util.ImmutableBitSet) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) Mapping(org.apache.calcite.util.mapping.Mapping) Function(com.google.common.base.Function) SqlSplittableAggFunction(org.apache.calcite.sql.SqlSplittableAggFunction) SqlAggFunction(org.apache.calcite.sql.SqlAggFunction) RexBuilder(org.apache.calcite.rex.RexBuilder) SqlSplittableAggFunction(org.apache.calcite.sql.SqlSplittableAggFunction) RelBuilder(org.apache.calcite.tools.RelBuilder) Join(org.apache.calcite.rel.core.Join) LogicalJoin(org.apache.calcite.rel.logical.LogicalJoin) SqlAggFunction(org.apache.calcite.sql.SqlAggFunction) AggregateCall(org.apache.calcite.rel.core.AggregateCall) RelNode(org.apache.calcite.rel.RelNode) Mappings(org.apache.calcite.util.mapping.Mappings) Aggregate(org.apache.calcite.rel.core.Aggregate) LogicalAggregate(org.apache.calcite.rel.logical.LogicalAggregate) RexNode(org.apache.calcite.rex.RexNode)

Example 35 with RexBuilder

use of org.apache.calcite.rex.RexBuilder in project hive by apache.

the class HiveRelMdPredicates method getPredicates.

/**
   * Infers predicates for a project.
   *
   * <ol>
   * <li>create a mapping from input to projection. Map only positions that
   * directly reference an input column.
   * <li>Expressions that only contain above columns are retained in the
   * Project's pullExpressions list.
   * <li>For e.g. expression 'a + e = 9' below will not be pulled up because 'e'
   * is not in the projection list.
   *
   * <pre>
   * childPullUpExprs:      {a &gt; 7, b + c &lt; 10, a + e = 9}
   * projectionExprs:       {a, b, c, e / 2}
   * projectionPullupExprs: {a &gt; 7, b + c &lt; 10}
   * </pre>
   *
   * </ol>
   */
public RelOptPredicateList getPredicates(Project project, RelMetadataQuery mq) {
    RelNode child = project.getInput();
    final RexBuilder rexBuilder = project.getCluster().getRexBuilder();
    RelOptPredicateList childInfo = mq.getPulledUpPredicates(child);
    List<RexNode> projectPullUpPredicates = new ArrayList<RexNode>();
    HashMultimap<Integer, Integer> inpIndxToOutIndxMap = HashMultimap.create();
    ImmutableBitSet.Builder columnsMappedBuilder = ImmutableBitSet.builder();
    Mapping m = Mappings.create(MappingType.PARTIAL_FUNCTION, child.getRowType().getFieldCount(), project.getRowType().getFieldCount());
    for (Ord<RexNode> o : Ord.zip(project.getProjects())) {
        if (o.e instanceof RexInputRef) {
            int sIdx = ((RexInputRef) o.e).getIndex();
            m.set(sIdx, o.i);
            inpIndxToOutIndxMap.put(sIdx, o.i);
            columnsMappedBuilder.set(sIdx);
        }
    }
    // Go over childPullUpPredicates. If a predicate only contains columns in
    // 'columnsMapped' construct a new predicate based on mapping.
    final ImmutableBitSet columnsMapped = columnsMappedBuilder.build();
    for (RexNode r : childInfo.pulledUpPredicates) {
        ImmutableBitSet rCols = RelOptUtil.InputFinder.bits(r);
        if (columnsMapped.contains(rCols)) {
            r = r.accept(new RexPermuteInputsShuttle(m, child));
            projectPullUpPredicates.add(r);
        }
    }
    // Project can also generate constants. We need to include them.
    for (Ord<RexNode> expr : Ord.zip(project.getProjects())) {
        if (RexLiteral.isNullLiteral(expr.e)) {
            projectPullUpPredicates.add(rexBuilder.makeCall(SqlStdOperatorTable.IS_NULL, rexBuilder.makeInputRef(project, expr.i)));
        } else if (expr.e instanceof RexLiteral) {
            final RexLiteral literal = (RexLiteral) expr.e;
            projectPullUpPredicates.add(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, rexBuilder.makeInputRef(project, expr.i), literal));
        } else if (expr.e instanceof RexCall && HiveCalciteUtil.isDeterministicFuncOnLiterals(expr.e)) {
            //TODO: Move this to calcite
            projectPullUpPredicates.add(rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, rexBuilder.makeInputRef(project, expr.i), expr.e));
        }
    }
    return RelOptPredicateList.of(projectPullUpPredicates);
}
Also used : RexLiteral(org.apache.calcite.rex.RexLiteral) ImmutableBitSet(org.apache.calcite.util.ImmutableBitSet) ArrayList(java.util.ArrayList) Mapping(org.apache.calcite.util.mapping.Mapping) RexCall(org.apache.calcite.rex.RexCall) RelNode(org.apache.calcite.rel.RelNode) RelOptPredicateList(org.apache.calcite.plan.RelOptPredicateList) RexBuilder(org.apache.calcite.rex.RexBuilder) RexInputRef(org.apache.calcite.rex.RexInputRef) RexPermuteInputsShuttle(org.apache.calcite.rex.RexPermuteInputsShuttle) RexNode(org.apache.calcite.rex.RexNode)

Aggregations

RexBuilder (org.apache.calcite.rex.RexBuilder)60 RexNode (org.apache.calcite.rex.RexNode)52 ArrayList (java.util.ArrayList)32 RelDataType (org.apache.calcite.rel.type.RelDataType)26 RelNode (org.apache.calcite.rel.RelNode)24 RelDataTypeField (org.apache.calcite.rel.type.RelDataTypeField)20 ImmutableBitSet (org.apache.calcite.util.ImmutableBitSet)13 AggregateCall (org.apache.calcite.rel.core.AggregateCall)11 RelDataTypeFactory (org.apache.calcite.rel.type.RelDataTypeFactory)11 RexInputRef (org.apache.calcite.rex.RexInputRef)10 RelOptCluster (org.apache.calcite.plan.RelOptCluster)9 HashMap (java.util.HashMap)8 RelBuilder (org.apache.calcite.tools.RelBuilder)8 ImmutableList (com.google.common.collect.ImmutableList)6 BigDecimal (java.math.BigDecimal)6 RelOptPredicateList (org.apache.calcite.plan.RelOptPredicateList)6 RexLiteral (org.apache.calcite.rex.RexLiteral)6 SqlAggFunction (org.apache.calcite.sql.SqlAggFunction)6 Pair (org.apache.calcite.util.Pair)6 CalciteSemanticException (org.apache.hadoop.hive.ql.optimizer.calcite.CalciteSemanticException)6