Search in sources :

Example 21 with Mapping

use of org.apache.calcite.util.mapping.Mapping in project hive by apache.

the class HiveCardinalityPreservingJoinOptimization method trim.

@Override
public RelNode trim(RelBuilder relBuilder, RelNode root) {
    try {
        if (root.getInputs().size() != 1) {
            LOG.debug("Only plans where root has one input are supported. Root: {}", root);
            return root;
        }
        REL_BUILDER.set(relBuilder);
        RexBuilder rexBuilder = relBuilder.getRexBuilder();
        RelNode rootInput = root.getInput(0);
        // Build the list of RexInputRef from root input RowType
        List<RexInputRef> rootFieldList = new ArrayList<>(rootInput.getRowType().getFieldCount());
        List<String> newColumnNames = new ArrayList<>();
        for (int i = 0; i < rootInput.getRowType().getFieldList().size(); ++i) {
            RelDataTypeField relDataTypeField = rootInput.getRowType().getFieldList().get(i);
            rootFieldList.add(rexBuilder.makeInputRef(relDataTypeField.getType(), i));
            newColumnNames.add(relDataTypeField.getName());
        }
        // Bit set to gather the refs that backtrack to constant values
        BitSet constants = new BitSet();
        List<JoinedBackFields> lineages = getExpressionLineageOf(rootFieldList, rootInput, constants);
        if (lineages == null) {
            LOG.debug("Some projected field lineage can not be determined");
            return root;
        }
        // 1. Collect candidate tables for join back and map RexNodes coming from those tables to their index in the
        // rootInput row type
        // Collect all used fields from original plan
        ImmutableBitSet fieldsUsed = ImmutableBitSet.of(constants.stream().toArray());
        List<TableToJoinBack> tableToJoinBackList = new ArrayList<>(lineages.size());
        Map<Integer, RexNode> rexNodesToShuttle = new HashMap<>(rootInput.getRowType().getFieldCount());
        for (JoinedBackFields joinedBackFields : lineages) {
            Optional<ImmutableBitSet> projectedKeys = joinedBackFields.relOptHiveTable.getNonNullableKeys().stream().filter(joinedBackFields.fieldsInSourceTable::contains).findFirst();
            if (projectedKeys.isPresent() && !projectedKeys.get().equals(joinedBackFields.fieldsInSourceTable)) {
                TableToJoinBack tableToJoinBack = new TableToJoinBack(projectedKeys.get(), joinedBackFields);
                tableToJoinBackList.add(tableToJoinBack);
                fieldsUsed = fieldsUsed.union(joinedBackFields.getSource(projectedKeys.get()));
                for (TableInputRefHolder mapping : joinedBackFields.mapping) {
                    if (!fieldsUsed.get(mapping.indexInOriginalRowType)) {
                        rexNodesToShuttle.put(mapping.indexInOriginalRowType, mapping.rexNode);
                    }
                }
            } else {
                fieldsUsed = fieldsUsed.union(joinedBackFields.fieldsInOriginalRowType);
            }
        }
        if (tableToJoinBackList.isEmpty()) {
            LOG.debug("None of the tables has keys projected, unable to join back");
            return root;
        }
        // 2. Trim out non-key fields of joined back tables
        Set<RelDataTypeField> extraFields = Collections.emptySet();
        TrimResult trimResult = dispatchTrimFields(rootInput, fieldsUsed, extraFields);
        RelNode newInput = trimResult.left;
        if (newInput.getRowType().equals(rootInput.getRowType())) {
            LOG.debug("Nothing was trimmed out.");
            return root;
        }
        // 3. Join back tables to the top of original plan
        Mapping newInputMapping = trimResult.right;
        Map<RexTableInputRef, Integer> tableInputRefMapping = new HashMap<>();
        for (TableToJoinBack tableToJoinBack : tableToJoinBackList) {
            LOG.debug("Joining back table {}", tableToJoinBack.joinedBackFields.relOptHiveTable.getName());
            // 3.1. Create new TableScan of tables to join back
            RelOptHiveTable relOptTable = tableToJoinBack.joinedBackFields.relOptHiveTable;
            RelOptCluster cluster = relBuilder.getCluster();
            HiveTableScan tableScan = new HiveTableScan(cluster, cluster.traitSetOf(HiveRelNode.CONVENTION), relOptTable, relOptTable.getHiveTableMD().getTableName(), null, false, false);
            // 3.2. Create Project with the required fields from this table
            RelNode projectTableAccessRel = tableScan.project(tableToJoinBack.joinedBackFields.fieldsInSourceTable, new HashSet<>(0), REL_BUILDER.get());
            // 3.3. Create mapping between the Project and TableScan
            Mapping projectMapping = Mappings.create(MappingType.INVERSE_SURJECTION, tableScan.getRowType().getFieldCount(), tableToJoinBack.joinedBackFields.fieldsInSourceTable.cardinality());
            int projectIndex = 0;
            for (int i : tableToJoinBack.joinedBackFields.fieldsInSourceTable) {
                projectMapping.set(i, projectIndex);
                ++projectIndex;
            }
            int offset = newInput.getRowType().getFieldCount();
            // 3.4. Map rexTableInputRef to the index where it can be found in the new Input row type
            for (TableInputRefHolder mapping : tableToJoinBack.joinedBackFields.mapping) {
                int indexInSourceTable = mapping.tableInputRef.getIndex();
                if (!tableToJoinBack.keys.get(indexInSourceTable)) {
                    // 3.5. if this is not a key field it is shifted by the left input field count
                    tableInputRefMapping.put(mapping.tableInputRef, offset + projectMapping.getTarget(indexInSourceTable));
                }
            }
            // 3.7. Create Join
            relBuilder.push(newInput);
            relBuilder.push(projectTableAccessRel);
            RexNode joinCondition = joinCondition(newInput, newInputMapping, tableToJoinBack, projectTableAccessRel, projectMapping, rexBuilder);
            newInput = relBuilder.join(JoinRelType.INNER, joinCondition).build();
        }
        // 4. Collect rexNodes for Project
        TableInputRefMapper mapper = new TableInputRefMapper(tableInputRefMapping, rexBuilder, newInput);
        List<RexNode> rexNodeList = new ArrayList<>(rootInput.getRowType().getFieldCount());
        for (int i = 0; i < rootInput.getRowType().getFieldCount(); i++) {
            RexNode rexNode = rexNodesToShuttle.get(i);
            if (rexNode != null) {
                rexNodeList.add(mapper.apply(rexNode));
            } else {
                int target = newInputMapping.getTarget(i);
                rexNodeList.add(rexBuilder.makeInputRef(newInput.getRowType().getFieldList().get(target).getType(), target));
            }
        }
        // 5. Create Project on top of all Join backs
        relBuilder.push(newInput);
        relBuilder.project(rexNodeList, newColumnNames);
        return root.copy(root.getTraitSet(), singletonList(relBuilder.build()));
    } finally {
        REL_BUILDER.remove();
    }
}
Also used : RelOptCluster(org.apache.calcite.plan.RelOptCluster) ImmutableBitSet(org.apache.calcite.util.ImmutableBitSet) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) Mapping(org.apache.calcite.util.mapping.Mapping) RexBuilder(org.apache.calcite.rex.RexBuilder) ImmutableBitSet(org.apache.calcite.util.ImmutableBitSet) BitSet(java.util.BitSet) RexTableInputRef(org.apache.calcite.rex.RexTableInputRef) RelDataTypeField(org.apache.calcite.rel.type.RelDataTypeField) RelOptHiveTable(org.apache.hadoop.hive.ql.optimizer.calcite.RelOptHiveTable) HiveRelNode(org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveRelNode) RelNode(org.apache.calcite.rel.RelNode) RexInputRef(org.apache.calcite.rex.RexInputRef) HiveTableScan(org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveTableScan) RexNode(org.apache.calcite.rex.RexNode)

