Search in sources :

Example 26 with SqlAggFunction

use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.sql.SqlAggFunction in project hive by apache.

the class HiveAggregateSplitRule method onMatch.

@Override
public void onMatch(RelOptRuleCall call) {
    final Aggregate aggregate = call.rel(0);
    final RelBuilder relBuilder = call.builder();
    // If any aggregate is distinct, bail out
    // If any aggregate is the grouping id, bail out
    // If any aggregate call has a filter, bail out
    // If any aggregate functions do not support splitting, bail out
    final ImmutableBitSet bottomAggregateGroupSet = aggregate.getGroupSet();
    final List<AggregateCall> topAggregateCalls = new ArrayList<>();
    for (int i = 0; i < aggregate.getAggCallList().size(); i++) {
        AggregateCall aggregateCall = aggregate.getAggCallList().get(i);
        if (aggregateCall.isDistinct()) {
            return;
        }
        if (aggregateCall.getAggregation().equals(HiveGroupingID.INSTANCE)) {
            return;
        }
        if (aggregateCall.filterArg >= 0) {
            return;
        }
        SqlAggFunction aggFunction = HiveRelBuilder.getRollup(aggregateCall.getAggregation());
        if (aggFunction == null) {
            return;
        }
        topAggregateCalls.add(AggregateCall.create(aggFunction, aggregateCall.isDistinct(), aggregateCall.isApproximate(), ImmutableList.of(bottomAggregateGroupSet.cardinality() + i), -1, aggregateCall.collation, aggregateCall.type, aggregateCall.name));
    }
    if (aggregate.getCluster().getMetadataQuery().areColumnsUnique(aggregate.getInput(), bottomAggregateGroupSet)) {
        // Nothing to do, probably already pushed
        return;
    }
    final ImmutableBitSet topAggregateGroupSet = ImmutableBitSet.range(0, bottomAggregateGroupSet.cardinality());
    final Map<Integer, Integer> map = new HashMap<>();
    bottomAggregateGroupSet.forEach(k -> map.put(k, map.size()));
    ImmutableList<ImmutableBitSet> topAggregateGroupSets = ImmutableBitSet.ORDERING.immutableSortedCopy(ImmutableBitSet.permute(aggregate.groupSets, map));
    relBuilder.push(aggregate.getInput()).aggregate(relBuilder.groupKey(bottomAggregateGroupSet), aggregate.getAggCallList()).aggregate(relBuilder.groupKey(topAggregateGroupSet, topAggregateGroupSets), topAggregateCalls);
    call.transformTo(relBuilder.build());
}
Also used : AggregateCall(org.apache.calcite.rel.core.AggregateCall) RelBuilder(org.apache.calcite.tools.RelBuilder) HiveRelBuilder(org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelBuilder) ImmutableBitSet(org.apache.calcite.util.ImmutableBitSet) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) SqlAggFunction(org.apache.calcite.sql.SqlAggFunction) Aggregate(org.apache.calcite.rel.core.Aggregate) HiveAggregate(org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate)

Example 27 with SqlAggFunction

use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.sql.SqlAggFunction in project flink by apache.

the class FlinkAggregateExpandDistinctAggregatesRule method rewriteUsingGroupingSets.

