Search in sources :

Example 81 with Join

use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.Join in project flink by apache.

the class RelDecorrelator method decorrelate.

protected RelNode decorrelate(RelNode root) {
    // first adjust count() expression if any
    final RelBuilderFactory f = relBuilderFactory();
    HepProgram program = HepProgram.builder().addRuleInstance(AdjustProjectForCountAggregateRule.config(false, this, f).toRule()).addRuleInstance(AdjustProjectForCountAggregateRule.config(true, this, f).toRule()).addRuleInstance(FilterJoinRule.FilterIntoJoinRule.Config.DEFAULT.withRelBuilderFactory(f).withOperandSupplier(b0 -> b0.operand(Filter.class).oneInput(b1 -> b1.operand(Join.class).anyInputs())).withDescription("FilterJoinRule:filter").as(FilterJoinRule.FilterIntoJoinRule.Config.class).withSmart(true).withPredicate((join, joinType, exp) -> true).as(FilterJoinRule.FilterIntoJoinRule.Config.class).toRule()).addRuleInstance(CoreRules.FILTER_PROJECT_TRANSPOSE.config.withRelBuilderFactory(f).as(FilterProjectTransposeRule.Config.class).withOperandFor(Filter.class, filter -> !RexUtil.containsCorrelation(filter.getCondition()), Project.class, project -> true).withCopyFilter(true).withCopyProject(true).toRule()).addRuleInstance(FilterCorrelateRule.Config.DEFAULT.withRelBuilderFactory(f).toRule()).build();
    HepPlanner planner = createPlanner(program);
    planner.setRoot(root);
    root = planner.findBestExp();
    // Perform decorrelation.
    map.clear();
    final Frame frame = getInvoke(root, null);
    if (frame != null) {
        // has been rewritten; apply rules post-decorrelation
        final HepProgram program2 = HepProgram.builder().addRuleInstance(CoreRules.FILTER_INTO_JOIN.config.withRelBuilderFactory(f).toRule()).addRuleInstance(CoreRules.JOIN_CONDITION_PUSH.config.withRelBuilderFactory(f).toRule()).build();
        final HepPlanner planner2 = createPlanner(program2);
        final RelNode newRoot = frame.r;
        planner2.setRoot(newRoot);
        return planner2.findBestExp();
    }
    return root;
}
Also used : RelBuilderFactory(org.apache.calcite.tools.RelBuilderFactory) Values(org.apache.calcite.rel.core.Values) Mappings(org.apache.calcite.util.mapping.Mappings) MultimapBuilder(com.google.common.collect.MultimapBuilder) BigDecimal(java.math.BigDecimal) CorrelationId(org.apache.calcite.rel.core.CorrelationId) Map(java.util.Map) ImmutableBitSet(org.apache.calcite.util.ImmutableBitSet) ImmutableBeans(org.apache.calcite.util.ImmutableBeans) Function2(org.apache.calcite.linq4j.function.Function2) SqlKind(org.apache.calcite.sql.SqlKind) RexVisitorImpl(org.apache.calcite.rex.RexVisitorImpl) SqlCountAggFunction(org.apache.calcite.sql.fun.SqlCountAggFunction) Set(java.util.Set) SqlExplainLevel(org.apache.calcite.sql.SqlExplainLevel) RelMetadataQuery(org.apache.calcite.rel.metadata.RelMetadataQuery) SqlStdOperatorTable(org.apache.calcite.sql.fun.SqlStdOperatorTable) RelCollation(org.apache.calcite.rel.RelCollation) RexCorrelVariable(org.apache.calcite.rex.RexCorrelVariable) HepRelVertex(org.apache.calcite.plan.hep.HepRelVertex) RexCall(org.apache.calcite.rex.RexCall) Iterables(com.google.common.collect.Iterables) Holder(org.apache.calcite.util.Holder) Correlate(org.apache.calcite.rel.core.Correlate) Ord(org.apache.calcite.linq4j.Ord) Filter(org.apache.calcite.rel.core.Filter) Join(org.apache.calcite.rel.core.Join) ArrayList(java.util.ArrayList) LogicalCorrelate(org.apache.calcite.rel.logical.LogicalCorrelate) ReflectiveVisitor(org.apache.calcite.util.ReflectiveVisitor) ImmutableSortedMap(com.google.common.collect.ImmutableSortedMap) ImmutableSortedSet(com.google.common.collect.ImmutableSortedSet) RelDataType(org.apache.calcite.rel.type.RelDataType) BiRel(org.apache.calcite.rel.BiRel) FilterProjectTransposeRule(org.apache.calcite.rel.rules.FilterProjectTransposeRule) RelRule(org.apache.calcite.plan.RelRule) Aggregate(org.apache.calcite.rel.core.Aggregate) RelHomogeneousShuttle(org.apache.calcite.rel.RelHomogeneousShuttle) LogicalAggregate(org.apache.calcite.rel.logical.LogicalAggregate) SqlFunction(org.apache.calcite.sql.SqlFunction) TreeMap(java.util.TreeMap) CalciteTrace(org.apache.calcite.util.trace.CalciteTrace) JoinRelType(org.apache.calcite.rel.core.JoinRelType) AggregateCall(org.apache.calcite.rel.core.AggregateCall) RelDataTypeFactory(org.apache.calcite.rel.type.RelDataTypeFactory) LogicalFilter(org.apache.calcite.rel.logical.LogicalFilter) RelFactories(org.apache.calcite.rel.core.RelFactories) RelOptCostImpl(org.apache.calcite.plan.RelOptCostImpl) LogicalTableFunctionScan(org.apache.calcite.rel.logical.LogicalTableFunctionScan) RelMdUtil(org.apache.calcite.rel.metadata.RelMdUtil) RexUtil(org.apache.calcite.rex.RexUtil) LogicalJoin(org.apache.calcite.rel.logical.LogicalJoin) RexNode(org.apache.calcite.rex.RexNode) RelBuilder(org.apache.calcite.tools.RelBuilder) HepProgram(org.apache.calcite.plan.hep.HepProgram) RelOptCluster(org.apache.calcite.plan.RelOptCluster) Litmus(org.apache.calcite.util.Litmus) ImmutableSet(com.google.common.collect.ImmutableSet) SortedSetMultimap(com.google.common.collect.SortedSetMultimap) ImmutableMap(com.google.common.collect.ImmutableMap) RexLiteral(org.apache.calcite.rex.RexLiteral) Collection(java.util.Collection) Context(org.apache.calcite.plan.Context) NavigableMap(java.util.NavigableMap) Collectors(java.util.stream.Collectors) Sets(com.google.common.collect.Sets) RexInputRef(org.apache.calcite.rex.RexInputRef) Objects(java.util.Objects) List(java.util.List) FilterJoinRule(org.apache.calcite.rel.rules.FilterJoinRule) Sort(org.apache.calcite.rel.core.Sort) RelDataTypeField(org.apache.calcite.rel.type.RelDataTypeField) SqlExplainFormat(org.apache.calcite.sql.SqlExplainFormat) SortedMap(java.util.SortedMap) Project(org.apache.calcite.rel.core.Project) RexFieldAccess(org.apache.calcite.rex.RexFieldAccess) HashMap(java.util.HashMap) Multimap(com.google.common.collect.Multimap) RelOptUtil(org.apache.calcite.plan.RelOptUtil) SqlSingleValueAggFunction(org.apache.calcite.sql.fun.SqlSingleValueAggFunction) HashSet(java.util.HashSet) ImmutableList(com.google.common.collect.ImmutableList) LogicalSnapshot(org.apache.calcite.rel.logical.LogicalSnapshot) Pair(org.apache.calcite.util.Pair) RelBuilderFactory(org.apache.calcite.tools.RelBuilderFactory) SqlOperator(org.apache.calcite.sql.SqlOperator) Nonnull(javax.annotation.Nonnull) Logger(org.slf4j.Logger) LogicalProject(org.apache.calcite.rel.logical.LogicalProject) RexBuilder(org.apache.calcite.rex.RexBuilder) RexSubQuery(org.apache.calcite.rex.RexSubQuery) HepPlanner(org.apache.calcite.plan.hep.HepPlanner) RelNode(org.apache.calcite.rel.RelNode) RelOptRuleCall(org.apache.calcite.plan.RelOptRuleCall) ReflectUtil(org.apache.calcite.util.ReflectUtil) CoreRules(org.apache.calcite.rel.rules.CoreRules) FilterCorrelateRule(org.apache.calcite.rel.rules.FilterCorrelateRule) RexShuttle(org.apache.calcite.rex.RexShuttle) Util(org.apache.calcite.util.Util) Collections(java.util.Collections) HepProgram(org.apache.calcite.plan.hep.HepProgram) RelNode(org.apache.calcite.rel.RelNode) Join(org.apache.calcite.rel.core.Join) LogicalJoin(org.apache.calcite.rel.logical.LogicalJoin) FilterJoinRule(org.apache.calcite.rel.rules.FilterJoinRule) HepPlanner(org.apache.calcite.plan.hep.HepPlanner)