Example 22 with Mapping

use of org.apache.calcite.util.mapping.Mapping in project flink by apache.

the class FlinkAggregateJoinTransposeRule method onMatch.

public void onMatch(RelOptRuleCall call) {
    final Aggregate aggregate = call.rel(0);
    final Join join = call.rel(1);
    final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
    final RelBuilder relBuilder = call.builder();
    // If any aggregate call has a filter, bail out
    for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
        if (aggregateCall.getAggregation().unwrap(SqlSplittableAggFunction.class) == null) {
            return;
        }
        if (aggregateCall.filterArg >= 0) {
            return;
        }
    }
    // aggregate operator
    if (join.getJoinType() != JoinRelType.INNER) {
        return;
    }
    if (!allowFunctions && !aggregate.getAggCallList().isEmpty()) {
        return;
    }
    // Do the columns used by the join appear in the output of the aggregate?
    final ImmutableBitSet aggregateColumns = aggregate.getGroupSet();
    final RelMetadataQuery mq = RelMetadataQuery.instance();
    final ImmutableBitSet keyColumns = keyColumns(aggregateColumns, mq.getPulledUpPredicates(join).pulledUpPredicates);
    final ImmutableBitSet joinColumns = RelOptUtil.InputFinder.bits(join.getCondition());
    final boolean allColumnsInAggregate = keyColumns.contains(joinColumns);
    final ImmutableBitSet belowAggregateColumns = aggregateColumns.union(joinColumns);
    // Split join condition
    final List<Integer> leftKeys = Lists.newArrayList();
    final List<Integer> rightKeys = Lists.newArrayList();
    final List<Boolean> filterNulls = Lists.newArrayList();
    RexNode nonEquiConj = RelOptUtil.splitJoinCondition(join.getLeft(), join.getRight(), join.getCondition(), leftKeys, rightKeys, filterNulls);
    // If it contains non-equi join conditions, we bail out
    if (!nonEquiConj.isAlwaysTrue()) {
        return;
    }
    // Push each aggregate function down to each side that contains all of its
    // arguments. Note that COUNT(*), because it has no arguments, can go to
    // both sides.
    final Map<Integer, Integer> map = new HashMap<>();
    final List<Side> sides = new ArrayList<>();
    int uniqueCount = 0;
    int offset = 0;
    int belowOffset = 0;
    for (int s = 0; s < 2; s++) {
        final Side side = new Side();
        final RelNode joinInput = join.getInput(s);
        int fieldCount = joinInput.getRowType().getFieldCount();
        final ImmutableBitSet fieldSet = ImmutableBitSet.range(offset, offset + fieldCount);
        final ImmutableBitSet belowAggregateKeyNotShifted = belowAggregateColumns.intersect(fieldSet);
        for (Ord<Integer> c : Ord.zip(belowAggregateKeyNotShifted)) {
            map.put(c.e, belowOffset + c.i);
        }
        final ImmutableBitSet belowAggregateKey = belowAggregateKeyNotShifted.shift(-offset);
        final boolean unique;
        if (!allowFunctions) {
            assert aggregate.getAggCallList().isEmpty();
            // If there are no functions, it doesn't matter as much whether we
            // aggregate the inputs before the join, because there will not be
            // any functions experiencing a cartesian product effect.
            //
            // But finding out whether the input is already unique requires a call
            // to areColumnsUnique that currently (until [CALCITE-1048] "Make
            // metadata more robust" is fixed) places a heavy load on
            // the metadata system.
            //
            // So we choose to imagine the the input is already unique, which is
            // untrue but harmless.
            //
            Util.discard(Bug.CALCITE_1048_FIXED);
            unique = true;
        } else {
            final Boolean unique0 = mq.areColumnsUnique(joinInput, belowAggregateKey);
            unique = unique0 != null && unique0;
        }
        if (unique) {
            ++uniqueCount;
            side.aggregate = false;
            side.newInput = joinInput;
        } else {
            side.aggregate = true;
            List<AggregateCall> belowAggCalls = new ArrayList<>();
            final SqlSplittableAggFunction.Registry<AggregateCall> belowAggCallRegistry = registry(belowAggCalls);
            final Mappings.TargetMapping mapping = s == 0 ? Mappings.createIdentity(fieldCount) : Mappings.createShiftMapping(fieldCount + offset, 0, offset, fieldCount);
            for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) {
                final SqlAggFunction aggregation = aggCall.e.getAggregation();
                final SqlSplittableAggFunction splitter = Preconditions.checkNotNull(aggregation.unwrap(SqlSplittableAggFunction.class));
                final AggregateCall call1;
                if (fieldSet.contains(ImmutableBitSet.of(aggCall.e.getArgList()))) {
                    call1 = splitter.split(aggCall.e, mapping);
                } else {
                    call1 = splitter.other(rexBuilder.getTypeFactory(), aggCall.e);
                }
                if (call1 != null) {
                    side.split.put(aggCall.i, belowAggregateKey.cardinality() + belowAggCallRegistry.register(call1));
                }
            }
            side.newInput = relBuilder.push(joinInput).aggregate(relBuilder.groupKey(belowAggregateKey, false, null), belowAggCalls).build();
        }
        offset += fieldCount;
        belowOffset += side.newInput.getRowType().getFieldCount();
        sides.add(side);
    }
    if (uniqueCount == 2) {
        // invocation of this rule; if we continue we might loop forever.
        return;
    }
    // Update condition
    final Mapping mapping = (Mapping) Mappings.target(new Function<Integer, Integer>() {

        public Integer apply(Integer a0) {
            return map.get(a0);
        }
    }, join.getRowType().getFieldCount(), belowOffset);
    final RexNode newCondition = RexUtil.apply(mapping, join.getCondition());
    // Create new join
    relBuilder.push(sides.get(0).newInput).push(sides.get(1).newInput).join(join.getJoinType(), newCondition);
    // Aggregate above to sum up the sub-totals
    final List<AggregateCall> newAggCalls = new ArrayList<>();
    final int groupIndicatorCount = aggregate.getGroupCount() + aggregate.getIndicatorCount();
    final int newLeftWidth = sides.get(0).newInput.getRowType().getFieldCount();
    final List<RexNode> projects = new ArrayList<>(rexBuilder.identityProjects(relBuilder.peek().getRowType()));
    for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) {
        final SqlAggFunction aggregation = aggCall.e.getAggregation();
        final SqlSplittableAggFunction splitter = Preconditions.checkNotNull(aggregation.unwrap(SqlSplittableAggFunction.class));
        final Integer leftSubTotal = sides.get(0).split.get(aggCall.i);
        final Integer rightSubTotal = sides.get(1).split.get(aggCall.i);
        newAggCalls.add(splitter.topSplit(rexBuilder, registry(projects), groupIndicatorCount, relBuilder.peek().getRowType(), aggCall.e, leftSubTotal == null ? -1 : leftSubTotal, rightSubTotal == null ? -1 : rightSubTotal + newLeftWidth));
    }
    relBuilder.project(projects);
    boolean aggConvertedToProjects = false;
    if (allColumnsInAggregate) {
        // let's see if we can convert aggregate into projects
        List<RexNode> projects2 = new ArrayList<>();
        for (int key : Mappings.apply(mapping, aggregate.getGroupSet())) {
            projects2.add(relBuilder.field(key));
        }
        for (AggregateCall newAggCall : newAggCalls) {
            final SqlSplittableAggFunction splitter = newAggCall.getAggregation().unwrap(SqlSplittableAggFunction.class);
            if (splitter != null) {
                projects2.add(splitter.singleton(rexBuilder, relBuilder.peek().getRowType(), newAggCall));
            }
        }
        if (projects2.size() == aggregate.getGroupSet().cardinality() + newAggCalls.size()) {
            // We successfully converted agg calls into projects.
            relBuilder.project(projects2);
            aggConvertedToProjects = true;
        }
    }
    if (!aggConvertedToProjects) {
        relBuilder.aggregate(relBuilder.groupKey(Mappings.apply(mapping, aggregate.getGroupSet()), aggregate.indicator, Mappings.apply2(mapping, aggregate.getGroupSets())), newAggCalls);
    }
    call.transformTo(relBuilder.build());
}
Also used : RelMetadataQuery(org.apache.calcite.rel.metadata.RelMetadataQuery) ImmutableBitSet(org.apache.calcite.util.ImmutableBitSet) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) Mapping(org.apache.calcite.util.mapping.Mapping) Function(com.google.common.base.Function) SqlSplittableAggFunction(org.apache.calcite.sql.SqlSplittableAggFunction) SqlAggFunction(org.apache.calcite.sql.SqlAggFunction) RexBuilder(org.apache.calcite.rex.RexBuilder) SqlSplittableAggFunction(org.apache.calcite.sql.SqlSplittableAggFunction) RelBuilder(org.apache.calcite.tools.RelBuilder) Join(org.apache.calcite.rel.core.Join) LogicalJoin(org.apache.calcite.rel.logical.LogicalJoin) SqlAggFunction(org.apache.calcite.sql.SqlAggFunction) AggregateCall(org.apache.calcite.rel.core.AggregateCall) RelNode(org.apache.calcite.rel.RelNode) Mappings(org.apache.calcite.util.mapping.Mappings) Aggregate(org.apache.calcite.rel.core.Aggregate) LogicalAggregate(org.apache.calcite.rel.logical.LogicalAggregate) RexNode(org.apache.calcite.rex.RexNode)