/*
	public RelBuilder convertSingletonDistinct(RelBuilder relBuilder,
											   Aggregate aggregate, Set<Pair<List<Integer>, Integer>> argLists) {
		// For example,
		//	SELECT deptno, COUNT(*), SUM(bonus), MIN(DISTINCT sal)
		//	FROM emp
		//	GROUP BY deptno
		//
		// becomes
		//
		//	SELECT deptno, SUM(cnt), SUM(bonus), MIN(sal)
		//	FROM (
		//		  SELECT deptno, COUNT(*) as cnt, SUM(bonus), sal
		//		  FROM EMP
		//		  GROUP BY deptno, sal)			// Aggregate B
		//	GROUP BY deptno						// Aggregate A
		relBuilder.push(aggregate.getInput());
		final List<Pair<RexNode, String>> projects = new ArrayList<>();
		final Map<Integer, Integer> sourceOf = new HashMap<>();
		SortedSet<Integer> newGroupSet = new TreeSet<>();
		final List<RelDataTypeField> childFields =
				relBuilder.peek().getRowType().getFieldList();
		final boolean hasGroupBy = aggregate.getGroupSet().size() > 0;

		// Add the distinct aggregate column(s) to the group-by columns,
		// if not already a part of the group-by
		newGroupSet.addAll(aggregate.getGroupSet().asList());
		for (Pair<List<Integer>, Integer> argList : argLists) {
			newGroupSet.addAll(argList.getKey());
		}

		// Re-map the arguments to the aggregate A. These arguments will get
		// remapped because of the intermediate aggregate B generated as part of the
		// transformation.
		for (int arg : newGroupSet) {
			sourceOf.put(arg, projects.size());
			projects.add(RexInputRef.of2(arg, childFields));
		}
		// Generate the intermediate aggregate B
		final List<AggregateCall> aggCalls = aggregate.getAggCallList();
		final List<AggregateCall> newAggCalls = new ArrayList<>();
		final List<Integer> fakeArgs = new ArrayList<>();
		final Map<AggregateCall, Integer> callArgMap = new HashMap<>();
		// First identify the real arguments, then use the rest for fake arguments
		// e.g. if real arguments are 0, 1, 3. Then the fake arguments will be 2, 4
		for (final AggregateCall aggCall : aggCalls) {
			if (!aggCall.isDistinct()) {
				for (int arg : aggCall.getArgList()) {
					if (!sourceOf.containsKey(arg)) {
						sourceOf.put(arg, projects.size());
					}
				}
			}
		}
		int fakeArg0 = 0;
		for (final AggregateCall aggCall : aggCalls) {
			// We will deal with non-distinct aggregates below
			if (!aggCall.isDistinct()) {
				boolean isGroupKeyUsedInAgg = false;
				for (int arg : aggCall.getArgList()) {
					if (sourceOf.containsKey(arg)) {
						isGroupKeyUsedInAgg = true;
						break;
					}
				}
				if (aggCall.getArgList().size() == 0 || isGroupKeyUsedInAgg) {
					while (sourceOf.get(fakeArg0) != null) {
						++fakeArg0;
					}
					fakeArgs.add(fakeArg0);
				}
			}
		}
		for (final AggregateCall aggCall : aggCalls) {
			if (!aggCall.isDistinct()) {
				for (int arg : aggCall.getArgList()) {
					if (!sourceOf.containsKey(arg)) {
						sourceOf.remove(arg);
					}
				}
			}
		}
		// Compute the remapped arguments using fake arguments for non-distinct
		// aggregates with no arguments e.g. count(*).
		int fakeArgIdx = 0;
		for (final AggregateCall aggCall : aggCalls) {
			// Project the column corresponding to the distinct aggregate. Project
			// as-is all the non-distinct aggregates
			if (!aggCall.isDistinct()) {
				final AggregateCall newCall =
						AggregateCall.create(aggCall.getAggregation(), false,
								aggCall.getArgList(), -1,
								ImmutableBitSet.of(newGroupSet).cardinality(),
								relBuilder.peek(), null, aggCall.name);
				newAggCalls.add(newCall);
				if (newCall.getArgList().size() == 0) {
					int fakeArg = fakeArgs.get(fakeArgIdx);
					callArgMap.put(newCall, fakeArg);
					sourceOf.put(fakeArg, projects.size());
					projects.add(
							Pair.of((RexNode) new RexInputRef(fakeArg, newCall.getType()),
									newCall.getName()));
					++fakeArgIdx;
				} else {
					for (int arg : newCall.getArgList()) {
						if (sourceOf.containsKey(arg)) {
							int fakeArg = fakeArgs.get(fakeArgIdx);
							callArgMap.put(newCall, fakeArg);
							sourceOf.put(fakeArg, projects.size());
							projects.add(
									Pair.of((RexNode) new RexInputRef(fakeArg, newCall.getType()),
											newCall.getName()));
							++fakeArgIdx;
						} else {
							sourceOf.put(arg, projects.size());
							projects.add(
									Pair.of((RexNode) new RexInputRef(arg, newCall.getType()),
											newCall.getName()));
						}
					}
				}
			}
		}
		// Generate the aggregate B (see the reference example above)
		relBuilder.push(
				aggregate.copy(
						aggregate.getTraitSet(), relBuilder.build(),
						false, ImmutableBitSet.of(newGroupSet), null, newAggCalls));
		// Convert the existing aggregate to aggregate A (see the reference example above)
		final List<AggregateCall> newTopAggCalls =
				Lists.newArrayList(aggregate.getAggCallList());
		// Use the remapped arguments for the (non)distinct aggregate calls
		for (int i = 0; i < newTopAggCalls.size(); i++) {
			// Re-map arguments.
			final AggregateCall aggCall = newTopAggCalls.get(i);
			final int argCount = aggCall.getArgList().size();
			final List<Integer> newArgs = new ArrayList<>(argCount);
			final AggregateCall newCall;


			for (int j = 0; j < argCount; j++) {
				final Integer arg = aggCall.getArgList().get(j);
				if (callArgMap.containsKey(aggCall)) {
					newArgs.add(sourceOf.get(callArgMap.get(aggCall)));
				}
				else {
					newArgs.add(sourceOf.get(arg));
				}
			}
			if (aggCall.isDistinct()) {
				newCall =
						AggregateCall.create(aggCall.getAggregation(), false, newArgs,
								-1, aggregate.getGroupSet().cardinality(), relBuilder.peek(),
								aggCall.getType(), aggCall.name);
			} else {
				// If aggregate B had a COUNT aggregate call the corresponding aggregate at
				// aggregate A must be SUM. For other aggregates, it remains the same.
				if (aggCall.getAggregation() instanceof SqlCountAggFunction) {
					if (aggCall.getArgList().size() == 0) {
						newArgs.add(sourceOf.get(callArgMap.get(aggCall)));
					}
					if (hasGroupBy) {
						SqlSumAggFunction sumAgg = new SqlSumAggFunction(null);
						newCall =
								AggregateCall.create(sumAgg, false, newArgs, -1,
										aggregate.getGroupSet().cardinality(), relBuilder.peek(),
										aggCall.getType(), aggCall.getName());
					} else {
						SqlSumEmptyIsZeroAggFunction sumAgg = new SqlSumEmptyIsZeroAggFunction();
						newCall =
								AggregateCall.create(sumAgg, false, newArgs, -1,
										aggregate.getGroupSet().cardinality(), relBuilder.peek(),
										aggCall.getType(), aggCall.getName());
					}
				} else {
					newCall =
							AggregateCall.create(aggCall.getAggregation(), false, newArgs, -1,
									aggregate.getGroupSet().cardinality(),
									relBuilder.peek(), aggCall.getType(), aggCall.name);
				}
			}
			newTopAggCalls.set(i, newCall);
		}
		// Populate the group-by keys with the remapped arguments for aggregate A
		newGroupSet.clear();
		for (int arg : aggregate.getGroupSet()) {
			newGroupSet.add(sourceOf.get(arg));
		}
		relBuilder.push(
				aggregate.copy(aggregate.getTraitSet(),
						relBuilder.build(), aggregate.indicator,
						ImmutableBitSet.of(newGroupSet), null, newTopAggCalls));
		return relBuilder;
	}
	*/
