Search in sources :

Example 86 with AggregateCall

use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.AggregateCall in project flink by apache.

the class FlinkAggregateExpandDistinctAggregatesRule method rewriteUsingGroupingSets.

/*
	public RelBuilder convertSingletonDistinct(RelBuilder relBuilder,
											   Aggregate aggregate, Set<Pair<List<Integer>, Integer>> argLists) {
		// For example,
		//	SELECT deptno, COUNT(*), SUM(bonus), MIN(DISTINCT sal)
		//	FROM emp
		//	GROUP BY deptno
		//
		// becomes
		//
		//	SELECT deptno, SUM(cnt), SUM(bonus), MIN(sal)
		//	FROM (
		//		  SELECT deptno, COUNT(*) as cnt, SUM(bonus), sal
		//		  FROM EMP
		//		  GROUP BY deptno, sal)			// Aggregate B
		//	GROUP BY deptno						// Aggregate A
		relBuilder.push(aggregate.getInput());
		final List<Pair<RexNode, String>> projects = new ArrayList<>();
		final Map<Integer, Integer> sourceOf = new HashMap<>();
		SortedSet<Integer> newGroupSet = new TreeSet<>();
		final List<RelDataTypeField> childFields =
				relBuilder.peek().getRowType().getFieldList();
		final boolean hasGroupBy = aggregate.getGroupSet().size() > 0;

		// Add the distinct aggregate column(s) to the group-by columns,
		// if not already a part of the group-by
		newGroupSet.addAll(aggregate.getGroupSet().asList());
		for (Pair<List<Integer>, Integer> argList : argLists) {
			newGroupSet.addAll(argList.getKey());
		}

		// Re-map the arguments to the aggregate A. These arguments will get
		// remapped because of the intermediate aggregate B generated as part of the
		// transformation.
		for (int arg : newGroupSet) {
			sourceOf.put(arg, projects.size());
			projects.add(RexInputRef.of2(arg, childFields));
		}
		// Generate the intermediate aggregate B
		final List<AggregateCall> aggCalls = aggregate.getAggCallList();
		final List<AggregateCall> newAggCalls = new ArrayList<>();
		final List<Integer> fakeArgs = new ArrayList<>();
		final Map<AggregateCall, Integer> callArgMap = new HashMap<>();
		// First identify the real arguments, then use the rest for fake arguments
		// e.g. if real arguments are 0, 1, 3. Then the fake arguments will be 2, 4
		for (final AggregateCall aggCall : aggCalls) {
			if (!aggCall.isDistinct()) {
				for (int arg : aggCall.getArgList()) {
					if (!sourceOf.containsKey(arg)) {
						sourceOf.put(arg, projects.size());
					}
				}
			}
		}
		int fakeArg0 = 0;
		for (final AggregateCall aggCall : aggCalls) {
			// We will deal with non-distinct aggregates below
			if (!aggCall.isDistinct()) {
				boolean isGroupKeyUsedInAgg = false;
				for (int arg : aggCall.getArgList()) {
					if (sourceOf.containsKey(arg)) {
						isGroupKeyUsedInAgg = true;
						break;
					}
				}
				if (aggCall.getArgList().size() == 0 || isGroupKeyUsedInAgg) {
					while (sourceOf.get(fakeArg0) != null) {
						++fakeArg0;
					}
					fakeArgs.add(fakeArg0);
				}
			}
		}
		for (final AggregateCall aggCall : aggCalls) {
			if (!aggCall.isDistinct()) {
				for (int arg : aggCall.getArgList()) {
					if (!sourceOf.containsKey(arg)) {
						sourceOf.remove(arg);
					}
				}
			}
		}
		// Compute the remapped arguments using fake arguments for non-distinct
		// aggregates with no arguments e.g. count(*).
		int fakeArgIdx = 0;
		for (final AggregateCall aggCall : aggCalls) {
			// Project the column corresponding to the distinct aggregate. Project
			// as-is all the non-distinct aggregates
			if (!aggCall.isDistinct()) {
				final AggregateCall newCall =
						AggregateCall.create(aggCall.getAggregation(), false,
								aggCall.getArgList(), -1,
								ImmutableBitSet.of(newGroupSet).cardinality(),
								relBuilder.peek(), null, aggCall.name);
				newAggCalls.add(newCall);
				if (newCall.getArgList().size() == 0) {
					int fakeArg = fakeArgs.get(fakeArgIdx);
					callArgMap.put(newCall, fakeArg);
					sourceOf.put(fakeArg, projects.size());
					projects.add(
							Pair.of((RexNode) new RexInputRef(fakeArg, newCall.getType()),
									newCall.getName()));
					++fakeArgIdx;
				} else {
					for (int arg : newCall.getArgList()) {
						if (sourceOf.containsKey(arg)) {
							int fakeArg = fakeArgs.get(fakeArgIdx);
							callArgMap.put(newCall, fakeArg);
							sourceOf.put(fakeArg, projects.size());
							projects.add(
									Pair.of((RexNode) new RexInputRef(fakeArg, newCall.getType()),
											newCall.getName()));
							++fakeArgIdx;
						} else {
							sourceOf.put(arg, projects.size());
							projects.add(
									Pair.of((RexNode) new RexInputRef(arg, newCall.getType()),
											newCall.getName()));
						}
					}
				}
			}
		}
		// Generate the aggregate B (see the reference example above)
		relBuilder.push(
				aggregate.copy(
						aggregate.getTraitSet(), relBuilder.build(),
						false, ImmutableBitSet.of(newGroupSet), null, newAggCalls));
		// Convert the existing aggregate to aggregate A (see the reference example above)
		final List<AggregateCall> newTopAggCalls =
				Lists.newArrayList(aggregate.getAggCallList());
		// Use the remapped arguments for the (non)distinct aggregate calls
		for (int i = 0; i < newTopAggCalls.size(); i++) {
			// Re-map arguments.
			final AggregateCall aggCall = newTopAggCalls.get(i);
			final int argCount = aggCall.getArgList().size();
			final List<Integer> newArgs = new ArrayList<>(argCount);
			final AggregateCall newCall;


			for (int j = 0; j < argCount; j++) {
				final Integer arg = aggCall.getArgList().get(j);
				if (callArgMap.containsKey(aggCall)) {
					newArgs.add(sourceOf.get(callArgMap.get(aggCall)));
				}
				else {
					newArgs.add(sourceOf.get(arg));
				}
			}
			if (aggCall.isDistinct()) {
				newCall =
						AggregateCall.create(aggCall.getAggregation(), false, newArgs,
								-1, aggregate.getGroupSet().cardinality(), relBuilder.peek(),
								aggCall.getType(), aggCall.name);
			} else {
				// If aggregate B had a COUNT aggregate call the corresponding aggregate at
				// aggregate A must be SUM. For other aggregates, it remains the same.
				if (aggCall.getAggregation() instanceof SqlCountAggFunction) {
					if (aggCall.getArgList().size() == 0) {
						newArgs.add(sourceOf.get(callArgMap.get(aggCall)));
					}
					if (hasGroupBy) {
						SqlSumAggFunction sumAgg = new SqlSumAggFunction(null);
						newCall =
								AggregateCall.create(sumAgg, false, newArgs, -1,
										aggregate.getGroupSet().cardinality(), relBuilder.peek(),
										aggCall.getType(), aggCall.getName());
					} else {
						SqlSumEmptyIsZeroAggFunction sumAgg = new SqlSumEmptyIsZeroAggFunction();
						newCall =
								AggregateCall.create(sumAgg, false, newArgs, -1,
										aggregate.getGroupSet().cardinality(), relBuilder.peek(),
										aggCall.getType(), aggCall.getName());
					}
				} else {
					newCall =
							AggregateCall.create(aggCall.getAggregation(), false, newArgs, -1,
									aggregate.getGroupSet().cardinality(),
									relBuilder.peek(), aggCall.getType(), aggCall.name);
				}
			}
			newTopAggCalls.set(i, newCall);
		}
		// Populate the group-by keys with the remapped arguments for aggregate A
		newGroupSet.clear();
		for (int arg : aggregate.getGroupSet()) {
			newGroupSet.add(sourceOf.get(arg));
		}
		relBuilder.push(
				aggregate.copy(aggregate.getTraitSet(),
						relBuilder.build(), aggregate.indicator,
						ImmutableBitSet.of(newGroupSet), null, newTopAggCalls));
		return relBuilder;
	}
	*/