Example 23 with Mapping

use of org.apache.calcite.util.mapping.Mapping in project hive by apache.

the class HiveRelMdPredicates method getPredicates.

/**
 * Infers predicates for an Aggregate.
 *
 * <p>Pulls up predicates that only contains references to columns in the
 * GroupSet. For e.g.
 *
 * <pre>
 * inputPullUpExprs : { a &gt; 7, b + c &lt; 10, a + e = 9}
 * groupSet         : { a, b}
 * pulledUpExprs    : { a &gt; 7}
 * </pre>
 */
public RelOptPredicateList getPredicates(Aggregate agg, RelMetadataQuery mq) {
    final RelNode input = agg.getInput();
    final RelOptPredicateList inputInfo = mq.getPulledUpPredicates(input);
    final List<RexNode> aggPullUpPredicates = new ArrayList<>();
    final RexBuilder rexBuilder = agg.getCluster().getRexBuilder();
    ImmutableBitSet groupKeys = agg.getGroupSet();
    Mapping m = Mappings.create(MappingType.PARTIAL_FUNCTION, input.getRowType().getFieldCount(), agg.getRowType().getFieldCount());
    int i = 0;
    for (int j : groupKeys) {
        m.set(j, i++);
    }
    for (RexNode r : inputInfo.pulledUpPredicates) {
        ImmutableBitSet rCols = RelOptUtil.InputFinder.bits(r);
        if (!rCols.isEmpty() && groupKeys.contains(rCols)) {
            r = r.accept(new RexPermuteInputsShuttle(m, input));
            aggPullUpPredicates.add(r);
        }
    }
    return RelOptPredicateList.of(rexBuilder, aggPullUpPredicates);
}
Also used : RelNode(org.apache.calcite.rel.RelNode) ImmutableBitSet(org.apache.calcite.util.ImmutableBitSet) RelOptPredicateList(org.apache.calcite.plan.RelOptPredicateList) ArrayList(java.util.ArrayList) RexBuilder(org.apache.calcite.rex.RexBuilder) Mapping(org.apache.calcite.util.mapping.Mapping) RexPermuteInputsShuttle(org.apache.calcite.rex.RexPermuteInputsShuttle) RexNode(org.apache.calcite.rex.RexNode)