@SuppressWarnings("DanglingJavadoc")
private void rewriteUsingGroupingSets(RelOptRuleCall call, Aggregate aggregate, Set<Pair<List<Integer>, Integer>> argLists) {
    final Set<ImmutableBitSet> groupSetTreeSet = new TreeSet<>(ImmutableBitSet.ORDERING);
    groupSetTreeSet.add(aggregate.getGroupSet());
    for (Pair<List<Integer>, Integer> argList : argLists) {
        groupSetTreeSet.add(ImmutableBitSet.of(argList.left).setIf(argList.right, argList.right >= 0).union(aggregate.getGroupSet()));
    }
    final ImmutableList<ImmutableBitSet> groupSets = ImmutableList.copyOf(groupSetTreeSet);
    final ImmutableBitSet fullGroupSet = ImmutableBitSet.union(groupSets);
    final List<AggregateCall> distinctAggCalls = new ArrayList<>();
    for (Pair<AggregateCall, String> aggCall : aggregate.getNamedAggCalls()) {
        if (!aggCall.left.isDistinct()) {
            distinctAggCalls.add(aggCall.left.rename(aggCall.right));
        }
    }
    final RelBuilder relBuilder = call.builder();
    relBuilder.push(aggregate.getInput());
    relBuilder.aggregate(relBuilder.groupKey(fullGroupSet, groupSets.size() > 1, groupSets), distinctAggCalls);
    final RelNode distinct = relBuilder.peek();
    final int groupCount = fullGroupSet.cardinality();
    final int indicatorCount = groupSets.size() > 1 ? groupCount : 0;
    final RelOptCluster cluster = aggregate.getCluster();
    final RexBuilder rexBuilder = cluster.getRexBuilder();
    final RelDataTypeFactory typeFactory = cluster.getTypeFactory();
    final RelDataType booleanType = typeFactory.createTypeWithNullability(typeFactory.createSqlType(SqlTypeName.BOOLEAN), false);
    final List<Pair<RexNode, String>> predicates = new ArrayList<>();
    final Map<ImmutableBitSet, Integer> filters = new HashMap<>();
    /** Function to register a filter for a group set. */
    class Registrar {

        RexNode group = null;

        private int register(ImmutableBitSet groupSet) {
            if (group == null) {
                group = makeGroup(groupCount - 1);
            }
            final RexNode node = rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, group, rexBuilder.makeExactLiteral(toNumber(remap(fullGroupSet, groupSet))));
            predicates.add(Pair.of(node, toString(groupSet)));
            return groupCount + indicatorCount + distinctAggCalls.size() + predicates.size() - 1;
        }

        private RexNode makeGroup(int i) {
            final RexInputRef ref = rexBuilder.makeInputRef(booleanType, groupCount + i);
            final RexNode kase = rexBuilder.makeCall(SqlStdOperatorTable.CASE, ref, rexBuilder.makeExactLiteral(BigDecimal.ZERO), rexBuilder.makeExactLiteral(TWO.pow(i)));
            if (i == 0) {
                return kase;
            } else {
                return rexBuilder.makeCall(SqlStdOperatorTable.PLUS, makeGroup(i - 1), kase);
            }
        }

        private BigDecimal toNumber(ImmutableBitSet bitSet) {
            BigDecimal n = BigDecimal.ZERO;
            for (int key : bitSet) {
                n = n.add(TWO.pow(key));
            }
            return n;
        }

        private String toString(ImmutableBitSet bitSet) {
            final StringBuilder buf = new StringBuilder("$i");
            for (int key : bitSet) {
                buf.append(key).append('_');
            }
            return buf.substring(0, buf.length() - 1);
        }
    }
    final Registrar registrar = new Registrar();
    for (ImmutableBitSet groupSet : groupSets) {
        filters.put(groupSet, registrar.register(groupSet));
    }
    if (!predicates.isEmpty()) {
        List<Pair<RexNode, String>> nodes = new ArrayList<>();
        for (RelDataTypeField f : relBuilder.peek().getRowType().getFieldList()) {
            final RexNode node = rexBuilder.makeInputRef(f.getType(), f.getIndex());
            nodes.add(Pair.of(node, f.getName()));
        }
        nodes.addAll(predicates);
        relBuilder.project(Pair.left(nodes), Pair.right(nodes));
    }
    int x = groupCount + indicatorCount;
    final List<AggregateCall> newCalls = new ArrayList<>();
    for (AggregateCall aggCall : aggregate.getAggCallList()) {
        final int newFilterArg;
        final List<Integer> newArgList;
        final SqlAggFunction aggregation;
        if (!aggCall.isDistinct()) {
            aggregation = SqlStdOperatorTable.MIN;
            newArgList = ImmutableIntList.of(x++);
            newFilterArg = filters.get(aggregate.getGroupSet());
        } else {
            aggregation = aggCall.getAggregation();
            newArgList = remap(fullGroupSet, aggCall.getArgList());
            newFilterArg = filters.get(ImmutableBitSet.of(aggCall.getArgList()).setIf(aggCall.filterArg, aggCall.filterArg >= 0).union(aggregate.getGroupSet()));
        }
        final AggregateCall newCall = AggregateCall.create(aggregation, false, newArgList, newFilterArg, aggregate.getGroupCount(), distinct, null, aggCall.name);
        newCalls.add(newCall);
    }
    relBuilder.aggregate(relBuilder.groupKey(remap(fullGroupSet, aggregate.getGroupSet()), aggregate.indicator, remap(fullGroupSet, aggregate.getGroupSets())), newCalls);
    relBuilder.convert(aggregate.getRowType(), true);
    call.transformTo(relBuilder.build());
}
Also used : RelOptCluster(org.apache.calcite.plan.RelOptCluster) ImmutableBitSet(org.apache.calcite.util.ImmutableBitSet) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) RelDataType(org.apache.calcite.rel.type.RelDataType) TreeSet(java.util.TreeSet) RelDataTypeFactory(org.apache.calcite.rel.type.RelDataTypeFactory) RexBuilder(org.apache.calcite.rex.RexBuilder) ArrayList(java.util.ArrayList) ImmutableList(com.google.common.collect.ImmutableList) ImmutableIntList(org.apache.calcite.util.ImmutableIntList) List(java.util.List) Pair(org.apache.calcite.util.Pair) RelBuilder(org.apache.calcite.tools.RelBuilder) SqlAggFunction(org.apache.calcite.sql.SqlAggFunction) BigDecimal(java.math.BigDecimal) AggregateCall(org.apache.calcite.rel.core.AggregateCall) RelDataTypeField(org.apache.calcite.rel.type.RelDataTypeField) RelNode(org.apache.calcite.rel.RelNode) RexInputRef(org.apache.calcite.rex.RexInputRef) RexNode(org.apache.calcite.rex.RexNode)

