use of org.apache.calcite.rel.core.AggregateCall in project hive by apache.
the class SubstitutionVisitor method unifyAggregates.
public static MutableRel unifyAggregates(MutableAggregate query, MutableAggregate target) {
if (query.getGroupType() != Aggregate.Group.SIMPLE || target.getGroupType() != Aggregate.Group.SIMPLE) {
throw new AssertionError(Bug.CALCITE_461_FIXED);
}
MutableRel result;
if (query.getGroupSet().equals(target.getGroupSet())) {
// Same level of aggregation. Generate a project.
final List<Integer> projects = Lists.newArrayList();
final int groupCount = query.getGroupSet().cardinality();
for (int i = 0; i < groupCount; i++) {
projects.add(i);
}
for (AggregateCall aggregateCall : query.getAggCallList()) {
int i = target.getAggCallList().indexOf(aggregateCall);
if (i < 0) {
return null;
}
projects.add(groupCount + i);
}
result = MutableRels.createProject(target, projects);
} else {
// Target is coarser level of aggregation. Generate an aggregate.
final ImmutableBitSet.Builder groupSet = ImmutableBitSet.builder();
final List<Integer> targetGroupList = target.getGroupSet().asList();
for (int c : query.getGroupSet()) {
int c2 = targetGroupList.indexOf(c);
if (c2 < 0) {
return null;
}
groupSet.set(c2);
}
final List<AggregateCall> aggregateCalls = Lists.newArrayList();
for (AggregateCall aggregateCall : query.getAggCallList()) {
if (aggregateCall.isDistinct()) {
return null;
}
int i = target.getAggCallList().indexOf(aggregateCall);
if (i < 0) {
return null;
}
aggregateCalls.add(AggregateCall.create(getRollup(aggregateCall.getAggregation()), aggregateCall.isDistinct(), ImmutableList.of(target.groupSet.cardinality() + i), -1, aggregateCall.type, aggregateCall.name));
}
result = MutableAggregate.of(target, false, groupSet.build(), null, aggregateCalls);
}
return MutableRels.createCastRel(result, query.getRowType(), true);
}
use of org.apache.calcite.rel.core.AggregateCall in project hive by apache.
the class HiveExpandDistinctAggregatesRule method rewriteAggCalls.
private static void rewriteAggCalls(List<AggregateCall> newAggCalls, List<Integer> argList, Map<Integer, Integer> sourceOf) {
// "COUNT(DISTINCT e.sal)" becomes "COUNT(distinct_e.sal)".
for (int i = 0; i < newAggCalls.size(); i++) {
final AggregateCall aggCall = newAggCalls.get(i);
// COUNT(DISTINCT gender) or SUM(sal).
if (!aggCall.isDistinct()) {
continue;
}
if (!aggCall.getArgList().equals(argList)) {
continue;
}
// Re-map arguments.
final int argCount = aggCall.getArgList().size();
final List<Integer> newArgs = new ArrayList<Integer>(argCount);
for (int j = 0; j < argCount; j++) {
final Integer arg = aggCall.getArgList().get(j);
newArgs.add(sourceOf.get(arg));
}
final AggregateCall newAggCall = new AggregateCall(aggCall.getAggregation(), false, newArgs, aggCall.getType(), aggCall.getName());
newAggCalls.set(i, newAggCall);
}
}
use of org.apache.calcite.rel.core.AggregateCall in project drill by apache.
the class DrillReduceAggregatesRule method reduceStddev.
private RexNode reduceStddev(Aggregate oldAggRel, AggregateCall oldCall, boolean biased, boolean sqrt, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping, List<RexNode> inputExprs) {
// stddev_pop(x) ==>
// power(
// (sum(x * x) - sum(x) * sum(x) / count(x))
// / count(x),
// .5)
//
// stddev_samp(x) ==>
// power(
// (sum(x * x) - sum(x) * sum(x) / count(x))
// / nullif(count(x) - 1, 0),
// .5)
final PlannerSettings plannerSettings = (PlannerSettings) oldAggRel.getCluster().getPlanner().getContext();
final boolean isInferenceEnabled = plannerSettings.isTypeInferenceEnabled();
final int nGroups = oldAggRel.getGroupCount();
RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory();
final RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
assert oldCall.getArgList().size() == 1 : oldCall.getArgList();
final int argOrdinal = oldCall.getArgList().get(0);
final RelDataType argType = getFieldType(oldAggRel.getInput(), argOrdinal);
// final RexNode argRef = inputExprs.get(argOrdinal);
RexNode argRef = rexBuilder.makeCall(CastHighOp, inputExprs.get(argOrdinal));
inputExprs.set(argOrdinal, argRef);
final RexNode argSquared = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, argRef, argRef);
final int argSquaredOrdinal = lookupOrAdd(inputExprs, argSquared);
final RelDataType sumType = typeFactory.createTypeWithNullability(argType, true);
final AggregateCall sumArgSquaredAggCall = AggregateCall.create(new SqlSumAggFunction(sumType), oldCall.isDistinct(), ImmutableIntList.of(argSquaredOrdinal), -1, sumType, null);
final RexNode sumArgSquared = rexBuilder.addAggCall(sumArgSquaredAggCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(argType));
final AggregateCall sumArgAggCall = AggregateCall.create(new SqlSumAggFunction(sumType), oldCall.isDistinct(), ImmutableIntList.of(argOrdinal), -1, sumType, null);
final RexNode sumArg = rexBuilder.addAggCall(sumArgAggCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(argType));
final RexNode sumSquaredArg = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, sumArg, sumArg);
final SqlCountAggFunction countAgg = (SqlCountAggFunction) SqlStdOperatorTable.COUNT;
final RelDataType countType = countAgg.getReturnType(typeFactory);
final AggregateCall countArgAggCall = AggregateCall.create(countAgg, oldCall.isDistinct(), oldCall.getArgList(), -1, countType, null);
final RexNode countArg = rexBuilder.addAggCall(countArgAggCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(argType));
final RexNode avgSumSquaredArg = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, sumSquaredArg, countArg);
final RexNode diff = rexBuilder.makeCall(SqlStdOperatorTable.MINUS, sumArgSquared, avgSumSquaredArg);
final RexNode denominator;
if (biased) {
denominator = countArg;
} else {
final RexLiteral one = rexBuilder.makeExactLiteral(BigDecimal.ONE);
final RexNode nul = rexBuilder.makeNullLiteral(countArg.getType().getSqlTypeName());
final RexNode countMinusOne = rexBuilder.makeCall(SqlStdOperatorTable.MINUS, countArg, one);
final RexNode countEqOne = rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, countArg, one);
denominator = rexBuilder.makeCall(SqlStdOperatorTable.CASE, countEqOne, nul, countMinusOne);
}
final SqlOperator divide;
if (isInferenceEnabled) {
divide = new DrillSqlOperator("divide", 2, true, oldCall.getType(), false);
} else {
divide = SqlStdOperatorTable.DIVIDE;
}
final RexNode div = rexBuilder.makeCall(divide, diff, denominator);
RexNode result = div;
if (sqrt) {
final RexNode half = rexBuilder.makeExactLiteral(new BigDecimal("0.5"));
result = rexBuilder.makeCall(SqlStdOperatorTable.POWER, div, half);
}
if (isInferenceEnabled) {
return result;
} else {
/*
* Currently calcite's strategy to infer the return type of aggregate functions
* is wrong because it uses the first known argument to determine output type. For
* instance if we are performing stddev on an integer column then it interprets the
* output type to be integer which is incorrect as it should be double. So based on
* this if we add cast after rewriting the aggregate we add an additional cast which
* would cause wrong results. So we simply add a cast to ANY.
*/
return rexBuilder.makeCast(typeFactory.createSqlType(SqlTypeName.ANY), result);
}
}
use of org.apache.calcite.rel.core.AggregateCall in project drill by apache.
the class DrillReduceAggregatesRule method reduceSum.
private RexNode reduceSum(Aggregate oldAggRel, AggregateCall oldCall, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping) {
final PlannerSettings plannerSettings = (PlannerSettings) oldAggRel.getCluster().getPlanner().getContext();
final boolean isInferenceEnabled = plannerSettings.isTypeInferenceEnabled();
final int nGroups = oldAggRel.getGroupCount();
RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory();
RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
int arg = oldCall.getArgList().get(0);
RelDataType argType = getFieldType(oldAggRel.getInput(), arg);
final RelDataType sumType;
final SqlAggFunction sumZeroAgg;
if (isInferenceEnabled) {
sumType = oldCall.getType();
sumZeroAgg = new DrillCalciteSqlAggFunctionWrapper(new SqlSumEmptyIsZeroAggFunction(), sumType);
} else {
sumType = typeFactory.createTypeWithNullability(argType, argType.isNullable());
sumZeroAgg = new SqlSumEmptyIsZeroAggFunction();
}
AggregateCall sumZeroCall = AggregateCall.create(sumZeroAgg, oldCall.isDistinct(), oldCall.getArgList(), -1, sumType, null);
final SqlCountAggFunction countAgg = (SqlCountAggFunction) SqlStdOperatorTable.COUNT;
final RelDataType countType = countAgg.getReturnType(typeFactory);
AggregateCall countCall = AggregateCall.create(countAgg, oldCall.isDistinct(), oldCall.getArgList(), -1, countType, null);
// NOTE: these references are with respect to the output
// of newAggRel
RexNode sumZeroRef = rexBuilder.addAggCall(sumZeroCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(argType));
if (!oldCall.getType().isNullable()) {
// null). Therefore we translate to SUM0(x).
return sumZeroRef;
}
RexNode countRef = rexBuilder.addAggCall(countCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(argType));
return rexBuilder.makeCall(SqlStdOperatorTable.CASE, rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, countRef, rexBuilder.makeExactLiteral(BigDecimal.ZERO)), rexBuilder.constantNull(), sumZeroRef);
}
use of org.apache.calcite.rel.core.AggregateCall in project drill by apache.
the class AggPrelBase method createKeysAndExprs.
protected void createKeysAndExprs() {
final List<String> childFields = getInput().getRowType().getFieldNames();
final List<String> fields = getRowType().getFieldNames();
for (int group : BitSets.toIter(groupSet)) {
FieldReference fr = FieldReference.getWithQuotedRef(childFields.get(group));
keys.add(new NamedExpression(fr, fr));
}
for (Ord<AggregateCall> aggCall : Ord.zip(aggCalls)) {
int aggExprOrdinal = groupSet.cardinality() + aggCall.i;
FieldReference ref = FieldReference.getWithQuotedRef(fields.get(aggExprOrdinal));
LogicalExpression expr = toDrill(aggCall.e, childFields);
NamedExpression ne = new NamedExpression(expr, ref);
aggExprs.add(ne);
if (getOperatorPhase() == OperatorPhase.PHASE_1of2) {
if (aggCall.e.getAggregation().getName().equals("COUNT")) {
// If we are doing a COUNT aggregate in Phase1of2, then in Phase2of2 we should SUM the COUNTs,
SqlAggFunction sumAggFun = new SqlSumCountAggFunction(aggCall.e.getType());
AggregateCall newAggCall = new AggregateCall(sumAggFun, aggCall.e.isDistinct(), Collections.singletonList(aggExprOrdinal), aggCall.e.getType(), aggCall.e.getName());
phase2AggCallList.add(newAggCall);
} else {
AggregateCall newAggCall = new AggregateCall(aggCall.e.getAggregation(), aggCall.e.isDistinct(), Collections.singletonList(aggExprOrdinal), aggCall.e.getType(), aggCall.e.getName());
phase2AggCallList.add(newAggCall);
}
}
}
}
Aggregations