Example 82 with Join

use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.Join in project flink by apache.

the class FlinkAggregateJoinTransposeRule method onMatch.

public void onMatch(RelOptRuleCall call) {
    final Aggregate origAgg = call.rel(0);
    final Join join = call.rel(1);
    final RexBuilder rexBuilder = origAgg.getCluster().getRexBuilder();
    final RelBuilder relBuilder = call.builder();
    // converts an aggregate with AUXILIARY_GROUP to a regular aggregate.
    // if the converted aggregate can be push down,
    // AggregateReduceGroupingRule will try reduce grouping of new aggregates created by this
    // rule
    final Pair<Aggregate, List<RexNode>> newAggAndProject = toRegularAggregate(origAgg);
    final Aggregate aggregate = newAggAndProject.left;
    final List<RexNode> projectAfterAgg = newAggAndProject.right;
    // If any aggregate call has a filter or distinct, bail out
    for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
        if (aggregateCall.getAggregation().unwrap(SqlSplittableAggFunction.class) == null) {
            return;
        }
        if (aggregateCall.filterArg >= 0 || aggregateCall.isDistinct()) {
            return;
        }
    }
    if (join.getJoinType() != JoinRelType.INNER) {
        return;
    }
    if (!allowFunctions && !aggregate.getAggCallList().isEmpty()) {
        return;
    }
    // Do the columns used by the join appear in the output of the aggregate?
    final ImmutableBitSet aggregateColumns = aggregate.getGroupSet();
    final RelMetadataQuery mq = call.getMetadataQuery();
    final ImmutableBitSet keyColumns = keyColumns(aggregateColumns, mq.getPulledUpPredicates(join).pulledUpPredicates);
    final ImmutableBitSet joinColumns = RelOptUtil.InputFinder.bits(join.getCondition());
    final boolean allColumnsInAggregate = keyColumns.contains(joinColumns);
    final ImmutableBitSet belowAggregateColumns = aggregateColumns.union(joinColumns);
    // Split join condition
    final List<Integer> leftKeys = com.google.common.collect.Lists.newArrayList();
    final List<Integer> rightKeys = com.google.common.collect.Lists.newArrayList();
    final List<Boolean> filterNulls = com.google.common.collect.Lists.newArrayList();
    RexNode nonEquiConj = RelOptUtil.splitJoinCondition(join.getLeft(), join.getRight(), join.getCondition(), leftKeys, rightKeys, filterNulls);
    // If it contains non-equi join conditions, we bail out
    if (!nonEquiConj.isAlwaysTrue()) {
        return;
    }
    // Push each aggregate function down to each side that contains all of its
    // arguments. Note that COUNT(*), because it has no arguments, can go to
    // both sides.
    final Map<Integer, Integer> map = new HashMap<>();
    final List<Side> sides = new ArrayList<>();
    int uniqueCount = 0;
    int offset = 0;
    int belowOffset = 0;
    for (int s = 0; s < 2; s++) {
        final Side side = new Side();
        final RelNode joinInput = join.getInput(s);
        int fieldCount = joinInput.getRowType().getFieldCount();
        final ImmutableBitSet fieldSet = ImmutableBitSet.range(offset, offset + fieldCount);
        final ImmutableBitSet belowAggregateKeyNotShifted = belowAggregateColumns.intersect(fieldSet);
        for (Ord<Integer> c : Ord.zip(belowAggregateKeyNotShifted)) {
            map.put(c.e, belowOffset + c.i);
        }
        final Mappings.TargetMapping mapping = s == 0 ? Mappings.createIdentity(fieldCount) : Mappings.createShiftMapping(fieldCount + offset, 0, offset, fieldCount);
        final ImmutableBitSet belowAggregateKey = belowAggregateKeyNotShifted.shift(-offset);
        final boolean unique;
        if (!allowFunctions) {
            assert aggregate.getAggCallList().isEmpty();
            // If there are no functions, it doesn't matter as much whether we
            // aggregate the inputs before the join, because there will not be
            // any functions experiencing a cartesian product effect.
            // 
            // But finding out whether the input is already unique requires a call
            // to areColumnsUnique that currently (until [CALCITE-1048] "Make
            // metadata more robust" is fixed) places a heavy load on
            // the metadata system.
            // 
            // So we choose to imagine the input is already unique, which is
            // untrue but harmless.
            // 
            Util.discard(Bug.CALCITE_1048_FIXED);
            unique = true;
        } else {
            final Boolean unique0 = mq.areColumnsUnique(joinInput, belowAggregateKey);
            unique = unique0 != null && unique0;
        }
        if (unique) {
            ++uniqueCount;
            side.aggregate = false;
            relBuilder.push(joinInput);
            final Map<Integer, Integer> belowAggregateKeyToNewProjectMap = new HashMap<>();
            final List<RexNode> projects = new ArrayList<>();
            for (Integer i : belowAggregateKey) {
                belowAggregateKeyToNewProjectMap.put(i, projects.size());
                projects.add(relBuilder.field(i));
            }
            for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) {
                final SqlAggFunction aggregation = aggCall.e.getAggregation();
                final SqlSplittableAggFunction splitter = Preconditions.checkNotNull(aggregation.unwrap(SqlSplittableAggFunction.class));
                if (!aggCall.e.getArgList().isEmpty() && fieldSet.contains(ImmutableBitSet.of(aggCall.e.getArgList()))) {
                    final RexNode singleton = splitter.singleton(rexBuilder, joinInput.getRowType(), aggCall.e.transform(mapping));
                    final RexNode targetSingleton = rexBuilder.ensureType(aggCall.e.type, singleton, false);
                    if (targetSingleton instanceof RexInputRef) {
                        final int index = ((RexInputRef) targetSingleton).getIndex();
                        if (!belowAggregateKey.get(index)) {
                            projects.add(targetSingleton);
                            side.split.put(aggCall.i, projects.size() - 1);
                        } else {
                            side.split.put(aggCall.i, belowAggregateKeyToNewProjectMap.get(index));
                        }
                    } else {
                        projects.add(targetSingleton);
                        side.split.put(aggCall.i, projects.size() - 1);
                    }
                }
            }
            relBuilder.project(projects);
            side.newInput = relBuilder.build();
        } else {
            side.aggregate = true;
            List<AggregateCall> belowAggCalls = new ArrayList<>();
            final SqlSplittableAggFunction.Registry<AggregateCall> belowAggCallRegistry = registry(belowAggCalls);
            final int oldGroupKeyCount = aggregate.getGroupCount();
            final int newGroupKeyCount = belowAggregateKey.cardinality();
            for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) {
                final SqlAggFunction aggregation = aggCall.e.getAggregation();
                final SqlSplittableAggFunction splitter = Preconditions.checkNotNull(aggregation.unwrap(SqlSplittableAggFunction.class));
                final AggregateCall call1;
                if (fieldSet.contains(ImmutableBitSet.of(aggCall.e.getArgList()))) {
                    final AggregateCall splitCall = splitter.split(aggCall.e, mapping);
                    call1 = splitCall.adaptTo(joinInput, splitCall.getArgList(), splitCall.filterArg, oldGroupKeyCount, newGroupKeyCount);
                } else {
                    call1 = splitter.other(rexBuilder.getTypeFactory(), aggCall.e);
                }
                if (call1 != null) {
                    side.split.put(aggCall.i, belowAggregateKey.cardinality() + belowAggCallRegistry.register(call1));
                }
            }
            side.newInput = relBuilder.push(joinInput).aggregate(relBuilder.groupKey(belowAggregateKey, null), belowAggCalls).build();
        }
        offset += fieldCount;
        belowOffset += side.newInput.getRowType().getFieldCount();
        sides.add(side);
    }
    if (uniqueCount == 2) {
        // invocation of this rule; if we continue we might loop forever.
        return;
    }
    // Update condition
    final Mapping mapping = (Mapping) Mappings.target(map::get, join.getRowType().getFieldCount(), belowOffset);
    final RexNode newCondition = RexUtil.apply(mapping, join.getCondition());
    // Create new join
    relBuilder.push(sides.get(0).newInput).push(sides.get(1).newInput).join(join.getJoinType(), newCondition);
    // Aggregate above to sum up the sub-totals
    final List<AggregateCall> newAggCalls = new ArrayList<>();
    final int groupIndicatorCount = aggregate.getGroupCount() + aggregate.getIndicatorCount();
    final int newLeftWidth = sides.get(0).newInput.getRowType().getFieldCount();
    final List<RexNode> projects = new ArrayList<>(rexBuilder.identityProjects(relBuilder.peek().getRowType()));
    for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) {
        final SqlAggFunction aggregation = aggCall.e.getAggregation();
        final SqlSplittableAggFunction splitter = Preconditions.checkNotNull(aggregation.unwrap(SqlSplittableAggFunction.class));
        final Integer leftSubTotal = sides.get(0).split.get(aggCall.i);
        final Integer rightSubTotal = sides.get(1).split.get(aggCall.i);
        newAggCalls.add(splitter.topSplit(rexBuilder, registry(projects), groupIndicatorCount, relBuilder.peek().getRowType(), aggCall.e, leftSubTotal == null ? -1 : leftSubTotal, rightSubTotal == null ? -1 : rightSubTotal + newLeftWidth));
    }
    relBuilder.project(projects);
    boolean aggConvertedToProjects = false;
    if (allColumnsInAggregate) {
        // let's see if we can convert aggregate into projects
        List<RexNode> projects2 = new ArrayList<>();
        for (int key : Mappings.apply(mapping, aggregate.getGroupSet())) {
            projects2.add(relBuilder.field(key));
        }
        int aggCallIdx = projects2.size();
        for (AggregateCall newAggCall : newAggCalls) {
            final SqlSplittableAggFunction splitter = newAggCall.getAggregation().unwrap(SqlSplittableAggFunction.class);
            if (splitter != null) {
                final RelDataType rowType = relBuilder.peek().getRowType();
                final RexNode singleton = splitter.singleton(rexBuilder, rowType, newAggCall);
                final RelDataType originalAggCallType = aggregate.getRowType().getFieldList().get(aggCallIdx).getType();
                final RexNode targetSingleton = rexBuilder.ensureType(originalAggCallType, singleton, false);
                projects2.add(targetSingleton);
            }
            aggCallIdx += 1;
        }
        if (projects2.size() == aggregate.getGroupSet().cardinality() + newAggCalls.size()) {
            // We successfully converted agg calls into projects.
            relBuilder.project(projects2);
            aggConvertedToProjects = true;
        }
    }
    if (!aggConvertedToProjects) {
        relBuilder.aggregate(relBuilder.groupKey(Mappings.apply(mapping, aggregate.getGroupSet()), Mappings.apply2(mapping, aggregate.getGroupSets())), newAggCalls);
    }
    if (projectAfterAgg != null) {
        relBuilder.project(projectAfterAgg, origAgg.getRowType().getFieldNames());
    }
    call.transformTo(relBuilder.build());
}
Also used : RelMetadataQuery(org.apache.calcite.rel.metadata.RelMetadataQuery) ImmutableBitSet(org.apache.calcite.util.ImmutableBitSet) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) Mapping(org.apache.calcite.util.mapping.Mapping) RelDataType(org.apache.calcite.rel.type.RelDataType) RexBuilder(org.apache.calcite.rex.RexBuilder) ArrayList(java.util.ArrayList) List(java.util.List) SqlSplittableAggFunction(org.apache.calcite.sql.SqlSplittableAggFunction) RelBuilder(org.apache.calcite.tools.RelBuilder) Join(org.apache.calcite.rel.core.Join) LogicalJoin(org.apache.calcite.rel.logical.LogicalJoin) SqlAggFunction(org.apache.calcite.sql.SqlAggFunction) AggregateCall(org.apache.calcite.rel.core.AggregateCall) RelNode(org.apache.calcite.rel.RelNode) Mappings(org.apache.calcite.util.mapping.Mappings) RexInputRef(org.apache.calcite.rex.RexInputRef) Aggregate(org.apache.calcite.rel.core.Aggregate) LogicalAggregate(org.apache.calcite.rel.logical.LogicalAggregate) RexNode(org.apache.calcite.rex.RexNode)

