use of org.apache.calcite.rel.core.AggregateCall in project hive by apache.
the class HiveRelDecorrelator method decorrelateRel.
/**
* Rewrites a {@link LogicalAggregate}.
*
* @param rel Aggregate to rewrite
*/
public Frame decorrelateRel(LogicalAggregate rel) throws SemanticException {
if (rel.getGroupType() != Aggregate.Group.SIMPLE) {
throw new AssertionError(Bug.CALCITE_461_FIXED);
}
// Aggregate itself should not reference cor vars.
assert !cm.mapRefRelToCorRef.containsKey(rel);
final RelNode oldInput = rel.getInput();
final Frame frame = getInvoke(oldInput, rel);
if (frame == null) {
// If input has not been rewritten, do not rewrite this rel.
return null;
}
final RelNode newInput = frame.r;
// map from newInput
Map<Integer, Integer> mapNewInputToProjOutputs = new HashMap<>();
final int oldGroupKeyCount = rel.getGroupSet().cardinality();
// Project projects the original expressions,
// plus any correlated variables the input wants to pass along.
final List<Pair<RexNode, String>> projects = Lists.newArrayList();
List<RelDataTypeField> newInputOutput = newInput.getRowType().getFieldList();
int newPos = 0;
// oldInput has the original group by keys in the front.
final NavigableMap<Integer, RexLiteral> omittedConstants = new TreeMap<>();
for (int i = 0; i < oldGroupKeyCount; i++) {
final RexLiteral constant = projectedLiteral(newInput, i);
if (constant != null) {
// Exclude constants. Aggregate({true}) occurs because Aggregate({})
// would generate 1 row even when applied to an empty table.
omittedConstants.put(i, constant);
continue;
}
int newInputPos = frame.oldToNewOutputs.get(i);
projects.add(RexInputRef.of2(newInputPos, newInputOutput));
mapNewInputToProjOutputs.put(newInputPos, newPos);
newPos++;
}
final SortedMap<CorDef, Integer> corDefOutputs = new TreeMap<>();
if (!frame.corDefOutputs.isEmpty()) {
// position oldGroupKeyCount.
for (Map.Entry<CorDef, Integer> entry : frame.corDefOutputs.entrySet()) {
projects.add(RexInputRef.of2(entry.getValue(), newInputOutput));
corDefOutputs.put(entry.getKey(), newPos);
mapNewInputToProjOutputs.put(entry.getValue(), newPos);
newPos++;
}
}
// add the remaining fields
final int newGroupKeyCount = newPos;
for (int i = 0; i < newInputOutput.size(); i++) {
if (!mapNewInputToProjOutputs.containsKey(i)) {
projects.add(RexInputRef.of2(i, newInputOutput));
mapNewInputToProjOutputs.put(i, newPos);
newPos++;
}
}
assert newPos == newInputOutput.size();
// This Project will be what the old input maps to,
// replacing any previous mapping from old input).
RelNode newProject = HiveProject.create(newInput, Pair.left(projects), Pair.right(projects));
// update mappings:
// oldInput ----> newInput
//
// newProject
// |
// oldInput ----> newInput
//
// is transformed to
//
// oldInput ----> newProject
// |
// newInput
Map<Integer, Integer> combinedMap = Maps.newHashMap();
for (Integer oldInputPos : frame.oldToNewOutputs.keySet()) {
combinedMap.put(oldInputPos, mapNewInputToProjOutputs.get(frame.oldToNewOutputs.get(oldInputPos)));
}
register(oldInput, newProject, combinedMap, corDefOutputs);
// now it's time to rewrite the Aggregate
final ImmutableBitSet newGroupSet = ImmutableBitSet.range(newGroupKeyCount);
List<AggregateCall> newAggCalls = Lists.newArrayList();
List<AggregateCall> oldAggCalls = rel.getAggCallList();
int oldInputOutputFieldCount = rel.getGroupSet().cardinality();
int newInputOutputFieldCount = newGroupSet.cardinality();
int i = -1;
for (AggregateCall oldAggCall : oldAggCalls) {
++i;
List<Integer> oldAggArgs = oldAggCall.getArgList();
List<Integer> aggArgs = Lists.newArrayList();
// for the argument.
for (int oldPos : oldAggArgs) {
aggArgs.add(combinedMap.get(oldPos));
}
final int filterArg = oldAggCall.filterArg < 0 ? oldAggCall.filterArg : combinedMap.get(oldAggCall.filterArg);
newAggCalls.add(oldAggCall.adaptTo(newProject, aggArgs, filterArg, oldGroupKeyCount, newGroupKeyCount));
// The old to new output position mapping will be the same as that
// of newProject, plus any aggregates that the oldAgg produces.
combinedMap.put(oldInputOutputFieldCount + i, newInputOutputFieldCount + i);
}
relBuilder.push(LogicalAggregate.create(newProject, false, newGroupSet, null, newAggCalls));
if (!omittedConstants.isEmpty()) {
final List<RexNode> postProjects = new ArrayList<>(relBuilder.fields());
for (Map.Entry<Integer, RexLiteral> entry : omittedConstants.descendingMap().entrySet()) {
postProjects.add(entry.getKey() + frame.corDefOutputs.size(), entry.getValue());
}
relBuilder.project(postProjects);
}
// located at the same position as the input newProject.
return register(rel, relBuilder.build(), combinedMap, corDefOutputs);
}
use of org.apache.calcite.rel.core.AggregateCall in project hive by apache.
the class HiveGBOpConvUtil method getGBInfo.
// For each of the GB op in the logical GB this should be called seperately;
// otherwise GBevaluator and expr nodes may get shared among multiple GB ops
private static GBInfo getGBInfo(HiveAggregate aggRel, OpAttr inputOpAf, HiveConf hc) throws SemanticException {
GBInfo gbInfo = new GBInfo();
// 0. Collect AggRel output col Names
gbInfo.outputColNames.addAll(aggRel.getRowType().getFieldNames());
// 1. Collect GB Keys
RelNode aggInputRel = aggRel.getInput();
ExprNodeConverter exprConv = new ExprNodeConverter(inputOpAf.tabAlias, aggInputRel.getRowType(), new HashSet<Integer>(), aggRel.getCluster().getTypeFactory(), true);
ExprNodeDesc tmpExprNodeDesc;
for (int i : aggRel.getGroupSet()) {
RexInputRef iRef = new RexInputRef(i, aggInputRel.getRowType().getFieldList().get(i).getType());
tmpExprNodeDesc = iRef.accept(exprConv);
gbInfo.gbKeys.add(tmpExprNodeDesc);
gbInfo.gbKeyColNamesInInput.add(aggInputRel.getRowType().getFieldNames().get(i));
gbInfo.gbKeyTypes.add(tmpExprNodeDesc.getTypeInfo());
}
// 2. Collect Grouping Set info
if (aggRel.getGroupType() != Group.SIMPLE) {
// 2.1 Translate Grouping set col bitset
ImmutableList<ImmutableBitSet> lstGrpSet = aggRel.getGroupSets();
long bitmap = 0;
for (ImmutableBitSet grpSet : lstGrpSet) {
bitmap = 0;
for (Integer bitIdx : grpSet.asList()) {
bitmap = SemanticAnalyzer.setBit(bitmap, bitIdx);
}
gbInfo.grpSets.add(bitmap);
}
Collections.sort(gbInfo.grpSets);
// 2.2 Check if GRpSet require additional MR Job
gbInfo.grpSetRqrAdditionalMRJob = gbInfo.grpSets.size() > hc.getIntVar(HiveConf.ConfVars.HIVE_NEW_JOB_GROUPING_SET_CARDINALITY);
// 2.3 Check if GROUPING_ID needs to be projected out
if (!aggRel.getAggCallList().isEmpty() && (aggRel.getAggCallList().get(aggRel.getAggCallList().size() - 1).getAggregation() == HiveGroupingID.INSTANCE)) {
gbInfo.grpIdFunctionNeeded = true;
}
}
// 3. Walk through UDAF & Collect Distinct Info
Set<Integer> distinctRefs = new HashSet<Integer>();
Map<Integer, Integer> distParamInRefsToOutputPos = new HashMap<Integer, Integer>();
for (AggregateCall aggCall : aggRel.getAggCallList()) {
if ((aggCall.getAggregation() == HiveGroupingID.INSTANCE) || !aggCall.isDistinct()) {
continue;
}
List<Integer> argLst = new ArrayList<Integer>(aggCall.getArgList());
List<String> argNames = HiveCalciteUtil.getFieldNames(argLst, aggInputRel);
ExprNodeDesc distinctExpr;
for (int i = 0; i < argLst.size(); i++) {
if (!distinctRefs.contains(argLst.get(i))) {
distinctRefs.add(argLst.get(i));
distinctExpr = HiveCalciteUtil.getExprNode(argLst.get(i), aggInputRel, exprConv);
// Only distinct nodes that are NOT part of the key should be added to distExprNodes
if (ExprNodeDescUtils.indexOf(distinctExpr, gbInfo.gbKeys) < 0) {
distParamInRefsToOutputPos.put(argLst.get(i), gbInfo.distExprNodes.size());
gbInfo.distExprNodes.add(distinctExpr);
gbInfo.distExprNames.add(argNames.get(i));
gbInfo.distExprTypes.add(distinctExpr.getTypeInfo());
}
}
}
}
// 4. Walk through UDAF & Collect UDAF Info
Set<Integer> deDupedNonDistIrefsSet = new HashSet<Integer>();
for (AggregateCall aggCall : aggRel.getAggCallList()) {
if (aggCall.getAggregation() == HiveGroupingID.INSTANCE) {
continue;
}
UDAFAttrs udafAttrs = new UDAFAttrs();
List<ExprNodeDesc> argExps = HiveCalciteUtil.getExprNodes(aggCall.getArgList(), aggInputRel, inputOpAf.tabAlias);
udafAttrs.udafParams.addAll(argExps);
udafAttrs.udafName = aggCall.getAggregation().getName();
udafAttrs.argList = aggCall.getArgList();
udafAttrs.isDistinctUDAF = aggCall.isDistinct();
List<Integer> argLst = new ArrayList<Integer>(aggCall.getArgList());
List<Integer> distColIndicesOfUDAF = new ArrayList<Integer>();
List<Integer> distUDAFParamsIndxInDistExprs = new ArrayList<Integer>();
for (int i = 0; i < argLst.size(); i++) {
// NOTE: distinct expr can be part of of GB key
if (udafAttrs.isDistinctUDAF) {
ExprNodeDesc argExpr = argExps.get(i);
Integer found = ExprNodeDescUtils.indexOf(argExpr, gbInfo.gbKeys);
distColIndicesOfUDAF.add(found < 0 ? distParamInRefsToOutputPos.get(argLst.get(i)) + gbInfo.gbKeys.size() + (gbInfo.grpSets.size() > 0 ? 1 : 0) : found);
distUDAFParamsIndxInDistExprs.add(distParamInRefsToOutputPos.get(argLst.get(i)));
} else {
// TODO: this seems wrong (following what Hive Regular does)
if (!distParamInRefsToOutputPos.containsKey(argLst.get(i)) && !deDupedNonDistIrefsSet.contains(argLst.get(i))) {
deDupedNonDistIrefsSet.add(argLst.get(i));
gbInfo.deDupedNonDistIrefs.add(udafAttrs.udafParams.get(i));
}
}
}
if (udafAttrs.isDistinctUDAF) {
gbInfo.containsDistinctAggr = true;
udafAttrs.udafParamsIndxInGBInfoDistExprs = distUDAFParamsIndxInDistExprs;
gbInfo.distColIndices.add(distColIndicesOfUDAF);
}
// special handling for count, similar to PlanModifierForASTConv::replaceEmptyGroupAggr()
udafAttrs.udafEvaluator = SemanticAnalyzer.getGenericUDAFEvaluator(udafAttrs.udafName, new ArrayList<ExprNodeDesc>(udafAttrs.udafParams), new ASTNode(), udafAttrs.isDistinctUDAF, udafAttrs.udafParams.size() == 0 && "count".equalsIgnoreCase(udafAttrs.udafName) ? true : false);
gbInfo.udafAttrs.add(udafAttrs);
}
// 4. Gather GB Memory threshold
gbInfo.groupByMemoryUsage = HiveConf.getFloatVar(hc, HiveConf.ConfVars.HIVEMAPAGGRHASHMEMORY);
gbInfo.memoryThreshold = HiveConf.getFloatVar(hc, HiveConf.ConfVars.HIVEMAPAGGRMEMORYTHRESHOLD);
// 5. Gather GB Physical pipeline (based on user config & Grping Sets size)
gbInfo.gbPhysicalPipelineMode = getAggOPMode(hc, gbInfo);
return gbInfo;
}
use of org.apache.calcite.rel.core.AggregateCall in project hive by apache.
the class HiveSubQRemoveRelBuilder method aggregate.
/**
* Creates an {@link org.apache.calcite.rel.core.Aggregate} with a list of
* calls.
*/
public HiveSubQRemoveRelBuilder aggregate(GroupKey groupKey, Iterable<AggCall> aggCalls) {
final RelDataType inputRowType = peek().getRowType();
final List<RexNode> extraNodes = projects(inputRowType);
final GroupKeyImpl groupKey_ = (GroupKeyImpl) groupKey;
final ImmutableBitSet groupSet = ImmutableBitSet.of(registerExpressions(extraNodes, groupKey_.nodes));
final ImmutableList<ImmutableBitSet> groupSets;
if (groupKey_.nodeLists != null) {
final int sizeBefore = extraNodes.size();
final SortedSet<ImmutableBitSet> groupSetSet = new TreeSet<>(ImmutableBitSet.ORDERING);
for (ImmutableList<RexNode> nodeList : groupKey_.nodeLists) {
final ImmutableBitSet groupSet2 = ImmutableBitSet.of(registerExpressions(extraNodes, nodeList));
if (!groupSet.contains(groupSet2)) {
throw new IllegalArgumentException("group set element " + nodeList + " must be a subset of group key");
}
groupSetSet.add(groupSet2);
}
groupSets = ImmutableList.copyOf(groupSetSet);
if (extraNodes.size() > sizeBefore) {
throw new IllegalArgumentException("group sets contained expressions not in group key: " + extraNodes.subList(sizeBefore, extraNodes.size()));
}
} else {
groupSets = ImmutableList.of(groupSet);
}
for (AggCall aggCall : aggCalls) {
if (aggCall instanceof AggCallImpl) {
final AggCallImpl aggCall1 = (AggCallImpl) aggCall;
registerExpressions(extraNodes, aggCall1.operands);
if (aggCall1.filter != null) {
registerExpression(extraNodes, aggCall1.filter);
}
}
}
if (extraNodes.size() > inputRowType.getFieldCount()) {
project(extraNodes);
}
final RelNode r = build();
final List<AggregateCall> aggregateCalls = new ArrayList<>();
for (AggCall aggCall : aggCalls) {
final AggregateCall aggregateCall;
if (aggCall instanceof AggCallImpl) {
final AggCallImpl aggCall1 = (AggCallImpl) aggCall;
final List<Integer> args = registerExpressions(extraNodes, aggCall1.operands);
final int filterArg = aggCall1.filter == null ? -1 : registerExpression(extraNodes, aggCall1.filter);
aggregateCall = AggregateCall.create(aggCall1.aggFunction, aggCall1.distinct, args, filterArg, groupSet.cardinality(), r, null, aggCall1.alias);
} else {
aggregateCall = ((AggCallImpl2) aggCall).aggregateCall;
}
aggregateCalls.add(aggregateCall);
}
assert ImmutableBitSet.ORDERING.isStrictlyOrdered(groupSets) : groupSets;
for (ImmutableBitSet set : groupSets) {
assert groupSet.contains(set);
}
RelNode aggregate = aggregateFactory.createAggregate(r, groupKey_.indicator, groupSet, groupSets, aggregateCalls);
push(aggregate);
return this;
}
use of org.apache.calcite.rel.core.AggregateCall in project hive by apache.
the class HiveExpandDistinctAggregatesRule method createCount.
/**
* @param aggr: the original aggregate
* @param argList: the original argList in aggregate
* @param cleanArgList: the new argList without duplicates
* @param map: the mapping from the original argList to the new argList
* @param sourceOfForCountDistinct: the sorted positions of groupset
* @return
* @throws CalciteSemanticException
*/
private RelNode createCount(Aggregate aggr, List<List<Integer>> argList, List<List<Integer>> cleanArgList, Map<Integer, Integer> map, List<Integer> sourceOfForCountDistinct) throws CalciteSemanticException {
List<RexNode> originalInputRefs = Lists.transform(aggr.getRowType().getFieldList(), new Function<RelDataTypeField, RexNode>() {
@Override
public RexNode apply(RelDataTypeField input) {
return new RexInputRef(input.getIndex(), input.getType());
}
});
final List<RexNode> gbChildProjLst = Lists.newArrayList();
// for non-singular args, count can include null, i.e. (,) is counted as 1
for (List<Integer> list : cleanArgList) {
RexNode condition = rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, originalInputRefs.get(originalInputRefs.size() - 1), rexBuilder.makeExactLiteral(new BigDecimal(getGroupingIdValue(list, sourceOfForCountDistinct, aggr.getGroupCount()))));
if (list.size() == 1) {
int pos = list.get(0);
RexNode notNull = rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, originalInputRefs.get(pos));
condition = rexBuilder.makeCall(SqlStdOperatorTable.AND, condition, notNull);
}
RexNode when = rexBuilder.makeCall(SqlStdOperatorTable.CASE, condition, rexBuilder.makeExactLiteral(BigDecimal.ONE), rexBuilder.constantNull());
gbChildProjLst.add(when);
}
// create the project before GB
RelNode gbInputRel = HiveProject.create(aggr, gbChildProjLst, null);
// create the aggregate
List<AggregateCall> aggregateCalls = Lists.newArrayList();
RelDataType aggFnRetType = TypeConverter.convert(TypeInfoFactory.longTypeInfo, cluster.getTypeFactory());
for (int i = 0; i < cleanArgList.size(); i++) {
AggregateCall aggregateCall = HiveCalciteUtil.createSingleArgAggCall("count", cluster, TypeInfoFactory.longTypeInfo, i, aggFnRetType);
aggregateCalls.add(aggregateCall);
}
Aggregate aggregate = new HiveAggregate(cluster, cluster.traitSetOf(HiveRelNode.CONVENTION), gbInputRel, ImmutableBitSet.of(), null, aggregateCalls);
// count(distinct x, y), count(distinct y, x), we find the correct mapping.
if (map.isEmpty()) {
return aggregate;
} else {
List<RexNode> originalAggrRefs = Lists.transform(aggregate.getRowType().getFieldList(), new Function<RelDataTypeField, RexNode>() {
@Override
public RexNode apply(RelDataTypeField input) {
return new RexInputRef(input.getIndex(), input.getType());
}
});
final List<RexNode> projLst = Lists.newArrayList();
int index = 0;
for (int i = 0; i < argList.size(); i++) {
if (map.containsKey(i)) {
projLst.add(originalAggrRefs.get(map.get(i)));
} else {
projLst.add(originalAggrRefs.get(index++));
}
}
return HiveProject.create(aggregate, projLst, null);
}
}
use of org.apache.calcite.rel.core.AggregateCall in project hive by apache.
the class HiveExpandDistinctAggregatesRule method createGroupingSets.
/**
* @param aggregate: the original aggregate
* @param argList: the original argList in aggregate
* @param cleanArgList: the new argList without duplicates
* @param map: the mapping from the original argList to the new argList
* @param sourceOfForCountDistinct: the sorted positions of groupset
* @return
*/
private Aggregate createGroupingSets(Aggregate aggregate, List<List<Integer>> argList, List<List<Integer>> cleanArgList, Map<Integer, Integer> map, List<Integer> sourceOfForCountDistinct) {
final ImmutableBitSet groupSet = ImmutableBitSet.of(sourceOfForCountDistinct);
final List<ImmutableBitSet> origGroupSets = new ArrayList<>();
for (int i = 0; i < argList.size(); i++) {
List<Integer> list = argList.get(i);
ImmutableBitSet bitSet = ImmutableBitSet.of(list);
int prev = origGroupSets.indexOf(bitSet);
if (prev == -1) {
origGroupSets.add(bitSet);
cleanArgList.add(list);
} else {
map.put(i, prev);
}
}
// Calcite expects the grouping sets sorted and without duplicates
Collections.sort(origGroupSets, ImmutableBitSet.COMPARATOR);
List<AggregateCall> aggregateCalls = new ArrayList<AggregateCall>();
// Create GroupingID column
AggregateCall aggCall = AggregateCall.create(HiveGroupingID.INSTANCE, false, new ImmutableList.Builder<Integer>().build(), -1, this.cluster.getTypeFactory().createSqlType(SqlTypeName.BIGINT), HiveGroupingID.INSTANCE.getName());
aggregateCalls.add(aggCall);
return new HiveAggregate(cluster, cluster.traitSetOf(HiveRelNode.CONVENTION), aggregate.getInput(), groupSet, origGroupSets, aggregateCalls);
}
Aggregations