use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.Aggregate in project calcite by apache.
the class MutableRels method toMutable.
public static MutableRel toMutable(RelNode rel) {
if (rel instanceof HepRelVertex) {
return toMutable(((HepRelVertex) rel).getCurrentRel());
}
if (rel instanceof RelSubset) {
return toMutable(Util.first(((RelSubset) rel).getBest(), ((RelSubset) rel).getOriginal()));
}
if (rel instanceof TableScan) {
return MutableScan.of((TableScan) rel);
}
if (rel instanceof Values) {
return MutableValues.of((Values) rel);
}
if (rel instanceof Project) {
final Project project = (Project) rel;
final MutableRel input = toMutable(project.getInput());
return MutableProject.of(input, project.getProjects(), project.getRowType().getFieldNames());
}
if (rel instanceof Filter) {
final Filter filter = (Filter) rel;
final MutableRel input = toMutable(filter.getInput());
return MutableFilter.of(input, filter.getCondition());
}
if (rel instanceof Aggregate) {
final Aggregate aggregate = (Aggregate) rel;
final MutableRel input = toMutable(aggregate.getInput());
return MutableAggregate.of(input, aggregate.getGroupSet(), aggregate.getGroupSets(), aggregate.getAggCallList());
}
if (rel instanceof Sort) {
final Sort sort = (Sort) rel;
final MutableRel input = toMutable(sort.getInput());
return MutableSort.of(input, sort.getCollation(), sort.offset, sort.fetch);
}
if (rel instanceof Calc) {
final Calc calc = (Calc) rel;
final MutableRel input = toMutable(calc.getInput());
return MutableCalc.of(input, calc.getProgram());
}
if (rel instanceof Exchange) {
final Exchange exchange = (Exchange) rel;
final MutableRel input = toMutable(exchange.getInput());
return MutableExchange.of(input, exchange.getDistribution());
}
if (rel instanceof Collect) {
final Collect collect = (Collect) rel;
final MutableRel input = toMutable(collect.getInput());
return MutableCollect.of(collect.getRowType(), input, collect.getFieldName());
}
if (rel instanceof Uncollect) {
final Uncollect uncollect = (Uncollect) rel;
final MutableRel input = toMutable(uncollect.getInput());
return MutableUncollect.of(uncollect.getRowType(), input, uncollect.withOrdinality);
}
if (rel instanceof Window) {
final Window window = (Window) rel;
final MutableRel input = toMutable(window.getInput());
return MutableWindow.of(window.getRowType(), input, window.groups, window.getConstants());
}
if (rel instanceof TableModify) {
final TableModify modify = (TableModify) rel;
final MutableRel input = toMutable(modify.getInput());
return MutableTableModify.of(modify.getRowType(), input, modify.getTable(), modify.getCatalogReader(), modify.getOperation(), modify.getUpdateColumnList(), modify.getSourceExpressionList(), modify.isFlattened());
}
if (rel instanceof Sample) {
final Sample sample = (Sample) rel;
final MutableRel input = toMutable(sample.getInput());
return MutableSample.of(input, sample.getSamplingParameters());
}
if (rel instanceof TableFunctionScan) {
final TableFunctionScan tableFunctionScan = (TableFunctionScan) rel;
final List<MutableRel> inputs = toMutables(tableFunctionScan.getInputs());
return MutableTableFunctionScan.of(tableFunctionScan.getCluster(), tableFunctionScan.getRowType(), inputs, tableFunctionScan.getCall(), tableFunctionScan.getElementType(), tableFunctionScan.getColumnMappings());
}
// is a sub-class of Join.
if (rel instanceof SemiJoin) {
final SemiJoin semiJoin = (SemiJoin) rel;
final MutableRel left = toMutable(semiJoin.getLeft());
final MutableRel right = toMutable(semiJoin.getRight());
return MutableSemiJoin.of(semiJoin.getRowType(), left, right, semiJoin.getCondition(), semiJoin.getLeftKeys(), semiJoin.getRightKeys());
}
if (rel instanceof Join) {
final Join join = (Join) rel;
final MutableRel left = toMutable(join.getLeft());
final MutableRel right = toMutable(join.getRight());
return MutableJoin.of(join.getRowType(), left, right, join.getCondition(), join.getJoinType(), join.getVariablesSet());
}
if (rel instanceof Correlate) {
final Correlate correlate = (Correlate) rel;
final MutableRel left = toMutable(correlate.getLeft());
final MutableRel right = toMutable(correlate.getRight());
return MutableCorrelate.of(correlate.getRowType(), left, right, correlate.getCorrelationId(), correlate.getRequiredColumns(), correlate.getJoinType());
}
if (rel instanceof Union) {
final Union union = (Union) rel;
final List<MutableRel> inputs = toMutables(union.getInputs());
return MutableUnion.of(union.getRowType(), inputs, union.all);
}
if (rel instanceof Minus) {
final Minus minus = (Minus) rel;
final List<MutableRel> inputs = toMutables(minus.getInputs());
return MutableMinus.of(minus.getRowType(), inputs, minus.all);
}
if (rel instanceof Intersect) {
final Intersect intersect = (Intersect) rel;
final List<MutableRel> inputs = toMutables(intersect.getInputs());
return MutableIntersect.of(intersect.getRowType(), inputs, intersect.all);
}
throw new RuntimeException("cannot translate " + rel + " to MutableRel");
}
use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.Aggregate in project calcite by apache.
the class RelMetadataTest method testPullUpPredicatesFromAggregation.
/**
* Unit test for
* {@link org.apache.calcite.rel.metadata.RelMdPredicates#getPredicates(Aggregate, RelMetadataQuery)}.
*/
@Test
public void testPullUpPredicatesFromAggregation() {
final String sql = "select a, max(b) from (\n" + " select 1 as a, 2 as b from emp)subq\n" + "group by a";
final Aggregate rel = (Aggregate) convertSql(sql);
final RelMetadataQuery mq = RelMetadataQuery.instance();
RelOptPredicateList inputSet = mq.getPulledUpPredicates(rel);
ImmutableList<RexNode> pulledUpPredicates = inputSet.pulledUpPredicates;
assertThat(pulledUpPredicates, sortsAs("[=($0, 1)]"));
}
use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.Aggregate in project calcite by apache.
the class AggregateExtractProjectRule method onMatch.
public void onMatch(RelOptRuleCall call) {
final Aggregate aggregate = call.rel(0);
final RelNode input = call.rel(1);
// Compute which input fields are used.
// 1. group fields are always used
final ImmutableBitSet.Builder inputFieldsUsed = aggregate.getGroupSet().rebuild();
// 2. agg functions
for (AggregateCall aggCall : aggregate.getAggCallList()) {
for (int i : aggCall.getArgList()) {
inputFieldsUsed.set(i);
}
if (aggCall.filterArg >= 0) {
inputFieldsUsed.set(aggCall.filterArg);
}
}
final RelBuilder relBuilder = call.builder().push(input);
final List<RexNode> projects = new ArrayList<>();
final Mapping mapping = Mappings.create(MappingType.INVERSE_SURJECTION, aggregate.getInput().getRowType().getFieldCount(), inputFieldsUsed.cardinality());
int j = 0;
for (int i : inputFieldsUsed.build()) {
projects.add(relBuilder.field(i));
mapping.set(i, j++);
}
relBuilder.project(projects);
final ImmutableBitSet newGroupSet = Mappings.apply(mapping, aggregate.getGroupSet());
final ImmutableList<ImmutableBitSet> newGroupSets = ImmutableList.copyOf(Iterables.transform(aggregate.getGroupSets(), new Function<ImmutableBitSet, ImmutableBitSet>() {
public ImmutableBitSet apply(ImmutableBitSet input) {
return Mappings.apply(mapping, input);
}
}));
final List<RelBuilder.AggCall> newAggCallList = new ArrayList<>();
for (AggregateCall aggCall : aggregate.getAggCallList()) {
final ImmutableList<RexNode> args = relBuilder.fields(Mappings.apply2(mapping, aggCall.getArgList()));
final RexNode filterArg = aggCall.filterArg < 0 ? null : relBuilder.field(Mappings.apply(mapping, aggCall.filterArg));
newAggCallList.add(relBuilder.aggregateCall(aggCall.getAggregation(), aggCall.isDistinct(), aggCall.isApproximate(), filterArg, aggCall.name, args));
}
final RelBuilder.GroupKey groupKey = relBuilder.groupKey(newGroupSet, newGroupSets);
relBuilder.aggregate(groupKey, newAggCallList);
call.transformTo(relBuilder.build());
}
use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.Aggregate in project calcite by apache.
the class AggregateFilterTransposeRule method onMatch.
public void onMatch(RelOptRuleCall call) {
final Aggregate aggregate = call.rel(0);
final Filter filter = call.rel(1);
// Do the columns used by the filter appear in the output of the aggregate?
final ImmutableBitSet filterColumns = RelOptUtil.InputFinder.bits(filter.getCondition());
final ImmutableBitSet newGroupSet = aggregate.getGroupSet().union(filterColumns);
final RelNode input = filter.getInput();
final RelMetadataQuery mq = call.getMetadataQuery();
final Boolean unique = mq.areColumnsUnique(input, newGroupSet);
if (unique != null && unique) {
// the rule fires forever: A-F => A-F-A => A-A-F-A => A-A-A-F-A => ...
return;
}
boolean allColumnsInAggregate = aggregate.getGroupSet().contains(filterColumns);
final Aggregate newAggregate = aggregate.copy(aggregate.getTraitSet(), input, false, newGroupSet, null, aggregate.getAggCallList());
final Mappings.TargetMapping mapping = Mappings.target(new Function<Integer, Integer>() {
public Integer apply(Integer a0) {
return newGroupSet.indexOf(a0);
}
}, input.getRowType().getFieldCount(), newGroupSet.cardinality());
final RexNode newCondition = RexUtil.apply(mapping, filter.getCondition());
final Filter newFilter = filter.copy(filter.getTraitSet(), newAggregate, newCondition);
if (allColumnsInAggregate && aggregate.getGroupType() == Group.SIMPLE) {
// Everything needed by the filter is returned by the aggregate.
assert newGroupSet.equals(aggregate.getGroupSet());
call.transformTo(newFilter);
} else {
// If aggregate uses grouping sets, we always need to split it.
// Otherwise, it means that grouping sets are not used, but the
// filter needs at least one extra column, and now aggregate it away.
final ImmutableBitSet.Builder topGroupSet = ImmutableBitSet.builder();
for (int c : aggregate.getGroupSet()) {
topGroupSet.set(newGroupSet.indexOf(c));
}
ImmutableList<ImmutableBitSet> newGroupingSets = null;
if (aggregate.getGroupType() != Group.SIMPLE) {
ImmutableList.Builder<ImmutableBitSet> newGroupingSetsBuilder = ImmutableList.builder();
for (ImmutableBitSet groupingSet : aggregate.getGroupSets()) {
final ImmutableBitSet.Builder newGroupingSet = ImmutableBitSet.builder();
for (int c : groupingSet) {
newGroupingSet.set(newGroupSet.indexOf(c));
}
newGroupingSetsBuilder.add(newGroupingSet.build());
}
newGroupingSets = newGroupingSetsBuilder.build();
}
final List<AggregateCall> topAggCallList = Lists.newArrayList();
int i = newGroupSet.cardinality();
for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
final SqlAggFunction rollup = SubstitutionVisitor.getRollup(aggregateCall.getAggregation());
if (rollup == null) {
// This aggregate cannot be rolled up.
return;
}
if (aggregateCall.isDistinct()) {
// Cannot roll up distinct.
return;
}
topAggCallList.add(AggregateCall.create(rollup, aggregateCall.isDistinct(), aggregateCall.isApproximate(), ImmutableList.of(i++), -1, aggregateCall.type, aggregateCall.name));
}
final Aggregate topAggregate = aggregate.copy(aggregate.getTraitSet(), newFilter, aggregate.indicator, topGroupSet.build(), newGroupingSets, topAggCallList);
call.transformTo(topAggregate);
}
}
use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.Aggregate in project calcite by apache.
the class AggregateProjectPullUpConstantsRule method onMatch.
// ~ Methods ----------------------------------------------------------------
public void onMatch(RelOptRuleCall call) {
final Aggregate aggregate = call.rel(0);
final RelNode input = call.rel(1);
assert !aggregate.indicator : "predicate ensured no grouping sets";
final int groupCount = aggregate.getGroupCount();
if (groupCount == 1) {
// GROUP BY list to the empty one.
return;
}
final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
final RelMetadataQuery mq = call.getMetadataQuery();
final RelOptPredicateList predicates = mq.getPulledUpPredicates(aggregate.getInput());
if (predicates == null) {
return;
}
final NavigableMap<Integer, RexNode> map = new TreeMap<>();
for (int key : aggregate.getGroupSet()) {
final RexInputRef ref = rexBuilder.makeInputRef(aggregate.getInput(), key);
if (predicates.constantMap.containsKey(ref)) {
map.put(key, predicates.constantMap.get(ref));
}
}
// None of the group expressions are constant. Nothing to do.
if (map.isEmpty()) {
return;
}
if (groupCount == map.size()) {
// At least a single item in group by is required.
// Otherwise "GROUP BY 1, 2" might be altered to "GROUP BY ()".
// Removing of the first element is not optimal here,
// however it will allow us to use fast path below (just trim
// groupCount).
map.remove(map.navigableKeySet().first());
}
ImmutableBitSet newGroupSet = aggregate.getGroupSet();
for (int key : map.keySet()) {
newGroupSet = newGroupSet.clear(key);
}
final int newGroupCount = newGroupSet.cardinality();
// If the constants are on the trailing edge of the group list, we just
// reduce the group count.
final RelBuilder relBuilder = call.builder();
relBuilder.push(input);
// Clone aggregate calls.
final List<AggregateCall> newAggCalls = new ArrayList<>();
for (AggregateCall aggCall : aggregate.getAggCallList()) {
newAggCalls.add(aggCall.adaptTo(input, aggCall.getArgList(), aggCall.filterArg, groupCount, newGroupCount));
}
relBuilder.aggregate(relBuilder.groupKey(newGroupSet, null), newAggCalls);
// Create a projection back again.
List<Pair<RexNode, String>> projects = new ArrayList<>();
int source = 0;
for (RelDataTypeField field : aggregate.getRowType().getFieldList()) {
RexNode expr;
final int i = field.getIndex();
if (i >= groupCount) {
// Aggregate expressions' names and positions are unchanged.
expr = relBuilder.field(i - map.size());
} else {
int pos = aggregate.getGroupSet().nth(i);
if (map.containsKey(pos)) {
// Re-generate the constant expression in the project.
RelDataType originalType = aggregate.getRowType().getFieldList().get(projects.size()).getType();
if (!originalType.equals(map.get(pos).getType())) {
expr = rexBuilder.makeCast(originalType, map.get(pos), true);
} else {
expr = map.get(pos);
}
} else {
// Project the aggregation expression, in its original
// position.
expr = relBuilder.field(source);
++source;
}
}
projects.add(Pair.of(expr, field.getName()));
}
// inverse
relBuilder.project(Pair.left(projects), Pair.right(projects));
call.transformTo(relBuilder.build());
}
Aggregations