use of org.apache.calcite.plan.RelOptCluster in project hive by apache.
the class HiveProjectSortTransposeRule method onMatch.
// ~ Methods ----------------------------------------------------------------
// implement RelOptRule
public void onMatch(RelOptRuleCall call) {
final HiveProject project = call.rel(0);
final HiveSortLimit sort = call.rel(1);
final RelOptCluster cluster = project.getCluster();
List<RelFieldCollation> fieldCollations = getNewRelFieldCollations(project, sort.getCollation(), cluster);
if (fieldCollations == null) {
return;
}
RelTraitSet traitSet = sort.getCluster().traitSetOf(HiveRelNode.CONVENTION);
RelCollation newCollation = traitSet.canonize(RelCollationImpl.of(fieldCollations));
// New operators
final RelNode newProject = project.copy(sort.getInput().getTraitSet(), ImmutableList.of(sort.getInput()));
final HiveSortLimit newSort = sort.copy(newProject.getTraitSet(), newProject, newCollation, sort.offset, sort.fetch);
call.transformTo(newSort);
}
use of org.apache.calcite.plan.RelOptCluster in project hive by apache.
the class HiveCardinalityPreservingJoinOptimization method trim.
@Override
public RelNode trim(RelBuilder relBuilder, RelNode root) {
try {
if (root.getInputs().size() != 1) {
LOG.debug("Only plans where root has one input are supported. Root: {}", root);
return root;
}
REL_BUILDER.set(relBuilder);
RexBuilder rexBuilder = relBuilder.getRexBuilder();
RelNode rootInput = root.getInput(0);
// Build the list of RexInputRef from root input RowType
List<RexInputRef> rootFieldList = new ArrayList<>(rootInput.getRowType().getFieldCount());
List<String> newColumnNames = new ArrayList<>();
for (int i = 0; i < rootInput.getRowType().getFieldList().size(); ++i) {
RelDataTypeField relDataTypeField = rootInput.getRowType().getFieldList().get(i);
rootFieldList.add(rexBuilder.makeInputRef(relDataTypeField.getType(), i));
newColumnNames.add(relDataTypeField.getName());
}
// Bit set to gather the refs that backtrack to constant values
BitSet constants = new BitSet();
List<JoinedBackFields> lineages = getExpressionLineageOf(rootFieldList, rootInput, constants);
if (lineages == null) {
LOG.debug("Some projected field lineage can not be determined");
return root;
}
// 1. Collect candidate tables for join back and map RexNodes coming from those tables to their index in the
// rootInput row type
// Collect all used fields from original plan
ImmutableBitSet fieldsUsed = ImmutableBitSet.of(constants.stream().toArray());
List<TableToJoinBack> tableToJoinBackList = new ArrayList<>(lineages.size());
Map<Integer, RexNode> rexNodesToShuttle = new HashMap<>(rootInput.getRowType().getFieldCount());
for (JoinedBackFields joinedBackFields : lineages) {
Optional<ImmutableBitSet> projectedKeys = joinedBackFields.relOptHiveTable.getNonNullableKeys().stream().filter(joinedBackFields.fieldsInSourceTable::contains).findFirst();
if (projectedKeys.isPresent() && !projectedKeys.get().equals(joinedBackFields.fieldsInSourceTable)) {
TableToJoinBack tableToJoinBack = new TableToJoinBack(projectedKeys.get(), joinedBackFields);
tableToJoinBackList.add(tableToJoinBack);
fieldsUsed = fieldsUsed.union(joinedBackFields.getSource(projectedKeys.get()));
for (TableInputRefHolder mapping : joinedBackFields.mapping) {
if (!fieldsUsed.get(mapping.indexInOriginalRowType)) {
rexNodesToShuttle.put(mapping.indexInOriginalRowType, mapping.rexNode);
}
}
} else {
fieldsUsed = fieldsUsed.union(joinedBackFields.fieldsInOriginalRowType);
}
}
if (tableToJoinBackList.isEmpty()) {
LOG.debug("None of the tables has keys projected, unable to join back");
return root;
}
// 2. Trim out non-key fields of joined back tables
Set<RelDataTypeField> extraFields = Collections.emptySet();
TrimResult trimResult = dispatchTrimFields(rootInput, fieldsUsed, extraFields);
RelNode newInput = trimResult.left;
if (newInput.getRowType().equals(rootInput.getRowType())) {
LOG.debug("Nothing was trimmed out.");
return root;
}
// 3. Join back tables to the top of original plan
Mapping newInputMapping = trimResult.right;
Map<RexTableInputRef, Integer> tableInputRefMapping = new HashMap<>();
for (TableToJoinBack tableToJoinBack : tableToJoinBackList) {
LOG.debug("Joining back table {}", tableToJoinBack.joinedBackFields.relOptHiveTable.getName());
// 3.1. Create new TableScan of tables to join back
RelOptHiveTable relOptTable = tableToJoinBack.joinedBackFields.relOptHiveTable;
RelOptCluster cluster = relBuilder.getCluster();
HiveTableScan tableScan = new HiveTableScan(cluster, cluster.traitSetOf(HiveRelNode.CONVENTION), relOptTable, relOptTable.getHiveTableMD().getTableName(), null, false, false);
// 3.2. Create Project with the required fields from this table
RelNode projectTableAccessRel = tableScan.project(tableToJoinBack.joinedBackFields.fieldsInSourceTable, new HashSet<>(0), REL_BUILDER.get());
// 3.3. Create mapping between the Project and TableScan
Mapping projectMapping = Mappings.create(MappingType.INVERSE_SURJECTION, tableScan.getRowType().getFieldCount(), tableToJoinBack.joinedBackFields.fieldsInSourceTable.cardinality());
int projectIndex = 0;
for (int i : tableToJoinBack.joinedBackFields.fieldsInSourceTable) {
projectMapping.set(i, projectIndex);
++projectIndex;
}
int offset = newInput.getRowType().getFieldCount();
// 3.4. Map rexTableInputRef to the index where it can be found in the new Input row type
for (TableInputRefHolder mapping : tableToJoinBack.joinedBackFields.mapping) {
int indexInSourceTable = mapping.tableInputRef.getIndex();
if (!tableToJoinBack.keys.get(indexInSourceTable)) {
// 3.5. if this is not a key field it is shifted by the left input field count
tableInputRefMapping.put(mapping.tableInputRef, offset + projectMapping.getTarget(indexInSourceTable));
}
}
// 3.7. Create Join
relBuilder.push(newInput);
relBuilder.push(projectTableAccessRel);
RexNode joinCondition = joinCondition(newInput, newInputMapping, tableToJoinBack, projectTableAccessRel, projectMapping, rexBuilder);
newInput = relBuilder.join(JoinRelType.INNER, joinCondition).build();
}
// 4. Collect rexNodes for Project
TableInputRefMapper mapper = new TableInputRefMapper(tableInputRefMapping, rexBuilder, newInput);
List<RexNode> rexNodeList = new ArrayList<>(rootInput.getRowType().getFieldCount());
for (int i = 0; i < rootInput.getRowType().getFieldCount(); i++) {
RexNode rexNode = rexNodesToShuttle.get(i);
if (rexNode != null) {
rexNodeList.add(mapper.apply(rexNode));
} else {
int target = newInputMapping.getTarget(i);
rexNodeList.add(rexBuilder.makeInputRef(newInput.getRowType().getFieldList().get(target).getType(), target));
}
}
// 5. Create Project on top of all Join backs
relBuilder.push(newInput);
relBuilder.project(rexNodeList, newColumnNames);
return root.copy(root.getTraitSet(), singletonList(relBuilder.build()));
} finally {
REL_BUILDER.remove();
}
}
use of org.apache.calcite.plan.RelOptCluster in project hive by apache.
the class HiveAggregateReduceFunctionsRule method reduceStddev.
private RexNode reduceStddev(Aggregate oldAggRel, AggregateCall oldCall, boolean biased, boolean sqrt, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping, List<RexNode> inputExprs) {
// stddev_pop(x) ==>
// power(
// (sum(x * x) - sum(x) * sum(x) / count(x))
// / count(x),
// .5)
//
// stddev_samp(x) ==>
// power(
// (sum(x * x) - sum(x) * sum(x) / count(x))
// / nullif(count(x) - 1, 0),
// .5)
final int nGroups = oldAggRel.getGroupCount();
final RelOptCluster cluster = oldAggRel.getCluster();
final RexBuilder rexBuilder = cluster.getRexBuilder();
final RelDataTypeFactory typeFactory = cluster.getTypeFactory();
assert oldCall.getArgList().size() == 1 : oldCall.getArgList();
final int argOrdinal = oldCall.getArgList().get(0);
final RelDataType argOrdinalType = getFieldType(oldAggRel.getInput(), argOrdinal);
final RelDataType oldCallType = typeFactory.createTypeWithNullability(oldCall.getType(), true);
final RexNode argRef = rexBuilder.ensureType(oldCallType, inputExprs.get(argOrdinal), false);
final int argRefOrdinal = lookupOrAdd(inputExprs, argRef);
final RelDataType sumReturnType = getSumReturnType(rexBuilder.getTypeFactory(), argRef.getType());
final RexNode argSquared = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, argRef, argRef);
final int argSquaredOrdinal = lookupOrAdd(inputExprs, argSquared);
final RelDataType sumSquaredReturnType = getSumReturnType(rexBuilder.getTypeFactory(), argSquared.getType());
final AggregateCall sumArgSquaredAggCall = createAggregateCallWithBinding(typeFactory, new HiveSqlSumAggFunction(oldCall.isDistinct(), ReturnTypes.explicit(sumSquaredReturnType), InferTypes.explicit(Collections.singletonList(argSquared.getType())), // SqlStdOperatorTable.SUM,
oldCall.getAggregation().getOperandTypeChecker()), argSquared.getType(), oldAggRel, oldCall, argSquaredOrdinal);
final RexNode sumArgSquared = rexBuilder.addAggCall(sumArgSquaredAggCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(sumArgSquaredAggCall.getType()));
final AggregateCall sumArgAggCall = AggregateCall.create(new HiveSqlSumAggFunction(oldCall.isDistinct(), ReturnTypes.explicit(sumReturnType), InferTypes.explicit(Collections.singletonList(argOrdinalType)), // SqlStdOperatorTable.SUM,
oldCall.getAggregation().getOperandTypeChecker()), oldCall.isDistinct(), oldCall.isApproximate(), ImmutableIntList.of(argRefOrdinal), oldCall.filterArg, oldAggRel.getGroupCount(), oldAggRel.getInput(), null, null);
final RexNode sumArg = rexBuilder.addAggCall(sumArgAggCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(sumArgAggCall.getType()));
final RexNode sumArgCast = rexBuilder.ensureType(oldCallType, sumArg, true);
final RexNode sumSquaredArg = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, sumArgCast, sumArgCast);
RelDataType countRetType = typeFactory.createTypeWithNullability(typeFactory.createSqlType(SqlTypeName.BIGINT), true);
final AggregateCall countArgAggCall = AggregateCall.create(new HiveSqlCountAggFunction(oldCall.isDistinct(), ReturnTypes.explicit(countRetType), oldCall.getAggregation().getOperandTypeInference(), // SqlStdOperatorTable.COUNT,
oldCall.getAggregation().getOperandTypeChecker()), oldCall.isDistinct(), oldCall.isApproximate(), oldCall.getArgList(), oldCall.filterArg, oldAggRel.getGroupCount(), oldAggRel.getInput(), countRetType, null);
final RexNode countArg = rexBuilder.addAggCall(countArgAggCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(argOrdinalType));
final RexNode avgSumSquaredArg = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, sumSquaredArg, countArg);
final RexNode diff = rexBuilder.makeCall(SqlStdOperatorTable.MINUS, sumArgSquared, avgSumSquaredArg);
final RexNode denominator;
if (biased) {
denominator = countArg;
} else {
final RexLiteral one = rexBuilder.makeExactLiteral(BigDecimal.ONE);
final RexNode nul = rexBuilder.makeCast(countArg.getType(), rexBuilder.constantNull());
final RexNode countMinusOne = rexBuilder.makeCall(SqlStdOperatorTable.MINUS, countArg, one);
final RexNode countEqOne = rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, countArg, one);
denominator = rexBuilder.makeCall(SqlStdOperatorTable.CASE, countEqOne, nul, countMinusOne);
}
final RexNode div = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, diff, denominator);
RexNode result = div;
if (sqrt) {
final RexNode half = rexBuilder.makeExactLiteral(new BigDecimal("0.5"));
result = rexBuilder.makeCall(SqlStdOperatorTable.POWER, div, half);
}
return rexBuilder.makeCast(oldCall.getType(), result);
}
use of org.apache.calcite.plan.RelOptCluster in project hive by apache.
the class HiveProject method create.
/**
* Creates a HiveProject with no sort keys.
*
* @param child
* input relational expression
* @param exps
* set of expressions for the input columns
* @param fieldNames
* aliases of the expressions
*/
public static HiveProject create(RelNode child, List<? extends RexNode> exps, List<String> fieldNames) throws CalciteSemanticException {
RelOptCluster cluster = child.getCluster();
// 1 Ensure columnNames are unique - CALCITE-411
if (fieldNames != null && !Util.isDistinct(fieldNames)) {
String msg = "Select list contains multiple expressions with the same name." + fieldNames;
throw new CalciteSemanticException(msg, UnsupportedFeature.Same_name_in_multiple_expressions);
}
RelDataType rowType = RexUtil.createStructType(cluster.getTypeFactory(), exps, fieldNames, SqlValidatorUtil.EXPR_SUGGESTER);
return create(cluster, child, exps, rowType, Collections.<RelCollation>emptyList());
}
use of org.apache.calcite.plan.RelOptCluster in project hive by apache.
the class HiveSortExchange method create.
/**
* Creates a HiveSortExchange.
*
* @param input Input relational expression
* @param distribution Distribution specification
* @param collation Collation specification
* @param keys Keys specification
*/
public static HiveSortExchange create(RelNode input, RelDistribution distribution, RelCollation collation, ImmutableList<RexNode> keys) {
RelOptCluster cluster = input.getCluster();
distribution = RelDistributionTraitDef.INSTANCE.canonize(distribution);
collation = RelCollationTraitDef.INSTANCE.canonize(collation);
RelTraitSet traitSet = getTraitSet(cluster, collation, distribution);
return new HiveSortExchange(cluster, traitSet, input, distribution, collation, keys);
}
Aggregations