Search in sources :

Example 11 with Pair

use of org.apache.calcite.util.Pair in project hive by apache.

the class HiveUnionPullUpConstantsRule method onMatch.

@Override
public void onMatch(RelOptRuleCall call) {
    final Union union = call.rel(0);
    final int count = union.getRowType().getFieldCount();
    if (count == 1) {
        // Project operator.
        return;
    }
    final RexBuilder rexBuilder = union.getCluster().getRexBuilder();
    final RelMetadataQuery mq = RelMetadataQuery.instance();
    final RelOptPredicateList predicates = mq.getPulledUpPredicates(union);
    if (predicates == null) {
        return;
    }
    Map<RexNode, RexNode> conditionsExtracted = HiveReduceExpressionsRule.predicateConstants(RexNode.class, rexBuilder, predicates);
    Map<RexNode, RexNode> constants = new HashMap<>();
    for (int i = 0; i < count; i++) {
        RexNode expr = rexBuilder.makeInputRef(union, i);
        if (conditionsExtracted.containsKey(expr)) {
            constants.put(expr, conditionsExtracted.get(expr));
        }
    }
    // None of the expressions are constant. Nothing to do.
    if (constants.isEmpty()) {
        return;
    }
    // Create expressions for Project operators before and after the Union
    List<RelDataTypeField> fields = union.getRowType().getFieldList();
    List<RexNode> topChildExprs = new ArrayList<>();
    List<String> topChildExprsFields = new ArrayList<>();
    List<RexNode> refs = new ArrayList<>();
    ImmutableBitSet.Builder refsIndexBuilder = ImmutableBitSet.builder();
    for (int i = 0; i < count; i++) {
        RexNode expr = rexBuilder.makeInputRef(union, i);
        RelDataTypeField field = fields.get(i);
        if (constants.containsKey(expr)) {
            topChildExprs.add(constants.get(expr));
            topChildExprsFields.add(field.getName());
        } else {
            topChildExprs.add(expr);
            topChildExprsFields.add(field.getName());
            refs.add(expr);
            refsIndexBuilder.set(i);
        }
    }
    ImmutableBitSet refsIndex = refsIndexBuilder.build();
    // Update top Project positions
    final Mappings.TargetMapping mapping = RelOptUtil.permutation(refs, union.getInput(0).getRowType()).inverse();
    topChildExprs = ImmutableList.copyOf(RexUtil.apply(mapping, topChildExprs));
    // Create new Project-Union-Project sequences
    final RelBuilder relBuilder = call.builder();
    for (int i = 0; i < union.getInputs().size(); i++) {
        RelNode input = union.getInput(i);
        List<Pair<RexNode, String>> newChildExprs = new ArrayList<>();
        for (int j = 0; j < refsIndex.cardinality(); j++) {
            int pos = refsIndex.nth(j);
            newChildExprs.add(Pair.<RexNode, String>of(rexBuilder.makeInputRef(input, pos), input.getRowType().getFieldList().get(pos).getName()));
        }
        if (newChildExprs.isEmpty()) {
            // At least a single item in project is required.
            newChildExprs.add(Pair.<RexNode, String>of(topChildExprs.get(0), topChildExprsFields.get(0)));
        }
        // Add the input with project on top
        relBuilder.push(input);
        relBuilder.project(Pair.left(newChildExprs), Pair.right(newChildExprs));
    }
    relBuilder.union(union.all, union.getInputs().size());
    // Create top Project fixing nullability of fields
    relBuilder.project(topChildExprs, topChildExprsFields);
    relBuilder.convert(union.getRowType(), false);
    call.transformTo(relBuilder.build());
}
Also used : RelMetadataQuery(org.apache.calcite.rel.metadata.RelMetadataQuery) RelBuilder(org.apache.calcite.tools.RelBuilder) ImmutableBitSet(org.apache.calcite.util.ImmutableBitSet) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) HiveUnion(org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveUnion) Union(org.apache.calcite.rel.core.Union) RelDataTypeField(org.apache.calcite.rel.type.RelDataTypeField) Mappings(org.apache.calcite.util.mapping.Mappings) RelNode(org.apache.calcite.rel.RelNode) RelOptPredicateList(org.apache.calcite.plan.RelOptPredicateList) RexBuilder(org.apache.calcite.rex.RexBuilder) RexNode(org.apache.calcite.rex.RexNode) Pair(org.apache.calcite.util.Pair)