Example 24 with Mapping

use of org.apache.calcite.util.mapping.Mapping in project hive by apache.

the class RelFieldTrimmer method trimFields.

/**
 * Variant of {@link #trimFields(RelNode, ImmutableBitSet, Set)} for
 * {@link org.apache.calcite.rel.logical.LogicalTableFunctionScan}.
 */
public TrimResult trimFields(LogicalTableFunctionScan tabFun, ImmutableBitSet fieldsUsed, Set<RelDataTypeField> extraFields) {
    final RelDataType rowType = tabFun.getRowType();
    final int fieldCount = rowType.getFieldCount();
    final List<RelNode> newInputs = new ArrayList<>();
    for (RelNode input : tabFun.getInputs()) {
        final int inputFieldCount = input.getRowType().getFieldCount();
        ImmutableBitSet inputFieldsUsed = ImmutableBitSet.range(inputFieldCount);
        // Create input with trimmed columns.
        final Set<RelDataTypeField> inputExtraFields = Collections.emptySet();
        TrimResult trimResult = trimChildRestore(tabFun, input, inputFieldsUsed, inputExtraFields);
        assert trimResult.right.isIdentity();
        newInputs.add(trimResult.left);
    }
    LogicalTableFunctionScan newTabFun = tabFun;
    if (!tabFun.getInputs().equals(newInputs)) {
        newTabFun = tabFun.copy(tabFun.getTraitSet(), newInputs, tabFun.getCall(), tabFun.getElementType(), tabFun.getRowType(), tabFun.getColumnMappings());
    }
    assert newTabFun.getClass() == tabFun.getClass();
    // Always project all fields.
    Mapping mapping = Mappings.createIdentity(fieldCount);
    return result(newTabFun, mapping);
}
Also used : LogicalTableFunctionScan(org.apache.calcite.rel.logical.LogicalTableFunctionScan) RelDataTypeField(org.apache.calcite.rel.type.RelDataTypeField) RelNode(org.apache.calcite.rel.RelNode) ImmutableBitSet(org.apache.calcite.util.ImmutableBitSet) ArrayList(java.util.ArrayList) RelDataType(org.apache.calcite.rel.type.RelDataType) Mapping(org.apache.calcite.util.mapping.Mapping)

