use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.Aggregate in project hive by apache.
the class HiveExpandDistinctAggregatesRule method onMatch.
// ~ Methods ----------------------------------------------------------------
@Override
public void onMatch(RelOptRuleCall call) {
final Aggregate aggregate = call.rel(0);
int numCountDistinct = getNumCountDistinctCall(aggregate);
if (numCountDistinct == 0 || numCountDistinct + aggregate.getGroupCount() >= Long.SIZE || aggregate.getGroupType() != Group.SIMPLE) {
return;
}
// Find all of the agg expressions. We use a List (for all count(distinct))
// as well as a Set (for all others) to ensure determinism.
int nonDistinctCount = 0;
List<List<Integer>> argListList = new ArrayList<List<Integer>>();
Set<List<Integer>> argListSets = new LinkedHashSet<List<Integer>>();
ImmutableBitSet.Builder newGroupSet = ImmutableBitSet.builder();
newGroupSet.addAll(aggregate.getGroupSet());
for (AggregateCall aggCall : aggregate.getAggCallList()) {
if (!aggCall.isDistinct()) {
++nonDistinctCount;
continue;
}
ArrayList<Integer> argList = new ArrayList<Integer>();
for (Integer arg : aggCall.getArgList()) {
argList.add(arg);
newGroupSet.set(arg);
}
// Aggr checks for sorted argList.
argListList.add(argList);
argListSets.add(argList);
}
Preconditions.checkArgument(argListSets.size() > 0, "containsDistinctCall lied");
if (numCountDistinct > 1 && numCountDistinct == aggregate.getAggCallList().size()) {
LOG.debug("Trigger countDistinct rewrite. numCountDistinct is " + numCountDistinct);
// now positions contains all the distinct positions, i.e., $5, $4, $6
// we need to first sort them as group by set
// and then get their position later, i.e., $4->1, $5->2, $6->3
cluster = aggregate.getCluster();
rexBuilder = cluster.getRexBuilder();
try {
call.transformTo(convert(aggregate, argListList, newGroupSet.build()));
} catch (CalciteSemanticException e) {
LOG.debug(e.toString());
throw new RuntimeException(e);
}
return;
}
// If all of the agg expressions are distinct and have the same
// arguments then we can use a more efficient form.
final RelMetadataQuery mq = call.getMetadataQuery();
if ((nonDistinctCount == 0) && (argListSets.size() == 1)) {
for (Integer arg : argListSets.iterator().next()) {
Set<RelColumnOrigin> colOrigs = mq.getColumnOrigins(aggregate.getInput(), arg);
if (null != colOrigs) {
for (RelColumnOrigin colOrig : colOrigs) {
RelOptHiveTable hiveTbl = (RelOptHiveTable) colOrig.getOriginTable();
if (hiveTbl.getPartColInfoMap().containsKey(colOrig.getOriginColumnOrdinal())) {
// Encountered partitioning column, this will be better handled by MetadataOnly optimizer.
return;
}
}
}
}
RelNode converted = convertMonopole(aggregate, argListSets.iterator().next());
call.transformTo(converted);
return;
}
}
use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.Aggregate in project hive by apache.
the class HiveExpandDistinctAggregatesRule method createCount.
/**
* @param aggr: the original aggregate
* @param argList: the original argList in aggregate
* @param cleanArgList: the new argList without duplicates
* @param map: the mapping from the original argList to the new argList
* @param newGroupSet: the sorted positions of groupset
* @return
* @throws CalciteSemanticException
*/
private RelNode createCount(Aggregate aggr, List<List<Integer>> argList, List<List<Integer>> cleanArgList, Map<Integer, Integer> map, ImmutableBitSet originalGroupSet, ImmutableBitSet newGroupSet) throws CalciteSemanticException {
final List<RexNode> originalInputRefs = aggr.getRowType().getFieldList().stream().map(input -> new RexInputRef(input.getIndex(), input.getType())).collect(Collectors.toList());
final List<RexNode> gbChildProjLst = Lists.newArrayList();
// for non-singular args, count can include null, i.e. (,) is counted as 1
for (List<Integer> list : cleanArgList) {
RexNode condition = rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, originalInputRefs.get(originalInputRefs.size() - 1), rexBuilder.makeExactLiteral(new BigDecimal(getGroupingIdValue(list, originalGroupSet, newGroupSet, aggr.getGroupCount()))));
if (list.size() == 1) {
int pos = list.get(0);
RexNode notNull = rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, originalInputRefs.get(pos));
condition = rexBuilder.makeCall(SqlStdOperatorTable.AND, condition, notNull);
}
RexNode caseExpr1 = rexBuilder.makeExactLiteral(BigDecimal.ONE);
RexNode caseExpr2 = rexBuilder.makeNullLiteral(caseExpr1.getType());
RexNode when = rexBuilder.makeCall(SqlStdOperatorTable.CASE, condition, caseExpr1, caseExpr2);
gbChildProjLst.add(when);
}
for (int pos : originalGroupSet) {
gbChildProjLst.add(originalInputRefs.get(newGroupSet.indexOf(pos)));
}
// create the project before GB
RelNode gbInputRel = HiveProject.create(aggr, gbChildProjLst, null);
// create the aggregate
List<AggregateCall> aggregateCalls = Lists.newArrayList();
RelDataType aggFnRetType = TypeConverter.convert(TypeInfoFactory.longTypeInfo, cluster.getTypeFactory());
for (int i = 0; i < cleanArgList.size(); i++) {
AggregateCall aggregateCall = HiveCalciteUtil.createSingleArgAggCall("count", cluster, TypeInfoFactory.longTypeInfo, i, aggFnRetType);
aggregateCalls.add(aggregateCall);
}
ImmutableBitSet groupSet = ImmutableBitSet.range(cleanArgList.size(), cleanArgList.size() + originalGroupSet.cardinality());
Aggregate aggregate = new HiveAggregate(cluster, cluster.traitSetOf(HiveRelNode.CONVENTION), gbInputRel, groupSet, null, aggregateCalls);
// count(distinct x, y), count(distinct y, x), we find the correct mapping.
if (map.isEmpty()) {
return aggregate;
} else {
final List<RexNode> originalAggrRefs = aggregate.getRowType().getFieldList().stream().map(input -> new RexInputRef(input.getIndex(), input.getType())).collect(Collectors.toList());
final List<RexNode> projLst = Lists.newArrayList();
int index = 0;
for (int i = 0; i < groupSet.cardinality(); i++) {
projLst.add(originalAggrRefs.get(index++));
}
for (int i = 0; i < argList.size(); i++) {
if (map.containsKey(i)) {
projLst.add(originalAggrRefs.get(map.get(i)));
} else {
projLst.add(originalAggrRefs.get(index++));
}
}
return HiveProject.create(aggregate, projLst, null);
}
}
use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.Aggregate in project hive by apache.
the class HiveExpandDistinctAggregatesRule method convert.
/**
* Converts an aggregate relational expression that contains only
* count(distinct) to grouping sets with count. For example select
* count(distinct department_id), count(distinct gender), count(distinct
* education_level) from employee; can be transformed to
* select
* count(case when i=1 and department_id is not null then 1 else null end) as c0,
* count(case when i=2 and gender is not null then 1 else null end) as c1,
* count(case when i=4 and education_level is not null then 1 else null end) as c2
* from (select
* grouping__id as i, department_id, gender, education_level from employee
* group by department_id, gender, education_level grouping sets
* (department_id, gender, education_level))subq;
* @throws CalciteSemanticException
*/
private RelNode convert(Aggregate aggregate, List<List<Integer>> argList, ImmutableBitSet newGroupSet) throws CalciteSemanticException {
// we use this map to map the position of argList to the position of grouping set
Map<Integer, Integer> map = new HashMap<>();
List<List<Integer>> cleanArgList = new ArrayList<>();
final Aggregate groupingSets = createGroupingSets(aggregate, argList, cleanArgList, map, newGroupSet);
return createCount(groupingSets, argList, cleanArgList, map, aggregate.getGroupSet(), newGroupSet);
}
use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.Aggregate in project hive by apache.
the class HiveSemiJoinRule method recreateAggregateOperator.
protected static RelNode recreateAggregateOperator(RelBuilder builder, int[] adjustments, Aggregate topAggregate, RelNode newInputOperator) {
builder.push(newInputOperator);
ImmutableBitSet.Builder newGroupSet = ImmutableBitSet.builder();
for (int pos : topAggregate.getGroupSet()) {
newGroupSet.set(pos + adjustments[pos]);
}
GroupKey groupKey;
if (topAggregate.getGroupType() == Group.SIMPLE) {
groupKey = builder.groupKey(newGroupSet.build());
} else {
List<ImmutableBitSet> newGroupSets = new ArrayList<>();
for (ImmutableBitSet groupingSet : topAggregate.getGroupSets()) {
ImmutableBitSet.Builder newGroupingSet = ImmutableBitSet.builder();
for (int pos : groupingSet) {
newGroupingSet.set(pos + adjustments[pos]);
}
newGroupSets.add(newGroupingSet.build());
}
groupKey = builder.groupKey(newGroupSet.build(), newGroupSets);
}
List<AggregateCall> newAggCallList = new ArrayList<>();
for (AggregateCall aggregateCall : topAggregate.getAggCallList()) {
List<Integer> newArgList = aggregateCall.getArgList().stream().map(pos -> pos + adjustments[pos]).collect(Collectors.toList());
int newFilterArg = aggregateCall.filterArg != -1 ? aggregateCall.filterArg + adjustments[aggregateCall.filterArg] : -1;
RelCollation newCollation = aggregateCall.getCollation() != null ? RelCollations.of(aggregateCall.getCollation().getFieldCollations().stream().map(fc -> fc.withFieldIndex(fc.getFieldIndex() + adjustments[fc.getFieldIndex()])).collect(Collectors.toList())) : null;
newAggCallList.add(aggregateCall.copy(newArgList, newFilterArg, newCollation));
}
return builder.push(newInputOperator).aggregate(groupKey, newAggCallList).build();
}
use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.Aggregate in project hive by apache.
the class HiveWindowingFixRule method onMatch.
@Override
public void onMatch(RelOptRuleCall call) {
Project project = call.rel(0);
Aggregate aggregate = call.rel(1);
// 1. We go over the expressions in the project operator
// and we separate the windowing nodes that are result
// of an aggregate expression from the rest of nodes
final int groupingFields = aggregate.getGroupCount() + aggregate.getIndicatorCount();
Set<String> projectExprsDigest = new HashSet<String>();
Map<String, RexNode> windowingExprsDigestToNodes = new HashMap<String, RexNode>();
for (RexNode r : project.getProjects()) {
if (r instanceof RexOver) {
RexOver rexOverNode = (RexOver) r;
// Operands
for (RexNode operand : rexOverNode.getOperands()) {
if (operand instanceof RexInputRef && ((RexInputRef) operand).getIndex() >= groupingFields) {
windowingExprsDigestToNodes.put(operand.toString(), operand);
}
}
// Partition keys
for (RexNode partitionKey : rexOverNode.getWindow().partitionKeys) {
if (partitionKey instanceof RexInputRef && ((RexInputRef) partitionKey).getIndex() >= groupingFields) {
windowingExprsDigestToNodes.put(partitionKey.toString(), partitionKey);
}
}
// Order keys
for (RexFieldCollation orderKey : rexOverNode.getWindow().orderKeys) {
if (orderKey.left instanceof RexInputRef && ((RexInputRef) orderKey.left).getIndex() >= groupingFields) {
windowingExprsDigestToNodes.put(orderKey.left.toString(), orderKey.left);
}
}
} else {
projectExprsDigest.add(r.toString());
}
}
// 2. We check whether there is a column needed by the
// windowing operation that is missing in the
// project expressions. For instance, if the windowing
// operation is over an aggregation column, Hive expects
// that column to be in the Select clause of the query.
// The idea is that if there is a column missing, we will
// replace the old project operator by two new project
// operators:
// - a project operator containing the original columns
// of the project operator plus all the columns that were
// missing
// - a project on top of the previous one, that will take
// out the columns that were missing and were added by the
// previous project
// These data structures are needed to create the new project
// operator (below)
final List<RexNode> belowProjectExprs = new ArrayList<RexNode>();
final List<String> belowProjectColumnNames = new ArrayList<String>();
// This data structure is needed to create the new project
// operator (top)
final List<RexNode> topProjectExprs = new ArrayList<RexNode>();
final int projectCount = project.getProjects().size();
for (int i = 0; i < projectCount; i++) {
belowProjectExprs.add(project.getProjects().get(i));
belowProjectColumnNames.add(project.getRowType().getFieldNames().get(i));
topProjectExprs.add(RexInputRef.of(i, project.getRowType()));
}
boolean windowingFix = false;
for (Entry<String, RexNode> windowingExpr : windowingExprsDigestToNodes.entrySet()) {
if (!projectExprsDigest.contains(windowingExpr.getKey())) {
windowingFix = true;
belowProjectExprs.add(windowingExpr.getValue());
int colIndex = 0;
String alias = "window_col_" + colIndex;
while (belowProjectColumnNames.contains(alias)) {
alias = "window_col_" + (colIndex++);
}
belowProjectColumnNames.add(alias);
}
}
if (!windowingFix) {
// We do not need to do anything, we bail out
return;
}
// 3. We need to fix it, we create the two replacement project
// operators
RelNode newProjectRel = projectFactory.createProject(aggregate, Collections.emptyList(), belowProjectExprs, belowProjectColumnNames);
RelNode newTopProjectRel = projectFactory.createProject(newProjectRel, Collections.emptyList(), topProjectExprs, project.getRowType().getFieldNames());
call.transformTo(newTopProjectRel);
}
Aggregations