@SuppressWarnings("DanglingJavadoc")
private void rewriteUsingGroupingSets(RelOptRuleCall call, Aggregate aggregate, Set<Pair<List<Integer>, Integer>> argLists) {
    final Set<ImmutableBitSet> groupSetTreeSet = new TreeSet<>(ImmutableBitSet.ORDERING);
    groupSetTreeSet.add(aggregate.getGroupSet());
    for (Pair<List<Integer>, Integer> argList : argLists) {
        groupSetTreeSet.add(ImmutableBitSet.of(argList.left).setIf(argList.right, argList.right >= 0).union(aggregate.getGroupSet()));
    }
    final ImmutableList<ImmutableBitSet> groupSets = ImmutableList.copyOf(groupSetTreeSet);
    final ImmutableBitSet fullGroupSet = ImmutableBitSet.union(groupSets);
    final List<AggregateCall> distinctAggCalls = new ArrayList<>();
    for (Pair<AggregateCall, String> aggCall : aggregate.getNamedAggCalls()) {
        if (!aggCall.left.isDistinct()) {
            distinctAggCalls.add(aggCall.left.rename(aggCall.right));
        }
    }
    final RelBuilder relBuilder = call.builder();
    relBuilder.push(aggregate.getInput());
    relBuilder.aggregate(relBuilder.groupKey(fullGroupSet, groupSets.size() > 1, groupSets), distinctAggCalls);
    final RelNode distinct = relBuilder.peek();
    final int groupCount = fullGroupSet.cardinality();
    final int indicatorCount = groupSets.size() > 1 ? groupCount : 0;
    final RelOptCluster cluster = aggregate.getCluster();
    final RexBuilder rexBuilder = cluster.getRexBuilder();
    final RelDataTypeFactory typeFactory = cluster.getTypeFactory();
    final RelDataType booleanType = typeFactory.createTypeWithNullability(typeFactory.createSqlType(SqlTypeName.BOOLEAN), false);
    final List<Pair<RexNode, String>> predicates = new ArrayList<>();
    final Map<ImmutableBitSet, Integer> filters = new HashMap<>();
    /** Function to register a filter for a group set. */
    class Registrar {

        RexNode group = null;

        private int register(ImmutableBitSet groupSet) {
            if (group == null) {
                group = makeGroup(groupCount - 1);
            }
            final RexNode node = rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, group, rexBuilder.makeExactLiteral(toNumber(remap(fullGroupSet, groupSet))));
            predicates.add(Pair.of(node, toString(groupSet)));
            return groupCount + indicatorCount + distinctAggCalls.size() + predicates.size() - 1;
        }

        private RexNode makeGroup(int i) {
            final RexInputRef ref = rexBuilder.makeInputRef(booleanType, groupCount + i);
            final RexNode kase = rexBuilder.makeCall(SqlStdOperatorTable.CASE, ref, rexBuilder.makeExactLiteral(BigDecimal.ZERO), rexBuilder.makeExactLiteral(TWO.pow(i)));
            if (i == 0) {
                return kase;
            } else {
                return rexBuilder.makeCall(SqlStdOperatorTable.PLUS, makeGroup(i - 1), kase);
            }
        }

        private BigDecimal toNumber(ImmutableBitSet bitSet) {
            BigDecimal n = BigDecimal.ZERO;
            for (int key : bitSet) {
                n = n.add(TWO.pow(key));
            }
            return n;
        }

        private String toString(ImmutableBitSet bitSet) {
            final StringBuilder buf = new StringBuilder("$i");
            for (int key : bitSet) {
                buf.append(key).append('_');
            }
            return buf.substring(0, buf.length() - 1);
        }
    }
    final Registrar registrar = new Registrar();
    for (ImmutableBitSet groupSet : groupSets) {
        filters.put(groupSet, registrar.register(groupSet));
    }
    if (!predicates.isEmpty()) {
        List<Pair<RexNode, String>> nodes = new ArrayList<>();
        for (RelDataTypeField f : relBuilder.peek().getRowType().getFieldList()) {
            final RexNode node = rexBuilder.makeInputRef(f.getType(), f.getIndex());
            nodes.add(Pair.of(node, f.getName()));
        }
        nodes.addAll(predicates);
        relBuilder.project(Pair.left(nodes), Pair.right(nodes));
    }
    int x = groupCount + indicatorCount;
    final List<AggregateCall> newCalls = new ArrayList<>();
    for (AggregateCall aggCall : aggregate.getAggCallList()) {
        final int newFilterArg;
        final List<Integer> newArgList;
        final SqlAggFunction aggregation;
        if (!aggCall.isDistinct()) {
            aggregation = SqlStdOperatorTable.MIN;
            newArgList = ImmutableIntList.of(x++);
            newFilterArg = filters.get(aggregate.getGroupSet());
        } else {
            aggregation = aggCall.getAggregation();
            newArgList = remap(fullGroupSet, aggCall.getArgList());
            newFilterArg = filters.get(ImmutableBitSet.of(aggCall.getArgList()).setIf(aggCall.filterArg, aggCall.filterArg >= 0).union(aggregate.getGroupSet()));
        }
        final AggregateCall newCall = AggregateCall.create(aggregation, false, newArgList, newFilterArg, aggregate.getGroupCount(), distinct, null, aggCall.name);
        newCalls.add(newCall);
    }
    relBuilder.aggregate(relBuilder.groupKey(remap(fullGroupSet, aggregate.getGroupSet()), aggregate.indicator, remap(fullGroupSet, aggregate.getGroupSets())), newCalls);
    relBuilder.convert(aggregate.getRowType(), true);
    call.transformTo(relBuilder.build());
}
Also used : RelOptCluster(org.apache.calcite.plan.RelOptCluster) ImmutableBitSet(org.apache.calcite.util.ImmutableBitSet) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) RelDataType(org.apache.calcite.rel.type.RelDataType) TreeSet(java.util.TreeSet) RelDataTypeFactory(org.apache.calcite.rel.type.RelDataTypeFactory) RexBuilder(org.apache.calcite.rex.RexBuilder) ArrayList(java.util.ArrayList) ImmutableList(com.google.common.collect.ImmutableList) ImmutableIntList(org.apache.calcite.util.ImmutableIntList) List(java.util.List) Pair(org.apache.calcite.util.Pair) RelBuilder(org.apache.calcite.tools.RelBuilder) SqlAggFunction(org.apache.calcite.sql.SqlAggFunction) BigDecimal(java.math.BigDecimal) AggregateCall(org.apache.calcite.rel.core.AggregateCall) RelDataTypeField(org.apache.calcite.rel.type.RelDataTypeField) RelNode(org.apache.calcite.rel.RelNode) RexInputRef(org.apache.calcite.rex.RexInputRef) RexNode(org.apache.calcite.rex.RexNode)

