use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.sql.SqlAggFunction in project hive by apache.
the class PlanModifierForASTConv method replaceEmptyGroupAggr.
private static void replaceEmptyGroupAggr(final RelNode rel, RelNode parent) {
// If this function is called, the parent should only include constant
List<RexNode> exps = parent instanceof Project ? ((Project) parent).getProjects() : Collections.emptyList();
for (RexNode rexNode : exps) {
if (!rexNode.accept(new HiveCalciteUtil.ConstantFinder())) {
throw new RuntimeException("We expect " + parent.toString() + " to contain only constants. However, " + rexNode.toString() + " is " + rexNode.getKind());
}
}
HiveAggregate oldAggRel = (HiveAggregate) rel;
RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory();
RelDataType longType = TypeConverter.convert(TypeInfoFactory.longTypeInfo, typeFactory);
RelDataType intType = TypeConverter.convert(TypeInfoFactory.intTypeInfo, typeFactory);
// Create the dummy aggregation.
SqlAggFunction countFn = SqlFunctionConverter.getCalciteAggFn("count", false, ImmutableList.of(intType), longType);
// TODO: Using 0 might be wrong; might need to walk down to find the
// proper index of a dummy.
List<Integer> argList = ImmutableList.of(0);
AggregateCall dummyCall = new AggregateCall(countFn, false, argList, longType, null);
Aggregate newAggRel = oldAggRel.copy(oldAggRel.getTraitSet(), oldAggRel.getInput(), oldAggRel.indicator, oldAggRel.getGroupSet(), oldAggRel.getGroupSets(), ImmutableList.of(dummyCall));
RelNode select = introduceDerivedTable(newAggRel);
parent.replaceInput(0, select);
}
use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.sql.SqlAggFunction in project hive by apache.
the class HiveAggregateIncrementalRewritingRuleBase method onMatch.
@Override
public void onMatch(RelOptRuleCall call) {
final Aggregate agg = call.rel(aggregateIndex);
final Union union = call.rel(1);
final RelBuilder relBuilder = call.builder();
final RexBuilder rexBuilder = relBuilder.getRexBuilder();
// 1) First branch is query, second branch is MV
RelNode joinLeftInput = union.getInput(1);
final T joinRightInput = createJoinRightInput(call);
if (joinRightInput == null) {
return;
}
// 2) Introduce a Project on top of MV scan having all columns from the view plus a boolean literal which indicates
// whether the row with the key values coming from the joinRightInput exists in the view:
// - true means exist
// - null means not exists
// Project also needed to encapsulate the view scan by a subquery -> this is required by
// CalcitePlanner.fixUpASTAggregateInsertIncrementalRebuild
// CalcitePlanner.fixUpASTAggregateInsertDeleteIncrementalRebuild
List<RexNode> mvCols = new ArrayList<>(joinLeftInput.getRowType().getFieldCount());
for (int i = 0; i < joinLeftInput.getRowType().getFieldCount(); ++i) {
mvCols.add(rexBuilder.makeInputRef(joinLeftInput.getRowType().getFieldList().get(i).getType(), i));
}
mvCols.add(rexBuilder.makeLiteral(true));
joinLeftInput = relBuilder.push(joinLeftInput).project(mvCols).build();
// 3) Build conditions for join and start adding
// expressions for project operator
List<RexNode> projExprs = new ArrayList<>();
List<RexNode> joinConjs = new ArrayList<>();
int groupCount = agg.getGroupCount();
int totalCount = agg.getGroupCount() + agg.getAggCallList().size();
for (int leftPos = 0, rightPos = totalCount + 1; leftPos < groupCount; leftPos++, rightPos++) {
RexNode leftRef = rexBuilder.makeInputRef(joinLeftInput.getRowType().getFieldList().get(leftPos).getType(), leftPos);
RexNode rightRef = rexBuilder.makeInputRef(joinRightInput.rightInput.getRowType().getFieldList().get(leftPos).getType(), rightPos);
projExprs.add(rightRef);
RexNode nsEqExpr = rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_DISTINCT_FROM, ImmutableList.of(leftRef, rightRef));
joinConjs.add(nsEqExpr);
}
// 4) Create join node
RexNode joinCond = RexUtil.composeConjunction(rexBuilder, joinConjs);
RelNode join = relBuilder.push(joinLeftInput).push(joinRightInput.rightInput).join(JoinRelType.RIGHT, joinCond).build();
// functions
for (int i = 0, leftPos = groupCount, rightPos = totalCount + 1 + groupCount; leftPos < totalCount; i++, leftPos++, rightPos++) {
// case when source.s is null and mv2.s is null then null
// case when source.s IS null then mv2.s
// case when mv2.s IS null then source.s
// else source.s + mv2.s end
RexNode leftRef = rexBuilder.makeInputRef(joinLeftInput.getRowType().getFieldList().get(leftPos).getType(), leftPos);
RexNode rightRef = rexBuilder.makeInputRef(joinRightInput.rightInput.getRowType().getFieldList().get(leftPos).getType(), rightPos);
// Generate SQLOperator for merging the aggregations
SqlAggFunction aggCall = agg.getAggCallList().get(i).getAggregation();
RexNode elseReturn = createAggregateNode(aggCall, leftRef, rightRef, rexBuilder);
// According to SQL standard (and Hive) Aggregate functions eliminates null values however operators used in
// elseReturn expressions returns null if one of their operands is null
// hence we need a null check of both operands.
// Note: If both are null, we will fall into branch WHEN leftNull THEN rightRef
RexNode leftNull = rexBuilder.makeCall(SqlStdOperatorTable.IS_NULL, leftRef);
RexNode rightNull = rexBuilder.makeCall(SqlStdOperatorTable.IS_NULL, rightRef);
projExprs.add(rexBuilder.makeCall(SqlStdOperatorTable.CASE, leftNull, rightRef, rightNull, leftRef, elseReturn));
}
int flagIndex = joinLeftInput.getRowType().getFieldCount() - 1;
RexNode flagNode = rexBuilder.makeInputRef(join.getRowType().getFieldList().get(flagIndex).getType(), flagIndex);
// 6) Build plan
RelNode newNode = relBuilder.push(join).filter(createFilterCondition(joinRightInput, flagNode, projExprs, relBuilder)).project(projExprs).build();
call.transformTo(newNode);
}
use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.sql.SqlAggFunction in project hive by apache.
the class JDBCAggregationPushDownRule method matches.
@Override
public boolean matches(RelOptRuleCall call) {
final HiveAggregate agg = call.rel(0);
final HiveJdbcConverter converter = call.rel(1);
if (agg.getGroupType() != Group.SIMPLE) {
// TODO: Grouping sets not supported yet
return false;
}
for (AggregateCall relOptRuleOperand : agg.getAggCallList()) {
SqlAggFunction f = relOptRuleOperand.getAggregation();
if (f instanceof HiveSqlCountAggFunction) {
// count distinct with more that one argument is not supported
HiveSqlCountAggFunction countAgg = (HiveSqlCountAggFunction) f;
if (countAgg.isDistinct() && 1 < relOptRuleOperand.getArgList().size()) {
return false;
}
}
SqlKind kind = f.getKind();
if (!converter.getJdbcDialect().supportsAggregateFunction(kind)) {
return false;
}
}
return true;
}
use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.sql.SqlAggFunction in project hive by apache.
the class HiveAggregateInsertDeleteIncrementalRewritingRule method createJoinRightInput.
@Override
protected IncrementalComputePlanWithDeletedRows createJoinRightInput(RelOptRuleCall call) {
RelBuilder relBuilder = call.builder();
Aggregate aggregate = call.rel(2);
RexBuilder rexBuilder = relBuilder.getRexBuilder();
RelNode aggInput = aggregate.getInput();
// Propagate rowIsDeleted column
aggInput = HiveHepExtractRelNodeRule.execute(aggInput);
aggInput = new HiveRowIsDeletedPropagator(relBuilder).propagate(aggInput);
int rowIsDeletedIdx = aggInput.getRowType().getFieldCount() - 1;
RexNode rowIsDeletedNode = rexBuilder.makeInputRef(aggInput.getRowType().getFieldList().get(rowIsDeletedIdx).getType(), rowIsDeletedIdx);
// Search for count(*)
// and create case expression to aggregate inserted/deleted values:
//
// SELECT
// sum(case when t1.ROW__IS__DELETED then -b else b end) sumb,
// sum(case when t1.ROW__IS__DELETED then -1 else 1 end) countStar
int countIdx = -1;
List<RelBuilder.AggCall> newAggregateCalls = new ArrayList<>(aggregate.getAggCallList().size());
for (int i = 0; i < aggregate.getAggCallList().size(); ++i) {
AggregateCall aggregateCall = aggregate.getAggCallList().get(i);
if (aggregateCall.getAggregation().getKind() == SqlKind.COUNT && aggregateCall.getArgList().size() == 0) {
countIdx = i + aggregate.getGroupCount();
}
RexNode argument;
SqlAggFunction aggFunction;
switch(aggregateCall.getAggregation().getKind()) {
case COUNT:
aggFunction = SqlStdOperatorTable.SUM;
// count(*)
if (aggregateCall.getArgList().isEmpty()) {
argument = relBuilder.literal(1);
} else {
// count(<column_name>)
argument = genArgumentForCountColumn(relBuilder, aggInput, aggregateCall.getArgList().get(0));
}
break;
case SUM:
aggFunction = SqlStdOperatorTable.SUM;
Integer argumentIdx = aggregateCall.getArgList().get(0);
argument = rexBuilder.makeInputRef(aggInput.getRowType().getFieldList().get(argumentIdx).getType(), argumentIdx);
break;
default:
throw new AssertionError("Found an aggregation that could not be" + " recognized: " + aggregateCall);
}
RexNode minus = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, relBuilder.literal(-1), argument);
RexNode newArgument = rexBuilder.makeCall(SqlStdOperatorTable.CASE, rowIsDeletedNode, minus, argument);
newAggregateCalls.add(relBuilder.aggregateCall(aggFunction, newArgument));
}
if (countIdx == -1) {
// Can not rewrite, bail out;
return null;
}
return new IncrementalComputePlanWithDeletedRows(relBuilder.push(aggInput).aggregate(relBuilder.groupKey(aggregate.getGroupSet()), newAggregateCalls).build(), countIdx);
}
use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.sql.SqlAggFunction in project hive by apache.
the class HiveAggregateJoinTransposeRule method onMatch.
@Override
public void onMatch(RelOptRuleCall call) {
try {
final Aggregate aggregate = call.rel(0);
final Join join = call.rel(1);
final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
final RelBuilder relBuilder = call.builder();
// If any aggregate call has a filter, bail out
for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
if (aggregateCall.getAggregation().unwrap(SqlSplittableAggFunction.class) == null) {
return;
}
if (aggregateCall.filterArg >= 0) {
return;
}
}
// aggregate operator
if (join.getJoinType() != JoinRelType.INNER) {
return;
}
if (!allowFunctions && !aggregate.getAggCallList().isEmpty()) {
return;
}
boolean groupingUnique = isGroupingUnique(join, aggregate.getGroupSet());
if (!groupingUnique && !costBased) {
// there is no need to check further - the transformation may not happen
return;
}
// Do the columns used by the join appear in the output of the aggregate?
final ImmutableBitSet aggregateColumns = aggregate.getGroupSet();
final RelMetadataQuery mq = call.getMetadataQuery();
final ImmutableBitSet keyColumns = keyColumns(aggregateColumns, mq.getPulledUpPredicates(join).pulledUpPredicates);
final ImmutableBitSet joinColumns = RelOptUtil.InputFinder.bits(join.getCondition());
final boolean allColumnsInAggregate = keyColumns.contains(joinColumns);
final ImmutableBitSet belowAggregateColumns = aggregateColumns.union(joinColumns);
// Split join condition
final List<Integer> leftKeys = Lists.newArrayList();
final List<Integer> rightKeys = Lists.newArrayList();
final List<Boolean> filterNulls = Lists.newArrayList();
RexNode nonEquiConj = RelOptUtil.splitJoinCondition(join.getLeft(), join.getRight(), join.getCondition(), leftKeys, rightKeys, filterNulls);
// If it contains non-equi join conditions, we bail out
if (!nonEquiConj.isAlwaysTrue()) {
return;
}
// Push each aggregate function down to each side that contains all of its
// arguments. Note that COUNT(*), because it has no arguments, can go to
// both sides.
final Map<Integer, Integer> map = new HashMap<>();
final List<Side> sides = new ArrayList<>();
int uniqueCount = 0;
int offset = 0;
int belowOffset = 0;
for (int s = 0; s < 2; s++) {
final Side side = new Side();
final RelNode joinInput = join.getInput(s);
int fieldCount = joinInput.getRowType().getFieldCount();
final ImmutableBitSet fieldSet = ImmutableBitSet.range(offset, offset + fieldCount);
final ImmutableBitSet belowAggregateKeyNotShifted = belowAggregateColumns.intersect(fieldSet);
for (Ord<Integer> c : Ord.zip(belowAggregateKeyNotShifted)) {
map.put(c.e, belowOffset + c.i);
}
final ImmutableBitSet belowAggregateKey = belowAggregateKeyNotShifted.shift(-offset);
final boolean unique;
if (!allowFunctions) {
assert aggregate.getAggCallList().isEmpty();
// If there are no functions, it doesn't matter as much whether we
// aggregate the inputs before the join, because there will not be
// any functions experiencing a cartesian product effect.
//
// But finding out whether the input is already unique requires a call
// to areColumnsUnique that currently (until [CALCITE-1048] "Make
// metadata more robust" is fixed) places a heavy load on
// the metadata system.
//
// So we choose to imagine the the input is already unique, which is
// untrue but harmless.
//
unique = true;
} else {
final Boolean unique0 = mq.areColumnsUnique(joinInput, belowAggregateKey, true);
unique = unique0 != null && unique0;
}
if (unique) {
++uniqueCount;
relBuilder.push(joinInput);
relBuilder.project(belowAggregateKey.asList().stream().map(relBuilder::field).collect(Collectors.toList()));
side.newInput = relBuilder.build();
} else {
List<AggregateCall> belowAggCalls = new ArrayList<>();
final SqlSplittableAggFunction.Registry<AggregateCall> belowAggCallRegistry = registry(belowAggCalls);
final Mappings.TargetMapping mapping = s == 0 ? Mappings.createIdentity(fieldCount) : Mappings.createShiftMapping(fieldCount + offset, 0, offset, fieldCount);
for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) {
final SqlAggFunction aggregation = aggCall.e.getAggregation();
final SqlSplittableAggFunction splitter = Preconditions.checkNotNull(aggregation.unwrap(SqlSplittableAggFunction.class));
final AggregateCall call1;
if (fieldSet.contains(ImmutableBitSet.of(aggCall.e.getArgList()))) {
call1 = splitter.split(aggCall.e, mapping);
} else {
call1 = splitter.other(rexBuilder.getTypeFactory(), aggCall.e);
}
if (call1 != null) {
side.split.put(aggCall.i, belowAggregateKey.cardinality() + belowAggCallRegistry.register(call1));
}
}
side.newInput = relBuilder.push(joinInput).aggregate(relBuilder.groupKey(belowAggregateKey, null), belowAggCalls).build();
}
offset += fieldCount;
belowOffset += side.newInput.getRowType().getFieldCount();
sides.add(side);
}
if (uniqueCount == 2) {
// invocation of this rule; if we continue we might loop forever.
return;
}
// Update condition
final Mapping mapping = (Mapping) Mappings.target(map::get, join.getRowType().getFieldCount(), belowOffset);
final RexNode newCondition = RexUtil.apply(mapping, join.getCondition());
// Create new join
relBuilder.push(sides.get(0).newInput).push(sides.get(1).newInput).join(join.getJoinType(), newCondition);
// Aggregate above to sum up the sub-totals
final List<AggregateCall> newAggCalls = new ArrayList<>();
final int groupIndicatorCount = aggregate.getGroupCount() + aggregate.getIndicatorCount();
final int newLeftWidth = sides.get(0).newInput.getRowType().getFieldCount();
final List<RexNode> projects = new ArrayList<>(rexBuilder.identityProjects(relBuilder.peek().getRowType()));
for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) {
final SqlAggFunction aggregation = aggCall.e.getAggregation();
final SqlSplittableAggFunction splitter = Preconditions.checkNotNull(aggregation.unwrap(SqlSplittableAggFunction.class));
final Integer leftSubTotal = sides.get(0).split.get(aggCall.i);
final Integer rightSubTotal = sides.get(1).split.get(aggCall.i);
newAggCalls.add(splitter.topSplit(rexBuilder, registry(projects), groupIndicatorCount, relBuilder.peek().getRowType(), aggCall.e, leftSubTotal == null ? -1 : leftSubTotal, rightSubTotal == null ? -1 : rightSubTotal + newLeftWidth));
}
relBuilder.project(projects);
boolean aggConvertedToProjects = false;
if (allColumnsInAggregate) {
// let's see if we can convert aggregate into projects
List<RexNode> projects2 = new ArrayList<>();
for (int key : Mappings.apply(mapping, aggregate.getGroupSet())) {
projects2.add(relBuilder.field(key));
}
for (AggregateCall newAggCall : newAggCalls) {
final SqlSplittableAggFunction splitter = newAggCall.getAggregation().unwrap(SqlSplittableAggFunction.class);
if (splitter != null) {
final RelDataType rowType = relBuilder.peek().getRowType();
projects2.add(splitter.singleton(rexBuilder, rowType, newAggCall));
}
}
if (projects2.size() == aggregate.getGroupSet().cardinality() + newAggCalls.size()) {
// We successfully converted agg calls into projects.
relBuilder.project(projects2);
aggConvertedToProjects = true;
}
}
if (!aggConvertedToProjects) {
relBuilder.aggregate(relBuilder.groupKey(Mappings.apply(mapping, aggregate.getGroupSet()), Mappings.apply2(mapping, aggregate.getGroupSets())), newAggCalls);
}
RelNode r = relBuilder.build();
boolean transform = false;
if (uniqueBased && aggConvertedToProjects) {
transform = groupingUnique;
}
if (!transform && costBased) {
RelOptCost afterCost = mq.getCumulativeCost(r);
RelOptCost beforeCost = mq.getCumulativeCost(aggregate);
transform = afterCost.isLt(beforeCost);
}
if (transform) {
call.transformTo(r);
}
} catch (Exception e) {
if (noColsMissingStats.get() > 0) {
LOG.warn("Missing column stats (see previous messages), skipping aggregate-join transpose in CBO");
noColsMissingStats.set(0);
} else {
throw e;
}
}
}
Aggregations