use of org.apache.calcite.rel.core.AggregateCall in project hive by apache.
the class HiveAggregateSplitRule method onMatch.
@Override
public void onMatch(RelOptRuleCall call) {
final Aggregate aggregate = call.rel(0);
final RelBuilder relBuilder = call.builder();
// If any aggregate is distinct, bail out
// If any aggregate is the grouping id, bail out
// If any aggregate call has a filter, bail out
// If any aggregate functions do not support splitting, bail out
final ImmutableBitSet bottomAggregateGroupSet = aggregate.getGroupSet();
final List<AggregateCall> topAggregateCalls = new ArrayList<>();
for (int i = 0; i < aggregate.getAggCallList().size(); i++) {
AggregateCall aggregateCall = aggregate.getAggCallList().get(i);
if (aggregateCall.isDistinct()) {
return;
}
if (aggregateCall.getAggregation().equals(HiveGroupingID.INSTANCE)) {
return;
}
if (aggregateCall.filterArg >= 0) {
return;
}
SqlAggFunction aggFunction = HiveRelBuilder.getRollup(aggregateCall.getAggregation());
if (aggFunction == null) {
return;
}
topAggregateCalls.add(AggregateCall.create(aggFunction, aggregateCall.isDistinct(), aggregateCall.isApproximate(), ImmutableList.of(bottomAggregateGroupSet.cardinality() + i), -1, aggregateCall.collation, aggregateCall.type, aggregateCall.name));
}
if (aggregate.getCluster().getMetadataQuery().areColumnsUnique(aggregate.getInput(), bottomAggregateGroupSet)) {
// Nothing to do, probably already pushed
return;
}
final ImmutableBitSet topAggregateGroupSet = ImmutableBitSet.range(0, bottomAggregateGroupSet.cardinality());
final Map<Integer, Integer> map = new HashMap<>();
bottomAggregateGroupSet.forEach(k -> map.put(k, map.size()));
ImmutableList<ImmutableBitSet> topAggregateGroupSets = ImmutableBitSet.ORDERING.immutableSortedCopy(ImmutableBitSet.permute(aggregate.groupSets, map));
relBuilder.push(aggregate.getInput()).aggregate(relBuilder.groupKey(bottomAggregateGroupSet), aggregate.getAggCallList()).aggregate(relBuilder.groupKey(topAggregateGroupSet, topAggregateGroupSets), topAggregateCalls);
call.transformTo(relBuilder.build());
}
use of org.apache.calcite.rel.core.AggregateCall in project hive by apache.
the class HiveExpandDistinctAggregatesRule method rewriteAggCalls.
private static void rewriteAggCalls(List<AggregateCall> newAggCalls, List<Integer> argList, Map<Integer, Integer> sourceOf) {
// "COUNT(DISTINCT e.sal)" becomes "COUNT(distinct_e.sal)".
for (int i = 0; i < newAggCalls.size(); i++) {
final AggregateCall aggCall = newAggCalls.get(i);
// COUNT(DISTINCT gender) or SUM(sal).
if (!aggCall.isDistinct()) {
continue;
}
if (!aggCall.getArgList().equals(argList)) {
continue;
}
// Re-map arguments.
final int argCount = aggCall.getArgList().size();
final List<Integer> newArgs = new ArrayList<Integer>(argCount);
for (int j = 0; j < argCount; j++) {
final Integer arg = aggCall.getArgList().get(j);
newArgs.add(sourceOf.get(arg));
}
final AggregateCall newAggCall = new AggregateCall(aggCall.getAggregation(), false, newArgs, aggCall.getType(), aggCall.getName());
newAggCalls.set(i, newAggCall);
}
}
use of org.apache.calcite.rel.core.AggregateCall in project hive by apache.
the class HiveExpandDistinctAggregatesRule method createGroupingSets.
/**
* @param aggregate: the original aggregate
* @param argList: the original argList in aggregate
* @param cleanArgList: the new argList without duplicates
* @param map: the mapping from the original argList to the new argList
* @param groupSet: new group set
* @return
*/
private Aggregate createGroupingSets(Aggregate aggregate, List<List<Integer>> argList, List<List<Integer>> cleanArgList, Map<Integer, Integer> map, ImmutableBitSet groupSet) {
final List<ImmutableBitSet> origGroupSets = new ArrayList<>();
for (int i = 0; i < argList.size(); i++) {
List<Integer> list = argList.get(i);
ImmutableBitSet bitSet = aggregate.getGroupSet().union(ImmutableBitSet.of(list));
int prev = origGroupSets.indexOf(bitSet);
if (prev == -1) {
origGroupSets.add(bitSet);
cleanArgList.add(list);
} else {
map.put(i, prev);
}
}
// Calcite expects the grouping sets sorted and without duplicates
origGroupSets.sort(ImmutableBitSet.COMPARATOR);
List<AggregateCall> aggregateCalls = new ArrayList<AggregateCall>();
// Create GroupingID column
AggregateCall aggCall = AggregateCall.create(HiveGroupingID.INSTANCE, false, new ImmutableList.Builder<Integer>().build(), -1, this.cluster.getTypeFactory().createSqlType(SqlTypeName.BIGINT), HiveGroupingID.INSTANCE.getName());
aggregateCalls.add(aggCall);
return new HiveAggregate(cluster, cluster.traitSetOf(HiveRelNode.CONVENTION), aggregate.getInput(), groupSet, origGroupSets, aggregateCalls);
}
use of org.apache.calcite.rel.core.AggregateCall in project hive by apache.
the class HiveExpandDistinctAggregatesRule method convertMonopole.
/**
* Converts an aggregate relational expression that contains just one
* distinct aggregate function (or perhaps several over the same arguments)
* and no non-distinct aggregate functions.
*/
private RelNode convertMonopole(Aggregate aggregate, List<Integer> argList) {
// For example,
// SELECT deptno, COUNT(DISTINCT sal), SUM(DISTINCT sal)
// FROM emp
// GROUP BY deptno
//
// becomes
//
// SELECT deptno, COUNT(distinct_sal), SUM(distinct_sal)
// FROM (
// SELECT DISTINCT deptno, sal AS distinct_sal
// FROM EMP GROUP BY deptno)
// GROUP BY deptno
// Project the columns of the GROUP BY plus the arguments
// to the agg function.
Map<Integer, Integer> sourceOf = new HashMap<Integer, Integer>();
final Aggregate distinct = createSelectDistinct(aggregate, argList, sourceOf);
// Create an aggregate on top, with the new aggregate list.
final List<AggregateCall> newAggCalls = Lists.newArrayList(aggregate.getAggCallList());
rewriteAggCalls(newAggCalls, argList, sourceOf);
final int cardinality = aggregate.getGroupSet().cardinality();
return aggregate.copy(aggregate.getTraitSet(), distinct, aggregate.indicator, ImmutableBitSet.range(cardinality), null, newAggCalls);
}
use of org.apache.calcite.rel.core.AggregateCall in project hive by apache.
the class HiveAggregateReduceFunctionsRule method reduceStddev.
private RexNode reduceStddev(Aggregate oldAggRel, AggregateCall oldCall, boolean biased, boolean sqrt, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping, List<RexNode> inputExprs) {
// stddev_pop(x) ==>
// power(
// (sum(x * x) - sum(x) * sum(x) / count(x))
// / count(x),
// .5)
//
// stddev_samp(x) ==>
// power(
// (sum(x * x) - sum(x) * sum(x) / count(x))
// / nullif(count(x) - 1, 0),
// .5)
final int nGroups = oldAggRel.getGroupCount();
final RelOptCluster cluster = oldAggRel.getCluster();
final RexBuilder rexBuilder = cluster.getRexBuilder();
final RelDataTypeFactory typeFactory = cluster.getTypeFactory();
assert oldCall.getArgList().size() == 1 : oldCall.getArgList();
final int argOrdinal = oldCall.getArgList().get(0);
final RelDataType argOrdinalType = getFieldType(oldAggRel.getInput(), argOrdinal);
final RelDataType oldCallType = typeFactory.createTypeWithNullability(oldCall.getType(), true);
final RexNode argRef = rexBuilder.ensureType(oldCallType, inputExprs.get(argOrdinal), false);
final int argRefOrdinal = lookupOrAdd(inputExprs, argRef);
final RelDataType sumReturnType = getSumReturnType(rexBuilder.getTypeFactory(), argRef.getType());
final RexNode argSquared = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, argRef, argRef);
final int argSquaredOrdinal = lookupOrAdd(inputExprs, argSquared);
final RelDataType sumSquaredReturnType = getSumReturnType(rexBuilder.getTypeFactory(), argSquared.getType());
final AggregateCall sumArgSquaredAggCall = createAggregateCallWithBinding(typeFactory, new HiveSqlSumAggFunction(oldCall.isDistinct(), ReturnTypes.explicit(sumSquaredReturnType), InferTypes.explicit(Collections.singletonList(argSquared.getType())), // SqlStdOperatorTable.SUM,
oldCall.getAggregation().getOperandTypeChecker()), argSquared.getType(), oldAggRel, oldCall, argSquaredOrdinal);
final RexNode sumArgSquared = rexBuilder.addAggCall(sumArgSquaredAggCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(sumArgSquaredAggCall.getType()));
final AggregateCall sumArgAggCall = AggregateCall.create(new HiveSqlSumAggFunction(oldCall.isDistinct(), ReturnTypes.explicit(sumReturnType), InferTypes.explicit(Collections.singletonList(argOrdinalType)), // SqlStdOperatorTable.SUM,
oldCall.getAggregation().getOperandTypeChecker()), oldCall.isDistinct(), oldCall.isApproximate(), ImmutableIntList.of(argRefOrdinal), oldCall.filterArg, oldAggRel.getGroupCount(), oldAggRel.getInput(), null, null);
final RexNode sumArg = rexBuilder.addAggCall(sumArgAggCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(sumArgAggCall.getType()));
final RexNode sumArgCast = rexBuilder.ensureType(oldCallType, sumArg, true);
final RexNode sumSquaredArg = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, sumArgCast, sumArgCast);
RelDataType countRetType = typeFactory.createTypeWithNullability(typeFactory.createSqlType(SqlTypeName.BIGINT), true);
final AggregateCall countArgAggCall = AggregateCall.create(new HiveSqlCountAggFunction(oldCall.isDistinct(), ReturnTypes.explicit(countRetType), oldCall.getAggregation().getOperandTypeInference(), // SqlStdOperatorTable.COUNT,
oldCall.getAggregation().getOperandTypeChecker()), oldCall.isDistinct(), oldCall.isApproximate(), oldCall.getArgList(), oldCall.filterArg, oldAggRel.getGroupCount(), oldAggRel.getInput(), countRetType, null);
final RexNode countArg = rexBuilder.addAggCall(countArgAggCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(argOrdinalType));
final RexNode avgSumSquaredArg = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, sumSquaredArg, countArg);
final RexNode diff = rexBuilder.makeCall(SqlStdOperatorTable.MINUS, sumArgSquared, avgSumSquaredArg);
final RexNode denominator;
if (biased) {
denominator = countArg;
} else {
final RexLiteral one = rexBuilder.makeExactLiteral(BigDecimal.ONE);
final RexNode nul = rexBuilder.makeCast(countArg.getType(), rexBuilder.constantNull());
final RexNode countMinusOne = rexBuilder.makeCall(SqlStdOperatorTable.MINUS, countArg, one);
final RexNode countEqOne = rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, countArg, one);
denominator = rexBuilder.makeCall(SqlStdOperatorTable.CASE, countEqOne, nul, countMinusOne);
}
final RexNode div = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, diff, denominator);
RexNode result = div;
if (sqrt) {
final RexNode half = rexBuilder.makeExactLiteral(new BigDecimal("0.5"));
result = rexBuilder.makeCall(SqlStdOperatorTable.POWER, div, half);
}
return rexBuilder.makeCast(oldCall.getType(), result);
}
Aggregations