Example 87 with AggregateCall

use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.AggregateCall in project flink by apache.

the class FlinkAggregateJoinTransposeRule method onMatch.

public void onMatch(RelOptRuleCall call) {
    final Aggregate aggregate = call.rel(0);
    final Join join = call.rel(1);
    final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
    final RelBuilder relBuilder = call.builder();
    // If any aggregate call has a filter, bail out
    for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
        if (aggregateCall.getAggregation().unwrap(SqlSplittableAggFunction.class) == null) {
            return;
        }
        if (aggregateCall.filterArg >= 0) {
            return;
        }
    }
    // aggregate operator
    if (join.getJoinType() != JoinRelType.INNER) {
        return;
    }
    if (!allowFunctions && !aggregate.getAggCallList().isEmpty()) {
        return;
    }
    // Do the columns used by the join appear in the output of the aggregate?
    final ImmutableBitSet aggregateColumns = aggregate.getGroupSet();
    final RelMetadataQuery mq = RelMetadataQuery.instance();
    final ImmutableBitSet keyColumns = keyColumns(aggregateColumns, mq.getPulledUpPredicates(join).pulledUpPredicates);
    final ImmutableBitSet joinColumns = RelOptUtil.InputFinder.bits(join.getCondition());
    final boolean allColumnsInAggregate = keyColumns.contains(joinColumns);
    final ImmutableBitSet belowAggregateColumns = aggregateColumns.union(joinColumns);
    // Split join condition
    final List<Integer> leftKeys = Lists.newArrayList();
    final List<Integer> rightKeys = Lists.newArrayList();
    final List<Boolean> filterNulls = Lists.newArrayList();
    RexNode nonEquiConj = RelOptUtil.splitJoinCondition(join.getLeft(), join.getRight(), join.getCondition(), leftKeys, rightKeys, filterNulls);
    // If it contains non-equi join conditions, we bail out
    if (!nonEquiConj.isAlwaysTrue()) {
        return;
    }
    // Push each aggregate function down to each side that contains all of its
    // arguments. Note that COUNT(*), because it has no arguments, can go to
    // both sides.
    final Map<Integer, Integer> map = new HashMap<>();
    final List<Side> sides = new ArrayList<>();
    int uniqueCount = 0;
    int offset = 0;
    int belowOffset = 0;
    for (int s = 0; s < 2; s++) {
        final Side side = new Side();
        final RelNode joinInput = join.getInput(s);
        int fieldCount = joinInput.getRowType().getFieldCount();
        final ImmutableBitSet fieldSet = ImmutableBitSet.range(offset, offset + fieldCount);
        final ImmutableBitSet belowAggregateKeyNotShifted = belowAggregateColumns.intersect(fieldSet);
        for (Ord<Integer> c : Ord.zip(belowAggregateKeyNotShifted)) {
            map.put(c.e, belowOffset + c.i);
        }
        final ImmutableBitSet belowAggregateKey = belowAggregateKeyNotShifted.shift(-offset);
        final boolean unique;
        if (!allowFunctions) {
            assert aggregate.getAggCallList().isEmpty();
            // If there are no functions, it doesn't matter as much whether we
            // aggregate the inputs before the join, because there will not be
            // any functions experiencing a cartesian product effect.
            //
            // But finding out whether the input is already unique requires a call
            // to areColumnsUnique that currently (until [CALCITE-1048] "Make
            // metadata more robust" is fixed) places a heavy load on
            // the metadata system.
            //
            // So we choose to imagine the the input is already unique, which is
            // untrue but harmless.
            //
            Util.discard(Bug.CALCITE_1048_FIXED);
            unique = true;
        } else {
            final Boolean unique0 = mq.areColumnsUnique(joinInput, belowAggregateKey);
            unique = unique0 != null && unique0;
        }
        if (unique) {
            ++uniqueCount;
            side.aggregate = false;
            side.newInput = joinInput;
        } else {
            side.aggregate = true;
            List<AggregateCall> belowAggCalls = new ArrayList<>();
            final SqlSplittableAggFunction.Registry<AggregateCall> belowAggCallRegistry = registry(belowAggCalls);
            final Mappings.TargetMapping mapping = s == 0 ? Mappings.createIdentity(fieldCount) : Mappings.createShiftMapping(fieldCount + offset, 0, offset, fieldCount);
            for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) {
                final SqlAggFunction aggregation = aggCall.e.getAggregation();
                final SqlSplittableAggFunction splitter = Preconditions.checkNotNull(aggregation.unwrap(SqlSplittableAggFunction.class));
                final AggregateCall call1;
                if (fieldSet.contains(ImmutableBitSet.of(aggCall.e.getArgList()))) {
                    call1 = splitter.split(aggCall.e, mapping);
                } else {
                    call1 = splitter.other(rexBuilder.getTypeFactory(), aggCall.e);
                }
                if (call1 != null) {
                    side.split.put(aggCall.i, belowAggregateKey.cardinality() + belowAggCallRegistry.register(call1));
                }
            }
            side.newInput = relBuilder.push(joinInput).aggregate(relBuilder.groupKey(belowAggregateKey, false, null), belowAggCalls).build();
        }
        offset += fieldCount;
        belowOffset += side.newInput.getRowType().getFieldCount();
        sides.add(side);
    }
    if (uniqueCount == 2) {
        // invocation of this rule; if we continue we might loop forever.
        return;
    }
    // Update condition
    final Mapping mapping = (Mapping) Mappings.target(new Function<Integer, Integer>() {

        public Integer apply(Integer a0) {
            return map.get(a0);
        }
    }, join.getRowType().getFieldCount(), belowOffset);
    final RexNode newCondition = RexUtil.apply(mapping, join.getCondition());
    // Create new join
    relBuilder.push(sides.get(0).newInput).push(sides.get(1).newInput).join(join.getJoinType(), newCondition);
    // Aggregate above to sum up the sub-totals
    final List<AggregateCall> newAggCalls = new ArrayList<>();
    final int groupIndicatorCount = aggregate.getGroupCount() + aggregate.getIndicatorCount();
    final int newLeftWidth = sides.get(0).newInput.getRowType().getFieldCount();
    final List<RexNode> projects = new ArrayList<>(rexBuilder.identityProjects(relBuilder.peek().getRowType()));
    for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) {
        final SqlAggFunction aggregation = aggCall.e.getAggregation();
        final SqlSplittableAggFunction splitter = Preconditions.checkNotNull(aggregation.unwrap(SqlSplittableAggFunction.class));
        final Integer leftSubTotal = sides.get(0).split.get(aggCall.i);
        final Integer rightSubTotal = sides.get(1).split.get(aggCall.i);
        newAggCalls.add(splitter.topSplit(rexBuilder, registry(projects), groupIndicatorCount, relBuilder.peek().getRowType(), aggCall.e, leftSubTotal == null ? -1 : leftSubTotal, rightSubTotal == null ? -1 : rightSubTotal + newLeftWidth));
    }
    relBuilder.project(projects);
    boolean aggConvertedToProjects = false;
    if (allColumnsInAggregate) {
        // let's see if we can convert aggregate into projects
        List<RexNode> projects2 = new ArrayList<>();
        for (int key : Mappings.apply(mapping, aggregate.getGroupSet())) {
            projects2.add(relBuilder.field(key));
        }
        for (AggregateCall newAggCall : newAggCalls) {
            final SqlSplittableAggFunction splitter = newAggCall.getAggregation().unwrap(SqlSplittableAggFunction.class);
            if (splitter != null) {
                projects2.add(splitter.singleton(rexBuilder, relBuilder.peek().getRowType(), newAggCall));
            }
        }
        if (projects2.size() == aggregate.getGroupSet().cardinality() + newAggCalls.size()) {
            // We successfully converted agg calls into projects.
            relBuilder.project(projects2);
            aggConvertedToProjects = true;
        }
    }
    if (!aggConvertedToProjects) {
        relBuilder.aggregate(relBuilder.groupKey(Mappings.apply(mapping, aggregate.getGroupSet()), aggregate.indicator, Mappings.apply2(mapping, aggregate.getGroupSets())), newAggCalls);
    }
    call.transformTo(relBuilder.build());
}
Also used : RelMetadataQuery(org.apache.calcite.rel.metadata.RelMetadataQuery) ImmutableBitSet(org.apache.calcite.util.ImmutableBitSet) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) Mapping(org.apache.calcite.util.mapping.Mapping) Function(com.google.common.base.Function) SqlSplittableAggFunction(org.apache.calcite.sql.SqlSplittableAggFunction) SqlAggFunction(org.apache.calcite.sql.SqlAggFunction) RexBuilder(org.apache.calcite.rex.RexBuilder) SqlSplittableAggFunction(org.apache.calcite.sql.SqlSplittableAggFunction) RelBuilder(org.apache.calcite.tools.RelBuilder) Join(org.apache.calcite.rel.core.Join) LogicalJoin(org.apache.calcite.rel.logical.LogicalJoin) SqlAggFunction(org.apache.calcite.sql.SqlAggFunction) AggregateCall(org.apache.calcite.rel.core.AggregateCall) RelNode(org.apache.calcite.rel.RelNode) Mappings(org.apache.calcite.util.mapping.Mappings) Aggregate(org.apache.calcite.rel.core.Aggregate) LogicalAggregate(org.apache.calcite.rel.logical.LogicalAggregate) RexNode(org.apache.calcite.rex.RexNode)