Example 28 with SqlAggFunction

use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.sql.SqlAggFunction in project flink by apache.

the class FlinkAggregateJoinTransposeRule method onMatch.

public void onMatch(RelOptRuleCall call) {
    final Aggregate aggregate = call.rel(0);
    final Join join = call.rel(1);
    final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
    final RelBuilder relBuilder = call.builder();
    // If any aggregate call has a filter, bail out
    for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
        if (aggregateCall.getAggregation().unwrap(SqlSplittableAggFunction.class) == null) {
            return;
        }
        if (aggregateCall.filterArg >= 0) {
            return;
        }
    }
    // aggregate operator
    if (join.getJoinType() != JoinRelType.INNER) {
        return;
    }
    if (!allowFunctions && !aggregate.getAggCallList().isEmpty()) {
        return;
    }
    // Do the columns used by the join appear in the output of the aggregate?
    final ImmutableBitSet aggregateColumns = aggregate.getGroupSet();
    final RelMetadataQuery mq = RelMetadataQuery.instance();
    final ImmutableBitSet keyColumns = keyColumns(aggregateColumns, mq.getPulledUpPredicates(join).pulledUpPredicates);
    final ImmutableBitSet joinColumns = RelOptUtil.InputFinder.bits(join.getCondition());
    final boolean allColumnsInAggregate = keyColumns.contains(joinColumns);
    final ImmutableBitSet belowAggregateColumns = aggregateColumns.union(joinColumns);
    // Split join condition
    final List<Integer> leftKeys = Lists.newArrayList();
    final List<Integer> rightKeys = Lists.newArrayList();
    final List<Boolean> filterNulls = Lists.newArrayList();
    RexNode nonEquiConj = RelOptUtil.splitJoinCondition(join.getLeft(), join.getRight(), join.getCondition(), leftKeys, rightKeys, filterNulls);
    // If it contains non-equi join conditions, we bail out
    if (!nonEquiConj.isAlwaysTrue()) {
        return;
    }
    // Push each aggregate function down to each side that contains all of its
    // arguments. Note that COUNT(*), because it has no arguments, can go to
    // both sides.
    final Map<Integer, Integer> map = new HashMap<>();
    final List<Side> sides = new ArrayList<>();
    int uniqueCount = 0;
    int offset = 0;
    int belowOffset = 0;
    for (int s = 0; s < 2; s++) {
        final Side side = new Side();
        final RelNode joinInput = join.getInput(s);
        int fieldCount = joinInput.getRowType().getFieldCount();
        final ImmutableBitSet fieldSet = ImmutableBitSet.range(offset, offset + fieldCount);
        final ImmutableBitSet belowAggregateKeyNotShifted = belowAggregateColumns.intersect(fieldSet);
        for (Ord<Integer> c : Ord.zip(belowAggregateKeyNotShifted)) {
            map.put(c.e, belowOffset + c.i);
        }
        final ImmutableBitSet belowAggregateKey = belowAggregateKeyNotShifted.shift(-offset);
        final boolean unique;
        if (!allowFunctions) {
            assert aggregate.getAggCallList().isEmpty();
            // If there are no functions, it doesn't matter as much whether we
            // aggregate the inputs before the join, because there will not be
            // any functions experiencing a cartesian product effect.
            //
            // But finding out whether the input is already unique requires a call
            // to areColumnsUnique that currently (until [CALCITE-1048] "Make
            // metadata more robust" is fixed) places a heavy load on
            // the metadata system.
            //
            // So we choose to imagine the the input is already unique, which is
            // untrue but harmless.
            //
            Util.discard(Bug.CALCITE_1048_FIXED);
            unique = true;
        } else {
            final Boolean unique0 = mq.areColumnsUnique(joinInput, belowAggregateKey);
            unique = unique0 != null && unique0;
        }
        if (unique) {
            ++uniqueCount;
            side.aggregate = false;
            side.newInput = joinInput;
        } else {
            side.aggregate = true;
            List<AggregateCall> belowAggCalls = new ArrayList<>();
            final SqlSplittableAggFunction.Registry<AggregateCall> belowAggCallRegistry = registry(belowAggCalls);
            final Mappings.TargetMapping mapping = s == 0 ? Mappings.createIdentity(fieldCount) : Mappings.createShiftMapping(fieldCount + offset, 0, offset, fieldCount);
            for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) {
                final SqlAggFunction aggregation = aggCall.e.getAggregation();
                final SqlSplittableAggFunction splitter = Preconditions.checkNotNull(aggregation.unwrap(SqlSplittableAggFunction.class));
                final AggregateCall call1;
                if (fieldSet.contains(ImmutableBitSet.of(aggCall.e.getArgList()))) {
                    call1 = splitter.split(aggCall.e, mapping);
                } else {
                    call1 = splitter.other(rexBuilder.getTypeFactory(), aggCall.e);
                }
                if (call1 != null) {
                    side.split.put(aggCall.i, belowAggregateKey.cardinality() + belowAggCallRegistry.register(call1));
                }
            }
            side.newInput = relBuilder.push(joinInput).aggregate(relBuilder.groupKey(belowAggregateKey, false, null), belowAggCalls).build();
        }
        offset += fieldCount;
        belowOffset += side.newInput.getRowType().getFieldCount();
        sides.add(side);
    }
    if (uniqueCount == 2) {
        // invocation of this rule; if we continue we might loop forever.
        return;
    }
    // Update condition
    final Mapping mapping = (Mapping) Mappings.target(new Function<Integer, Integer>() {

        public Integer apply(Integer a0) {
            return map.get(a0);
        }
    }, join.getRowType().getFieldCount(), belowOffset);
    final RexNode newCondition = RexUtil.apply(mapping, join.getCondition());
    // Create new join
    relBuilder.push(sides.get(0).newInput).push(sides.get(1).newInput).join(join.getJoinType(), newCondition);
    // Aggregate above to sum up the sub-totals
    final List<AggregateCall> newAggCalls = new ArrayList<>();
    final int groupIndicatorCount = aggregate.getGroupCount() + aggregate.getIndicatorCount();
    final int newLeftWidth = sides.get(0).newInput.getRowType().getFieldCount();
    final List<RexNode> projects = new ArrayList<>(rexBuilder.identityProjects(relBuilder.peek().getRowType()));
    for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) {
        final SqlAggFunction aggregation = aggCall.e.getAggregation();
        final SqlSplittableAggFunction splitter = Preconditions.checkNotNull(aggregation.unwrap(SqlSplittableAggFunction.class));
        final Integer leftSubTotal = sides.get(0).split.get(aggCall.i);
        final Integer rightSubTotal = sides.get(1).split.get(aggCall.i);
        newAggCalls.add(splitter.topSplit(rexBuilder, registry(projects), groupIndicatorCount, relBuilder.peek().getRowType(), aggCall.e, leftSubTotal == null ? -1 : leftSubTotal, rightSubTotal == null ? -1 : rightSubTotal + newLeftWidth));
    }
    relBuilder.project(projects);
    boolean aggConvertedToProjects = false;
    if (allColumnsInAggregate) {
        // let's see if we can convert aggregate into projects
        List<RexNode> projects2 = new ArrayList<>();
        for (int key : Mappings.apply(mapping, aggregate.getGroupSet())) {
            projects2.add(relBuilder.field(key));
        }
        for (AggregateCall newAggCall : newAggCalls) {
            final SqlSplittableAggFunction splitter = newAggCall.getAggregation().unwrap(SqlSplittableAggFunction.class);
            if (splitter != null) {
                projects2.add(splitter.singleton(rexBuilder, relBuilder.peek().getRowType(), newAggCall));
            }
        }
        if (projects2.size() == aggregate.getGroupSet().cardinality() + newAggCalls.size()) {
            // We successfully converted agg calls into projects.
            relBuilder.project(projects2);
            aggConvertedToProjects = true;
        }
    }
    if (!aggConvertedToProjects) {
        relBuilder.aggregate(relBuilder.groupKey(Mappings.apply(mapping, aggregate.getGroupSet()), aggregate.indicator, Mappings.apply2(mapping, aggregate.getGroupSets())), newAggCalls);
    }
    call.transformTo(relBuilder.build());
}
Also used : RelMetadataQuery(org.apache.calcite.rel.metadata.RelMetadataQuery) ImmutableBitSet(org.apache.calcite.util.ImmutableBitSet) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) Mapping(org.apache.calcite.util.mapping.Mapping) Function(com.google.common.base.Function) SqlSplittableAggFunction(org.apache.calcite.sql.SqlSplittableAggFunction) SqlAggFunction(org.apache.calcite.sql.SqlAggFunction) RexBuilder(org.apache.calcite.rex.RexBuilder) SqlSplittableAggFunction(org.apache.calcite.sql.SqlSplittableAggFunction) RelBuilder(org.apache.calcite.tools.RelBuilder) Join(org.apache.calcite.rel.core.Join) LogicalJoin(org.apache.calcite.rel.logical.LogicalJoin) SqlAggFunction(org.apache.calcite.sql.SqlAggFunction) AggregateCall(org.apache.calcite.rel.core.AggregateCall) RelNode(org.apache.calcite.rel.RelNode) Mappings(org.apache.calcite.util.mapping.Mappings) Aggregate(org.apache.calcite.rel.core.Aggregate) LogicalAggregate(org.apache.calcite.rel.logical.LogicalAggregate) RexNode(org.apache.calcite.rex.RexNode)

