use of org.apache.calcite.rel.core.Aggregate in project hive by apache.
the class RelFieldTrimmer method trimFields.
/**
* Variant of {@link #trimFields(RelNode, ImmutableBitSet, Set)} for
* {@link org.apache.calcite.rel.logical.LogicalAggregate}.
*/
public TrimResult trimFields(Aggregate aggregate, ImmutableBitSet fieldsUsed, Set<RelDataTypeField> extraFields) {
// Fields:
//
// | sys fields | group fields | indicator fields | agg functions |
//
// Two kinds of trimming:
//
// 1. If agg rel has system fields but none of these are used, create an
// agg rel with no system fields.
//
// 2. If aggregate functions are not used, remove them.
//
// But group and indicator fields stay, even if they are not used.
final RelDataType rowType = aggregate.getRowType();
// Compute which input fields are used.
// 1. group fields are always used
final ImmutableBitSet.Builder inputFieldsUsed = aggregate.getGroupSet().rebuild();
// 2. agg functions
for (AggregateCall aggCall : aggregate.getAggCallList()) {
inputFieldsUsed.addAll(aggCall.getArgList());
if (aggCall.filterArg >= 0) {
inputFieldsUsed.set(aggCall.filterArg);
}
inputFieldsUsed.addAll(RelCollations.ordinals(aggCall.collation));
}
// Create input with trimmed columns.
final RelNode input = aggregate.getInput();
final Set<RelDataTypeField> inputExtraFields = Collections.emptySet();
final TrimResult trimResult = trimChild(aggregate, input, inputFieldsUsed.build(), inputExtraFields);
final RelNode newInput = trimResult.left;
final Mapping inputMapping = trimResult.right;
// We have to return group keys and (if present) indicators.
// So, pretend that the consumer asked for them.
final int groupCount = aggregate.getGroupSet().cardinality();
fieldsUsed = fieldsUsed.union(ImmutableBitSet.range(groupCount));
// there's nothing to do.
if (input == newInput && fieldsUsed.equals(ImmutableBitSet.range(rowType.getFieldCount()))) {
return result(aggregate, Mappings.createIdentity(rowType.getFieldCount()));
}
// Which agg calls are used by our consumer?
int j = groupCount;
int usedAggCallCount = 0;
for (int i = 0; i < aggregate.getAggCallList().size(); i++) {
if (fieldsUsed.get(j++)) {
++usedAggCallCount;
}
}
// Offset due to the number of system fields having changed.
Mapping mapping = Mappings.create(MappingType.INVERSE_SURJECTION, rowType.getFieldCount(), groupCount + usedAggCallCount);
final ImmutableBitSet newGroupSet = Mappings.apply(inputMapping, aggregate.getGroupSet());
final ImmutableList<ImmutableBitSet> newGroupSets = ImmutableList.copyOf(Iterables.transform(aggregate.getGroupSets(), input1 -> Mappings.apply(inputMapping, input1)));
// indicator fields first.
for (j = 0; j < groupCount; j++) {
mapping.set(j, j);
}
// Now create new agg calls, and populate mapping for them.
final RelBuilder relBuilder = REL_BUILDER.get();
relBuilder.push(newInput);
final List<RelBuilder.AggCall> newAggCallList = new ArrayList<>();
j = groupCount;
for (AggregateCall aggCall : aggregate.getAggCallList()) {
if (fieldsUsed.get(j)) {
final ImmutableList<RexNode> args = relBuilder.fields(Mappings.apply2(inputMapping, aggCall.getArgList()));
final RexNode filterArg = aggCall.filterArg < 0 ? null : relBuilder.field(Mappings.apply(inputMapping, aggCall.filterArg));
RelBuilder.AggCall newAggCall = relBuilder.aggregateCall(aggCall.getAggregation(), args).distinct(aggCall.isDistinct()).filter(filterArg).approximate(aggCall.isApproximate()).sort(relBuilder.fields(aggCall.collation)).as(aggCall.name);
mapping.set(j, groupCount + newAggCallList.size());
newAggCallList.add(newAggCall);
}
++j;
}
final RelBuilder.GroupKey groupKey = relBuilder.groupKey(newGroupSet, newGroupSets);
relBuilder.aggregate(groupKey, newAggCallList);
return result(relBuilder.build(), mapping);
}
use of org.apache.calcite.rel.core.Aggregate in project hive by apache.
the class HiveRemoveSqCountCheck method onMatch.
@Override
public void onMatch(RelOptRuleCall call) {
final Join topJoin = call.rel(0);
final Join join = call.rel(2);
final Aggregate aggregate = call.rel(6);
// in presence of grouping sets we can't remove sq_count_check
if (aggregate.indicator) {
return;
}
if (isAggregateWithoutGbyKeys(aggregate) || isAggWithConstantGbyKeys(aggregate, call)) {
// join(left, join.getRight)
RelNode newJoin = HiveJoin.getJoin(topJoin.getCluster(), join.getLeft(), topJoin.getRight(), topJoin.getCondition(), topJoin.getJoinType());
call.transformTo(newJoin);
}
}
use of org.apache.calcite.rel.core.Aggregate in project hive by apache.
the class HiveAggregateIncrementalRewritingRuleBase method onMatch.
@Override
public void onMatch(RelOptRuleCall call) {
final Aggregate agg = call.rel(aggregateIndex);
final Union union = call.rel(1);
final RelBuilder relBuilder = call.builder();
final RexBuilder rexBuilder = relBuilder.getRexBuilder();
// 1) First branch is query, second branch is MV
RelNode joinLeftInput = union.getInput(1);
final T joinRightInput = createJoinRightInput(call);
if (joinRightInput == null) {
return;
}
// 2) Introduce a Project on top of MV scan having all columns from the view plus a boolean literal which indicates
// whether the row with the key values coming from the joinRightInput exists in the view:
// - true means exist
// - null means not exists
// Project also needed to encapsulate the view scan by a subquery -> this is required by
// CalcitePlanner.fixUpASTAggregateInsertIncrementalRebuild
// CalcitePlanner.fixUpASTAggregateInsertDeleteIncrementalRebuild
List<RexNode> mvCols = new ArrayList<>(joinLeftInput.getRowType().getFieldCount());
for (int i = 0; i < joinLeftInput.getRowType().getFieldCount(); ++i) {
mvCols.add(rexBuilder.makeInputRef(joinLeftInput.getRowType().getFieldList().get(i).getType(), i));
}
mvCols.add(rexBuilder.makeLiteral(true));
joinLeftInput = relBuilder.push(joinLeftInput).project(mvCols).build();
// 3) Build conditions for join and start adding
// expressions for project operator
List<RexNode> projExprs = new ArrayList<>();
List<RexNode> joinConjs = new ArrayList<>();
int groupCount = agg.getGroupCount();
int totalCount = agg.getGroupCount() + agg.getAggCallList().size();
for (int leftPos = 0, rightPos = totalCount + 1; leftPos < groupCount; leftPos++, rightPos++) {
RexNode leftRef = rexBuilder.makeInputRef(joinLeftInput.getRowType().getFieldList().get(leftPos).getType(), leftPos);
RexNode rightRef = rexBuilder.makeInputRef(joinRightInput.rightInput.getRowType().getFieldList().get(leftPos).getType(), rightPos);
projExprs.add(rightRef);
RexNode nsEqExpr = rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_DISTINCT_FROM, ImmutableList.of(leftRef, rightRef));
joinConjs.add(nsEqExpr);
}
// 4) Create join node
RexNode joinCond = RexUtil.composeConjunction(rexBuilder, joinConjs);
RelNode join = relBuilder.push(joinLeftInput).push(joinRightInput.rightInput).join(JoinRelType.RIGHT, joinCond).build();
// functions
for (int i = 0, leftPos = groupCount, rightPos = totalCount + 1 + groupCount; leftPos < totalCount; i++, leftPos++, rightPos++) {
// case when source.s is null and mv2.s is null then null
// case when source.s IS null then mv2.s
// case when mv2.s IS null then source.s
// else source.s + mv2.s end
RexNode leftRef = rexBuilder.makeInputRef(joinLeftInput.getRowType().getFieldList().get(leftPos).getType(), leftPos);
RexNode rightRef = rexBuilder.makeInputRef(joinRightInput.rightInput.getRowType().getFieldList().get(leftPos).getType(), rightPos);
// Generate SQLOperator for merging the aggregations
SqlAggFunction aggCall = agg.getAggCallList().get(i).getAggregation();
RexNode elseReturn = createAggregateNode(aggCall, leftRef, rightRef, rexBuilder);
// According to SQL standard (and Hive) Aggregate functions eliminates null values however operators used in
// elseReturn expressions returns null if one of their operands is null
// hence we need a null check of both operands.
// Note: If both are null, we will fall into branch WHEN leftNull THEN rightRef
RexNode leftNull = rexBuilder.makeCall(SqlStdOperatorTable.IS_NULL, leftRef);
RexNode rightNull = rexBuilder.makeCall(SqlStdOperatorTable.IS_NULL, rightRef);
projExprs.add(rexBuilder.makeCall(SqlStdOperatorTable.CASE, leftNull, rightRef, rightNull, leftRef, elseReturn));
}
int flagIndex = joinLeftInput.getRowType().getFieldCount() - 1;
RexNode flagNode = rexBuilder.makeInputRef(join.getRowType().getFieldList().get(flagIndex).getType(), flagIndex);
// 6) Build plan
RelNode newNode = relBuilder.push(join).filter(createFilterCondition(joinRightInput, flagNode, projExprs, relBuilder)).project(projExprs).build();
call.transformTo(newNode);
}
use of org.apache.calcite.rel.core.Aggregate in project hive by apache.
the class HiveAggregateInsertDeleteIncrementalRewritingRule method createJoinRightInput.
@Override
protected IncrementalComputePlanWithDeletedRows createJoinRightInput(RelOptRuleCall call) {
RelBuilder relBuilder = call.builder();
Aggregate aggregate = call.rel(2);
RexBuilder rexBuilder = relBuilder.getRexBuilder();
RelNode aggInput = aggregate.getInput();
// Propagate rowIsDeleted column
aggInput = HiveHepExtractRelNodeRule.execute(aggInput);
aggInput = new HiveRowIsDeletedPropagator(relBuilder).propagate(aggInput);
int rowIsDeletedIdx = aggInput.getRowType().getFieldCount() - 1;
RexNode rowIsDeletedNode = rexBuilder.makeInputRef(aggInput.getRowType().getFieldList().get(rowIsDeletedIdx).getType(), rowIsDeletedIdx);
// Search for count(*)
// and create case expression to aggregate inserted/deleted values:
//
// SELECT
// sum(case when t1.ROW__IS__DELETED then -b else b end) sumb,
// sum(case when t1.ROW__IS__DELETED then -1 else 1 end) countStar
int countIdx = -1;
List<RelBuilder.AggCall> newAggregateCalls = new ArrayList<>(aggregate.getAggCallList().size());
for (int i = 0; i < aggregate.getAggCallList().size(); ++i) {
AggregateCall aggregateCall = aggregate.getAggCallList().get(i);
if (aggregateCall.getAggregation().getKind() == SqlKind.COUNT && aggregateCall.getArgList().size() == 0) {
countIdx = i + aggregate.getGroupCount();
}
RexNode argument;
SqlAggFunction aggFunction;
switch(aggregateCall.getAggregation().getKind()) {
case COUNT:
aggFunction = SqlStdOperatorTable.SUM;
// count(*)
if (aggregateCall.getArgList().isEmpty()) {
argument = relBuilder.literal(1);
} else {
// count(<column_name>)
argument = genArgumentForCountColumn(relBuilder, aggInput, aggregateCall.getArgList().get(0));
}
break;
case SUM:
aggFunction = SqlStdOperatorTable.SUM;
Integer argumentIdx = aggregateCall.getArgList().get(0);
argument = rexBuilder.makeInputRef(aggInput.getRowType().getFieldList().get(argumentIdx).getType(), argumentIdx);
break;
default:
throw new AssertionError("Found an aggregation that could not be" + " recognized: " + aggregateCall);
}
RexNode minus = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, relBuilder.literal(-1), argument);
RexNode newArgument = rexBuilder.makeCall(SqlStdOperatorTable.CASE, rowIsDeletedNode, minus, argument);
newAggregateCalls.add(relBuilder.aggregateCall(aggFunction, newArgument));
}
if (countIdx == -1) {
// Can not rewrite, bail out;
return null;
}
return new IncrementalComputePlanWithDeletedRows(relBuilder.push(aggInput).aggregate(relBuilder.groupKey(aggregate.getGroupSet()), newAggregateCalls).build(), countIdx);
}
use of org.apache.calcite.rel.core.Aggregate in project hive by apache.
the class HiveAggregateJoinTransposeRule method onMatch.
@Override
public void onMatch(RelOptRuleCall call) {
try {
final Aggregate aggregate = call.rel(0);
final Join join = call.rel(1);
final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
final RelBuilder relBuilder = call.builder();
// If any aggregate call has a filter, bail out
for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
if (aggregateCall.getAggregation().unwrap(SqlSplittableAggFunction.class) == null) {
return;
}
if (aggregateCall.filterArg >= 0) {
return;
}
}
// aggregate operator
if (join.getJoinType() != JoinRelType.INNER) {
return;
}
if (!allowFunctions && !aggregate.getAggCallList().isEmpty()) {
return;
}
boolean groupingUnique = isGroupingUnique(join, aggregate.getGroupSet());
if (!groupingUnique && !costBased) {
// there is no need to check further - the transformation may not happen
return;
}
// Do the columns used by the join appear in the output of the aggregate?
final ImmutableBitSet aggregateColumns = aggregate.getGroupSet();
final RelMetadataQuery mq = call.getMetadataQuery();
final ImmutableBitSet keyColumns = keyColumns(aggregateColumns, mq.getPulledUpPredicates(join).pulledUpPredicates);
final ImmutableBitSet joinColumns = RelOptUtil.InputFinder.bits(join.getCondition());
final boolean allColumnsInAggregate = keyColumns.contains(joinColumns);
final ImmutableBitSet belowAggregateColumns = aggregateColumns.union(joinColumns);
// Split join condition
final List<Integer> leftKeys = Lists.newArrayList();
final List<Integer> rightKeys = Lists.newArrayList();
final List<Boolean> filterNulls = Lists.newArrayList();
RexNode nonEquiConj = RelOptUtil.splitJoinCondition(join.getLeft(), join.getRight(), join.getCondition(), leftKeys, rightKeys, filterNulls);
// If it contains non-equi join conditions, we bail out
if (!nonEquiConj.isAlwaysTrue()) {
return;
}
// Push each aggregate function down to each side that contains all of its
// arguments. Note that COUNT(*), because it has no arguments, can go to
// both sides.
final Map<Integer, Integer> map = new HashMap<>();
final List<Side> sides = new ArrayList<>();
int uniqueCount = 0;
int offset = 0;
int belowOffset = 0;
for (int s = 0; s < 2; s++) {
final Side side = new Side();
final RelNode joinInput = join.getInput(s);
int fieldCount = joinInput.getRowType().getFieldCount();
final ImmutableBitSet fieldSet = ImmutableBitSet.range(offset, offset + fieldCount);
final ImmutableBitSet belowAggregateKeyNotShifted = belowAggregateColumns.intersect(fieldSet);
for (Ord<Integer> c : Ord.zip(belowAggregateKeyNotShifted)) {
map.put(c.e, belowOffset + c.i);
}
final ImmutableBitSet belowAggregateKey = belowAggregateKeyNotShifted.shift(-offset);
final boolean unique;
if (!allowFunctions) {
assert aggregate.getAggCallList().isEmpty();
// If there are no functions, it doesn't matter as much whether we
// aggregate the inputs before the join, because there will not be
// any functions experiencing a cartesian product effect.
//
// But finding out whether the input is already unique requires a call
// to areColumnsUnique that currently (until [CALCITE-1048] "Make
// metadata more robust" is fixed) places a heavy load on
// the metadata system.
//
// So we choose to imagine the the input is already unique, which is
// untrue but harmless.
//
unique = true;
} else {
final Boolean unique0 = mq.areColumnsUnique(joinInput, belowAggregateKey, true);
unique = unique0 != null && unique0;
}
if (unique) {
++uniqueCount;
relBuilder.push(joinInput);
relBuilder.project(belowAggregateKey.asList().stream().map(relBuilder::field).collect(Collectors.toList()));
side.newInput = relBuilder.build();
} else {
List<AggregateCall> belowAggCalls = new ArrayList<>();
final SqlSplittableAggFunction.Registry<AggregateCall> belowAggCallRegistry = registry(belowAggCalls);
final Mappings.TargetMapping mapping = s == 0 ? Mappings.createIdentity(fieldCount) : Mappings.createShiftMapping(fieldCount + offset, 0, offset, fieldCount);
for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) {
final SqlAggFunction aggregation = aggCall.e.getAggregation();
final SqlSplittableAggFunction splitter = Preconditions.checkNotNull(aggregation.unwrap(SqlSplittableAggFunction.class));
final AggregateCall call1;
if (fieldSet.contains(ImmutableBitSet.of(aggCall.e.getArgList()))) {
call1 = splitter.split(aggCall.e, mapping);
} else {
call1 = splitter.other(rexBuilder.getTypeFactory(), aggCall.e);
}
if (call1 != null) {
side.split.put(aggCall.i, belowAggregateKey.cardinality() + belowAggCallRegistry.register(call1));
}
}
side.newInput = relBuilder.push(joinInput).aggregate(relBuilder.groupKey(belowAggregateKey, null), belowAggCalls).build();
}
offset += fieldCount;
belowOffset += side.newInput.getRowType().getFieldCount();
sides.add(side);
}
if (uniqueCount == 2) {
// invocation of this rule; if we continue we might loop forever.
return;
}
// Update condition
final Mapping mapping = (Mapping) Mappings.target(map::get, join.getRowType().getFieldCount(), belowOffset);
final RexNode newCondition = RexUtil.apply(mapping, join.getCondition());
// Create new join
relBuilder.push(sides.get(0).newInput).push(sides.get(1).newInput).join(join.getJoinType(), newCondition);
// Aggregate above to sum up the sub-totals
final List<AggregateCall> newAggCalls = new ArrayList<>();
final int groupIndicatorCount = aggregate.getGroupCount() + aggregate.getIndicatorCount();
final int newLeftWidth = sides.get(0).newInput.getRowType().getFieldCount();
final List<RexNode> projects = new ArrayList<>(rexBuilder.identityProjects(relBuilder.peek().getRowType()));
for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) {
final SqlAggFunction aggregation = aggCall.e.getAggregation();
final SqlSplittableAggFunction splitter = Preconditions.checkNotNull(aggregation.unwrap(SqlSplittableAggFunction.class));
final Integer leftSubTotal = sides.get(0).split.get(aggCall.i);
final Integer rightSubTotal = sides.get(1).split.get(aggCall.i);
newAggCalls.add(splitter.topSplit(rexBuilder, registry(projects), groupIndicatorCount, relBuilder.peek().getRowType(), aggCall.e, leftSubTotal == null ? -1 : leftSubTotal, rightSubTotal == null ? -1 : rightSubTotal + newLeftWidth));
}
relBuilder.project(projects);
boolean aggConvertedToProjects = false;
if (allColumnsInAggregate) {
// let's see if we can convert aggregate into projects
List<RexNode> projects2 = new ArrayList<>();
for (int key : Mappings.apply(mapping, aggregate.getGroupSet())) {
projects2.add(relBuilder.field(key));
}
for (AggregateCall newAggCall : newAggCalls) {
final SqlSplittableAggFunction splitter = newAggCall.getAggregation().unwrap(SqlSplittableAggFunction.class);
if (splitter != null) {
final RelDataType rowType = relBuilder.peek().getRowType();
projects2.add(splitter.singleton(rexBuilder, rowType, newAggCall));
}
}
if (projects2.size() == aggregate.getGroupSet().cardinality() + newAggCalls.size()) {
// We successfully converted agg calls into projects.
relBuilder.project(projects2);
aggConvertedToProjects = true;
}
}
if (!aggConvertedToProjects) {
relBuilder.aggregate(relBuilder.groupKey(Mappings.apply(mapping, aggregate.getGroupSet()), Mappings.apply2(mapping, aggregate.getGroupSets())), newAggCalls);
}
RelNode r = relBuilder.build();
boolean transform = false;
if (uniqueBased && aggConvertedToProjects) {
transform = groupingUnique;
}
if (!transform && costBased) {
RelOptCost afterCost = mq.getCumulativeCost(r);
RelOptCost beforeCost = mq.getCumulativeCost(aggregate);
transform = afterCost.isLt(beforeCost);
}
if (transform) {
call.transformTo(r);
}
} catch (Exception e) {
if (noColsMissingStats.get() > 0) {
LOG.warn("Missing column stats (see previous messages), skipping aggregate-join transpose in CBO");
noColsMissingStats.set(0);
} else {
throw e;
}
}
}
Aggregations