Example 88 with AggregateCall

use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.AggregateCall in project storm by apache.

the class RelNodeCompiler method emitAggregateStmts.

private String emitAggregateStmts(Aggregate aggregate) {
    List<String> res = new ArrayList<>();
    StringWriter sw = new StringWriter();
    for (AggregateCall call : aggregate.getAggCallList()) {
        res.add(aggregateResult(call, new PrintWriter(sw)));
    }
    return NEW_LINE_JOINER.join(sw.toString(), String.format("          ctx.emit(new Values(%s, %s));", groupValueEmitStr("groupValues", aggregate.getGroupSet().cardinality()), Joiner.on(", ").join(res)));
}
Also used : AggregateCall(org.apache.calcite.rel.core.AggregateCall) StringWriter(java.io.StringWriter) ArrayList(java.util.ArrayList) PrintWriter(java.io.PrintWriter)

Example 89 with AggregateCall

use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.AggregateCall in project druid by druid-io.

the class GroupByRules method applyAggregate.

/**
   * Applies a filter -> project -> aggregate chain to a druidRel. Do not call this method unless
   * {@link #canApplyAggregate(DruidRel, Filter, Project, Aggregate)} returns true.
   *
   * @return new rel, or null if the chain cannot be applied
   */
private static DruidRel applyAggregate(final DruidRel druidRel, final Filter filter0, final Project project0, final Aggregate aggregate, final DruidOperatorTable operatorTable, final boolean approximateCountDistinct) {
    Preconditions.checkState(canApplyAggregate(druidRel, filter0, project0, aggregate), "Cannot applyAggregate.");
    final RowSignature sourceRowSignature;
    final boolean isNestedQuery = druidRel.getQueryBuilder().getGrouping() != null;
    if (isNestedQuery) {
        // Nested groupBy; source row signature is the output signature of druidRel.
        sourceRowSignature = druidRel.getOutputRowSignature();
    } else {
        sourceRowSignature = druidRel.getSourceRowSignature();
    }
    // Filter that should be applied before aggregating.
    final DimFilter filter;
    if (filter0 != null) {
        filter = Expressions.toFilter(operatorTable, druidRel.getPlannerContext(), sourceRowSignature, filter0.getCondition());
        if (filter == null) {
            // Can't plan this filter.
            return null;
        }
    } else if (druidRel.getQueryBuilder().getFilter() != null && !isNestedQuery) {
        // We're going to replace the existing druidRel, so inherit its filter.
        filter = druidRel.getQueryBuilder().getFilter();
    } else {
        filter = null;
    }
    // Projection that should be applied before aggregating.
    final Project project;
    if (project0 != null) {
        project = project0;
    } else if (druidRel.getQueryBuilder().getSelectProjection() != null && !isNestedQuery) {
        // We're going to replace the existing druidRel, so inherit its projection.
        project = druidRel.getQueryBuilder().getSelectProjection().getProject();
    } else {
        project = null;
    }
    final List<DimensionSpec> dimensions = Lists.newArrayList();
    final List<Aggregation> aggregations = Lists.newArrayList();
    final List<String> rowOrder = Lists.newArrayList();
    // Translate groupSet.
    final ImmutableBitSet groupSet = aggregate.getGroupSet();
    int dimOutputNameCounter = 0;
    for (int i : groupSet) {
        if (project != null && project.getChildExps().get(i) instanceof RexLiteral) {
            // Ignore literals in GROUP BY, so a user can write e.g. "GROUP BY 'dummy'" to group everything into a single
            // row. Add dummy rowOrder entry so NULLs come out. This is not strictly correct but it works as long as
            // nobody actually expects to see the literal.
            rowOrder.add(dimOutputName(dimOutputNameCounter++));
        } else {
            final RexNode rexNode = Expressions.fromFieldAccess(sourceRowSignature, project, i);
            final RowExtraction rex = Expressions.toRowExtraction(operatorTable, druidRel.getPlannerContext(), sourceRowSignature.getRowOrder(), rexNode);
            if (rex == null) {
                return null;
            }
            final SqlTypeName sqlTypeName = rexNode.getType().getSqlTypeName();
            final ValueType outputType = Calcites.getValueTypeForSqlTypeName(sqlTypeName);
            if (outputType == null) {
                throw new ISE("Cannot translate sqlTypeName[%s] to Druid type for field[%s]", sqlTypeName, rowOrder.get(i));
            }
            final DimensionSpec dimensionSpec = rex.toDimensionSpec(sourceRowSignature, dimOutputName(dimOutputNameCounter++), outputType);
            if (dimensionSpec == null) {
                return null;
            }
            dimensions.add(dimensionSpec);
            rowOrder.add(dimensionSpec.getOutputName());
        }
    }
    // Translate aggregates.
    for (int i = 0; i < aggregate.getAggCallList().size(); i++) {
        final AggregateCall aggCall = aggregate.getAggCallList().get(i);
        final Aggregation aggregation = translateAggregateCall(druidRel.getPlannerContext(), sourceRowSignature, project, aggCall, operatorTable, aggregations, i, approximateCountDistinct);
        if (aggregation == null) {
            return null;
        }
        aggregations.add(aggregation);
        rowOrder.add(aggregation.getOutputName());
    }
    if (isNestedQuery) {
        // Nested groupBy.
        return DruidNestedGroupBy.from(druidRel, filter, Grouping.create(dimensions, aggregations), aggregate.getRowType(), rowOrder);
    } else {
        // groupBy on a base dataSource.
        return druidRel.withQueryBuilder(druidRel.getQueryBuilder().withFilter(filter).withGrouping(Grouping.create(dimensions, aggregations), aggregate.getRowType(), rowOrder));
    }
}
Also used : DimensionSpec(io.druid.query.dimension.DimensionSpec) RexLiteral(org.apache.calcite.rex.RexLiteral) ImmutableBitSet(org.apache.calcite.util.ImmutableBitSet) SqlTypeName(org.apache.calcite.sql.type.SqlTypeName) ValueType(io.druid.segment.column.ValueType) RowExtraction(io.druid.sql.calcite.expression.RowExtraction) Aggregation(io.druid.sql.calcite.aggregation.Aggregation) AggregateCall(org.apache.calcite.rel.core.AggregateCall) Project(org.apache.calcite.rel.core.Project) ISE(io.druid.java.util.common.ISE) RowSignature(io.druid.sql.calcite.table.RowSignature) DimFilter(io.druid.query.filter.DimFilter) NotDimFilter(io.druid.query.filter.NotDimFilter) AndDimFilter(io.druid.query.filter.AndDimFilter) RexNode(org.apache.calcite.rex.RexNode)