Example 29 with SqlAggFunction

use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.sql.SqlAggFunction in project storm by apache.

the class RelNodeCompiler method aggregateResult.

private String aggregateResult(AggregateCall call, PrintWriter pw) {
    SqlAggFunction aggFunction = call.getAggregation();
    String aggregationName = call.getAggregation().getName();
    Type ty = typeFactory.getJavaClass(call.getType());
    String result;
    if (aggFunction instanceof SqlUserDefinedAggFunction) {
        AggregateFunction aggregateFunction = ((SqlUserDefinedAggFunction) aggFunction).function;
        result = doAggregateResult((AggregateFunctionImpl) aggregateFunction, reserveAggVarName(call), ty, pw);
    } else {
        List<BuiltinAggregateFunctions.TypeClass> typeClasses = BuiltinAggregateFunctions.TABLE.get(aggregationName);
        if (typeClasses == null) {
            throw new UnsupportedOperationException(aggregationName + " Not implemented");
        }
        result = doAggregateResult(AggregateFunctionImpl.create(findMatchingClass(aggregationName, typeClasses, ty)), reserveAggVarName(call), ty, pw);
    }
    return result;
}
Also used : RelDataType(org.apache.calcite.rel.type.RelDataType) Type(java.lang.reflect.Type) AggregateFunction(org.apache.calcite.schema.AggregateFunction) AggregateFunctionImpl(org.apache.calcite.schema.impl.AggregateFunctionImpl) SqlUserDefinedAggFunction(org.apache.calcite.sql.validate.SqlUserDefinedAggFunction) SqlAggFunction(org.apache.calcite.sql.SqlAggFunction)

