use of org.apache.calcite.rel.core.AggregateCall in project flink by apache.
the class FlinkAggregateExpandDistinctAggregatesRule method rewriteUsingGroupingSets.
/*
public RelBuilder convertSingletonDistinct(RelBuilder relBuilder,
Aggregate aggregate, Set<Pair<List<Integer>, Integer>> argLists) {
// For example,
// SELECT deptno, COUNT(*), SUM(bonus), MIN(DISTINCT sal)
// FROM emp
// GROUP BY deptno
//
// becomes
//
// SELECT deptno, SUM(cnt), SUM(bonus), MIN(sal)
// FROM (
// SELECT deptno, COUNT(*) as cnt, SUM(bonus), sal
// FROM EMP
// GROUP BY deptno, sal) // Aggregate B
// GROUP BY deptno // Aggregate A
relBuilder.push(aggregate.getInput());
final List<Pair<RexNode, String>> projects = new ArrayList<>();
final Map<Integer, Integer> sourceOf = new HashMap<>();
SortedSet<Integer> newGroupSet = new TreeSet<>();
final List<RelDataTypeField> childFields =
relBuilder.peek().getRowType().getFieldList();
final boolean hasGroupBy = aggregate.getGroupSet().size() > 0;
// Add the distinct aggregate column(s) to the group-by columns,
// if not already a part of the group-by
newGroupSet.addAll(aggregate.getGroupSet().asList());
for (Pair<List<Integer>, Integer> argList : argLists) {
newGroupSet.addAll(argList.getKey());
}
// Re-map the arguments to the aggregate A. These arguments will get
// remapped because of the intermediate aggregate B generated as part of the
// transformation.
for (int arg : newGroupSet) {
sourceOf.put(arg, projects.size());
projects.add(RexInputRef.of2(arg, childFields));
}
// Generate the intermediate aggregate B
final List<AggregateCall> aggCalls = aggregate.getAggCallList();
final List<AggregateCall> newAggCalls = new ArrayList<>();
final List<Integer> fakeArgs = new ArrayList<>();
final Map<AggregateCall, Integer> callArgMap = new HashMap<>();
// First identify the real arguments, then use the rest for fake arguments
// e.g. if real arguments are 0, 1, 3. Then the fake arguments will be 2, 4
for (final AggregateCall aggCall : aggCalls) {
if (!aggCall.isDistinct()) {
for (int arg : aggCall.getArgList()) {
if (!sourceOf.containsKey(arg)) {
sourceOf.put(arg, projects.size());
}
}
}
}
int fakeArg0 = 0;
for (final AggregateCall aggCall : aggCalls) {
// We will deal with non-distinct aggregates below
if (!aggCall.isDistinct()) {
boolean isGroupKeyUsedInAgg = false;
for (int arg : aggCall.getArgList()) {
if (sourceOf.containsKey(arg)) {
isGroupKeyUsedInAgg = true;
break;
}
}
if (aggCall.getArgList().size() == 0 || isGroupKeyUsedInAgg) {
while (sourceOf.get(fakeArg0) != null) {
++fakeArg0;
}
fakeArgs.add(fakeArg0);
}
}
}
for (final AggregateCall aggCall : aggCalls) {
if (!aggCall.isDistinct()) {
for (int arg : aggCall.getArgList()) {
if (!sourceOf.containsKey(arg)) {
sourceOf.remove(arg);
}
}
}
}
// Compute the remapped arguments using fake arguments for non-distinct
// aggregates with no arguments e.g. count(*).
int fakeArgIdx = 0;
for (final AggregateCall aggCall : aggCalls) {
// Project the column corresponding to the distinct aggregate. Project
// as-is all the non-distinct aggregates
if (!aggCall.isDistinct()) {
final AggregateCall newCall =
AggregateCall.create(aggCall.getAggregation(), false,
aggCall.getArgList(), -1,
ImmutableBitSet.of(newGroupSet).cardinality(),
relBuilder.peek(), null, aggCall.name);
newAggCalls.add(newCall);
if (newCall.getArgList().size() == 0) {
int fakeArg = fakeArgs.get(fakeArgIdx);
callArgMap.put(newCall, fakeArg);
sourceOf.put(fakeArg, projects.size());
projects.add(
Pair.of((RexNode) new RexInputRef(fakeArg, newCall.getType()),
newCall.getName()));
++fakeArgIdx;
} else {
for (int arg : newCall.getArgList()) {
if (sourceOf.containsKey(arg)) {
int fakeArg = fakeArgs.get(fakeArgIdx);
callArgMap.put(newCall, fakeArg);
sourceOf.put(fakeArg, projects.size());
projects.add(
Pair.of((RexNode) new RexInputRef(fakeArg, newCall.getType()),
newCall.getName()));
++fakeArgIdx;
} else {
sourceOf.put(arg, projects.size());
projects.add(
Pair.of((RexNode) new RexInputRef(arg, newCall.getType()),
newCall.getName()));
}
}
}
}
}
// Generate the aggregate B (see the reference example above)
relBuilder.push(
aggregate.copy(
aggregate.getTraitSet(), relBuilder.build(),
false, ImmutableBitSet.of(newGroupSet), null, newAggCalls));
// Convert the existing aggregate to aggregate A (see the reference example above)
final List<AggregateCall> newTopAggCalls =
Lists.newArrayList(aggregate.getAggCallList());
// Use the remapped arguments for the (non)distinct aggregate calls
for (int i = 0; i < newTopAggCalls.size(); i++) {
// Re-map arguments.
final AggregateCall aggCall = newTopAggCalls.get(i);
final int argCount = aggCall.getArgList().size();
final List<Integer> newArgs = new ArrayList<>(argCount);
final AggregateCall newCall;
for (int j = 0; j < argCount; j++) {
final Integer arg = aggCall.getArgList().get(j);
if (callArgMap.containsKey(aggCall)) {
newArgs.add(sourceOf.get(callArgMap.get(aggCall)));
}
else {
newArgs.add(sourceOf.get(arg));
}
}
if (aggCall.isDistinct()) {
newCall =
AggregateCall.create(aggCall.getAggregation(), false, newArgs,
-1, aggregate.getGroupSet().cardinality(), relBuilder.peek(),
aggCall.getType(), aggCall.name);
} else {
// If aggregate B had a COUNT aggregate call the corresponding aggregate at
// aggregate A must be SUM. For other aggregates, it remains the same.
if (aggCall.getAggregation() instanceof SqlCountAggFunction) {
if (aggCall.getArgList().size() == 0) {
newArgs.add(sourceOf.get(callArgMap.get(aggCall)));
}
if (hasGroupBy) {
SqlSumAggFunction sumAgg = new SqlSumAggFunction(null);
newCall =
AggregateCall.create(sumAgg, false, newArgs, -1,
aggregate.getGroupSet().cardinality(), relBuilder.peek(),
aggCall.getType(), aggCall.getName());
} else {
SqlSumEmptyIsZeroAggFunction sumAgg = new SqlSumEmptyIsZeroAggFunction();
newCall =
AggregateCall.create(sumAgg, false, newArgs, -1,
aggregate.getGroupSet().cardinality(), relBuilder.peek(),
aggCall.getType(), aggCall.getName());
}
} else {
newCall =
AggregateCall.create(aggCall.getAggregation(), false, newArgs, -1,
aggregate.getGroupSet().cardinality(),
relBuilder.peek(), aggCall.getType(), aggCall.name);
}
}
newTopAggCalls.set(i, newCall);
}
// Populate the group-by keys with the remapped arguments for aggregate A
newGroupSet.clear();
for (int arg : aggregate.getGroupSet()) {
newGroupSet.add(sourceOf.get(arg));
}
relBuilder.push(
aggregate.copy(aggregate.getTraitSet(),
relBuilder.build(), aggregate.indicator,
ImmutableBitSet.of(newGroupSet), null, newTopAggCalls));
return relBuilder;
}
*/
@SuppressWarnings("DanglingJavadoc")
private void rewriteUsingGroupingSets(RelOptRuleCall call, Aggregate aggregate, Set<Pair<List<Integer>, Integer>> argLists) {
final Set<ImmutableBitSet> groupSetTreeSet = new TreeSet<>(ImmutableBitSet.ORDERING);
groupSetTreeSet.add(aggregate.getGroupSet());
for (Pair<List<Integer>, Integer> argList : argLists) {
groupSetTreeSet.add(ImmutableBitSet.of(argList.left).setIf(argList.right, argList.right >= 0).union(aggregate.getGroupSet()));
}
final ImmutableList<ImmutableBitSet> groupSets = ImmutableList.copyOf(groupSetTreeSet);
final ImmutableBitSet fullGroupSet = ImmutableBitSet.union(groupSets);
final List<AggregateCall> distinctAggCalls = new ArrayList<>();
for (Pair<AggregateCall, String> aggCall : aggregate.getNamedAggCalls()) {
if (!aggCall.left.isDistinct()) {
distinctAggCalls.add(aggCall.left.rename(aggCall.right));
}
}
final RelBuilder relBuilder = call.builder();
relBuilder.push(aggregate.getInput());
relBuilder.aggregate(relBuilder.groupKey(fullGroupSet, groupSets.size() > 1, groupSets), distinctAggCalls);
final RelNode distinct = relBuilder.peek();
final int groupCount = fullGroupSet.cardinality();
final int indicatorCount = groupSets.size() > 1 ? groupCount : 0;
final RelOptCluster cluster = aggregate.getCluster();
final RexBuilder rexBuilder = cluster.getRexBuilder();
final RelDataTypeFactory typeFactory = cluster.getTypeFactory();
final RelDataType booleanType = typeFactory.createTypeWithNullability(typeFactory.createSqlType(SqlTypeName.BOOLEAN), false);
final List<Pair<RexNode, String>> predicates = new ArrayList<>();
final Map<ImmutableBitSet, Integer> filters = new HashMap<>();
/** Function to register a filter for a group set. */
class Registrar {
RexNode group = null;
private int register(ImmutableBitSet groupSet) {
if (group == null) {
group = makeGroup(groupCount - 1);
}
final RexNode node = rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, group, rexBuilder.makeExactLiteral(toNumber(remap(fullGroupSet, groupSet))));
predicates.add(Pair.of(node, toString(groupSet)));
return groupCount + indicatorCount + distinctAggCalls.size() + predicates.size() - 1;
}
private RexNode makeGroup(int i) {
final RexInputRef ref = rexBuilder.makeInputRef(booleanType, groupCount + i);
final RexNode kase = rexBuilder.makeCall(SqlStdOperatorTable.CASE, ref, rexBuilder.makeExactLiteral(BigDecimal.ZERO), rexBuilder.makeExactLiteral(TWO.pow(i)));
if (i == 0) {
return kase;
} else {
return rexBuilder.makeCall(SqlStdOperatorTable.PLUS, makeGroup(i - 1), kase);
}
}
private BigDecimal toNumber(ImmutableBitSet bitSet) {
BigDecimal n = BigDecimal.ZERO;
for (int key : bitSet) {
n = n.add(TWO.pow(key));
}
return n;
}
private String toString(ImmutableBitSet bitSet) {
final StringBuilder buf = new StringBuilder("$i");
for (int key : bitSet) {
buf.append(key).append('_');
}
return buf.substring(0, buf.length() - 1);
}
}
final Registrar registrar = new Registrar();
for (ImmutableBitSet groupSet : groupSets) {
filters.put(groupSet, registrar.register(groupSet));
}
if (!predicates.isEmpty()) {
List<Pair<RexNode, String>> nodes = new ArrayList<>();
for (RelDataTypeField f : relBuilder.peek().getRowType().getFieldList()) {
final RexNode node = rexBuilder.makeInputRef(f.getType(), f.getIndex());
nodes.add(Pair.of(node, f.getName()));
}
nodes.addAll(predicates);
relBuilder.project(Pair.left(nodes), Pair.right(nodes));
}
int x = groupCount + indicatorCount;
final List<AggregateCall> newCalls = new ArrayList<>();
for (AggregateCall aggCall : aggregate.getAggCallList()) {
final int newFilterArg;
final List<Integer> newArgList;
final SqlAggFunction aggregation;
if (!aggCall.isDistinct()) {
aggregation = SqlStdOperatorTable.MIN;
newArgList = ImmutableIntList.of(x++);
newFilterArg = filters.get(aggregate.getGroupSet());
} else {
aggregation = aggCall.getAggregation();
newArgList = remap(fullGroupSet, aggCall.getArgList());
newFilterArg = filters.get(ImmutableBitSet.of(aggCall.getArgList()).setIf(aggCall.filterArg, aggCall.filterArg >= 0).union(aggregate.getGroupSet()));
}
final AggregateCall newCall = AggregateCall.create(aggregation, false, newArgList, newFilterArg, aggregate.getGroupCount(), distinct, null, aggCall.name);
newCalls.add(newCall);
}
relBuilder.aggregate(relBuilder.groupKey(remap(fullGroupSet, aggregate.getGroupSet()), aggregate.indicator, remap(fullGroupSet, aggregate.getGroupSets())), newCalls);
relBuilder.convert(aggregate.getRowType(), true);
call.transformTo(relBuilder.build());
}
use of org.apache.calcite.rel.core.AggregateCall in project flink by apache.
the class FlinkAggregateJoinTransposeRule method onMatch.
public void onMatch(RelOptRuleCall call) {
final Aggregate aggregate = call.rel(0);
final Join join = call.rel(1);
final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
final RelBuilder relBuilder = call.builder();
// If any aggregate call has a filter, bail out
for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
if (aggregateCall.getAggregation().unwrap(SqlSplittableAggFunction.class) == null) {
return;
}
if (aggregateCall.filterArg >= 0) {
return;
}
}
// aggregate operator
if (join.getJoinType() != JoinRelType.INNER) {
return;
}
if (!allowFunctions && !aggregate.getAggCallList().isEmpty()) {
return;
}
// Do the columns used by the join appear in the output of the aggregate?
final ImmutableBitSet aggregateColumns = aggregate.getGroupSet();
final RelMetadataQuery mq = RelMetadataQuery.instance();
final ImmutableBitSet keyColumns = keyColumns(aggregateColumns, mq.getPulledUpPredicates(join).pulledUpPredicates);
final ImmutableBitSet joinColumns = RelOptUtil.InputFinder.bits(join.getCondition());
final boolean allColumnsInAggregate = keyColumns.contains(joinColumns);
final ImmutableBitSet belowAggregateColumns = aggregateColumns.union(joinColumns);
// Split join condition
final List<Integer> leftKeys = Lists.newArrayList();
final List<Integer> rightKeys = Lists.newArrayList();
final List<Boolean> filterNulls = Lists.newArrayList();
RexNode nonEquiConj = RelOptUtil.splitJoinCondition(join.getLeft(), join.getRight(), join.getCondition(), leftKeys, rightKeys, filterNulls);
// If it contains non-equi join conditions, we bail out
if (!nonEquiConj.isAlwaysTrue()) {
return;
}
// Push each aggregate function down to each side that contains all of its
// arguments. Note that COUNT(*), because it has no arguments, can go to
// both sides.
final Map<Integer, Integer> map = new HashMap<>();
final List<Side> sides = new ArrayList<>();
int uniqueCount = 0;
int offset = 0;
int belowOffset = 0;
for (int s = 0; s < 2; s++) {
final Side side = new Side();
final RelNode joinInput = join.getInput(s);
int fieldCount = joinInput.getRowType().getFieldCount();
final ImmutableBitSet fieldSet = ImmutableBitSet.range(offset, offset + fieldCount);
final ImmutableBitSet belowAggregateKeyNotShifted = belowAggregateColumns.intersect(fieldSet);
for (Ord<Integer> c : Ord.zip(belowAggregateKeyNotShifted)) {
map.put(c.e, belowOffset + c.i);
}
final ImmutableBitSet belowAggregateKey = belowAggregateKeyNotShifted.shift(-offset);
final boolean unique;
if (!allowFunctions) {
assert aggregate.getAggCallList().isEmpty();
// If there are no functions, it doesn't matter as much whether we
// aggregate the inputs before the join, because there will not be
// any functions experiencing a cartesian product effect.
//
// But finding out whether the input is already unique requires a call
// to areColumnsUnique that currently (until [CALCITE-1048] "Make
// metadata more robust" is fixed) places a heavy load on
// the metadata system.
//
// So we choose to imagine the the input is already unique, which is
// untrue but harmless.
//
Util.discard(Bug.CALCITE_1048_FIXED);
unique = true;
} else {
final Boolean unique0 = mq.areColumnsUnique(joinInput, belowAggregateKey);
unique = unique0 != null && unique0;
}
if (unique) {
++uniqueCount;
side.aggregate = false;
side.newInput = joinInput;
} else {
side.aggregate = true;
List<AggregateCall> belowAggCalls = new ArrayList<>();
final SqlSplittableAggFunction.Registry<AggregateCall> belowAggCallRegistry = registry(belowAggCalls);
final Mappings.TargetMapping mapping = s == 0 ? Mappings.createIdentity(fieldCount) : Mappings.createShiftMapping(fieldCount + offset, 0, offset, fieldCount);
for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) {
final SqlAggFunction aggregation = aggCall.e.getAggregation();
final SqlSplittableAggFunction splitter = Preconditions.checkNotNull(aggregation.unwrap(SqlSplittableAggFunction.class));
final AggregateCall call1;
if (fieldSet.contains(ImmutableBitSet.of(aggCall.e.getArgList()))) {
call1 = splitter.split(aggCall.e, mapping);
} else {
call1 = splitter.other(rexBuilder.getTypeFactory(), aggCall.e);
}
if (call1 != null) {
side.split.put(aggCall.i, belowAggregateKey.cardinality() + belowAggCallRegistry.register(call1));
}
}
side.newInput = relBuilder.push(joinInput).aggregate(relBuilder.groupKey(belowAggregateKey, false, null), belowAggCalls).build();
}
offset += fieldCount;
belowOffset += side.newInput.getRowType().getFieldCount();
sides.add(side);
}
if (uniqueCount == 2) {
// invocation of this rule; if we continue we might loop forever.
return;
}
// Update condition
final Mapping mapping = (Mapping) Mappings.target(new Function<Integer, Integer>() {
public Integer apply(Integer a0) {
return map.get(a0);
}
}, join.getRowType().getFieldCount(), belowOffset);
final RexNode newCondition = RexUtil.apply(mapping, join.getCondition());
// Create new join
relBuilder.push(sides.get(0).newInput).push(sides.get(1).newInput).join(join.getJoinType(), newCondition);
// Aggregate above to sum up the sub-totals
final List<AggregateCall> newAggCalls = new ArrayList<>();
final int groupIndicatorCount = aggregate.getGroupCount() + aggregate.getIndicatorCount();
final int newLeftWidth = sides.get(0).newInput.getRowType().getFieldCount();
final List<RexNode> projects = new ArrayList<>(rexBuilder.identityProjects(relBuilder.peek().getRowType()));
for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) {
final SqlAggFunction aggregation = aggCall.e.getAggregation();
final SqlSplittableAggFunction splitter = Preconditions.checkNotNull(aggregation.unwrap(SqlSplittableAggFunction.class));
final Integer leftSubTotal = sides.get(0).split.get(aggCall.i);
final Integer rightSubTotal = sides.get(1).split.get(aggCall.i);
newAggCalls.add(splitter.topSplit(rexBuilder, registry(projects), groupIndicatorCount, relBuilder.peek().getRowType(), aggCall.e, leftSubTotal == null ? -1 : leftSubTotal, rightSubTotal == null ? -1 : rightSubTotal + newLeftWidth));
}
relBuilder.project(projects);
boolean aggConvertedToProjects = false;
if (allColumnsInAggregate) {
// let's see if we can convert aggregate into projects
List<RexNode> projects2 = new ArrayList<>();
for (int key : Mappings.apply(mapping, aggregate.getGroupSet())) {
projects2.add(relBuilder.field(key));
}
for (AggregateCall newAggCall : newAggCalls) {
final SqlSplittableAggFunction splitter = newAggCall.getAggregation().unwrap(SqlSplittableAggFunction.class);
if (splitter != null) {
projects2.add(splitter.singleton(rexBuilder, relBuilder.peek().getRowType(), newAggCall));
}
}
if (projects2.size() == aggregate.getGroupSet().cardinality() + newAggCalls.size()) {
// We successfully converted agg calls into projects.
relBuilder.project(projects2);
aggConvertedToProjects = true;
}
}
if (!aggConvertedToProjects) {
relBuilder.aggregate(relBuilder.groupKey(Mappings.apply(mapping, aggregate.getGroupSet()), aggregate.indicator, Mappings.apply2(mapping, aggregate.getGroupSets())), newAggCalls);
}
call.transformTo(relBuilder.build());
}
use of org.apache.calcite.rel.core.AggregateCall in project storm by apache.
the class RelNodeCompiler method emitAggregateStmts.
private String emitAggregateStmts(Aggregate aggregate) {
List<String> res = new ArrayList<>();
StringWriter sw = new StringWriter();
for (AggregateCall call : aggregate.getAggCallList()) {
res.add(aggregateResult(call, new PrintWriter(sw)));
}
return NEW_LINE_JOINER.join(sw.toString(), String.format(" ctx.emit(new Values(%s, %s));", groupValueEmitStr("groupValues", aggregate.getGroupSet().cardinality()), Joiner.on(", ").join(res)));
}
use of org.apache.calcite.rel.core.AggregateCall in project hive by apache.
the class HiveGBOpConvUtil method getGBInfo.
// For each of the GB op in the logical GB this should be called seperately;
// otherwise GBevaluator and expr nodes may get shared among multiple GB ops
private static GBInfo getGBInfo(HiveAggregate aggRel, OpAttr inputOpAf, HiveConf hc) throws SemanticException {
GBInfo gbInfo = new GBInfo();
// 0. Collect AggRel output col Names
gbInfo.outputColNames.addAll(aggRel.getRowType().getFieldNames());
// 1. Collect GB Keys
RelNode aggInputRel = aggRel.getInput();
ExprNodeConverter exprConv = new ExprNodeConverter(inputOpAf.tabAlias, aggInputRel.getRowType(), new HashSet<Integer>(), aggRel.getCluster().getTypeFactory(), true);
ExprNodeDesc tmpExprNodeDesc;
for (int i : aggRel.getGroupSet()) {
RexInputRef iRef = new RexInputRef(i, aggInputRel.getRowType().getFieldList().get(i).getType());
tmpExprNodeDesc = iRef.accept(exprConv);
gbInfo.gbKeys.add(tmpExprNodeDesc);
gbInfo.gbKeyColNamesInInput.add(aggInputRel.getRowType().getFieldNames().get(i));
gbInfo.gbKeyTypes.add(tmpExprNodeDesc.getTypeInfo());
}
// 2. Collect Grouping Set info
if (aggRel.indicator) {
// 2.1 Translate Grouping set col bitset
ImmutableList<ImmutableBitSet> lstGrpSet = aggRel.getGroupSets();
int bitmap = 0;
for (ImmutableBitSet grpSet : lstGrpSet) {
bitmap = 0;
for (Integer bitIdx : grpSet.asList()) {
bitmap = SemanticAnalyzer.setBit(bitmap, bitIdx);
}
gbInfo.grpSets.add(bitmap);
}
Collections.sort(gbInfo.grpSets);
// 2.2 Check if GRpSet require additional MR Job
gbInfo.grpSetRqrAdditionalMRJob = gbInfo.grpSets.size() > hc.getIntVar(HiveConf.ConfVars.HIVE_NEW_JOB_GROUPING_SET_CARDINALITY);
// 2.3 Check if GROUPING_ID needs to be projected out
if (!aggRel.getAggCallList().isEmpty() && (aggRel.getAggCallList().get(aggRel.getAggCallList().size() - 1).getAggregation() == HiveGroupingID.INSTANCE)) {
gbInfo.grpIdFunctionNeeded = true;
}
}
// 3. Walk through UDAF & Collect Distinct Info
Set<Integer> distinctRefs = new HashSet<Integer>();
Map<Integer, Integer> distParamInRefsToOutputPos = new HashMap<Integer, Integer>();
for (AggregateCall aggCall : aggRel.getAggCallList()) {
if ((aggCall.getAggregation() == HiveGroupingID.INSTANCE) || !aggCall.isDistinct()) {
continue;
}
List<Integer> argLst = new ArrayList<Integer>(aggCall.getArgList());
List<String> argNames = HiveCalciteUtil.getFieldNames(argLst, aggInputRel);
ExprNodeDesc distinctExpr;
for (int i = 0; i < argLst.size(); i++) {
if (!distinctRefs.contains(argLst.get(i))) {
distinctRefs.add(argLst.get(i));
distinctExpr = HiveCalciteUtil.getExprNode(argLst.get(i), aggInputRel, exprConv);
// Only distinct nodes that are NOT part of the key should be added to distExprNodes
if (ExprNodeDescUtils.indexOf(distinctExpr, gbInfo.gbKeys) < 0) {
distParamInRefsToOutputPos.put(argLst.get(i), gbInfo.distExprNodes.size());
gbInfo.distExprNodes.add(distinctExpr);
gbInfo.distExprNames.add(argNames.get(i));
gbInfo.distExprTypes.add(distinctExpr.getTypeInfo());
}
}
}
}
// 4. Walk through UDAF & Collect UDAF Info
Set<Integer> deDupedNonDistIrefsSet = new HashSet<Integer>();
for (AggregateCall aggCall : aggRel.getAggCallList()) {
if (aggCall.getAggregation() == HiveGroupingID.INSTANCE) {
continue;
}
UDAFAttrs udafAttrs = new UDAFAttrs();
List<ExprNodeDesc> argExps = HiveCalciteUtil.getExprNodes(aggCall.getArgList(), aggInputRel, inputOpAf.tabAlias);
udafAttrs.udafParams.addAll(argExps);
udafAttrs.udafName = aggCall.getAggregation().getName();
udafAttrs.argList = aggCall.getArgList();
udafAttrs.isDistinctUDAF = aggCall.isDistinct();
List<Integer> argLst = new ArrayList<Integer>(aggCall.getArgList());
List<Integer> distColIndicesOfUDAF = new ArrayList<Integer>();
List<Integer> distUDAFParamsIndxInDistExprs = new ArrayList<Integer>();
for (int i = 0; i < argLst.size(); i++) {
// NOTE: distinct expr can be part of of GB key
if (udafAttrs.isDistinctUDAF) {
ExprNodeDesc argExpr = argExps.get(i);
Integer found = ExprNodeDescUtils.indexOf(argExpr, gbInfo.gbKeys);
distColIndicesOfUDAF.add(found < 0 ? distParamInRefsToOutputPos.get(argLst.get(i)) + gbInfo.gbKeys.size() + (gbInfo.grpSets.size() > 0 ? 1 : 0) : found);
distUDAFParamsIndxInDistExprs.add(distParamInRefsToOutputPos.get(argLst.get(i)));
} else {
// TODO: this seems wrong (following what Hive Regular does)
if (!distParamInRefsToOutputPos.containsKey(argLst.get(i)) && !deDupedNonDistIrefsSet.contains(argLst.get(i))) {
deDupedNonDistIrefsSet.add(argLst.get(i));
gbInfo.deDupedNonDistIrefs.add(udafAttrs.udafParams.get(i));
}
}
}
if (udafAttrs.isDistinctUDAF) {
gbInfo.containsDistinctAggr = true;
udafAttrs.udafParamsIndxInGBInfoDistExprs = distUDAFParamsIndxInDistExprs;
gbInfo.distColIndices.add(distColIndicesOfUDAF);
}
// special handling for count, similar to PlanModifierForASTConv::replaceEmptyGroupAggr()
udafAttrs.udafEvaluator = SemanticAnalyzer.getGenericUDAFEvaluator(udafAttrs.udafName, new ArrayList<ExprNodeDesc>(udafAttrs.udafParams), new ASTNode(), udafAttrs.isDistinctUDAF, udafAttrs.udafParams.size() == 0 && "count".equalsIgnoreCase(udafAttrs.udafName) ? true : false);
gbInfo.udafAttrs.add(udafAttrs);
}
// 4. Gather GB Memory threshold
gbInfo.groupByMemoryUsage = HiveConf.getFloatVar(hc, HiveConf.ConfVars.HIVEMAPAGGRHASHMEMORY);
gbInfo.memoryThreshold = HiveConf.getFloatVar(hc, HiveConf.ConfVars.HIVEMAPAGGRMEMORYTHRESHOLD);
// 5. Gather GB Physical pipeline (based on user config & Grping Sets size)
gbInfo.gbPhysicalPipelineMode = getAggOPMode(hc, gbInfo);
return gbInfo;
}
use of org.apache.calcite.rel.core.AggregateCall in project hive by apache.
the class HiveRelDecorrelator method decorrelateRel.
/**
* Rewrites a {@link LogicalAggregate}.
*
* @param rel Aggregate to rewrite
*/
public Frame decorrelateRel(LogicalAggregate rel) throws SemanticException {
if (rel.getGroupType() != Aggregate.Group.SIMPLE) {
throw new AssertionError(Bug.CALCITE_461_FIXED);
}
// Aggregate itself should not reference cor vars.
assert !cm.mapRefRelToCorRef.containsKey(rel);
final RelNode oldInput = rel.getInput();
final Frame frame = getInvoke(oldInput, rel);
if (frame == null) {
// If input has not been rewritten, do not rewrite this rel.
return null;
}
final RelNode newInput = frame.r;
// map from newInput
Map<Integer, Integer> mapNewInputToProjOutputs = new HashMap<>();
final int oldGroupKeyCount = rel.getGroupSet().cardinality();
// Project projects the original expressions,
// plus any correlated variables the input wants to pass along.
final List<Pair<RexNode, String>> projects = Lists.newArrayList();
List<RelDataTypeField> newInputOutput = newInput.getRowType().getFieldList();
int newPos = 0;
// oldInput has the original group by keys in the front.
final NavigableMap<Integer, RexLiteral> omittedConstants = new TreeMap<>();
for (int i = 0; i < oldGroupKeyCount; i++) {
final RexLiteral constant = projectedLiteral(newInput, i);
if (constant != null) {
// Exclude constants. Aggregate({true}) occurs because Aggregate({})
// would generate 1 row even when applied to an empty table.
omittedConstants.put(i, constant);
continue;
}
int newInputPos = frame.oldToNewOutputs.get(i);
projects.add(RexInputRef.of2(newInputPos, newInputOutput));
mapNewInputToProjOutputs.put(newInputPos, newPos);
newPos++;
}
final SortedMap<CorDef, Integer> corDefOutputs = new TreeMap<>();
if (!frame.corDefOutputs.isEmpty()) {
// position oldGroupKeyCount.
for (Map.Entry<CorDef, Integer> entry : frame.corDefOutputs.entrySet()) {
projects.add(RexInputRef.of2(entry.getValue(), newInputOutput));
corDefOutputs.put(entry.getKey(), newPos);
mapNewInputToProjOutputs.put(entry.getValue(), newPos);
newPos++;
}
}
// add the remaining fields
final int newGroupKeyCount = newPos;
for (int i = 0; i < newInputOutput.size(); i++) {
if (!mapNewInputToProjOutputs.containsKey(i)) {
projects.add(RexInputRef.of2(i, newInputOutput));
mapNewInputToProjOutputs.put(i, newPos);
newPos++;
}
}
assert newPos == newInputOutput.size();
// This Project will be what the old input maps to,
// replacing any previous mapping from old input).
RelNode newProject = HiveProject.create(newInput, Pair.left(projects), Pair.right(projects));
// update mappings:
// oldInput ----> newInput
//
// newProject
// |
// oldInput ----> newInput
//
// is transformed to
//
// oldInput ----> newProject
// |
// newInput
Map<Integer, Integer> combinedMap = Maps.newHashMap();
for (Integer oldInputPos : frame.oldToNewOutputs.keySet()) {
combinedMap.put(oldInputPos, mapNewInputToProjOutputs.get(frame.oldToNewOutputs.get(oldInputPos)));
}
register(oldInput, newProject, combinedMap, corDefOutputs);
// now it's time to rewrite the Aggregate
final ImmutableBitSet newGroupSet = ImmutableBitSet.range(newGroupKeyCount);
List<AggregateCall> newAggCalls = Lists.newArrayList();
List<AggregateCall> oldAggCalls = rel.getAggCallList();
int oldInputOutputFieldCount = rel.getGroupSet().cardinality();
int newInputOutputFieldCount = newGroupSet.cardinality();
int i = -1;
for (AggregateCall oldAggCall : oldAggCalls) {
++i;
List<Integer> oldAggArgs = oldAggCall.getArgList();
List<Integer> aggArgs = Lists.newArrayList();
// for the argument.
for (int oldPos : oldAggArgs) {
aggArgs.add(combinedMap.get(oldPos));
}
final int filterArg = oldAggCall.filterArg < 0 ? oldAggCall.filterArg : combinedMap.get(oldAggCall.filterArg);
newAggCalls.add(oldAggCall.adaptTo(newProject, aggArgs, filterArg, oldGroupKeyCount, newGroupKeyCount));
// The old to new output position mapping will be the same as that
// of newProject, plus any aggregates that the oldAgg produces.
combinedMap.put(oldInputOutputFieldCount + i, newInputOutputFieldCount + i);
}
relBuilder.push(LogicalAggregate.create(newProject, false, newGroupSet, null, newAggCalls));
if (!omittedConstants.isEmpty()) {
final List<RexNode> postProjects = new ArrayList<>(relBuilder.fields());
for (Map.Entry<Integer, RexLiteral> entry : omittedConstants.descendingMap().entrySet()) {
postProjects.add(entry.getKey() + frame.corDefOutputs.size(), entry.getValue());
}
relBuilder.project(postProjects);
}
// located at the same position as the input newProject.
return register(rel, relBuilder.build(), combinedMap, corDefOutputs);
}
Aggregations