Example 90 with AggregateCall

use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.AggregateCall in project drill by apache.

the class ConvertCountToDirectScan method onMatch.

@Override
public void onMatch(RelOptRuleCall call) {
    final DrillAggregateRel agg = (DrillAggregateRel) call.rel(0);
    final DrillScanRel scan = (DrillScanRel) call.rel(call.rels.length - 1);
    final DrillProjectRel proj = call.rels.length == 3 ? (DrillProjectRel) call.rel(1) : null;
    final GroupScan oldGrpScan = scan.getGroupScan();
    final PlannerSettings settings = PrelUtil.getPlannerSettings(call.getPlanner());
    //    4) No distinct agg call.
    if (!(oldGrpScan.getScanStats(settings).getGroupScanProperty().hasExactRowCount() && agg.getGroupCount() == 0 && agg.getAggCallList().size() == 1 && !agg.containsDistinctCall())) {
        return;
    }
    AggregateCall aggCall = agg.getAggCallList().get(0);
    if (aggCall.getAggregation().getName().equals("COUNT")) {
        long cnt = 0;
        //  count(Not-null-input) ==> rowCount
        if (aggCall.getArgList().isEmpty() || (aggCall.getArgList().size() == 1 && !agg.getInput().getRowType().getFieldList().get(aggCall.getArgList().get(0).intValue()).getType().isNullable())) {
            cnt = (long) oldGrpScan.getScanStats(settings).getRecordCount();
        } else if (aggCall.getArgList().size() == 1) {
            // count(columnName) ==> Agg ( Scan )) ==> columnValueCount
            int index = aggCall.getArgList().get(0);
            if (proj != null) {
                if (proj.getProjects().get(index) instanceof RexInputRef) {
                    index = ((RexInputRef) proj.getProjects().get(index)).getIndex();
                } else {
                    // do not apply for all other cases.
                    return;
                }
            }
            String columnName = scan.getRowType().getFieldNames().get(index).toLowerCase();
            cnt = oldGrpScan.getColumnValueCount(SchemaPath.getSimplePath(columnName));
            if (cnt == GroupScan.NO_COLUMN_STATS) {
                // if column stats are not available don't apply this rule
                return;
            }
        } else {
            // do nothing.
            return;
        }
        RelDataType scanRowType = getCountDirectScanRowType(agg.getCluster().getTypeFactory());
        final ScanPrel newScan = ScanPrel.create(scan, scan.getTraitSet().plus(Prel.DRILL_PHYSICAL).plus(DrillDistributionTrait.SINGLETON), getCountDirectScan(cnt), scanRowType);
        List<RexNode> exprs = Lists.newArrayList();
        exprs.add(RexInputRef.of(0, scanRowType));
        final ProjectPrel newProj = new ProjectPrel(agg.getCluster(), agg.getTraitSet().plus(Prel.DRILL_PHYSICAL).plus(DrillDistributionTrait.SINGLETON), newScan, exprs, agg.getRowType());
        call.transformTo(newProj);
    }
}
Also used : DrillScanRel(org.apache.drill.exec.planner.logical.DrillScanRel) DrillProjectRel(org.apache.drill.exec.planner.logical.DrillProjectRel) DrillAggregateRel(org.apache.drill.exec.planner.logical.DrillAggregateRel) RelDataType(org.apache.calcite.rel.type.RelDataType) DirectGroupScan(org.apache.drill.exec.store.direct.DirectGroupScan) GroupScan(org.apache.drill.exec.physical.base.GroupScan) AggregateCall(org.apache.calcite.rel.core.AggregateCall) RexInputRef(org.apache.calcite.rex.RexInputRef) RexNode(org.apache.calcite.rex.RexNode)