Example 30 with SqlAggFunction

use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.sql.SqlAggFunction in project drill by axbaretto.

the class AggPrelBase method createKeysAndExprs.

protected void createKeysAndExprs() {
    final List<String> childFields = getInput().getRowType().getFieldNames();
    final List<String> fields = getRowType().getFieldNames();
    for (int group : BitSets.toIter(groupSet)) {
        FieldReference fr = FieldReference.getWithQuotedRef(childFields.get(group));
        keys.add(new NamedExpression(fr, fr));
    }
    for (Ord<AggregateCall> aggCall : Ord.zip(aggCalls)) {
        int aggExprOrdinal = groupSet.cardinality() + aggCall.i;
        FieldReference ref = FieldReference.getWithQuotedRef(fields.get(aggExprOrdinal));
        LogicalExpression expr = toDrill(aggCall.e, childFields);
        NamedExpression ne = new NamedExpression(expr, ref);
        aggExprs.add(ne);
        if (getOperatorPhase() == OperatorPhase.PHASE_1of2) {
            if (aggCall.e.getAggregation().getName().equals("COUNT")) {
                // If we are doing a COUNT aggregate in Phase1of2, then in Phase2of2 we should SUM the COUNTs,
                SqlAggFunction sumAggFun = new SqlSumCountAggFunction(aggCall.e.getType());
                AggregateCall newAggCall = AggregateCall.create(sumAggFun, aggCall.e.isDistinct(), aggCall.e.isApproximate(), Collections.singletonList(aggExprOrdinal), aggCall.e.filterArg, aggCall.e.getType(), aggCall.e.getName());
                phase2AggCallList.add(newAggCall);
            } else {
                AggregateCall newAggCall = AggregateCall.create(aggCall.e.getAggregation(), aggCall.e.isDistinct(), aggCall.e.isApproximate(), Collections.singletonList(aggExprOrdinal), aggCall.e.filterArg, aggCall.e.getType(), aggCall.e.getName());
                phase2AggCallList.add(newAggCall);
            }
        }
    }
}
Also used : AggregateCall(org.apache.calcite.rel.core.AggregateCall) LogicalExpression(org.apache.drill.common.expression.LogicalExpression) FieldReference(org.apache.drill.common.expression.FieldReference) NamedExpression(org.apache.drill.common.logical.data.NamedExpression) SqlAggFunction(org.apache.calcite.sql.SqlAggFunction)

