use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.Join in project flink by apache.
the class RelDecorrelator method decorrelate.
protected RelNode decorrelate(RelNode root) {
// first adjust count() expression if any
final RelBuilderFactory f = relBuilderFactory();
HepProgram program = HepProgram.builder().addRuleInstance(AdjustProjectForCountAggregateRule.config(false, this, f).toRule()).addRuleInstance(AdjustProjectForCountAggregateRule.config(true, this, f).toRule()).addRuleInstance(FilterJoinRule.FilterIntoJoinRule.Config.DEFAULT.withRelBuilderFactory(f).withOperandSupplier(b0 -> b0.operand(Filter.class).oneInput(b1 -> b1.operand(Join.class).anyInputs())).withDescription("FilterJoinRule:filter").as(FilterJoinRule.FilterIntoJoinRule.Config.class).withSmart(true).withPredicate((join, joinType, exp) -> true).as(FilterJoinRule.FilterIntoJoinRule.Config.class).toRule()).addRuleInstance(CoreRules.FILTER_PROJECT_TRANSPOSE.config.withRelBuilderFactory(f).as(FilterProjectTransposeRule.Config.class).withOperandFor(Filter.class, filter -> !RexUtil.containsCorrelation(filter.getCondition()), Project.class, project -> true).withCopyFilter(true).withCopyProject(true).toRule()).addRuleInstance(FilterCorrelateRule.Config.DEFAULT.withRelBuilderFactory(f).toRule()).build();
HepPlanner planner = createPlanner(program);
planner.setRoot(root);
root = planner.findBestExp();
// Perform decorrelation.
map.clear();
final Frame frame = getInvoke(root, null);
if (frame != null) {
// has been rewritten; apply rules post-decorrelation
final HepProgram program2 = HepProgram.builder().addRuleInstance(CoreRules.FILTER_INTO_JOIN.config.withRelBuilderFactory(f).toRule()).addRuleInstance(CoreRules.JOIN_CONDITION_PUSH.config.withRelBuilderFactory(f).toRule()).build();
final HepPlanner planner2 = createPlanner(program2);
final RelNode newRoot = frame.r;
planner2.setRoot(newRoot);
return planner2.findBestExp();
}
return root;
}
use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.Join in project flink by apache.
the class FlinkAggregateJoinTransposeRule method onMatch.
public void onMatch(RelOptRuleCall call) {
final Aggregate origAgg = call.rel(0);
final Join join = call.rel(1);
final RexBuilder rexBuilder = origAgg.getCluster().getRexBuilder();
final RelBuilder relBuilder = call.builder();
// converts an aggregate with AUXILIARY_GROUP to a regular aggregate.
// if the converted aggregate can be push down,
// AggregateReduceGroupingRule will try reduce grouping of new aggregates created by this
// rule
final Pair<Aggregate, List<RexNode>> newAggAndProject = toRegularAggregate(origAgg);
final Aggregate aggregate = newAggAndProject.left;
final List<RexNode> projectAfterAgg = newAggAndProject.right;
// If any aggregate call has a filter or distinct, bail out
for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
if (aggregateCall.getAggregation().unwrap(SqlSplittableAggFunction.class) == null) {
return;
}
if (aggregateCall.filterArg >= 0 || aggregateCall.isDistinct()) {
return;
}
}
if (join.getJoinType() != JoinRelType.INNER) {
return;
}
if (!allowFunctions && !aggregate.getAggCallList().isEmpty()) {
return;
}
// Do the columns used by the join appear in the output of the aggregate?
final ImmutableBitSet aggregateColumns = aggregate.getGroupSet();
final RelMetadataQuery mq = call.getMetadataQuery();
final ImmutableBitSet keyColumns = keyColumns(aggregateColumns, mq.getPulledUpPredicates(join).pulledUpPredicates);
final ImmutableBitSet joinColumns = RelOptUtil.InputFinder.bits(join.getCondition());
final boolean allColumnsInAggregate = keyColumns.contains(joinColumns);
final ImmutableBitSet belowAggregateColumns = aggregateColumns.union(joinColumns);
// Split join condition
final List<Integer> leftKeys = com.google.common.collect.Lists.newArrayList();
final List<Integer> rightKeys = com.google.common.collect.Lists.newArrayList();
final List<Boolean> filterNulls = com.google.common.collect.Lists.newArrayList();
RexNode nonEquiConj = RelOptUtil.splitJoinCondition(join.getLeft(), join.getRight(), join.getCondition(), leftKeys, rightKeys, filterNulls);
// If it contains non-equi join conditions, we bail out
if (!nonEquiConj.isAlwaysTrue()) {
return;
}
// Push each aggregate function down to each side that contains all of its
// arguments. Note that COUNT(*), because it has no arguments, can go to
// both sides.
final Map<Integer, Integer> map = new HashMap<>();
final List<Side> sides = new ArrayList<>();
int uniqueCount = 0;
int offset = 0;
int belowOffset = 0;
for (int s = 0; s < 2; s++) {
final Side side = new Side();
final RelNode joinInput = join.getInput(s);
int fieldCount = joinInput.getRowType().getFieldCount();
final ImmutableBitSet fieldSet = ImmutableBitSet.range(offset, offset + fieldCount);
final ImmutableBitSet belowAggregateKeyNotShifted = belowAggregateColumns.intersect(fieldSet);
for (Ord<Integer> c : Ord.zip(belowAggregateKeyNotShifted)) {
map.put(c.e, belowOffset + c.i);
}
final Mappings.TargetMapping mapping = s == 0 ? Mappings.createIdentity(fieldCount) : Mappings.createShiftMapping(fieldCount + offset, 0, offset, fieldCount);
final ImmutableBitSet belowAggregateKey = belowAggregateKeyNotShifted.shift(-offset);
final boolean unique;
if (!allowFunctions) {
assert aggregate.getAggCallList().isEmpty();
// If there are no functions, it doesn't matter as much whether we
// aggregate the inputs before the join, because there will not be
// any functions experiencing a cartesian product effect.
//
// But finding out whether the input is already unique requires a call
// to areColumnsUnique that currently (until [CALCITE-1048] "Make
// metadata more robust" is fixed) places a heavy load on
// the metadata system.
//
// So we choose to imagine the input is already unique, which is
// untrue but harmless.
//
Util.discard(Bug.CALCITE_1048_FIXED);
unique = true;
} else {
final Boolean unique0 = mq.areColumnsUnique(joinInput, belowAggregateKey);
unique = unique0 != null && unique0;
}
if (unique) {
++uniqueCount;
side.aggregate = false;
relBuilder.push(joinInput);
final Map<Integer, Integer> belowAggregateKeyToNewProjectMap = new HashMap<>();
final List<RexNode> projects = new ArrayList<>();
for (Integer i : belowAggregateKey) {
belowAggregateKeyToNewProjectMap.put(i, projects.size());
projects.add(relBuilder.field(i));
}
for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) {
final SqlAggFunction aggregation = aggCall.e.getAggregation();
final SqlSplittableAggFunction splitter = Preconditions.checkNotNull(aggregation.unwrap(SqlSplittableAggFunction.class));
if (!aggCall.e.getArgList().isEmpty() && fieldSet.contains(ImmutableBitSet.of(aggCall.e.getArgList()))) {
final RexNode singleton = splitter.singleton(rexBuilder, joinInput.getRowType(), aggCall.e.transform(mapping));
final RexNode targetSingleton = rexBuilder.ensureType(aggCall.e.type, singleton, false);
if (targetSingleton instanceof RexInputRef) {
final int index = ((RexInputRef) targetSingleton).getIndex();
if (!belowAggregateKey.get(index)) {
projects.add(targetSingleton);
side.split.put(aggCall.i, projects.size() - 1);
} else {
side.split.put(aggCall.i, belowAggregateKeyToNewProjectMap.get(index));
}
} else {
projects.add(targetSingleton);
side.split.put(aggCall.i, projects.size() - 1);
}
}
}
relBuilder.project(projects);
side.newInput = relBuilder.build();
} else {
side.aggregate = true;
List<AggregateCall> belowAggCalls = new ArrayList<>();
final SqlSplittableAggFunction.Registry<AggregateCall> belowAggCallRegistry = registry(belowAggCalls);
final int oldGroupKeyCount = aggregate.getGroupCount();
final int newGroupKeyCount = belowAggregateKey.cardinality();
for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) {
final SqlAggFunction aggregation = aggCall.e.getAggregation();
final SqlSplittableAggFunction splitter = Preconditions.checkNotNull(aggregation.unwrap(SqlSplittableAggFunction.class));
final AggregateCall call1;
if (fieldSet.contains(ImmutableBitSet.of(aggCall.e.getArgList()))) {
final AggregateCall splitCall = splitter.split(aggCall.e, mapping);
call1 = splitCall.adaptTo(joinInput, splitCall.getArgList(), splitCall.filterArg, oldGroupKeyCount, newGroupKeyCount);
} else {
call1 = splitter.other(rexBuilder.getTypeFactory(), aggCall.e);
}
if (call1 != null) {
side.split.put(aggCall.i, belowAggregateKey.cardinality() + belowAggCallRegistry.register(call1));
}
}
side.newInput = relBuilder.push(joinInput).aggregate(relBuilder.groupKey(belowAggregateKey, null), belowAggCalls).build();
}
offset += fieldCount;
belowOffset += side.newInput.getRowType().getFieldCount();
sides.add(side);
}
if (uniqueCount == 2) {
// invocation of this rule; if we continue we might loop forever.
return;
}
// Update condition
final Mapping mapping = (Mapping) Mappings.target(map::get, join.getRowType().getFieldCount(), belowOffset);
final RexNode newCondition = RexUtil.apply(mapping, join.getCondition());
// Create new join
relBuilder.push(sides.get(0).newInput).push(sides.get(1).newInput).join(join.getJoinType(), newCondition);
// Aggregate above to sum up the sub-totals
final List<AggregateCall> newAggCalls = new ArrayList<>();
final int groupIndicatorCount = aggregate.getGroupCount() + aggregate.getIndicatorCount();
final int newLeftWidth = sides.get(0).newInput.getRowType().getFieldCount();
final List<RexNode> projects = new ArrayList<>(rexBuilder.identityProjects(relBuilder.peek().getRowType()));
for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) {
final SqlAggFunction aggregation = aggCall.e.getAggregation();
final SqlSplittableAggFunction splitter = Preconditions.checkNotNull(aggregation.unwrap(SqlSplittableAggFunction.class));
final Integer leftSubTotal = sides.get(0).split.get(aggCall.i);
final Integer rightSubTotal = sides.get(1).split.get(aggCall.i);
newAggCalls.add(splitter.topSplit(rexBuilder, registry(projects), groupIndicatorCount, relBuilder.peek().getRowType(), aggCall.e, leftSubTotal == null ? -1 : leftSubTotal, rightSubTotal == null ? -1 : rightSubTotal + newLeftWidth));
}
relBuilder.project(projects);
boolean aggConvertedToProjects = false;
if (allColumnsInAggregate) {
// let's see if we can convert aggregate into projects
List<RexNode> projects2 = new ArrayList<>();
for (int key : Mappings.apply(mapping, aggregate.getGroupSet())) {
projects2.add(relBuilder.field(key));
}
int aggCallIdx = projects2.size();
for (AggregateCall newAggCall : newAggCalls) {
final SqlSplittableAggFunction splitter = newAggCall.getAggregation().unwrap(SqlSplittableAggFunction.class);
if (splitter != null) {
final RelDataType rowType = relBuilder.peek().getRowType();
final RexNode singleton = splitter.singleton(rexBuilder, rowType, newAggCall);
final RelDataType originalAggCallType = aggregate.getRowType().getFieldList().get(aggCallIdx).getType();
final RexNode targetSingleton = rexBuilder.ensureType(originalAggCallType, singleton, false);
projects2.add(targetSingleton);
}
aggCallIdx += 1;
}
if (projects2.size() == aggregate.getGroupSet().cardinality() + newAggCalls.size()) {
// We successfully converted agg calls into projects.
relBuilder.project(projects2);
aggConvertedToProjects = true;
}
}
if (!aggConvertedToProjects) {
relBuilder.aggregate(relBuilder.groupKey(Mappings.apply(mapping, aggregate.getGroupSet()), Mappings.apply2(mapping, aggregate.getGroupSets())), newAggCalls);
}
if (projectAfterAgg != null) {
relBuilder.project(projectAfterAgg, origAgg.getRowType().getFieldNames());
}
call.transformTo(relBuilder.build());
}
use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.Join in project flink by apache.
the class FlinkJoinPushExpressionsRule method onMatch.
@Override
public void onMatch(RelOptRuleCall call) {
Join join = call.rel(0);
// Push expression in join condition into Project below Join.
RelNode newJoin = RelOptUtil.pushDownJoinConditions(join, call.builder());
// If the join is the same, we bail out
if (newJoin instanceof Join) {
final RexNode newCondition = ((Join) newJoin).getCondition();
if (join.getCondition().equals(newCondition)) {
return;
}
}
call.transformTo(newJoin);
}
use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.Join in project flink by apache.
the class FlinkProjectJoinTransposeRule method onMatch.
// ~ Methods ----------------------------------------------------------------
// implement RelOptRule
public void onMatch(RelOptRuleCall call) {
Project origProj = call.rel(0);
final Join join = call.rel(1);
if (!join.getJoinType().projectsRight()) {
// TODO: support SEMI/ANTI join later
return;
}
// locate all fields referenced in the projection and join condition;
// determine which inputs are referenced in the projection and
// join condition; if all fields are being referenced and there are no
// special expressions, no point in proceeding any further
PushProjector pushProject = new PushProjector(origProj, join.getCondition(), join, preserveExprCondition, call.builder());
if (pushProject.locateAllRefs()) {
return;
}
// create left and right projections, projecting only those
// fields referenced on each side
RelNode leftProjRel = pushProject.createProjectRefsAndExprs(join.getLeft(), true, false);
RelNode rightProjRel = pushProject.createProjectRefsAndExprs(join.getRight(), true, true);
// convert the join condition to reference the projected columns
RexNode newJoinFilter = null;
int[] adjustments = pushProject.getAdjustments();
if (join.getCondition() != null) {
List<RelDataTypeField> projJoinFieldList = new ArrayList<>();
projJoinFieldList.addAll(join.getSystemFieldList());
projJoinFieldList.addAll(leftProjRel.getRowType().getFieldList());
projJoinFieldList.addAll(rightProjRel.getRowType().getFieldList());
newJoinFilter = pushProject.convertRefsAndExprs(join.getCondition(), projJoinFieldList, adjustments);
}
// create a new join with the projected children
Join newJoinRel = join.copy(join.getTraitSet(), newJoinFilter, leftProjRel, rightProjRel, join.getJoinType(), join.isSemiJoinDone());
// put the original project on top of the join, converting it to
// reference the modified projection list
RelNode topProject = pushProject.createNewProject(newJoinRel, adjustments);
call.transformTo(topProject);
}
use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.Join in project flink by apache.
the class FlinkSemiAntiJoinJoinTransposeRule method onMatch.
// ~ Methods ----------------------------------------------------------------
// implement RelOptRule
public void onMatch(RelOptRuleCall call) {
LogicalJoin semiAntiJoin = call.rel(0);
if (semiAntiJoin.getJoinType() != JoinRelType.SEMI && semiAntiJoin.getJoinType() != JoinRelType.ANTI) {
return;
}
final Join join = call.rel(1);
if (join.getJoinType() == JoinRelType.SEMI || join.getJoinType() == JoinRelType.ANTI) {
return;
}
// TODO support other join type
if (join.getJoinType() != JoinRelType.INNER) {
return;
}
// unsupported cases:
// 1. (NOT) EXISTS with uncorrelation
// 2. keys in SemiJoin condition are from both Join's left and Join's right
Pair<ImmutableBitSet, ImmutableBitSet> inputRefs = getSemiAntiJoinConditionInputRefs(semiAntiJoin);
final ImmutableBitSet leftInputRefs = inputRefs.left;
final ImmutableBitSet rightInputRefs = inputRefs.right;
// TODO currently we does not handle this
if (leftInputRefs.isEmpty() || rightInputRefs.isEmpty()) {
return;
}
// X is the left child of the join below the semi-join
// Y is the right child of the join below the semi-join
// Z is the right child of the semi-join
int nFieldsX = join.getLeft().getRowType().getFieldList().size();
int nFieldsY = join.getRight().getRowType().getFieldList().size();
int nFieldsZ = semiAntiJoin.getRight().getRowType().getFieldList().size();
int nTotalFields = nFieldsX + nFieldsY + nFieldsZ;
List<RelDataTypeField> fields = new ArrayList<RelDataTypeField>();
// create a list of fields for the full join result; note that
// we can't simply use the fields from the semi-join because the
// row-type of a semi-join only includes the left hand side fields
List<RelDataTypeField> joinFields = semiAntiJoin.getRowType().getFieldList();
for (int i = 0; i < (nFieldsX + nFieldsY); i++) {
fields.add(joinFields.get(i));
}
joinFields = semiAntiJoin.getRight().getRowType().getFieldList();
for (int i = 0; i < nFieldsZ; i++) {
fields.add(joinFields.get(i));
}
// determine which operands below the semi-join are the actual
// Rels that participate in the semi-join
int nKeysFromX = 0;
int nKeysFromY = 0;
for (int leftKey : leftInputRefs) {
if (leftKey < nFieldsX) {
nKeysFromX++;
} else {
nKeysFromY++;
}
}
// e.g. SELECT * FROM x, y WHERE x.c = y.f AND x.a IN (SELECT z.i FROM z WHERE y.e = z.j)
if (nKeysFromX > 0 && nKeysFromY > 0) {
return;
}
// otherwise, a semi-join wouldn't have been created
assert (nKeysFromX == 0) || (nKeysFromX == leftInputRefs.cardinality());
assert (nKeysFromY == 0) || (nKeysFromY == leftInputRefs.cardinality());
// need to convert the semi/anti join condition and possibly the keys
RexNode newSemiAntiJoinFilter;
int[] adjustments = new int[nTotalFields];
if (nKeysFromX > 0) {
// (X, Y, Z) --> (X, Z, Y)
// semiJoin(X, Z)
// pass 0 as Y's adjustment because there shouldn't be any
// references to Y in the semi/anti join filter
setJoinAdjustments(adjustments, nFieldsX, nFieldsY, nFieldsZ, 0, -nFieldsY);
newSemiAntiJoinFilter = semiAntiJoin.getCondition().accept(new RelOptUtil.RexInputConverter(semiAntiJoin.getCluster().getRexBuilder(), fields, adjustments));
} else {
// (X, Y, Z) --> (X, Y, Z)
// semiJoin(Y, Z)
setJoinAdjustments(adjustments, nFieldsX, nFieldsY, nFieldsZ, -nFieldsX, -nFieldsX);
newSemiAntiJoinFilter = semiAntiJoin.getCondition().accept(new RelOptUtil.RexInputConverter(semiAntiJoin.getCluster().getRexBuilder(), fields, adjustments));
}
// create the new join
RelNode newSemiAntiJoinLeft;
if (nKeysFromX > 0) {
newSemiAntiJoinLeft = join.getLeft();
} else {
newSemiAntiJoinLeft = join.getRight();
}
Join newSemiAntiJoin = LogicalJoin.create(newSemiAntiJoinLeft, semiAntiJoin.getRight(), join.getHints(), newSemiAntiJoinFilter, semiAntiJoin.getVariablesSet(), semiAntiJoin.getJoinType());
RelNode leftJoinRel;
RelNode rightJoinRel;
if (nKeysFromX > 0) {
leftJoinRel = newSemiAntiJoin;
rightJoinRel = join.getRight();
} else {
leftJoinRel = join.getLeft();
rightJoinRel = newSemiAntiJoin;
}
RelNode newJoinRel = join.copy(join.getTraitSet(), join.getCondition(), leftJoinRel, rightJoinRel, join.getJoinType(), join.isSemiJoinDone());
call.transformTo(newJoinRel);
}
Aggregations