Example 83 with Join

use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.Join in project flink by apache.

the class FlinkJoinPushExpressionsRule method onMatch.

@Override
public void onMatch(RelOptRuleCall call) {
    Join join = call.rel(0);
    // Push expression in join condition into Project below Join.
    RelNode newJoin = RelOptUtil.pushDownJoinConditions(join, call.builder());
    // If the join is the same, we bail out
    if (newJoin instanceof Join) {
        final RexNode newCondition = ((Join) newJoin).getCondition();
        if (join.getCondition().equals(newCondition)) {
            return;
        }
    }
    call.transformTo(newJoin);
}
Also used : RelNode(org.apache.calcite.rel.RelNode) Join(org.apache.calcite.rel.core.Join) RexNode(org.apache.calcite.rex.RexNode)

Example 84 with Join

use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.Join in project flink by apache.

the class FlinkProjectJoinTransposeRule method onMatch.

// ~ Methods ----------------------------------------------------------------
// implement RelOptRule
public void onMatch(RelOptRuleCall call) {
    Project origProj = call.rel(0);
    final Join join = call.rel(1);
    if (!join.getJoinType().projectsRight()) {
        // TODO: support SEMI/ANTI join later
        return;
    }
    // locate all fields referenced in the projection and join condition;
    // determine which inputs are referenced in the projection and
    // join condition; if all fields are being referenced and there are no
    // special expressions, no point in proceeding any further
    PushProjector pushProject = new PushProjector(origProj, join.getCondition(), join, preserveExprCondition, call.builder());
    if (pushProject.locateAllRefs()) {
        return;
    }
    // create left and right projections, projecting only those
    // fields referenced on each side
    RelNode leftProjRel = pushProject.createProjectRefsAndExprs(join.getLeft(), true, false);
    RelNode rightProjRel = pushProject.createProjectRefsAndExprs(join.getRight(), true, true);
    // convert the join condition to reference the projected columns
    RexNode newJoinFilter = null;
    int[] adjustments = pushProject.getAdjustments();
    if (join.getCondition() != null) {
        List<RelDataTypeField> projJoinFieldList = new ArrayList<>();
        projJoinFieldList.addAll(join.getSystemFieldList());
        projJoinFieldList.addAll(leftProjRel.getRowType().getFieldList());
        projJoinFieldList.addAll(rightProjRel.getRowType().getFieldList());
        newJoinFilter = pushProject.convertRefsAndExprs(join.getCondition(), projJoinFieldList, adjustments);
    }
    // create a new join with the projected children
    Join newJoinRel = join.copy(join.getTraitSet(), newJoinFilter, leftProjRel, rightProjRel, join.getJoinType(), join.isSemiJoinDone());
    // put the original project on top of the join, converting it to
    // reference the modified projection list
    RelNode topProject = pushProject.createNewProject(newJoinRel, adjustments);
    call.transformTo(topProject);
}
Also used : Project(org.apache.calcite.rel.core.Project) PushProjector(org.apache.calcite.rel.rules.PushProjector) RelDataTypeField(org.apache.calcite.rel.type.RelDataTypeField) RelNode(org.apache.calcite.rel.RelNode) ArrayList(java.util.ArrayList) Join(org.apache.calcite.rel.core.Join) RexNode(org.apache.calcite.rex.RexNode)