Aggregations

SqlAggFunction (org.apache.calcite.sql.SqlAggFunction)47 RelDataType (org.apache.calcite.rel.type.RelDataType)30 ArrayList (java.util.ArrayList)25 RexNode (org.apache.calcite.rex.RexNode)24 AggregateCall (org.apache.calcite.rel.core.AggregateCall)20 RexBuilder (org.apache.calcite.rex.RexBuilder)15 RelNode (org.apache.calcite.rel.RelNode)13 List (java.util.List)11 Aggregate (org.apache.calcite.rel.core.Aggregate)10 RelBuilder (org.apache.calcite.tools.RelBuilder)10 RelDataTypeFactory (org.apache.calcite.rel.type.RelDataTypeFactory)9 ImmutableBitSet (org.apache.calcite.util.ImmutableBitSet)9 HashMap (java.util.HashMap)8 ImmutableList (com.google.common.collect.ImmutableList)7 SqlOperator (org.apache.calcite.sql.SqlOperator)6 Type (java.lang.reflect.Type)5 RelMetadataQuery (org.apache.calcite.rel.metadata.RelMetadataQuery)5 RexInputRef (org.apache.calcite.rex.RexInputRef)5 SqlKind (org.apache.calcite.sql.SqlKind)5 Mappings (org.apache.calcite.util.mapping.Mappings)5