Example 12 with Pair

use of org.apache.calcite.util.Pair in project drill by apache.

the class NumberingRelWriter method done.

public RelWriter done(RelNode node) {
    int i = 0;
    if (values.size() > 0 && values.get(0).left.equals("subset")) {
        ++i;
    }
    for (RelNode input : node.getInputs()) {
        assert values.get(i).right == input;
        ++i;
    }
    for (RexNode expr : node.getChildExps()) {
        assert values.get(i).right == expr;
        ++i;
    }
    final List<Pair<String, Object>> valuesCopy = ImmutableList.copyOf(values);
    values.clear();
    explain_(node, valuesCopy);
    pw.flush();
    return this;
}
Also used : RelNode(org.apache.calcite.rel.RelNode) RexNode(org.apache.calcite.rex.RexNode) Pair(org.apache.calcite.util.Pair)

Example 13 with Pair

use of org.apache.calcite.util.Pair in project lucene-solr by apache.

the class SolrAggregate method toSolrMetric.

private Pair<String, String> toSolrMetric(Implementor implementor, AggregateCall aggCall, List<String> inNames) {
    SqlAggFunction aggregation = aggCall.getAggregation();
    List<Integer> args = aggCall.getArgList();
    switch(args.size()) {
        case 0:
            if (aggregation.equals(SqlStdOperatorTable.COUNT)) {
                return new Pair<>(aggregation.getName(), "*");
            }
        case 1:
            String inName = inNames.get(args.get(0));
            String name = implementor.fieldMappings.getOrDefault(inName, inName);
            if (SUPPORTED_AGGREGATIONS.contains(aggregation)) {
                return new Pair<>(aggregation.getName(), name);
            }
        default:
            throw new AssertionError("Invalid aggregation " + aggregation + " with args " + args + " with names" + inNames);
    }
}
Also used : SqlAggFunction(org.apache.calcite.sql.SqlAggFunction) Pair(org.apache.calcite.util.Pair)

Example 14 with Pair

use of org.apache.calcite.util.Pair in project flink by apache.

the class FlinkAggregateExpandDistinctAggregatesRule method convertSingletonDistinct.

/**
	 * Converts an aggregate with one distinct aggregate and one or more
	 * non-distinct aggregates to multi-phase aggregates (see reference example
	 * below).
	 *
	 * @param relBuilder Contains the input relational expression
	 * @param aggregate  Original aggregate
	 * @param argLists   Arguments and filters to the distinct aggregate function
	 *
	 */