Example 25 with Mapping

use of org.apache.calcite.util.mapping.Mapping in project hive by apache.

the class RelFieldTrimmer method trimFields.

/**
 * Variant of {@link #trimFields(RelNode, ImmutableBitSet, Set)} for
 * {@link org.apache.calcite.rel.logical.LogicalJoin}.
 */
public TrimResult trimFields(Join join, ImmutableBitSet fieldsUsed, Set<RelDataTypeField> extraFields) {
    final int fieldCount = join.getSystemFieldList().size() + join.getLeft().getRowType().getFieldCount() + join.getRight().getRowType().getFieldCount();
    final RexNode conditionExpr = join.getCondition();
    final int systemFieldCount = join.getSystemFieldList().size();
    // Add in fields used in the condition.
    final Set<RelDataTypeField> combinedInputExtraFields = new LinkedHashSet<>(extraFields);
    RelOptUtil.InputFinder inputFinder = new RelOptUtil.InputFinder(combinedInputExtraFields, fieldsUsed);
    conditionExpr.accept(inputFinder);
    final ImmutableBitSet fieldsUsedPlus = inputFinder.build();
    // If no system fields are used, we can remove them.
    int systemFieldUsedCount = 0;
    for (int i = 0; i < systemFieldCount; ++i) {
        if (fieldsUsed.get(i)) {
            ++systemFieldUsedCount;
        }
    }
    final int newSystemFieldCount;
    if (systemFieldUsedCount == 0) {
        newSystemFieldCount = 0;
    } else {
        newSystemFieldCount = systemFieldCount;
    }
    int offset = systemFieldCount;
    int changeCount = 0;
    int newFieldCount = newSystemFieldCount;
    final List<RelNode> newInputs = new ArrayList<>(2);
    final List<Mapping> inputMappings = new ArrayList<>();
    final List<Integer> inputExtraFieldCounts = new ArrayList<>();
    for (RelNode input : join.getInputs()) {
        final RelDataType inputRowType = input.getRowType();
        final int inputFieldCount = inputRowType.getFieldCount();
        // Compute required mapping.
        ImmutableBitSet.Builder inputFieldsUsed = ImmutableBitSet.builder();
        for (int bit : fieldsUsedPlus) {
            if (bit >= offset && bit < offset + inputFieldCount) {
                inputFieldsUsed.set(bit - offset);
            }
        }
        // If there are system fields, we automatically use the
        // corresponding field in each input.
        inputFieldsUsed.set(0, newSystemFieldCount);
        // FIXME: We ought to collect extra fields for each input
        // individually. For now, we assume that just one input has
        // on-demand fields.
        Set<RelDataTypeField> inputExtraFields = RelDataTypeImpl.extra(inputRowType) == null ? Collections.emptySet() : combinedInputExtraFields;
        inputExtraFieldCounts.add(inputExtraFields.size());
        TrimResult trimResult = trimChild(join, input, inputFieldsUsed.build(), inputExtraFields);
        newInputs.add(trimResult.left);
        if (trimResult.left != input) {
            ++changeCount;
        }
        final Mapping inputMapping = trimResult.right;
        inputMappings.add(inputMapping);
        // Move offset to point to start of next input.
        offset += inputFieldCount;
        newFieldCount += inputMapping.getTargetCount() + inputExtraFields.size();
    }
    Mapping mapping = Mappings.create(MappingType.INVERSE_SURJECTION, fieldCount, newFieldCount);
    for (int i = 0; i < newSystemFieldCount; ++i) {
        mapping.set(i, i);
    }
    offset = systemFieldCount;
    int newOffset = newSystemFieldCount;
    for (int i = 0; i < inputMappings.size(); i++) {
        Mapping inputMapping = inputMappings.get(i);
        for (IntPair pair : inputMapping) {
            mapping.set(pair.source + offset, pair.target + newOffset);
        }
        offset += inputMapping.getSourceCount();
        newOffset += inputMapping.getTargetCount() + inputExtraFieldCounts.get(i);
    }
    if (changeCount == 0 && mapping.isIdentity()) {
        return result(join, Mappings.createIdentity(fieldCount));
    }
    // Build new join.
    final RexVisitor<RexNode> shuttle = new RexPermuteInputsShuttle(mapping, newInputs.get(0), newInputs.get(1));
    RexNode newConditionExpr = conditionExpr.accept(shuttle);
    final RelBuilder relBuilder = REL_BUILDER.get();
    relBuilder.push(newInputs.get(0));
    relBuilder.push(newInputs.get(1));
    switch(join.getJoinType()) {
        case SEMI:
        case ANTI:
            // For SemiJoins and AntiJoins only map fields from the left-side
            if (join.getJoinType() == JoinRelType.SEMI) {
                relBuilder.semiJoin(newConditionExpr);
            } else {
                relBuilder.antiJoin(newConditionExpr);
            }
            Mapping inputMapping = inputMappings.get(0);
            mapping = Mappings.create(MappingType.INVERSE_SURJECTION, join.getRowType().getFieldCount(), newSystemFieldCount + inputMapping.getTargetCount());
            for (int i = 0; i < newSystemFieldCount; ++i) {
                mapping.set(i, i);
            }
            offset = systemFieldCount;
            newOffset = newSystemFieldCount;
            for (IntPair pair : inputMapping) {
                mapping.set(pair.source + offset, pair.target + newOffset);
            }
            break;
        default:
            relBuilder.join(join.getJoinType(), newConditionExpr);
    }
    return result(relBuilder.build(), mapping);
}
Also used : LinkedHashSet(java.util.LinkedHashSet) RelBuilder(org.apache.calcite.tools.RelBuilder) ImmutableBitSet(org.apache.calcite.util.ImmutableBitSet) RelOptUtil(org.apache.calcite.plan.RelOptUtil) ArrayList(java.util.ArrayList) Mapping(org.apache.calcite.util.mapping.Mapping) RelDataType(org.apache.calcite.rel.type.RelDataType) IntPair(org.apache.calcite.util.mapping.IntPair) RelDataTypeField(org.apache.calcite.rel.type.RelDataTypeField) RelNode(org.apache.calcite.rel.RelNode) RexPermuteInputsShuttle(org.apache.calcite.rex.RexPermuteInputsShuttle) RexNode(org.apache.calcite.rex.RexNode)