Aggregations

AggregateCall (org.apache.calcite.rel.core.AggregateCall)158 ArrayList (java.util.ArrayList)82 RexNode (org.apache.calcite.rex.RexNode)78 ImmutableBitSet (org.apache.calcite.util.ImmutableBitSet)57 RelNode (org.apache.calcite.rel.RelNode)54 RexBuilder (org.apache.calcite.rex.RexBuilder)52 RelDataType (org.apache.calcite.rel.type.RelDataType)42 Aggregate (org.apache.calcite.rel.core.Aggregate)37 RelDataTypeField (org.apache.calcite.rel.type.RelDataTypeField)36 RexInputRef (org.apache.calcite.rex.RexInputRef)33 RelBuilder (org.apache.calcite.tools.RelBuilder)29 HashMap (java.util.HashMap)28 SqlAggFunction (org.apache.calcite.sql.SqlAggFunction)28 List (java.util.List)27 RexLiteral (org.apache.calcite.rex.RexLiteral)23 Pair (org.apache.calcite.util.Pair)20 ImmutableList (com.google.common.collect.ImmutableList)19 Project (org.apache.calcite.rel.core.Project)17 RelDataTypeFactory (org.apache.calcite.rel.type.RelDataTypeFactory)17 LogicalAggregate (org.apache.calcite.rel.logical.LogicalAggregate)16