private RelBuilder convertSingletonDistinct(RelBuilder relBuilder, Aggregate aggregate, Set<Pair<List<Integer>, Integer>> argLists) {
    // For example,
    //	SELECT deptno, COUNT(*), SUM(bonus), MIN(DISTINCT sal)
    //	FROM emp
    //	GROUP BY deptno
    //
    // becomes
    //
    //	SELECT deptno, SUM(cnt), SUM(bonus), MIN(sal)
    //	FROM (
    //		  SELECT deptno, COUNT(*) as cnt, SUM(bonus), sal
    //		  FROM EMP
    //		  GROUP BY deptno, sal)			// Aggregate B
    //	GROUP BY deptno						// Aggregate A
    relBuilder.push(aggregate.getInput());
    final List<Pair<RexNode, String>> projects = new ArrayList<>();
    final Map<Integer, Integer> sourceOf = new HashMap<>();
    SortedSet<Integer> newGroupSet = new TreeSet<>();
    final List<RelDataTypeField> childFields = relBuilder.peek().getRowType().getFieldList();
    final boolean hasGroupBy = aggregate.getGroupSet().size() > 0;
    SortedSet<Integer> groupSet = new TreeSet<>(aggregate.getGroupSet().asList());
    // Add the distinct aggregate column(s) to the group-by columns,
    // if not already a part of the group-by
    newGroupSet.addAll(aggregate.getGroupSet().asList());
    for (Pair<List<Integer>, Integer> argList : argLists) {
        newGroupSet.addAll(argList.getKey());
    }
    // transformation.
    for (int arg : newGroupSet) {
        sourceOf.put(arg, projects.size());
        projects.add(RexInputRef.of2(arg, childFields));
    }
    // Generate the intermediate aggregate B
    final List<AggregateCall> aggCalls = aggregate.getAggCallList();
    final List<AggregateCall> newAggCalls = new ArrayList<>();
    final List<Integer> fakeArgs = new ArrayList<>();
    final Map<AggregateCall, Integer> callArgMap = new HashMap<>();
    // e.g. if real arguments are 0, 1, 3. Then the fake arguments will be 2, 4
    for (final AggregateCall aggCall : aggCalls) {
        if (!aggCall.isDistinct()) {
            for (int arg : aggCall.getArgList()) {
                if (!groupSet.contains(arg)) {
                    sourceOf.put(arg, projects.size());
                }
            }
        }
    }
    int fakeArg0 = 0;
    for (final AggregateCall aggCall : aggCalls) {
        // We will deal with non-distinct aggregates below
        if (!aggCall.isDistinct()) {
            boolean isGroupKeyUsedInAgg = false;
            for (int arg : aggCall.getArgList()) {
                if (groupSet.contains(arg)) {
                    isGroupKeyUsedInAgg = true;
                    break;
                }
            }
            if (aggCall.getArgList().size() == 0 || isGroupKeyUsedInAgg) {
                while (sourceOf.get(fakeArg0) != null) {
                    ++fakeArg0;
                }
                fakeArgs.add(fakeArg0);
                ++fakeArg0;
            }
        }
    }
    for (final AggregateCall aggCall : aggCalls) {
        if (!aggCall.isDistinct()) {
            for (int arg : aggCall.getArgList()) {
                if (!groupSet.contains(arg)) {
                    sourceOf.remove(arg);
                }
            }
        }
    }
    // Compute the remapped arguments using fake arguments for non-distinct
    // aggregates with no arguments e.g. count(*).
    int fakeArgIdx = 0;
    for (final AggregateCall aggCall : aggCalls) {
        // as-is all the non-distinct aggregates
        if (!aggCall.isDistinct()) {
            final AggregateCall newCall = AggregateCall.create(aggCall.getAggregation(), false, aggCall.getArgList(), -1, ImmutableBitSet.of(newGroupSet).cardinality(), relBuilder.peek(), null, aggCall.name);
            newAggCalls.add(newCall);
            if (newCall.getArgList().size() == 0) {
                int fakeArg = fakeArgs.get(fakeArgIdx);
                callArgMap.put(newCall, fakeArg);
                sourceOf.put(fakeArg, projects.size());
                projects.add(Pair.of((RexNode) new RexInputRef(fakeArg, newCall.getType()), newCall.getName()));
                ++fakeArgIdx;
            } else {
                for (int arg : newCall.getArgList()) {
                    if (groupSet.contains(arg)) {
                        int fakeArg = fakeArgs.get(fakeArgIdx);
                        callArgMap.put(newCall, fakeArg);
                        sourceOf.put(fakeArg, projects.size());
                        projects.add(Pair.of((RexNode) new RexInputRef(fakeArg, newCall.getType()), newCall.getName()));
                        ++fakeArgIdx;
                    } else {
                        sourceOf.put(arg, projects.size());
                        projects.add(Pair.of((RexNode) new RexInputRef(arg, newCall.getType()), newCall.getName()));
                    }
                }
            }
        }
    }
    // Generate the aggregate B (see the reference example above)
    relBuilder.push(aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), false, ImmutableBitSet.of(newGroupSet), null, newAggCalls));
    // Convert the existing aggregate to aggregate A (see the reference example above)
    final List<AggregateCall> newTopAggCalls = Lists.newArrayList(aggregate.getAggCallList());
    // Use the remapped arguments for the (non)distinct aggregate calls
    for (int i = 0; i < newTopAggCalls.size(); i++) {
        // Re-map arguments.
        final AggregateCall aggCall = newTopAggCalls.get(i);
        final int argCount = aggCall.getArgList().size();
        final List<Integer> newArgs = new ArrayList<>(argCount);
        final AggregateCall newCall;
        for (int j = 0; j < argCount; j++) {
            final Integer arg = aggCall.getArgList().get(j);
            if (callArgMap.containsKey(aggCall)) {
                newArgs.add(sourceOf.get(callArgMap.get(aggCall)));
            } else {
                newArgs.add(sourceOf.get(arg));
            }
        }
        if (aggCall.isDistinct()) {
            newCall = AggregateCall.create(aggCall.getAggregation(), false, newArgs, -1, aggregate.getGroupSet().cardinality(), relBuilder.peek(), aggCall.getType(), aggCall.name);
        } else {
            // aggregate A must be SUM. For other aggregates, it remains the same.
            if (aggCall.getAggregation() instanceof SqlCountAggFunction) {
                if (aggCall.getArgList().size() == 0) {
                    newArgs.add(sourceOf.get(callArgMap.get(aggCall)));
                }
                if (hasGroupBy) {
                    SqlSumAggFunction sumAgg = new SqlSumAggFunction(null);
                    newCall = AggregateCall.create(sumAgg, false, newArgs, -1, aggregate.getGroupSet().cardinality(), relBuilder.peek(), aggCall.getType(), aggCall.getName());
                } else {
                    SqlSumEmptyIsZeroAggFunction sumAgg = new SqlSumEmptyIsZeroAggFunction();
                    newCall = AggregateCall.create(sumAgg, false, newArgs, -1, aggregate.getGroupSet().cardinality(), relBuilder.peek(), aggCall.getType(), aggCall.getName());
                }
            } else {
                newCall = AggregateCall.create(aggCall.getAggregation(), false, newArgs, -1, aggregate.getGroupSet().cardinality(), relBuilder.peek(), aggCall.getType(), aggCall.name);
            }
        }
        newTopAggCalls.set(i, newCall);
    }
    // Populate the group-by keys with the remapped arguments for aggregate A
    newGroupSet.clear();
    for (int arg : aggregate.getGroupSet()) {
        newGroupSet.add(sourceOf.get(arg));
    }
    relBuilder.push(aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), aggregate.indicator, ImmutableBitSet.of(newGroupSet), null, newTopAggCalls));
    return relBuilder;
}
Also used : HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) SqlCountAggFunction(org.apache.calcite.sql.fun.SqlCountAggFunction) AggregateCall(org.apache.calcite.rel.core.AggregateCall) SqlSumEmptyIsZeroAggFunction(org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction) RelDataTypeField(org.apache.calcite.rel.type.RelDataTypeField) TreeSet(java.util.TreeSet) SqlSumAggFunction(org.apache.calcite.sql.fun.SqlSumAggFunction) RexInputRef(org.apache.calcite.rex.RexInputRef) ArrayList(java.util.ArrayList) ImmutableList(com.google.common.collect.ImmutableList) ImmutableIntList(org.apache.calcite.util.ImmutableIntList) List(java.util.List) Pair(org.apache.calcite.util.Pair) RexNode(org.apache.calcite.rex.RexNode)

