use of org.apache.calcite.util.mapping.Mapping in project hive by apache.
the class HiveCardinalityPreservingJoinOptimization method trim.
@Override
public RelNode trim(RelBuilder relBuilder, RelNode root) {
try {
if (root.getInputs().size() != 1) {
LOG.debug("Only plans where root has one input are supported. Root: {}", root);
return root;
}
REL_BUILDER.set(relBuilder);
RexBuilder rexBuilder = relBuilder.getRexBuilder();
RelNode rootInput = root.getInput(0);
// Build the list of RexInputRef from root input RowType
List<RexInputRef> rootFieldList = new ArrayList<>(rootInput.getRowType().getFieldCount());
List<String> newColumnNames = new ArrayList<>();
for (int i = 0; i < rootInput.getRowType().getFieldList().size(); ++i) {
RelDataTypeField relDataTypeField = rootInput.getRowType().getFieldList().get(i);
rootFieldList.add(rexBuilder.makeInputRef(relDataTypeField.getType(), i));
newColumnNames.add(relDataTypeField.getName());
}
// Bit set to gather the refs that backtrack to constant values
BitSet constants = new BitSet();
List<JoinedBackFields> lineages = getExpressionLineageOf(rootFieldList, rootInput, constants);
if (lineages == null) {
LOG.debug("Some projected field lineage can not be determined");
return root;
}
// 1. Collect candidate tables for join back and map RexNodes coming from those tables to their index in the
// rootInput row type
// Collect all used fields from original plan
ImmutableBitSet fieldsUsed = ImmutableBitSet.of(constants.stream().toArray());
List<TableToJoinBack> tableToJoinBackList = new ArrayList<>(lineages.size());
Map<Integer, RexNode> rexNodesToShuttle = new HashMap<>(rootInput.getRowType().getFieldCount());
for (JoinedBackFields joinedBackFields : lineages) {
Optional<ImmutableBitSet> projectedKeys = joinedBackFields.relOptHiveTable.getNonNullableKeys().stream().filter(joinedBackFields.fieldsInSourceTable::contains).findFirst();
if (projectedKeys.isPresent() && !projectedKeys.get().equals(joinedBackFields.fieldsInSourceTable)) {
TableToJoinBack tableToJoinBack = new TableToJoinBack(projectedKeys.get(), joinedBackFields);
tableToJoinBackList.add(tableToJoinBack);
fieldsUsed = fieldsUsed.union(joinedBackFields.getSource(projectedKeys.get()));
for (TableInputRefHolder mapping : joinedBackFields.mapping) {
if (!fieldsUsed.get(mapping.indexInOriginalRowType)) {
rexNodesToShuttle.put(mapping.indexInOriginalRowType, mapping.rexNode);
}
}
} else {
fieldsUsed = fieldsUsed.union(joinedBackFields.fieldsInOriginalRowType);
}
}
if (tableToJoinBackList.isEmpty()) {
LOG.debug("None of the tables has keys projected, unable to join back");
return root;
}
// 2. Trim out non-key fields of joined back tables
Set<RelDataTypeField> extraFields = Collections.emptySet();
TrimResult trimResult = dispatchTrimFields(rootInput, fieldsUsed, extraFields);
RelNode newInput = trimResult.left;
if (newInput.getRowType().equals(rootInput.getRowType())) {
LOG.debug("Nothing was trimmed out.");
return root;
}
// 3. Join back tables to the top of original plan
Mapping newInputMapping = trimResult.right;
Map<RexTableInputRef, Integer> tableInputRefMapping = new HashMap<>();
for (TableToJoinBack tableToJoinBack : tableToJoinBackList) {
LOG.debug("Joining back table {}", tableToJoinBack.joinedBackFields.relOptHiveTable.getName());
// 3.1. Create new TableScan of tables to join back
RelOptHiveTable relOptTable = tableToJoinBack.joinedBackFields.relOptHiveTable;
RelOptCluster cluster = relBuilder.getCluster();
HiveTableScan tableScan = new HiveTableScan(cluster, cluster.traitSetOf(HiveRelNode.CONVENTION), relOptTable, relOptTable.getHiveTableMD().getTableName(), null, false, false);
// 3.2. Create Project with the required fields from this table
RelNode projectTableAccessRel = tableScan.project(tableToJoinBack.joinedBackFields.fieldsInSourceTable, new HashSet<>(0), REL_BUILDER.get());
// 3.3. Create mapping between the Project and TableScan
Mapping projectMapping = Mappings.create(MappingType.INVERSE_SURJECTION, tableScan.getRowType().getFieldCount(), tableToJoinBack.joinedBackFields.fieldsInSourceTable.cardinality());
int projectIndex = 0;
for (int i : tableToJoinBack.joinedBackFields.fieldsInSourceTable) {
projectMapping.set(i, projectIndex);
++projectIndex;
}
int offset = newInput.getRowType().getFieldCount();
// 3.4. Map rexTableInputRef to the index where it can be found in the new Input row type
for (TableInputRefHolder mapping : tableToJoinBack.joinedBackFields.mapping) {
int indexInSourceTable = mapping.tableInputRef.getIndex();
if (!tableToJoinBack.keys.get(indexInSourceTable)) {
// 3.5. if this is not a key field it is shifted by the left input field count
tableInputRefMapping.put(mapping.tableInputRef, offset + projectMapping.getTarget(indexInSourceTable));
}
}
// 3.7. Create Join
relBuilder.push(newInput);
relBuilder.push(projectTableAccessRel);
RexNode joinCondition = joinCondition(newInput, newInputMapping, tableToJoinBack, projectTableAccessRel, projectMapping, rexBuilder);
newInput = relBuilder.join(JoinRelType.INNER, joinCondition).build();
}
// 4. Collect rexNodes for Project
TableInputRefMapper mapper = new TableInputRefMapper(tableInputRefMapping, rexBuilder, newInput);
List<RexNode> rexNodeList = new ArrayList<>(rootInput.getRowType().getFieldCount());
for (int i = 0; i < rootInput.getRowType().getFieldCount(); i++) {
RexNode rexNode = rexNodesToShuttle.get(i);
if (rexNode != null) {
rexNodeList.add(mapper.apply(rexNode));
} else {
int target = newInputMapping.getTarget(i);
rexNodeList.add(rexBuilder.makeInputRef(newInput.getRowType().getFieldList().get(target).getType(), target));
}
}
// 5. Create Project on top of all Join backs
relBuilder.push(newInput);
relBuilder.project(rexNodeList, newColumnNames);
return root.copy(root.getTraitSet(), singletonList(relBuilder.build()));
} finally {
REL_BUILDER.remove();
}
}
use of org.apache.calcite.util.mapping.Mapping in project flink by apache.
the class FlinkAggregateJoinTransposeRule method onMatch.
public void onMatch(RelOptRuleCall call) {
final Aggregate aggregate = call.rel(0);
final Join join = call.rel(1);
final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
final RelBuilder relBuilder = call.builder();
// If any aggregate call has a filter, bail out
for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
if (aggregateCall.getAggregation().unwrap(SqlSplittableAggFunction.class) == null) {
return;
}
if (aggregateCall.filterArg >= 0) {
return;
}
}
// aggregate operator
if (join.getJoinType() != JoinRelType.INNER) {
return;
}
if (!allowFunctions && !aggregate.getAggCallList().isEmpty()) {
return;
}
// Do the columns used by the join appear in the output of the aggregate?
final ImmutableBitSet aggregateColumns = aggregate.getGroupSet();
final RelMetadataQuery mq = RelMetadataQuery.instance();
final ImmutableBitSet keyColumns = keyColumns(aggregateColumns, mq.getPulledUpPredicates(join).pulledUpPredicates);
final ImmutableBitSet joinColumns = RelOptUtil.InputFinder.bits(join.getCondition());
final boolean allColumnsInAggregate = keyColumns.contains(joinColumns);
final ImmutableBitSet belowAggregateColumns = aggregateColumns.union(joinColumns);
// Split join condition
final List<Integer> leftKeys = Lists.newArrayList();
final List<Integer> rightKeys = Lists.newArrayList();
final List<Boolean> filterNulls = Lists.newArrayList();
RexNode nonEquiConj = RelOptUtil.splitJoinCondition(join.getLeft(), join.getRight(), join.getCondition(), leftKeys, rightKeys, filterNulls);
// If it contains non-equi join conditions, we bail out
if (!nonEquiConj.isAlwaysTrue()) {
return;
}
// Push each aggregate function down to each side that contains all of its
// arguments. Note that COUNT(*), because it has no arguments, can go to
// both sides.
final Map<Integer, Integer> map = new HashMap<>();
final List<Side> sides = new ArrayList<>();
int uniqueCount = 0;
int offset = 0;
int belowOffset = 0;
for (int s = 0; s < 2; s++) {
final Side side = new Side();
final RelNode joinInput = join.getInput(s);
int fieldCount = joinInput.getRowType().getFieldCount();
final ImmutableBitSet fieldSet = ImmutableBitSet.range(offset, offset + fieldCount);
final ImmutableBitSet belowAggregateKeyNotShifted = belowAggregateColumns.intersect(fieldSet);
for (Ord<Integer> c : Ord.zip(belowAggregateKeyNotShifted)) {
map.put(c.e, belowOffset + c.i);
}
final ImmutableBitSet belowAggregateKey = belowAggregateKeyNotShifted.shift(-offset);
final boolean unique;
if (!allowFunctions) {
assert aggregate.getAggCallList().isEmpty();
// If there are no functions, it doesn't matter as much whether we
// aggregate the inputs before the join, because there will not be
// any functions experiencing a cartesian product effect.
//
// But finding out whether the input is already unique requires a call
// to areColumnsUnique that currently (until [CALCITE-1048] "Make
// metadata more robust" is fixed) places a heavy load on
// the metadata system.
//
// So we choose to imagine the the input is already unique, which is
// untrue but harmless.
//
Util.discard(Bug.CALCITE_1048_FIXED);
unique = true;
} else {
final Boolean unique0 = mq.areColumnsUnique(joinInput, belowAggregateKey);
unique = unique0 != null && unique0;
}
if (unique) {
++uniqueCount;
side.aggregate = false;
side.newInput = joinInput;
} else {
side.aggregate = true;
List<AggregateCall> belowAggCalls = new ArrayList<>();
final SqlSplittableAggFunction.Registry<AggregateCall> belowAggCallRegistry = registry(belowAggCalls);
final Mappings.TargetMapping mapping = s == 0 ? Mappings.createIdentity(fieldCount) : Mappings.createShiftMapping(fieldCount + offset, 0, offset, fieldCount);
for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) {
final SqlAggFunction aggregation = aggCall.e.getAggregation();
final SqlSplittableAggFunction splitter = Preconditions.checkNotNull(aggregation.unwrap(SqlSplittableAggFunction.class));
final AggregateCall call1;
if (fieldSet.contains(ImmutableBitSet.of(aggCall.e.getArgList()))) {
call1 = splitter.split(aggCall.e, mapping);
} else {
call1 = splitter.other(rexBuilder.getTypeFactory(), aggCall.e);
}
if (call1 != null) {
side.split.put(aggCall.i, belowAggregateKey.cardinality() + belowAggCallRegistry.register(call1));
}
}
side.newInput = relBuilder.push(joinInput).aggregate(relBuilder.groupKey(belowAggregateKey, false, null), belowAggCalls).build();
}
offset += fieldCount;
belowOffset += side.newInput.getRowType().getFieldCount();
sides.add(side);
}
if (uniqueCount == 2) {
// invocation of this rule; if we continue we might loop forever.
return;
}
// Update condition
final Mapping mapping = (Mapping) Mappings.target(new Function<Integer, Integer>() {
public Integer apply(Integer a0) {
return map.get(a0);
}
}, join.getRowType().getFieldCount(), belowOffset);
final RexNode newCondition = RexUtil.apply(mapping, join.getCondition());
// Create new join
relBuilder.push(sides.get(0).newInput).push(sides.get(1).newInput).join(join.getJoinType(), newCondition);
// Aggregate above to sum up the sub-totals
final List<AggregateCall> newAggCalls = new ArrayList<>();
final int groupIndicatorCount = aggregate.getGroupCount() + aggregate.getIndicatorCount();
final int newLeftWidth = sides.get(0).newInput.getRowType().getFieldCount();
final List<RexNode> projects = new ArrayList<>(rexBuilder.identityProjects(relBuilder.peek().getRowType()));
for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) {
final SqlAggFunction aggregation = aggCall.e.getAggregation();
final SqlSplittableAggFunction splitter = Preconditions.checkNotNull(aggregation.unwrap(SqlSplittableAggFunction.class));
final Integer leftSubTotal = sides.get(0).split.get(aggCall.i);
final Integer rightSubTotal = sides.get(1).split.get(aggCall.i);
newAggCalls.add(splitter.topSplit(rexBuilder, registry(projects), groupIndicatorCount, relBuilder.peek().getRowType(), aggCall.e, leftSubTotal == null ? -1 : leftSubTotal, rightSubTotal == null ? -1 : rightSubTotal + newLeftWidth));
}
relBuilder.project(projects);
boolean aggConvertedToProjects = false;
if (allColumnsInAggregate) {
// let's see if we can convert aggregate into projects
List<RexNode> projects2 = new ArrayList<>();
for (int key : Mappings.apply(mapping, aggregate.getGroupSet())) {
projects2.add(relBuilder.field(key));
}
for (AggregateCall newAggCall : newAggCalls) {
final SqlSplittableAggFunction splitter = newAggCall.getAggregation().unwrap(SqlSplittableAggFunction.class);
if (splitter != null) {
projects2.add(splitter.singleton(rexBuilder, relBuilder.peek().getRowType(), newAggCall));
}
}
if (projects2.size() == aggregate.getGroupSet().cardinality() + newAggCalls.size()) {
// We successfully converted agg calls into projects.
relBuilder.project(projects2);
aggConvertedToProjects = true;
}
}
if (!aggConvertedToProjects) {
relBuilder.aggregate(relBuilder.groupKey(Mappings.apply(mapping, aggregate.getGroupSet()), aggregate.indicator, Mappings.apply2(mapping, aggregate.getGroupSets())), newAggCalls);
}
call.transformTo(relBuilder.build());
}
use of org.apache.calcite.util.mapping.Mapping in project hive by apache.
the class HiveRelMdPredicates method getPredicates.
/**
* Infers predicates for an Aggregate.
*
* <p>Pulls up predicates that only contains references to columns in the
* GroupSet. For e.g.
*
* <pre>
* inputPullUpExprs : { a > 7, b + c < 10, a + e = 9}
* groupSet : { a, b}
* pulledUpExprs : { a > 7}
* </pre>
*/
public RelOptPredicateList getPredicates(Aggregate agg, RelMetadataQuery mq) {
final RelNode input = agg.getInput();
final RelOptPredicateList inputInfo = mq.getPulledUpPredicates(input);
final List<RexNode> aggPullUpPredicates = new ArrayList<>();
final RexBuilder rexBuilder = agg.getCluster().getRexBuilder();
ImmutableBitSet groupKeys = agg.getGroupSet();
Mapping m = Mappings.create(MappingType.PARTIAL_FUNCTION, input.getRowType().getFieldCount(), agg.getRowType().getFieldCount());
int i = 0;
for (int j : groupKeys) {
m.set(j, i++);
}
for (RexNode r : inputInfo.pulledUpPredicates) {
ImmutableBitSet rCols = RelOptUtil.InputFinder.bits(r);
if (!rCols.isEmpty() && groupKeys.contains(rCols)) {
r = r.accept(new RexPermuteInputsShuttle(m, input));
aggPullUpPredicates.add(r);
}
}
return RelOptPredicateList.of(rexBuilder, aggPullUpPredicates);
}
use of org.apache.calcite.util.mapping.Mapping in project hive by apache.
the class RelFieldTrimmer method trimFields.
/**
* Variant of {@link #trimFields(RelNode, ImmutableBitSet, Set)} for
* {@link org.apache.calcite.rel.logical.LogicalTableFunctionScan}.
*/
public TrimResult trimFields(LogicalTableFunctionScan tabFun, ImmutableBitSet fieldsUsed, Set<RelDataTypeField> extraFields) {
final RelDataType rowType = tabFun.getRowType();
final int fieldCount = rowType.getFieldCount();
final List<RelNode> newInputs = new ArrayList<>();
for (RelNode input : tabFun.getInputs()) {
final int inputFieldCount = input.getRowType().getFieldCount();
ImmutableBitSet inputFieldsUsed = ImmutableBitSet.range(inputFieldCount);
// Create input with trimmed columns.
final Set<RelDataTypeField> inputExtraFields = Collections.emptySet();
TrimResult trimResult = trimChildRestore(tabFun, input, inputFieldsUsed, inputExtraFields);
assert trimResult.right.isIdentity();
newInputs.add(trimResult.left);
}
LogicalTableFunctionScan newTabFun = tabFun;
if (!tabFun.getInputs().equals(newInputs)) {
newTabFun = tabFun.copy(tabFun.getTraitSet(), newInputs, tabFun.getCall(), tabFun.getElementType(), tabFun.getRowType(), tabFun.getColumnMappings());
}
assert newTabFun.getClass() == tabFun.getClass();
// Always project all fields.
Mapping mapping = Mappings.createIdentity(fieldCount);
return result(newTabFun, mapping);
}
use of org.apache.calcite.util.mapping.Mapping in project hive by apache.
the class RelFieldTrimmer method trimFields.
/**
* Variant of {@link #trimFields(RelNode, ImmutableBitSet, Set)} for
* {@link org.apache.calcite.rel.logical.LogicalJoin}.
*/
public TrimResult trimFields(Join join, ImmutableBitSet fieldsUsed, Set<RelDataTypeField> extraFields) {
final int fieldCount = join.getSystemFieldList().size() + join.getLeft().getRowType().getFieldCount() + join.getRight().getRowType().getFieldCount();
final RexNode conditionExpr = join.getCondition();
final int systemFieldCount = join.getSystemFieldList().size();
// Add in fields used in the condition.
final Set<RelDataTypeField> combinedInputExtraFields = new LinkedHashSet<>(extraFields);
RelOptUtil.InputFinder inputFinder = new RelOptUtil.InputFinder(combinedInputExtraFields, fieldsUsed);
conditionExpr.accept(inputFinder);
final ImmutableBitSet fieldsUsedPlus = inputFinder.build();
// If no system fields are used, we can remove them.
int systemFieldUsedCount = 0;
for (int i = 0; i < systemFieldCount; ++i) {
if (fieldsUsed.get(i)) {
++systemFieldUsedCount;
}
}
final int newSystemFieldCount;
if (systemFieldUsedCount == 0) {
newSystemFieldCount = 0;
} else {
newSystemFieldCount = systemFieldCount;
}
int offset = systemFieldCount;
int changeCount = 0;
int newFieldCount = newSystemFieldCount;
final List<RelNode> newInputs = new ArrayList<>(2);
final List<Mapping> inputMappings = new ArrayList<>();
final List<Integer> inputExtraFieldCounts = new ArrayList<>();
for (RelNode input : join.getInputs()) {
final RelDataType inputRowType = input.getRowType();
final int inputFieldCount = inputRowType.getFieldCount();
// Compute required mapping.
ImmutableBitSet.Builder inputFieldsUsed = ImmutableBitSet.builder();
for (int bit : fieldsUsedPlus) {
if (bit >= offset && bit < offset + inputFieldCount) {
inputFieldsUsed.set(bit - offset);
}
}
// If there are system fields, we automatically use the
// corresponding field in each input.
inputFieldsUsed.set(0, newSystemFieldCount);
// FIXME: We ought to collect extra fields for each input
// individually. For now, we assume that just one input has
// on-demand fields.
Set<RelDataTypeField> inputExtraFields = RelDataTypeImpl.extra(inputRowType) == null ? Collections.emptySet() : combinedInputExtraFields;
inputExtraFieldCounts.add(inputExtraFields.size());
TrimResult trimResult = trimChild(join, input, inputFieldsUsed.build(), inputExtraFields);
newInputs.add(trimResult.left);
if (trimResult.left != input) {
++changeCount;
}
final Mapping inputMapping = trimResult.right;
inputMappings.add(inputMapping);
// Move offset to point to start of next input.
offset += inputFieldCount;
newFieldCount += inputMapping.getTargetCount() + inputExtraFields.size();
}
Mapping mapping = Mappings.create(MappingType.INVERSE_SURJECTION, fieldCount, newFieldCount);
for (int i = 0; i < newSystemFieldCount; ++i) {
mapping.set(i, i);
}
offset = systemFieldCount;
int newOffset = newSystemFieldCount;
for (int i = 0; i < inputMappings.size(); i++) {
Mapping inputMapping = inputMappings.get(i);
for (IntPair pair : inputMapping) {
mapping.set(pair.source + offset, pair.target + newOffset);
}
offset += inputMapping.getSourceCount();
newOffset += inputMapping.getTargetCount() + inputExtraFieldCounts.get(i);
}
if (changeCount == 0 && mapping.isIdentity()) {
return result(join, Mappings.createIdentity(fieldCount));
}
// Build new join.
final RexVisitor<RexNode> shuttle = new RexPermuteInputsShuttle(mapping, newInputs.get(0), newInputs.get(1));
RexNode newConditionExpr = conditionExpr.accept(shuttle);
final RelBuilder relBuilder = REL_BUILDER.get();
relBuilder.push(newInputs.get(0));
relBuilder.push(newInputs.get(1));
switch(join.getJoinType()) {
case SEMI:
case ANTI:
// For SemiJoins and AntiJoins only map fields from the left-side
if (join.getJoinType() == JoinRelType.SEMI) {
relBuilder.semiJoin(newConditionExpr);
} else {
relBuilder.antiJoin(newConditionExpr);
}
Mapping inputMapping = inputMappings.get(0);
mapping = Mappings.create(MappingType.INVERSE_SURJECTION, join.getRowType().getFieldCount(), newSystemFieldCount + inputMapping.getTargetCount());
for (int i = 0; i < newSystemFieldCount; ++i) {
mapping.set(i, i);
}
offset = systemFieldCount;
newOffset = newSystemFieldCount;
for (IntPair pair : inputMapping) {
mapping.set(pair.source + offset, pair.target + newOffset);
}
break;
default:
relBuilder.join(join.getJoinType(), newConditionExpr);
}
return result(relBuilder.build(), mapping);
}
Aggregations