use of org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveMultiJoin in project hive by apache.
the class HiveRelFieldTrimmer method trimFields.
/**
* Variant of {@link #trimFields(RelNode, ImmutableBitSet, Set)} for
* {@link org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveMultiJoin}.
*/
public TrimResult trimFields(HiveMultiJoin join, ImmutableBitSet fieldsUsed, Set<RelDataTypeField> extraFields) {
final int fieldCount = join.getRowType().getFieldCount();
final RexNode conditionExpr = join.getCondition();
final List<RexNode> joinFilters = join.getJoinFilters();
// Add in fields used in the condition.
final Set<RelDataTypeField> combinedInputExtraFields = new LinkedHashSet<RelDataTypeField>(extraFields);
RelOptUtil.InputFinder inputFinder = new RelOptUtil.InputFinder(combinedInputExtraFields);
inputFinder.inputBitSet.addAll(fieldsUsed);
conditionExpr.accept(inputFinder);
final ImmutableBitSet fieldsUsedPlus = inputFinder.inputBitSet.build();
int inputStartPos = 0;
int changeCount = 0;
int newFieldCount = 0;
List<RelNode> newInputs = new ArrayList<RelNode>();
List<Mapping> inputMappings = new ArrayList<Mapping>();
for (RelNode input : join.getInputs()) {
final RelDataType inputRowType = input.getRowType();
final int inputFieldCount = inputRowType.getFieldCount();
// Compute required mapping.
ImmutableBitSet.Builder inputFieldsUsed = ImmutableBitSet.builder();
for (int bit : fieldsUsedPlus) {
if (bit >= inputStartPos && bit < inputStartPos + inputFieldCount) {
inputFieldsUsed.set(bit - inputStartPos);
}
}
Set<RelDataTypeField> inputExtraFields = Collections.<RelDataTypeField>emptySet();
TrimResult trimResult = trimChild(join, input, inputFieldsUsed.build(), inputExtraFields);
newInputs.add(trimResult.left);
if (trimResult.left != input) {
++changeCount;
}
final Mapping inputMapping = trimResult.right;
inputMappings.add(inputMapping);
// Move offset to point to start of next input.
inputStartPos += inputFieldCount;
newFieldCount += inputMapping.getTargetCount();
}
Mapping mapping = Mappings.create(MappingType.INVERSE_SURJECTION, fieldCount, newFieldCount);
int offset = 0;
int newOffset = 0;
for (int i = 0; i < inputMappings.size(); i++) {
Mapping inputMapping = inputMappings.get(i);
for (IntPair pair : inputMapping) {
mapping.set(pair.source + offset, pair.target + newOffset);
}
offset += inputMapping.getSourceCount();
newOffset += inputMapping.getTargetCount();
}
if (changeCount == 0 && mapping.isIdentity()) {
return new TrimResult(join, Mappings.createIdentity(fieldCount));
}
// Build new join.
final RexVisitor<RexNode> shuttle = new RexPermuteInputsShuttle(mapping, newInputs.toArray(new RelNode[newInputs.size()]));
RexNode newConditionExpr = conditionExpr.accept(shuttle);
List<RexNode> newJoinFilters = Lists.newArrayList();
for (RexNode joinFilter : joinFilters) {
newJoinFilters.add(joinFilter.accept(shuttle));
}
final RelDataType newRowType = RelOptUtil.permute(join.getCluster().getTypeFactory(), join.getRowType(), mapping);
final RelNode newJoin = new HiveMultiJoin(join.getCluster(), newInputs, newConditionExpr, newRowType, join.getJoinInputs(), join.getJoinTypes(), newJoinFilters);
return new TrimResult(newJoin, mapping);
}
use of org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveMultiJoin in project hive by apache.
the class HiveOpConverter method translateJoin.
private OpAttr translateJoin(RelNode joinRel) throws SemanticException {
// 0. Additional data structures needed for the join optimization
// through Hive
String[] baseSrc = new String[joinRel.getInputs().size()];
String tabAlias = getHiveDerivedTableAlias();
// 1. Convert inputs
OpAttr[] inputs = new OpAttr[joinRel.getInputs().size()];
List<Operator<?>> children = new ArrayList<Operator<?>>(joinRel.getInputs().size());
for (int i = 0; i < inputs.length; i++) {
inputs[i] = dispatch(joinRel.getInput(i));
children.add(inputs[i].inputs.get(0));
baseSrc[i] = inputs[i].tabAlias;
}
// 2. Generate tags
for (int tag = 0; tag < children.size(); tag++) {
ReduceSinkOperator reduceSinkOp = (ReduceSinkOperator) children.get(tag);
reduceSinkOp.getConf().setTag(tag);
}
// 3. Virtual columns
Set<Integer> newVcolsInCalcite = new HashSet<Integer>();
newVcolsInCalcite.addAll(inputs[0].vcolsInCalcite);
if (joinRel instanceof HiveMultiJoin || !(joinRel instanceof SemiJoin)) {
int shift = inputs[0].inputs.get(0).getSchema().getSignature().size();
for (int i = 1; i < inputs.length; i++) {
newVcolsInCalcite.addAll(HiveCalciteUtil.shiftVColsSet(inputs[i].vcolsInCalcite, shift));
shift += inputs[i].inputs.get(0).getSchema().getSignature().size();
}
}
if (LOG.isDebugEnabled()) {
LOG.debug("Translating operator rel#" + joinRel.getId() + ":" + joinRel.getRelTypeName() + " with row type: [" + joinRel.getRowType() + "]");
}
// 4. Extract join key expressions from HiveSortExchange
ExprNodeDesc[][] joinExpressions = new ExprNodeDesc[inputs.length][];
for (int i = 0; i < inputs.length; i++) {
joinExpressions[i] = ((HiveSortExchange) joinRel.getInput(i)).getJoinExpressions();
}
// 5. Extract rest of join predicate info. We infer the rest of join condition
// that will be added to the filters (join conditions that are not part of
// the join key)
List<RexNode> joinFilters;
if (joinRel instanceof HiveJoin) {
joinFilters = ImmutableList.of(((HiveJoin) joinRel).getJoinFilter());
} else if (joinRel instanceof HiveMultiJoin) {
joinFilters = ((HiveMultiJoin) joinRel).getJoinFilters();
} else if (joinRel instanceof HiveSemiJoin) {
joinFilters = ImmutableList.of(((HiveSemiJoin) joinRel).getJoinFilter());
} else {
throw new SemanticException("Can't handle join type: " + joinRel.getClass().getName());
}
List<List<ExprNodeDesc>> filterExpressions = Lists.newArrayList();
for (int i = 0; i < joinFilters.size(); i++) {
List<ExprNodeDesc> filterExpressionsForInput = new ArrayList<ExprNodeDesc>();
if (joinFilters.get(i) != null) {
for (RexNode conj : RelOptUtil.conjunctions(joinFilters.get(i))) {
ExprNodeDesc expr = convertToExprNode(conj, joinRel, null, newVcolsInCalcite);
filterExpressionsForInput.add(expr);
}
}
filterExpressions.add(filterExpressionsForInput);
}
// 6. Generate Join operator
JoinOperator joinOp = genJoin(joinRel, joinExpressions, filterExpressions, children, baseSrc, tabAlias);
// 7. Return result
return new OpAttr(tabAlias, newVcolsInCalcite, joinOp);
}
use of org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveMultiJoin in project hive by apache.
the class HiveInsertExchange4JoinRule method onMatch.
@Override
public void onMatch(RelOptRuleCall call) {
JoinPredicateInfo joinPredInfo;
if (call.rel(0) instanceof HiveMultiJoin) {
HiveMultiJoin multiJoin = call.rel(0);
joinPredInfo = multiJoin.getJoinPredicateInfo();
} else if (call.rel(0) instanceof HiveJoin) {
HiveJoin hiveJoin = call.rel(0);
joinPredInfo = hiveJoin.getJoinPredicateInfo();
} else if (call.rel(0) instanceof Join) {
Join join = call.rel(0);
try {
joinPredInfo = HiveCalciteUtil.JoinPredicateInfo.constructJoinPredicateInfo(join);
} catch (CalciteSemanticException e) {
throw new RuntimeException(e);
}
} else {
return;
}
for (RelNode child : call.rel(0).getInputs()) {
if (((HepRelVertex) child).getCurrentRel() instanceof Exchange) {
return;
}
}
// Get key columns from inputs. Those are the columns on which we will distribute on.
// It is also the columns we will sort on.
List<RelNode> newInputs = new ArrayList<RelNode>();
for (int i = 0; i < call.rel(0).getInputs().size(); i++) {
List<Integer> joinKeyPositions = new ArrayList<Integer>();
ImmutableList.Builder<RexNode> joinExprsBuilder = new ImmutableList.Builder<RexNode>();
Set<String> keySet = Sets.newHashSet();
ImmutableList.Builder<RelFieldCollation> collationListBuilder = new ImmutableList.Builder<RelFieldCollation>();
for (int j = 0; j < joinPredInfo.getEquiJoinPredicateElements().size(); j++) {
JoinLeafPredicateInfo joinLeafPredInfo = joinPredInfo.getEquiJoinPredicateElements().get(j);
for (RexNode joinExprNode : joinLeafPredInfo.getJoinExprs(i)) {
if (keySet.add(joinExprNode.toString())) {
joinExprsBuilder.add(joinExprNode);
}
}
for (int pos : joinLeafPredInfo.getProjsJoinKeysInChildSchema(i)) {
if (!joinKeyPositions.contains(pos)) {
joinKeyPositions.add(pos);
collationListBuilder.add(new RelFieldCollation(pos));
}
}
}
HiveSortExchange exchange = HiveSortExchange.create(call.rel(0).getInput(i), new HiveRelDistribution(RelDistribution.Type.HASH_DISTRIBUTED, joinKeyPositions), new HiveRelCollation(collationListBuilder.build()), joinExprsBuilder.build());
newInputs.add(exchange);
}
RelNode newOp;
if (call.rel(0) instanceof HiveMultiJoin) {
HiveMultiJoin multiJoin = call.rel(0);
newOp = multiJoin.copy(multiJoin.getTraitSet(), newInputs);
} else if (call.rel(0) instanceof Join) {
Join join = call.rel(0);
newOp = join.copy(join.getTraitSet(), join.getCondition(), newInputs.get(0), newInputs.get(1), join.getJoinType(), join.isSemiJoinDone());
} else {
return;
}
call.getPlanner().onCopy(call.rel(0), newOp);
call.transformTo(newOp);
}
use of org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveMultiJoin in project hive by apache.
the class HiveJoinToMultiJoinRule method mergeJoin.
// This method tries to merge the join with its left child. The left
// child should be a join for this to happen.
private static RelNode mergeJoin(HiveJoin join, RelNode left, RelNode right) {
final RexBuilder rexBuilder = join.getCluster().getRexBuilder();
// We check whether the join can be combined with any of its children
final List<RelNode> newInputs = Lists.newArrayList();
final List<RexNode> newJoinCondition = Lists.newArrayList();
final List<Pair<Integer, Integer>> joinInputs = Lists.newArrayList();
final List<JoinRelType> joinTypes = Lists.newArrayList();
final List<RexNode> joinFilters = Lists.newArrayList();
// Left child
if (left instanceof HiveJoin || left instanceof HiveMultiJoin) {
final RexNode leftCondition;
final List<Pair<Integer, Integer>> leftJoinInputs;
final List<JoinRelType> leftJoinTypes;
final List<RexNode> leftJoinFilters;
boolean combinable;
if (left instanceof HiveJoin) {
HiveJoin hj = (HiveJoin) left;
leftCondition = hj.getCondition();
leftJoinInputs = ImmutableList.of(Pair.of(0, 1));
leftJoinTypes = ImmutableList.of(hj.getJoinType());
leftJoinFilters = ImmutableList.of(hj.getJoinFilter());
try {
combinable = isCombinableJoin(join, hj);
} catch (CalciteSemanticException e) {
LOG.trace("Failed to merge join-join", e);
combinable = false;
}
} else {
HiveMultiJoin hmj = (HiveMultiJoin) left;
leftCondition = hmj.getCondition();
leftJoinInputs = hmj.getJoinInputs();
leftJoinTypes = hmj.getJoinTypes();
leftJoinFilters = hmj.getJoinFilters();
try {
combinable = isCombinableJoin(join, hmj);
} catch (CalciteSemanticException e) {
LOG.trace("Failed to merge join-multijoin", e);
combinable = false;
}
}
if (combinable) {
newJoinCondition.add(leftCondition);
for (int i = 0; i < leftJoinInputs.size(); i++) {
joinInputs.add(leftJoinInputs.get(i));
joinTypes.add(leftJoinTypes.get(i));
joinFilters.add(leftJoinFilters.get(i));
}
newInputs.addAll(left.getInputs());
} else {
// The join operation in the child is not on the same keys
return null;
}
} else {
// The left child is not a join or multijoin operator
return null;
}
final int numberLeftInputs = newInputs.size();
// Right child
newInputs.add(right);
// If we cannot combine any of the children, we bail out
newJoinCondition.add(join.getCondition());
if (newJoinCondition.size() == 1) {
return null;
}
final List<RelDataTypeField> systemFieldList = ImmutableList.of();
List<List<RexNode>> joinKeyExprs = new ArrayList<List<RexNode>>();
List<Integer> filterNulls = new ArrayList<Integer>();
for (int i = 0; i < newInputs.size(); i++) {
joinKeyExprs.add(new ArrayList<RexNode>());
}
RexNode filters;
try {
filters = HiveRelOptUtil.splitHiveJoinCondition(systemFieldList, newInputs, join.getCondition(), joinKeyExprs, filterNulls, null);
} catch (CalciteSemanticException e) {
LOG.trace("Failed to merge joins", e);
return null;
}
ImmutableBitSet.Builder keysInInputsBuilder = ImmutableBitSet.builder();
for (int i = 0; i < newInputs.size(); i++) {
List<RexNode> partialCondition = joinKeyExprs.get(i);
if (!partialCondition.isEmpty()) {
keysInInputsBuilder.set(i);
}
}
// If we cannot merge, we bail out
ImmutableBitSet keysInInputs = keysInInputsBuilder.build();
ImmutableBitSet leftReferencedInputs = keysInInputs.intersect(ImmutableBitSet.range(numberLeftInputs));
ImmutableBitSet rightReferencedInputs = keysInInputs.intersect(ImmutableBitSet.range(numberLeftInputs, newInputs.size()));
if (join.getJoinType() != JoinRelType.INNER && (leftReferencedInputs.cardinality() > 1 || rightReferencedInputs.cardinality() > 1)) {
return null;
}
// Otherwise, we add to the join specs
if (join.getJoinType() != JoinRelType.INNER) {
int leftInput = keysInInputs.nextSetBit(0);
int rightInput = keysInInputs.nextSetBit(numberLeftInputs);
joinInputs.add(Pair.of(leftInput, rightInput));
joinTypes.add(join.getJoinType());
joinFilters.add(filters);
} else {
for (int i : leftReferencedInputs) {
for (int j : rightReferencedInputs) {
joinInputs.add(Pair.of(i, j));
joinTypes.add(join.getJoinType());
joinFilters.add(filters);
}
}
}
// We can now create a multijoin operator
RexNode newCondition = RexUtil.flatten(rexBuilder, RexUtil.composeConjunction(rexBuilder, newJoinCondition, false));
List<RelNode> newInputsArray = Lists.newArrayList(newInputs);
JoinPredicateInfo joinPredInfo = null;
try {
joinPredInfo = HiveCalciteUtil.JoinPredicateInfo.constructJoinPredicateInfo(newInputsArray, systemFieldList, newCondition);
} catch (CalciteSemanticException e) {
throw new RuntimeException(e);
}
// If the number of joins < number of input tables-1, this is not a star join.
if (joinPredInfo.getEquiJoinPredicateElements().size() < newInputs.size() - 1) {
return null;
}
// Validate that the multi-join is a valid star join before returning it.
for (int i = 0; i < newInputs.size(); i++) {
List<RexNode> joinKeys = null;
for (int j = 0; j < joinPredInfo.getEquiJoinPredicateElements().size(); j++) {
List<RexNode> currJoinKeys = joinPredInfo.getEquiJoinPredicateElements().get(j).getJoinExprs(i);
if (currJoinKeys.isEmpty()) {
continue;
}
if (joinKeys == null) {
joinKeys = currJoinKeys;
} else {
// Bail out if this is the case.
if (!joinKeys.containsAll(currJoinKeys) || !currJoinKeys.containsAll(joinKeys)) {
return null;
}
}
}
}
return new HiveMultiJoin(join.getCluster(), newInputsArray, newCondition, join.getRowType(), joinInputs, joinTypes, joinFilters, joinPredInfo);
}
use of org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveMultiJoin in project hive by apache.
the class HiveOpConverter method genJoin.
private static JoinOperator genJoin(RelNode join, ExprNodeDesc[][] joinExpressions, List<List<ExprNodeDesc>> filterExpressions, List<Operator<?>> children, String[] baseSrc, String tabAlias) throws SemanticException {
// 1. Extract join type
JoinCondDesc[] joinCondns;
boolean semiJoin;
boolean noOuterJoin;
if (join instanceof HiveMultiJoin) {
HiveMultiJoin hmj = (HiveMultiJoin) join;
joinCondns = new JoinCondDesc[hmj.getJoinInputs().size()];
for (int i = 0; i < hmj.getJoinInputs().size(); i++) {
joinCondns[i] = new JoinCondDesc(new JoinCond(hmj.getJoinInputs().get(i).left, hmj.getJoinInputs().get(i).right, transformJoinType(hmj.getJoinTypes().get(i))));
}
semiJoin = false;
noOuterJoin = !hmj.isOuterJoin();
} else {
joinCondns = new JoinCondDesc[1];
semiJoin = join instanceof SemiJoin;
JoinType joinType;
if (semiJoin) {
joinType = JoinType.LEFTSEMI;
} else {
joinType = extractJoinType((Join) join);
}
joinCondns[0] = new JoinCondDesc(new JoinCond(0, 1, joinType));
noOuterJoin = joinType != JoinType.FULLOUTER && joinType != JoinType.LEFTOUTER && joinType != JoinType.RIGHTOUTER;
}
// 2. We create the join aux structures
ArrayList<ColumnInfo> outputColumns = new ArrayList<ColumnInfo>();
ArrayList<String> outputColumnNames = new ArrayList<String>(join.getRowType().getFieldNames());
Operator<?>[] childOps = new Operator[children.size()];
Map<String, Byte> reversedExprs = new HashMap<String, Byte>();
Map<Byte, List<ExprNodeDesc>> exprMap = new HashMap<Byte, List<ExprNodeDesc>>();
Map<Byte, List<ExprNodeDesc>> filters = new HashMap<Byte, List<ExprNodeDesc>>();
Map<String, ExprNodeDesc> colExprMap = new HashMap<String, ExprNodeDesc>();
HashMap<Integer, Set<String>> posToAliasMap = new HashMap<Integer, Set<String>>();
int outputPos = 0;
for (int pos = 0; pos < children.size(); pos++) {
// 2.1. Backtracking from RS
ReduceSinkOperator inputRS = (ReduceSinkOperator) children.get(pos);
if (inputRS.getNumParent() != 1) {
throw new SemanticException("RS should have single parent");
}
Operator<?> parent = inputRS.getParentOperators().get(0);
ReduceSinkDesc rsDesc = inputRS.getConf();
int[] index = inputRS.getValueIndex();
Byte tag = (byte) rsDesc.getTag();
// 2.1.1. If semijoin...
if (semiJoin && pos != 0) {
exprMap.put(tag, new ArrayList<ExprNodeDesc>());
childOps[pos] = inputRS;
continue;
}
posToAliasMap.put(pos, new HashSet<String>(inputRS.getSchema().getTableNames()));
List<String> keyColNames = rsDesc.getOutputKeyColumnNames();
List<String> valColNames = rsDesc.getOutputValueColumnNames();
Map<String, ExprNodeDesc> descriptors = buildBacktrackFromReduceSinkForJoin(outputPos, outputColumnNames, keyColNames, valColNames, index, parent, baseSrc[pos]);
List<ColumnInfo> parentColumns = parent.getSchema().getSignature();
for (int i = 0; i < index.length; i++) {
ColumnInfo info = new ColumnInfo(parentColumns.get(i));
info.setInternalName(outputColumnNames.get(outputPos));
info.setTabAlias(tabAlias);
outputColumns.add(info);
reversedExprs.put(outputColumnNames.get(outputPos), tag);
outputPos++;
}
exprMap.put(tag, new ArrayList<ExprNodeDesc>(descriptors.values()));
colExprMap.putAll(descriptors);
childOps[pos] = inputRS;
}
// 3. We populate the filters and filterMap structure needed in the join descriptor
List<List<ExprNodeDesc>> filtersPerInput = Lists.newArrayList();
int[][] filterMap = new int[children.size()][];
for (int i = 0; i < children.size(); i++) {
filtersPerInput.add(new ArrayList<ExprNodeDesc>());
}
// 3. We populate the filters structure
for (int i = 0; i < filterExpressions.size(); i++) {
int leftPos = joinCondns[i].getLeft();
int rightPos = joinCondns[i].getRight();
for (ExprNodeDesc expr : filterExpressions.get(i)) {
// We need to update the exprNode, as currently
// they refer to columns in the output of the join;
// they should refer to the columns output by the RS
int inputPos = updateExprNode(expr, reversedExprs, colExprMap);
if (inputPos == -1) {
inputPos = leftPos;
}
filtersPerInput.get(inputPos).add(expr);
if (joinCondns[i].getType() == JoinDesc.FULL_OUTER_JOIN || joinCondns[i].getType() == JoinDesc.LEFT_OUTER_JOIN || joinCondns[i].getType() == JoinDesc.RIGHT_OUTER_JOIN) {
if (inputPos == leftPos) {
updateFilterMap(filterMap, leftPos, rightPos);
} else {
updateFilterMap(filterMap, rightPos, leftPos);
}
}
}
}
for (int pos = 0; pos < children.size(); pos++) {
ReduceSinkOperator inputRS = (ReduceSinkOperator) children.get(pos);
ReduceSinkDesc rsDesc = inputRS.getConf();
Byte tag = (byte) rsDesc.getTag();
filters.put(tag, filtersPerInput.get(pos));
}
// 4. We create the join operator with its descriptor
JoinDesc desc = new JoinDesc(exprMap, outputColumnNames, noOuterJoin, joinCondns, filters, joinExpressions, null);
desc.setReversedExprs(reversedExprs);
desc.setFilterMap(filterMap);
JoinOperator joinOp = (JoinOperator) OperatorFactory.getAndMakeChild(childOps[0].getCompilationOpContext(), desc, new RowSchema(outputColumns), childOps);
joinOp.setColumnExprMap(colExprMap);
joinOp.setPosToAliasMap(posToAliasMap);
joinOp.getConf().setBaseSrc(baseSrc);
if (LOG.isDebugEnabled()) {
LOG.debug("Generated " + joinOp + " with row schema: [" + joinOp.getSchema() + "]");
}
return joinOp;
}
Aggregations