Example 15 with Pair

use of org.apache.calcite.util.Pair in project flink by apache.

the class FlinkAggregateExpandDistinctAggregatesRule method rewriteUsingGroupingSets.

/*
	public RelBuilder convertSingletonDistinct(RelBuilder relBuilder,
											   Aggregate aggregate, Set<Pair<List<Integer>, Integer>> argLists) {
		// For example,
		//	SELECT deptno, COUNT(*), SUM(bonus), MIN(DISTINCT sal)
		//	FROM emp
		//	GROUP BY deptno
		//
		// becomes
		//
		//	SELECT deptno, SUM(cnt), SUM(bonus), MIN(sal)
		//	FROM (
		//		  SELECT deptno, COUNT(*) as cnt, SUM(bonus), sal
		//		  FROM EMP
		//		  GROUP BY deptno, sal)			// Aggregate B
		//	GROUP BY deptno						// Aggregate A
		relBuilder.push(aggregate.getInput());
		final List<Pair<RexNode, String>> projects = new ArrayList<>();
		final Map<Integer, Integer> sourceOf = new HashMap<>();
		SortedSet<Integer> newGroupSet = new TreeSet<>();
		final List<RelDataTypeField> childFields =
				relBuilder.peek().getRowType().getFieldList();
		final boolean hasGroupBy = aggregate.getGroupSet().size() > 0;

		// Add the distinct aggregate column(s) to the group-by columns,
		// if not already a part of the group-by
		newGroupSet.addAll(aggregate.getGroupSet().asList());
		for (Pair<List<Integer>, Integer> argList : argLists) {
			newGroupSet.addAll(argList.getKey());
		}

		// Re-map the arguments to the aggregate A. These arguments will get
		// remapped because of the intermediate aggregate B generated as part of the
		// transformation.
		for (int arg : newGroupSet) {
			sourceOf.put(arg, projects.size());
			projects.add(RexInputRef.of2(arg, childFields));
		}
		// Generate the intermediate aggregate B
		final List<AggregateCall> aggCalls = aggregate.getAggCallList();
		final List<AggregateCall> newAggCalls = new ArrayList<>();
		final List<Integer> fakeArgs = new ArrayList<>();
		final Map<AggregateCall, Integer> callArgMap = new HashMap<>();
		// First identify the real arguments, then use the rest for fake arguments
		// e.g. if real arguments are 0, 1, 3. Then the fake arguments will be 2, 4
		for (final AggregateCall aggCall : aggCalls) {
			if (!aggCall.isDistinct()) {
				for (int arg : aggCall.getArgList()) {
					if (!sourceOf.containsKey(arg)) {
						sourceOf.put(arg, projects.size());
					}
				}
			}
		}
		int fakeArg0 = 0;
		for (final AggregateCall aggCall : aggCalls) {
			// We will deal with non-distinct aggregates below
			if (!aggCall.isDistinct()) {
				boolean isGroupKeyUsedInAgg = false;
				for (int arg : aggCall.getArgList()) {
					if (sourceOf.containsKey(arg)) {
						isGroupKeyUsedInAgg = true;
						break;
					}
				}
				if (aggCall.getArgList().size() == 0 || isGroupKeyUsedInAgg) {
					while (sourceOf.get(fakeArg0) != null) {
						++fakeArg0;
					}
					fakeArgs.add(fakeArg0);
				}
			}
		}
		for (final AggregateCall aggCall : aggCalls) {
			if (!aggCall.isDistinct()) {
				for (int arg : aggCall.getArgList()) {
					if (!sourceOf.containsKey(arg)) {
						sourceOf.remove(arg);
					}
				}
			}
		}
		// Compute the remapped arguments using fake arguments for non-distinct
		// aggregates with no arguments e.g. count(*).
		int fakeArgIdx = 0;
		for (final AggregateCall aggCall : aggCalls) {
			// Project the column corresponding to the distinct aggregate. Project
			// as-is all the non-distinct aggregates
			if (!aggCall.isDistinct()) {
				final AggregateCall newCall =
						AggregateCall.create(aggCall.getAggregation(), false,
								aggCall.getArgList(), -1,
								ImmutableBitSet.of(newGroupSet).cardinality(),
								relBuilder.peek(), null, aggCall.name);
				newAggCalls.add(newCall);
				if (newCall.getArgList().size() == 0) {
					int fakeArg = fakeArgs.get(fakeArgIdx);
					callArgMap.put(newCall, fakeArg);
					sourceOf.put(fakeArg, projects.size());
					projects.add(
							Pair.of((RexNode) new RexInputRef(fakeArg, newCall.getType()),
									newCall.getName()));
					++fakeArgIdx;
				} else {
					for (int arg : newCall.getArgList()) {
						if (sourceOf.containsKey(arg)) {
							int fakeArg = fakeArgs.get(fakeArgIdx);
							callArgMap.put(newCall, fakeArg);
							sourceOf.put(fakeArg, projects.size());
							projects.add(
									Pair.of((RexNode) new RexInputRef(fakeArg, newCall.getType()),
											newCall.getName()));
							++fakeArgIdx;
						} else {
							sourceOf.put(arg, projects.size());
							projects.add(
									Pair.of((RexNode) new RexInputRef(arg, newCall.getType()),
											newCall.getName()));
						}
					}
				}
			}
		}
		// Generate the aggregate B (see the reference example above)
		relBuilder.push(
				aggregate.copy(
						aggregate.getTraitSet(), relBuilder.build(),
						false, ImmutableBitSet.of(newGroupSet), null, newAggCalls));
		// Convert the existing aggregate to aggregate A (see the reference example above)
		final List<AggregateCall> newTopAggCalls =
				Lists.newArrayList(aggregate.getAggCallList());
		// Use the remapped arguments for the (non)distinct aggregate calls
		for (int i = 0; i < newTopAggCalls.size(); i++) {
			// Re-map arguments.
			final AggregateCall aggCall = newTopAggCalls.get(i);
			final int argCount = aggCall.getArgList().size();
			final List<Integer> newArgs = new ArrayList<>(argCount);
			final AggregateCall newCall;


			for (int j = 0; j < argCount; j++) {
				final Integer arg = aggCall.getArgList().get(j);
				if (callArgMap.containsKey(aggCall)) {
					newArgs.add(sourceOf.get(callArgMap.get(aggCall)));
				}
				else {
					newArgs.add(sourceOf.get(arg));
				}
			}
			if (aggCall.isDistinct()) {
				newCall =
						AggregateCall.create(aggCall.getAggregation(), false, newArgs,
								-1, aggregate.getGroupSet().cardinality(), relBuilder.peek(),
								aggCall.getType(), aggCall.name);
			} else {
				// If aggregate B had a COUNT aggregate call the corresponding aggregate at
				// aggregate A must be SUM. For other aggregates, it remains the same.
				if (aggCall.getAggregation() instanceof SqlCountAggFunction) {
					if (aggCall.getArgList().size() == 0) {
						newArgs.add(sourceOf.get(callArgMap.get(aggCall)));
					}
					if (hasGroupBy) {
						SqlSumAggFunction sumAgg = new SqlSumAggFunction(null);
						newCall =
								AggregateCall.create(sumAgg, false, newArgs, -1,
										aggregate.getGroupSet().cardinality(), relBuilder.peek(),
										aggCall.getType(), aggCall.getName());
					} else {
						SqlSumEmptyIsZeroAggFunction sumAgg = new SqlSumEmptyIsZeroAggFunction();
						newCall =
								AggregateCall.create(sumAgg, false, newArgs, -1,
										aggregate.getGroupSet().cardinality(), relBuilder.peek(),
										aggCall.getType(), aggCall.getName());
					}
				} else {
					newCall =
							AggregateCall.create(aggCall.getAggregation(), false, newArgs, -1,
									aggregate.getGroupSet().cardinality(),
									relBuilder.peek(), aggCall.getType(), aggCall.name);
				}
			}
			newTopAggCalls.set(i, newCall);
		}
		// Populate the group-by keys with the remapped arguments for aggregate A
		newGroupSet.clear();
		for (int arg : aggregate.getGroupSet()) {
			newGroupSet.add(sourceOf.get(arg));
		}
		relBuilder.push(
				aggregate.copy(aggregate.getTraitSet(),
						relBuilder.build(), aggregate.indicator,
						ImmutableBitSet.of(newGroupSet), null, newTopAggCalls));
		return relBuilder;
	}
	*/