Example 85 with Join

use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.Join in project flink by apache.

the class FlinkSemiAntiJoinJoinTransposeRule method onMatch.

// ~ Methods ----------------------------------------------------------------
// implement RelOptRule
public void onMatch(RelOptRuleCall call) {
    LogicalJoin semiAntiJoin = call.rel(0);
    if (semiAntiJoin.getJoinType() != JoinRelType.SEMI && semiAntiJoin.getJoinType() != JoinRelType.ANTI) {
        return;
    }
    final Join join = call.rel(1);
    if (join.getJoinType() == JoinRelType.SEMI || join.getJoinType() == JoinRelType.ANTI) {
        return;
    }
    // TODO support other join type
    if (join.getJoinType() != JoinRelType.INNER) {
        return;
    }
    // unsupported cases:
    // 1. (NOT) EXISTS with uncorrelation
    // 2. keys in SemiJoin condition are from both Join's left and Join's right
    Pair<ImmutableBitSet, ImmutableBitSet> inputRefs = getSemiAntiJoinConditionInputRefs(semiAntiJoin);
    final ImmutableBitSet leftInputRefs = inputRefs.left;
    final ImmutableBitSet rightInputRefs = inputRefs.right;
    // TODO currently we does not handle this
    if (leftInputRefs.isEmpty() || rightInputRefs.isEmpty()) {
        return;
    }
    // X is the left child of the join below the semi-join
    // Y is the right child of the join below the semi-join
    // Z is the right child of the semi-join
    int nFieldsX = join.getLeft().getRowType().getFieldList().size();
    int nFieldsY = join.getRight().getRowType().getFieldList().size();
    int nFieldsZ = semiAntiJoin.getRight().getRowType().getFieldList().size();
    int nTotalFields = nFieldsX + nFieldsY + nFieldsZ;
    List<RelDataTypeField> fields = new ArrayList<RelDataTypeField>();
    // create a list of fields for the full join result; note that
    // we can't simply use the fields from the semi-join because the
    // row-type of a semi-join only includes the left hand side fields
    List<RelDataTypeField> joinFields = semiAntiJoin.getRowType().getFieldList();
    for (int i = 0; i < (nFieldsX + nFieldsY); i++) {
        fields.add(joinFields.get(i));
    }
    joinFields = semiAntiJoin.getRight().getRowType().getFieldList();
    for (int i = 0; i < nFieldsZ; i++) {
        fields.add(joinFields.get(i));
    }
    // determine which operands below the semi-join are the actual
    // Rels that participate in the semi-join
    int nKeysFromX = 0;
    int nKeysFromY = 0;
    for (int leftKey : leftInputRefs) {
        if (leftKey < nFieldsX) {
            nKeysFromX++;
        } else {
            nKeysFromY++;
        }
    }
    // e.g. SELECT * FROM x, y WHERE x.c = y.f AND x.a IN (SELECT z.i FROM z WHERE y.e = z.j)
    if (nKeysFromX > 0 && nKeysFromY > 0) {
        return;
    }
    // otherwise, a semi-join wouldn't have been created
    assert (nKeysFromX == 0) || (nKeysFromX == leftInputRefs.cardinality());
    assert (nKeysFromY == 0) || (nKeysFromY == leftInputRefs.cardinality());
    // need to convert the semi/anti join condition and possibly the keys
    RexNode newSemiAntiJoinFilter;
    int[] adjustments = new int[nTotalFields];
    if (nKeysFromX > 0) {
        // (X, Y, Z) --> (X, Z, Y)
        // semiJoin(X, Z)
        // pass 0 as Y's adjustment because there shouldn't be any
        // references to Y in the semi/anti join filter
        setJoinAdjustments(adjustments, nFieldsX, nFieldsY, nFieldsZ, 0, -nFieldsY);
        newSemiAntiJoinFilter = semiAntiJoin.getCondition().accept(new RelOptUtil.RexInputConverter(semiAntiJoin.getCluster().getRexBuilder(), fields, adjustments));
    } else {
        // (X, Y, Z) --> (X, Y, Z)
        // semiJoin(Y, Z)
        setJoinAdjustments(adjustments, nFieldsX, nFieldsY, nFieldsZ, -nFieldsX, -nFieldsX);
        newSemiAntiJoinFilter = semiAntiJoin.getCondition().accept(new RelOptUtil.RexInputConverter(semiAntiJoin.getCluster().getRexBuilder(), fields, adjustments));
    }
    // create the new join
    RelNode newSemiAntiJoinLeft;
    if (nKeysFromX > 0) {
        newSemiAntiJoinLeft = join.getLeft();
    } else {
        newSemiAntiJoinLeft = join.getRight();
    }
    Join newSemiAntiJoin = LogicalJoin.create(newSemiAntiJoinLeft, semiAntiJoin.getRight(), join.getHints(), newSemiAntiJoinFilter, semiAntiJoin.getVariablesSet(), semiAntiJoin.getJoinType());
    RelNode leftJoinRel;
    RelNode rightJoinRel;
    if (nKeysFromX > 0) {
        leftJoinRel = newSemiAntiJoin;
        rightJoinRel = join.getRight();
    } else {
        leftJoinRel = join.getLeft();
        rightJoinRel = newSemiAntiJoin;
    }
    RelNode newJoinRel = join.copy(join.getTraitSet(), join.getCondition(), leftJoinRel, rightJoinRel, join.getJoinType(), join.isSemiJoinDone());
    call.transformTo(newJoinRel);
}
Also used : RelDataTypeField(org.apache.calcite.rel.type.RelDataTypeField) ImmutableBitSet(org.apache.calcite.util.ImmutableBitSet) RelNode(org.apache.calcite.rel.RelNode) LogicalJoin(org.apache.calcite.rel.logical.LogicalJoin) ArrayList(java.util.ArrayList) Join(org.apache.calcite.rel.core.Join) LogicalJoin(org.apache.calcite.rel.logical.LogicalJoin) RexNode(org.apache.calcite.rex.RexNode)

Aggregations

Join (org.apache.calcite.rel.core.Join)73 RelNode (org.apache.calcite.rel.RelNode)45 RexNode (org.apache.calcite.rex.RexNode)40 ArrayList (java.util.ArrayList)31 LogicalJoin (org.apache.calcite.rel.logical.LogicalJoin)25 Project (org.apache.calcite.rel.core.Project)22 RexBuilder (org.apache.calcite.rex.RexBuilder)20 ImmutableBitSet (org.apache.calcite.util.ImmutableBitSet)18 RelBuilder (org.apache.calcite.tools.RelBuilder)17 HiveJoin (org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveJoin)14 Aggregate (org.apache.calcite.rel.core.Aggregate)13 Test (org.junit.Test)13 Filter (org.apache.calcite.rel.core.Filter)12 RelNode (org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.RelNode)11 SemiJoin (org.apache.calcite.rel.core.SemiJoin)11 RelOptCluster (org.apache.calcite.plan.RelOptCluster)10 JoinRelType (org.apache.calcite.rel.core.JoinRelType)9 RelMetadataQuery (org.apache.calcite.rel.metadata.RelMetadataQuery)9 Mappings (org.apache.calcite.util.mapping.Mappings)9 List (java.util.List)8