Aggregations

Mapping (org.apache.calcite.util.mapping.Mapping)44 RelNode (org.apache.calcite.rel.RelNode)36 ImmutableBitSet (org.apache.calcite.util.ImmutableBitSet)29 RelDataType (org.apache.calcite.rel.type.RelDataType)26 RexNode (org.apache.calcite.rex.RexNode)25 ArrayList (java.util.ArrayList)22 RelDataTypeField (org.apache.calcite.rel.type.RelDataTypeField)20 RelBuilder (org.apache.calcite.tools.RelBuilder)15 RexPermuteInputsShuttle (org.apache.calcite.rex.RexPermuteInputsShuttle)13 RexBuilder (org.apache.calcite.rex.RexBuilder)11 RelOptUtil (org.apache.calcite.plan.RelOptUtil)9 LinkedHashSet (java.util.LinkedHashSet)8 AggregateCall (org.apache.calcite.rel.core.AggregateCall)8 Aggregate (org.apache.calcite.rel.core.Aggregate)6 Join (org.apache.calcite.rel.core.Join)6 RexInputRef (org.apache.calcite.rex.RexInputRef)6 RexLiteral (org.apache.calcite.rex.RexLiteral)6 Mappings (org.apache.calcite.util.mapping.Mappings)6 ImmutableList (com.google.common.collect.ImmutableList)5 HashMap (java.util.HashMap)5