@SuppressWarnings("DanglingJavadoc")
private void rewriteUsingGroupingSets(RelOptRuleCall call, Aggregate aggregate, Set<Pair<List<Integer>, Integer>> argLists) {
    final Set<ImmutableBitSet> groupSetTreeSet = new TreeSet<>(ImmutableBitSet.ORDERING);
    groupSetTreeSet.add(aggregate.getGroupSet());
    for (Pair<List<Integer>, Integer> argList : argLists) {
        groupSetTreeSet.add(ImmutableBitSet.of(argList.left).setIf(argList.right, argList.right >= 0).union(aggregate.getGroupSet()));
    }
    final ImmutableList<ImmutableBitSet> groupSets = ImmutableList.copyOf(groupSetTreeSet);
    final ImmutableBitSet fullGroupSet = ImmutableBitSet.union(groupSets);
    final List<AggregateCall> distinctAggCalls = new ArrayList<>();
    for (Pair<AggregateCall, String> aggCall : aggregate.getNamedAggCalls()) {
        if (!aggCall.left.isDistinct()) {
            distinctAggCalls.add(aggCall.left.rename(aggCall.right));
        }
    }
    final RelBuilder relBuilder = call.builder();
    relBuilder.push(aggregate.getInput());
    relBuilder.aggregate(relBuilder.groupKey(fullGroupSet, groupSets.size() > 1, groupSets), distinctAggCalls);
    final RelNode distinct = relBuilder.peek();
    final int groupCount = fullGroupSet.cardinality();
    final int indicatorCount = groupSets.size() > 1 ? groupCount : 0;
    final RelOptCluster cluster = aggregate.getCluster();
    final RexBuilder rexBuilder = cluster.getRexBuilder();
    final RelDataTypeFactory typeFactory = cluster.getTypeFactory();
    final RelDataType booleanType = typeFactory.createTypeWithNullability(typeFactory.createSqlType(SqlTypeName.BOOLEAN), false);
    final List<Pair<RexNode, String>> predicates = new ArrayList<>();
    final Map<ImmutableBitSet, Integer> filters = new HashMap<>();
    /** Function to register a filter for a group set. */
    class Registrar {

        RexNode group = null;

        private int register(ImmutableBitSet groupSet) {
            if (group == null) {
                group = makeGroup(groupCount - 1);
            }
            final RexNode node = rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, group, rexBuilder.makeExactLiteral(toNumber(remap(fullGroupSet, groupSet))));
            predicates.add(Pair.of(node, toString(groupSet)));
            return groupCount + indicatorCount + distinctAggCalls.size() + predicates.size() - 1;
        }

        private RexNode makeGroup(int i) {
            final RexInputRef ref = rexBuilder.makeInputRef(booleanType, groupCount + i);
            final RexNode kase = rexBuilder.makeCall(SqlStdOperatorTable.CASE, ref, rexBuilder.makeExactLiteral(BigDecimal.ZERO), rexBuilder.makeExactLiteral(TWO.pow(i)));
            if (i == 0) {
                return kase;
            } else {
                return rexBuilder.makeCall(SqlStdOperatorTable.PLUS, makeGroup(i - 1), kase);
            }
        }

        private BigDecimal toNumber(ImmutableBitSet bitSet) {
            BigDecimal n = BigDecimal.ZERO;
            for (int key : bitSet) {
                n = n.add(TWO.pow(key));
            }
            return n;
        }

        private String toString(ImmutableBitSet bitSet) {
            final StringBuilder buf = new StringBuilder("$i");
            for (int key : bitSet) {
                buf.append(key).append('_');
            }
            return buf.substring(0, buf.length() - 1);
        }
    }
    final Registrar registrar = new Registrar();
    for (ImmutableBitSet groupSet : groupSets) {
        filters.put(groupSet, registrar.register(groupSet));
    }
    if (!predicates.isEmpty()) {
        List<Pair<RexNode, String>> nodes = new ArrayList<>();
        for (RelDataTypeField f : relBuilder.peek().getRowType().getFieldList()) {
            final RexNode node = rexBuilder.makeInputRef(f.getType(), f.getIndex());
            nodes.add(Pair.of(node, f.getName()));
        }
        nodes.addAll(predicates);
        relBuilder.project(Pair.left(nodes), Pair.right(nodes));
    }
    int x = groupCount + indicatorCount;
    final List<AggregateCall> newCalls = new ArrayList<>();
    for (AggregateCall aggCall : aggregate.getAggCallList()) {
        final int newFilterArg;
        final List<Integer> newArgList;
        final SqlAggFunction aggregation;
        if (!aggCall.isDistinct()) {
            aggregation = SqlStdOperatorTable.MIN;
            newArgList = ImmutableIntList.of(x++);
            newFilterArg = filters.get(aggregate.getGroupSet());
        } else {
            aggregation = aggCall.getAggregation();
            newArgList = remap(fullGroupSet, aggCall.getArgList());
            newFilterArg = filters.get(ImmutableBitSet.of(aggCall.getArgList()).setIf(aggCall.filterArg, aggCall.filterArg >= 0).union(aggregate.getGroupSet()));
        }
        final AggregateCall newCall = AggregateCall.create(aggregation, false, newArgList, newFilterArg, aggregate.getGroupCount(), distinct, null, aggCall.name);
        newCalls.add(newCall);
    }
    relBuilder.aggregate(relBuilder.groupKey(remap(fullGroupSet, aggregate.getGroupSet()), aggregate.indicator, remap(fullGroupSet, aggregate.getGroupSets())), newCalls);
    relBuilder.convert(aggregate.getRowType(), true);
    call.transformTo(relBuilder.build());
}
Also used : RelOptCluster(org.apache.calcite.plan.RelOptCluster) ImmutableBitSet(org.apache.calcite.util.ImmutableBitSet) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) RelDataType(org.apache.calcite.rel.type.RelDataType) TreeSet(java.util.TreeSet) RelDataTypeFactory(org.apache.calcite.rel.type.RelDataTypeFactory) RexBuilder(org.apache.calcite.rex.RexBuilder) ArrayList(java.util.ArrayList) ImmutableList(com.google.common.collect.ImmutableList) ImmutableIntList(org.apache.calcite.util.ImmutableIntList) List(java.util.List) Pair(org.apache.calcite.util.Pair) RelBuilder(org.apache.calcite.tools.RelBuilder) SqlAggFunction(org.apache.calcite.sql.SqlAggFunction) BigDecimal(java.math.BigDecimal) AggregateCall(org.apache.calcite.rel.core.AggregateCall) RelDataTypeField(org.apache.calcite.rel.type.RelDataTypeField) RelNode(org.apache.calcite.rel.RelNode) RexInputRef(org.apache.calcite.rex.RexInputRef) RexNode(org.apache.calcite.rex.RexNode)

Aggregations

Pair (org.apache.calcite.util.Pair)26 RexNode (org.apache.calcite.rex.RexNode)22 RelDataTypeField (org.apache.calcite.rel.type.RelDataTypeField)17 RelNode (org.apache.calcite.rel.RelNode)16 ArrayList (java.util.ArrayList)14 ImmutableBitSet (org.apache.calcite.util.ImmutableBitSet)9 HashMap (java.util.HashMap)8 RexInputRef (org.apache.calcite.rex.RexInputRef)8 RexBuilder (org.apache.calcite.rex.RexBuilder)7 AggregateCall (org.apache.calcite.rel.core.AggregateCall)6 JoinRelType (org.apache.calcite.rel.core.JoinRelType)5 RelBuilder (org.apache.calcite.tools.RelBuilder)5 ImmutableList (com.google.common.collect.ImmutableList)4 ImmutableMap (com.google.common.collect.ImmutableMap)4 ImmutableSortedMap (com.google.common.collect.ImmutableSortedMap)4 List (java.util.List)4 Map (java.util.Map)4 NavigableMap (java.util.NavigableMap)4 SortedMap (java.util.SortedMap)4 TreeMap (java.util.TreeMap)4