use of org.apache.calcite.rel.core.Aggregate in project flink by apache.
the class FlinkAggregateJoinTransposeRule method onMatch.
public void onMatch(RelOptRuleCall call) {
final Aggregate origAgg = call.rel(0);
final Join join = call.rel(1);
final RexBuilder rexBuilder = origAgg.getCluster().getRexBuilder();
final RelBuilder relBuilder = call.builder();
// converts an aggregate with AUXILIARY_GROUP to a regular aggregate.
// if the converted aggregate can be push down,
// AggregateReduceGroupingRule will try reduce grouping of new aggregates created by this
// rule
final Pair<Aggregate, List<RexNode>> newAggAndProject = toRegularAggregate(origAgg);
final Aggregate aggregate = newAggAndProject.left;
final List<RexNode> projectAfterAgg = newAggAndProject.right;
// If any aggregate call has a filter or distinct, bail out
for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
if (aggregateCall.getAggregation().unwrap(SqlSplittableAggFunction.class) == null) {
return;
}
if (aggregateCall.filterArg >= 0 || aggregateCall.isDistinct()) {
return;
}
}
if (join.getJoinType() != JoinRelType.INNER) {
return;
}
if (!allowFunctions && !aggregate.getAggCallList().isEmpty()) {
return;
}
// Do the columns used by the join appear in the output of the aggregate?
final ImmutableBitSet aggregateColumns = aggregate.getGroupSet();
final RelMetadataQuery mq = call.getMetadataQuery();
final ImmutableBitSet keyColumns = keyColumns(aggregateColumns, mq.getPulledUpPredicates(join).pulledUpPredicates);
final ImmutableBitSet joinColumns = RelOptUtil.InputFinder.bits(join.getCondition());
final boolean allColumnsInAggregate = keyColumns.contains(joinColumns);
final ImmutableBitSet belowAggregateColumns = aggregateColumns.union(joinColumns);
// Split join condition
final List<Integer> leftKeys = com.google.common.collect.Lists.newArrayList();
final List<Integer> rightKeys = com.google.common.collect.Lists.newArrayList();
final List<Boolean> filterNulls = com.google.common.collect.Lists.newArrayList();
RexNode nonEquiConj = RelOptUtil.splitJoinCondition(join.getLeft(), join.getRight(), join.getCondition(), leftKeys, rightKeys, filterNulls);
// If it contains non-equi join conditions, we bail out
if (!nonEquiConj.isAlwaysTrue()) {
return;
}
// Push each aggregate function down to each side that contains all of its
// arguments. Note that COUNT(*), because it has no arguments, can go to
// both sides.
final Map<Integer, Integer> map = new HashMap<>();
final List<Side> sides = new ArrayList<>();
int uniqueCount = 0;
int offset = 0;
int belowOffset = 0;
for (int s = 0; s < 2; s++) {
final Side side = new Side();
final RelNode joinInput = join.getInput(s);
int fieldCount = joinInput.getRowType().getFieldCount();
final ImmutableBitSet fieldSet = ImmutableBitSet.range(offset, offset + fieldCount);
final ImmutableBitSet belowAggregateKeyNotShifted = belowAggregateColumns.intersect(fieldSet);
for (Ord<Integer> c : Ord.zip(belowAggregateKeyNotShifted)) {
map.put(c.e, belowOffset + c.i);
}
final Mappings.TargetMapping mapping = s == 0 ? Mappings.createIdentity(fieldCount) : Mappings.createShiftMapping(fieldCount + offset, 0, offset, fieldCount);
final ImmutableBitSet belowAggregateKey = belowAggregateKeyNotShifted.shift(-offset);
final boolean unique;
if (!allowFunctions) {
assert aggregate.getAggCallList().isEmpty();
// If there are no functions, it doesn't matter as much whether we
// aggregate the inputs before the join, because there will not be
// any functions experiencing a cartesian product effect.
//
// But finding out whether the input is already unique requires a call
// to areColumnsUnique that currently (until [CALCITE-1048] "Make
// metadata more robust" is fixed) places a heavy load on
// the metadata system.
//
// So we choose to imagine the input is already unique, which is
// untrue but harmless.
//
Util.discard(Bug.CALCITE_1048_FIXED);
unique = true;
} else {
final Boolean unique0 = mq.areColumnsUnique(joinInput, belowAggregateKey);
unique = unique0 != null && unique0;
}
if (unique) {
++uniqueCount;
side.aggregate = false;
relBuilder.push(joinInput);
final Map<Integer, Integer> belowAggregateKeyToNewProjectMap = new HashMap<>();
final List<RexNode> projects = new ArrayList<>();
for (Integer i : belowAggregateKey) {
belowAggregateKeyToNewProjectMap.put(i, projects.size());
projects.add(relBuilder.field(i));
}
for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) {
final SqlAggFunction aggregation = aggCall.e.getAggregation();
final SqlSplittableAggFunction splitter = Preconditions.checkNotNull(aggregation.unwrap(SqlSplittableAggFunction.class));
if (!aggCall.e.getArgList().isEmpty() && fieldSet.contains(ImmutableBitSet.of(aggCall.e.getArgList()))) {
final RexNode singleton = splitter.singleton(rexBuilder, joinInput.getRowType(), aggCall.e.transform(mapping));
final RexNode targetSingleton = rexBuilder.ensureType(aggCall.e.type, singleton, false);
if (targetSingleton instanceof RexInputRef) {
final int index = ((RexInputRef) targetSingleton).getIndex();
if (!belowAggregateKey.get(index)) {
projects.add(targetSingleton);
side.split.put(aggCall.i, projects.size() - 1);
} else {
side.split.put(aggCall.i, belowAggregateKeyToNewProjectMap.get(index));
}
} else {
projects.add(targetSingleton);
side.split.put(aggCall.i, projects.size() - 1);
}
}
}
relBuilder.project(projects);
side.newInput = relBuilder.build();
} else {
side.aggregate = true;
List<AggregateCall> belowAggCalls = new ArrayList<>();
final SqlSplittableAggFunction.Registry<AggregateCall> belowAggCallRegistry = registry(belowAggCalls);
final int oldGroupKeyCount = aggregate.getGroupCount();
final int newGroupKeyCount = belowAggregateKey.cardinality();
for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) {
final SqlAggFunction aggregation = aggCall.e.getAggregation();
final SqlSplittableAggFunction splitter = Preconditions.checkNotNull(aggregation.unwrap(SqlSplittableAggFunction.class));
final AggregateCall call1;
if (fieldSet.contains(ImmutableBitSet.of(aggCall.e.getArgList()))) {
final AggregateCall splitCall = splitter.split(aggCall.e, mapping);
call1 = splitCall.adaptTo(joinInput, splitCall.getArgList(), splitCall.filterArg, oldGroupKeyCount, newGroupKeyCount);
} else {
call1 = splitter.other(rexBuilder.getTypeFactory(), aggCall.e);
}
if (call1 != null) {
side.split.put(aggCall.i, belowAggregateKey.cardinality() + belowAggCallRegistry.register(call1));
}
}
side.newInput = relBuilder.push(joinInput).aggregate(relBuilder.groupKey(belowAggregateKey, null), belowAggCalls).build();
}
offset += fieldCount;
belowOffset += side.newInput.getRowType().getFieldCount();
sides.add(side);
}
if (uniqueCount == 2) {
// invocation of this rule; if we continue we might loop forever.
return;
}
// Update condition
final Mapping mapping = (Mapping) Mappings.target(map::get, join.getRowType().getFieldCount(), belowOffset);
final RexNode newCondition = RexUtil.apply(mapping, join.getCondition());
// Create new join
relBuilder.push(sides.get(0).newInput).push(sides.get(1).newInput).join(join.getJoinType(), newCondition);
// Aggregate above to sum up the sub-totals
final List<AggregateCall> newAggCalls = new ArrayList<>();
final int groupIndicatorCount = aggregate.getGroupCount() + aggregate.getIndicatorCount();
final int newLeftWidth = sides.get(0).newInput.getRowType().getFieldCount();
final List<RexNode> projects = new ArrayList<>(rexBuilder.identityProjects(relBuilder.peek().getRowType()));
for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) {
final SqlAggFunction aggregation = aggCall.e.getAggregation();
final SqlSplittableAggFunction splitter = Preconditions.checkNotNull(aggregation.unwrap(SqlSplittableAggFunction.class));
final Integer leftSubTotal = sides.get(0).split.get(aggCall.i);
final Integer rightSubTotal = sides.get(1).split.get(aggCall.i);
newAggCalls.add(splitter.topSplit(rexBuilder, registry(projects), groupIndicatorCount, relBuilder.peek().getRowType(), aggCall.e, leftSubTotal == null ? -1 : leftSubTotal, rightSubTotal == null ? -1 : rightSubTotal + newLeftWidth));
}
relBuilder.project(projects);
boolean aggConvertedToProjects = false;
if (allColumnsInAggregate) {
// let's see if we can convert aggregate into projects
List<RexNode> projects2 = new ArrayList<>();
for (int key : Mappings.apply(mapping, aggregate.getGroupSet())) {
projects2.add(relBuilder.field(key));
}
int aggCallIdx = projects2.size();
for (AggregateCall newAggCall : newAggCalls) {
final SqlSplittableAggFunction splitter = newAggCall.getAggregation().unwrap(SqlSplittableAggFunction.class);
if (splitter != null) {
final RelDataType rowType = relBuilder.peek().getRowType();
final RexNode singleton = splitter.singleton(rexBuilder, rowType, newAggCall);
final RelDataType originalAggCallType = aggregate.getRowType().getFieldList().get(aggCallIdx).getType();
final RexNode targetSingleton = rexBuilder.ensureType(originalAggCallType, singleton, false);
projects2.add(targetSingleton);
}
aggCallIdx += 1;
}
if (projects2.size() == aggregate.getGroupSet().cardinality() + newAggCalls.size()) {
// We successfully converted agg calls into projects.
relBuilder.project(projects2);
aggConvertedToProjects = true;
}
}
if (!aggConvertedToProjects) {
relBuilder.aggregate(relBuilder.groupKey(Mappings.apply(mapping, aggregate.getGroupSet()), Mappings.apply2(mapping, aggregate.getGroupSets())), newAggCalls);
}
if (projectAfterAgg != null) {
relBuilder.project(projectAfterAgg, origAgg.getRowType().getFieldNames());
}
call.transformTo(relBuilder.build());
}
use of org.apache.calcite.rel.core.Aggregate in project flink by apache.
the class FlinkAggregateExpandDistinctAggregatesRule method rewriteUsingGroupingSets.
private void rewriteUsingGroupingSets(RelOptRuleCall call, Aggregate aggregate) {
final Set<ImmutableBitSet> groupSetTreeSet = new TreeSet<>(ImmutableBitSet.ORDERING);
final Map<ImmutableBitSet, Integer> groupSetToDistinctAggCallFilterArg = new HashMap<>();
for (AggregateCall aggCall : aggregate.getAggCallList()) {
if (!aggCall.isDistinct()) {
groupSetTreeSet.add(aggregate.getGroupSet());
} else {
ImmutableBitSet groupSet = ImmutableBitSet.of(aggCall.getArgList()).setIf(aggCall.filterArg, aggCall.filterArg >= 0).union(aggregate.getGroupSet());
groupSetToDistinctAggCallFilterArg.put(groupSet, aggCall.filterArg);
groupSetTreeSet.add(groupSet);
}
}
final com.google.common.collect.ImmutableList<ImmutableBitSet> groupSets = com.google.common.collect.ImmutableList.copyOf(groupSetTreeSet);
final ImmutableBitSet fullGroupSet = ImmutableBitSet.union(groupSets);
final List<AggregateCall> distinctAggCalls = new ArrayList<>();
for (Pair<AggregateCall, String> aggCall : aggregate.getNamedAggCalls()) {
if (!aggCall.left.isDistinct()) {
AggregateCall newAggCall = aggCall.left.adaptTo(aggregate.getInput(), aggCall.left.getArgList(), aggCall.left.filterArg, aggregate.getGroupCount(), fullGroupSet.cardinality());
distinctAggCalls.add(newAggCall.rename(aggCall.right));
}
}
final RelBuilder relBuilder = call.builder();
relBuilder.push(aggregate.getInput());
final int groupCount = fullGroupSet.cardinality();
final Map<ImmutableBitSet, Integer> filters = new LinkedHashMap<>();
final int z = groupCount + distinctAggCalls.size();
distinctAggCalls.add(AggregateCall.create(SqlStdOperatorTable.GROUPING, false, false, false, ImmutableIntList.copyOf(fullGroupSet), -1, RelCollations.EMPTY, groupSets.size(), relBuilder.peek(), null, "$g"));
for (Ord<ImmutableBitSet> groupSet : Ord.zip(groupSets)) {
filters.put(groupSet.e, z + groupSet.i);
}
relBuilder.aggregate(relBuilder.groupKey(fullGroupSet, groupSets), distinctAggCalls);
final RelNode distinct = relBuilder.peek();
// values to BOOLEAN.
if (!filters.isEmpty()) {
final List<RexNode> nodes = new ArrayList<>(relBuilder.fields());
final RexNode nodeZ = nodes.remove(nodes.size() - 1);
for (Map.Entry<ImmutableBitSet, Integer> entry : filters.entrySet()) {
final long v = groupValue(fullGroupSet, entry.getKey());
// Get and remap the filterArg of the distinct aggregate call.
int distinctAggCallFilterArg = remap(fullGroupSet, groupSetToDistinctAggCallFilterArg.getOrDefault(entry.getKey(), -1));
RexNode expr;
if (distinctAggCallFilterArg < 0) {
expr = relBuilder.equals(nodeZ, relBuilder.literal(v));
} else {
RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
// merge the filter of the distinct aggregate call itself.
expr = relBuilder.and(relBuilder.equals(nodeZ, relBuilder.literal(v)), rexBuilder.makeCall(SqlStdOperatorTable.IS_TRUE, relBuilder.field(distinctAggCallFilterArg)));
}
nodes.add(relBuilder.alias(expr, "$g_" + v));
}
relBuilder.project(nodes);
}
int aggCallIdx = 0;
int x = groupCount;
final List<AggregateCall> newCalls = new ArrayList<>();
// TODO supports more aggCalls (currently only supports COUNT)
// Some aggregate functions (e.g. COUNT) have the special property that they can return a
// non-null result without any input. We need to make sure we return a result in this case.
final List<Integer> needDefaultValueAggCalls = new ArrayList<>();
for (AggregateCall aggCall : aggregate.getAggCallList()) {
final int newFilterArg;
final List<Integer> newArgList;
final SqlAggFunction aggregation;
if (!aggCall.isDistinct()) {
aggregation = SqlStdOperatorTable.MIN;
newArgList = ImmutableIntList.of(x++);
newFilterArg = filters.get(aggregate.getGroupSet());
switch(aggCall.getAggregation().getKind()) {
case COUNT:
needDefaultValueAggCalls.add(aggCallIdx);
break;
default:
}
} else {
aggregation = aggCall.getAggregation();
newArgList = remap(fullGroupSet, aggCall.getArgList());
newFilterArg = filters.get(ImmutableBitSet.of(aggCall.getArgList()).setIf(aggCall.filterArg, aggCall.filterArg >= 0).union(aggregate.getGroupSet()));
}
final AggregateCall newCall = AggregateCall.create(aggregation, false, aggCall.isApproximate(), false, newArgList, newFilterArg, RelCollations.EMPTY, aggregate.getGroupCount(), distinct, null, aggCall.name);
newCalls.add(newCall);
aggCallIdx++;
}
relBuilder.aggregate(relBuilder.groupKey(remap(fullGroupSet, aggregate.getGroupSet()), remap(fullGroupSet, aggregate.getGroupSets())), newCalls);
if (!needDefaultValueAggCalls.isEmpty() && aggregate.getGroupCount() == 0) {
final Aggregate newAgg = (Aggregate) relBuilder.peek();
final List<RexNode> nodes = new ArrayList<>();
for (int i = 0; i < newAgg.getGroupCount(); ++i) {
nodes.add(RexInputRef.of(i, newAgg.getRowType()));
}
for (int i = 0; i < newAgg.getAggCallList().size(); ++i) {
final RexNode inputRef = RexInputRef.of(newAgg.getGroupCount() + i, newAgg.getRowType());
RexNode newNode = inputRef;
if (needDefaultValueAggCalls.contains(i)) {
SqlKind originalFunKind = aggregate.getAggCallList().get(i).getAggregation().getKind();
switch(originalFunKind) {
case COUNT:
newNode = relBuilder.call(SqlStdOperatorTable.CASE, relBuilder.isNotNull(inputRef), inputRef, relBuilder.literal(BigDecimal.ZERO));
break;
default:
}
}
nodes.add(newNode);
}
relBuilder.project(nodes);
}
relBuilder.convert(aggregate.getRowType(), true);
call.transformTo(relBuilder.build());
}
use of org.apache.calcite.rel.core.Aggregate in project flink by apache.
the class RelTimeIndicatorConverter method gatherIndicesToMaterialize.
private Set<Integer> gatherIndicesToMaterialize(Aggregate agg, RelNode newInput) {
List<RelDataType> inputFieldTypes = RelOptUtil.getFieldTypeList(newInput.getRowType());
Predicate<Integer> isTimeIndicator = idx -> isTimeIndicatorType(inputFieldTypes.get(idx));
// add arguments of agg calls
Set<Integer> aggCallArgs = agg.getAggCallList().stream().map(AggregateCall::getArgList).flatMap(List::stream).filter(isTimeIndicator).collect(Collectors.toSet());
FlinkRelMetadataQuery fmq = FlinkRelMetadataQuery.reuseOrCreate(agg.getCluster().getMetadataQuery());
RelWindowProperties windowProps = fmq.getRelWindowProperties(newInput);
// add grouping sets
Set<Integer> groupSets = agg.getGroupSets().stream().map(grouping -> {
if (windowProps != null && groupingContainsWindowStartEnd(grouping, windowProps)) {
// of window_time column
return grouping.except(windowProps.getWindowTimeColumns());
} else {
return grouping;
}
}).flatMap(set -> set.asList().stream()).filter(isTimeIndicator).collect(Collectors.toSet());
Set<Integer> timeIndicatorIndices = new HashSet<>(aggCallArgs);
timeIndicatorIndices.addAll(groupSets);
return timeIndicatorIndices;
}
use of org.apache.calcite.rel.core.Aggregate in project drill by apache.
the class PluginConverterRule method matches.
@Override
public boolean matches(RelOptRuleCall call) {
RelNode rel = call.rel(0);
boolean canImplement = false;
// cannot use visitor pattern here, since RelShuttle supports only logical rel implementations
if (rel instanceof Aggregate) {
canImplement = pluginImplementor.canImplement(((Aggregate) rel));
} else if (rel instanceof Filter) {
canImplement = pluginImplementor.canImplement(((Filter) rel));
} else if (rel instanceof DrillLimitRelBase) {
canImplement = pluginImplementor.canImplement(((DrillLimitRelBase) rel));
} else if (rel instanceof Project) {
canImplement = pluginImplementor.canImplement(((Project) rel));
} else if (rel instanceof Sort) {
canImplement = pluginImplementor.canImplement(((Sort) rel));
} else if (rel instanceof Union) {
canImplement = pluginImplementor.canImplement(((Union) rel));
} else if (rel instanceof Join) {
canImplement = pluginImplementor.canImplement(((Join) rel));
}
return canImplement && super.matches(call);
}
use of org.apache.calcite.rel.core.Aggregate in project drill by apache.
the class PhoenixImplementor method visit.
@Override
public Result visit(Filter e) {
final RelNode input = e.getInput();
Result x = visitChild(0, input);
parseCorrelTable(e, x);
if (input instanceof Aggregate) {
return super.visit(e);
} else {
final Builder builder = x.builder(e, Clause.WHERE);
builder.setWhere(builder.context.toSql(null, e.getCondition()));
final List<SqlNode> selectList = new ArrayList<>();
e.getRowType().getFieldNames().forEach(fieldName -> {
/*
* phoenix does not support the wildcard in the subqueries,
* expand the wildcard at this time.
*/
addSelect(selectList, new SqlIdentifier(fieldName, POS), e.getRowType());
});
builder.setSelect(new SqlNodeList(selectList, POS));
return builder.result();
}
}
Aggregations