Search in sources :

Example 31 with Aggregate

use of org.apache.calcite.rel.core.Aggregate in project hive by apache.

the class RelFieldTrimmer method trimFields.

/**
 * Variant of {@link #trimFields(RelNode, ImmutableBitSet, Set)} for
 * {@link org.apache.calcite.rel.logical.LogicalAggregate}.
 */
public TrimResult trimFields(Aggregate aggregate, ImmutableBitSet fieldsUsed, Set<RelDataTypeField> extraFields) {
    // Fields:
    // 
    // | sys fields | group fields | indicator fields | agg functions |
    // 
    // Two kinds of trimming:
    // 
    // 1. If agg rel has system fields but none of these are used, create an
    // agg rel with no system fields.
    // 
    // 2. If aggregate functions are not used, remove them.
    // 
    // But group and indicator fields stay, even if they are not used.
    final RelDataType rowType = aggregate.getRowType();
    // Compute which input fields are used.
    // 1. group fields are always used
    final ImmutableBitSet.Builder inputFieldsUsed = aggregate.getGroupSet().rebuild();
    // 2. agg functions
    for (AggregateCall aggCall : aggregate.getAggCallList()) {
        inputFieldsUsed.addAll(aggCall.getArgList());
        if (aggCall.filterArg >= 0) {
            inputFieldsUsed.set(aggCall.filterArg);
        }
        inputFieldsUsed.addAll(RelCollations.ordinals(aggCall.collation));
    }
    // Create input with trimmed columns.
    final RelNode input = aggregate.getInput();
    final Set<RelDataTypeField> inputExtraFields = Collections.emptySet();
    final TrimResult trimResult = trimChild(aggregate, input, inputFieldsUsed.build(), inputExtraFields);
    final RelNode newInput = trimResult.left;
    final Mapping inputMapping = trimResult.right;
    // We have to return group keys and (if present) indicators.
    // So, pretend that the consumer asked for them.
    final int groupCount = aggregate.getGroupSet().cardinality();
    fieldsUsed = fieldsUsed.union(ImmutableBitSet.range(groupCount));
    // there's nothing to do.
    if (input == newInput && fieldsUsed.equals(ImmutableBitSet.range(rowType.getFieldCount()))) {
        return result(aggregate, Mappings.createIdentity(rowType.getFieldCount()));
    }
    // Which agg calls are used by our consumer?
    int j = groupCount;
    int usedAggCallCount = 0;
    for (int i = 0; i < aggregate.getAggCallList().size(); i++) {
        if (fieldsUsed.get(j++)) {
            ++usedAggCallCount;
        }
    }
    // Offset due to the number of system fields having changed.
    Mapping mapping = Mappings.create(MappingType.INVERSE_SURJECTION, rowType.getFieldCount(), groupCount + usedAggCallCount);
    final ImmutableBitSet newGroupSet = Mappings.apply(inputMapping, aggregate.getGroupSet());
    final ImmutableList<ImmutableBitSet> newGroupSets = ImmutableList.copyOf(Iterables.transform(aggregate.getGroupSets(), input1 -> Mappings.apply(inputMapping, input1)));
    // indicator fields first.
    for (j = 0; j < groupCount; j++) {
        mapping.set(j, j);
    }
    // Now create new agg calls, and populate mapping for them.
    final RelBuilder relBuilder = REL_BUILDER.get();
    relBuilder.push(newInput);
    final List<RelBuilder.AggCall> newAggCallList = new ArrayList<>();
    j = groupCount;
    for (AggregateCall aggCall : aggregate.getAggCallList()) {
        if (fieldsUsed.get(j)) {
            final ImmutableList<RexNode> args = relBuilder.fields(Mappings.apply2(inputMapping, aggCall.getArgList()));
            final RexNode filterArg = aggCall.filterArg < 0 ? null : relBuilder.field(Mappings.apply(inputMapping, aggCall.filterArg));
            RelBuilder.AggCall newAggCall = relBuilder.aggregateCall(aggCall.getAggregation(), args).distinct(aggCall.isDistinct()).filter(filterArg).approximate(aggCall.isApproximate()).sort(relBuilder.fields(aggCall.collation)).as(aggCall.name);
            mapping.set(j, groupCount + newAggCallList.size());
            newAggCallList.add(newAggCall);
        }
        ++j;
    }
    final RelBuilder.GroupKey groupKey = relBuilder.groupKey(newGroupSet, newGroupSets);
    relBuilder.aggregate(groupKey, newAggCallList);
    return result(relBuilder.build(), mapping);
}
Also used : RelDataTypeFactory(org.apache.calcite.rel.type.RelDataTypeFactory) Mappings(org.apache.calcite.util.mapping.Mappings) MappingType(org.apache.calcite.util.mapping.MappingType) LoggerFactory(org.slf4j.LoggerFactory) LogicalTableModify(org.apache.calcite.rel.logical.LogicalTableModify) LogicalTableFunctionScan(org.apache.calcite.rel.logical.LogicalTableFunctionScan) IntPair(org.apache.calcite.util.mapping.IntPair) BigDecimal(java.math.BigDecimal) RexUtil(org.apache.calcite.rex.RexUtil) CorrelationId(org.apache.calcite.rel.core.CorrelationId) RexNode(org.apache.calcite.rex.RexNode) RelBuilder(org.apache.calcite.tools.RelBuilder) RelOptCluster(org.apache.calcite.plan.RelOptCluster) ImmutableBitSet(org.apache.calcite.util.ImmutableBitSet) RexLiteral(org.apache.calcite.rex.RexLiteral) Set(java.util.Set) SqlExplainLevel(org.apache.calcite.sql.SqlExplainLevel) RelFieldCollation(org.apache.calcite.rel.RelFieldCollation) List(java.util.List) RelMetadataQuery(org.apache.calcite.rel.metadata.RelMetadataQuery) RelCollation(org.apache.calcite.rel.RelCollation) Sort(org.apache.calcite.rel.core.Sort) RexDynamicParam(org.apache.calcite.rex.RexDynamicParam) RelDataTypeField(org.apache.calcite.rel.type.RelDataTypeField) RexCorrelVariable(org.apache.calcite.rex.RexCorrelVariable) SqlExplainFormat(org.apache.calcite.sql.SqlExplainFormat) RelDataTypeImpl(org.apache.calcite.rel.type.RelDataTypeImpl) Project(org.apache.calcite.rel.core.Project) TableScan(org.apache.calcite.rel.core.TableScan) RexFieldAccess(org.apache.calcite.rex.RexFieldAccess) Iterables(com.google.common.collect.Iterables) Ord(org.apache.calcite.linq4j.Ord) SetOp(org.apache.calcite.rel.core.SetOp) Filter(org.apache.calcite.rel.core.Filter) RelOptUtil(org.apache.calcite.plan.RelOptUtil) Join(org.apache.calcite.rel.core.Join) ArrayList(java.util.ArrayList) ImmutableList(com.google.common.collect.ImmutableList) Pair(org.apache.calcite.util.Pair) Mapping(org.apache.calcite.util.mapping.Mapping) LogicalValues(org.apache.calcite.rel.logical.LogicalValues) RexPermuteInputsShuttle(org.apache.calcite.rex.RexPermuteInputsShuttle) ReflectiveVisitor(org.apache.calcite.util.ReflectiveVisitor) LinkedHashSet(java.util.LinkedHashSet) RelCollations(org.apache.calcite.rel.RelCollations) RelDataType(org.apache.calcite.rel.type.RelDataType) Bug(org.apache.calcite.util.Bug) Logger(org.slf4j.Logger) RexBuilder(org.apache.calcite.rex.RexBuilder) RelNode(org.apache.calcite.rel.RelNode) Aggregate(org.apache.calcite.rel.core.Aggregate) ReflectUtil(org.apache.calcite.util.ReflectUtil) RexVisitor(org.apache.calcite.rex.RexVisitor) JoinRelType(org.apache.calcite.rel.core.JoinRelType) AggregateCall(org.apache.calcite.rel.core.AggregateCall) CorrelationReferenceFinder(org.apache.calcite.sql2rel.CorrelationReferenceFinder) Util(org.apache.calcite.util.Util) Collections(java.util.Collections) RelBuilder(org.apache.calcite.tools.RelBuilder) ImmutableBitSet(org.apache.calcite.util.ImmutableBitSet) ArrayList(java.util.ArrayList) RelDataType(org.apache.calcite.rel.type.RelDataType) Mapping(org.apache.calcite.util.mapping.Mapping) AggregateCall(org.apache.calcite.rel.core.AggregateCall) RelDataTypeField(org.apache.calcite.rel.type.RelDataTypeField) RelNode(org.apache.calcite.rel.RelNode) RexNode(org.apache.calcite.rex.RexNode)

Example 32 with Aggregate

use of org.apache.calcite.rel.core.Aggregate in project hive by apache.

the class HiveRemoveSqCountCheck method onMatch.

@Override
public void onMatch(RelOptRuleCall call) {
    final Join topJoin = call.rel(0);
    final Join join = call.rel(2);
    final Aggregate aggregate = call.rel(6);
    // in presence of grouping sets we can't remove sq_count_check
    if (aggregate.indicator) {
        return;
    }
    if (isAggregateWithoutGbyKeys(aggregate) || isAggWithConstantGbyKeys(aggregate, call)) {
        // join(left, join.getRight)
        RelNode newJoin = HiveJoin.getJoin(topJoin.getCluster(), join.getLeft(), topJoin.getRight(), topJoin.getCondition(), topJoin.getJoinType());
        call.transformTo(newJoin);
    }
}
Also used : RelNode(org.apache.calcite.rel.RelNode) HiveJoin(org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveJoin) Join(org.apache.calcite.rel.core.Join) Aggregate(org.apache.calcite.rel.core.Aggregate)

Example 33 with Aggregate

use of org.apache.calcite.rel.core.Aggregate in project hive by apache.

the class HiveAggregateIncrementalRewritingRuleBase method onMatch.

@Override
public void onMatch(RelOptRuleCall call) {
    final Aggregate agg = call.rel(aggregateIndex);
    final Union union = call.rel(1);
    final RelBuilder relBuilder = call.builder();
    final RexBuilder rexBuilder = relBuilder.getRexBuilder();
    // 1) First branch is query, second branch is MV
    RelNode joinLeftInput = union.getInput(1);
    final T joinRightInput = createJoinRightInput(call);
    if (joinRightInput == null) {
        return;
    }
    // 2) Introduce a Project on top of MV scan having all columns from the view plus a boolean literal which indicates
    // whether the row with the key values coming from the joinRightInput exists in the view:
    // - true means exist
    // - null means not exists
    // Project also needed to encapsulate the view scan by a subquery -> this is required by
    // CalcitePlanner.fixUpASTAggregateInsertIncrementalRebuild
    // CalcitePlanner.fixUpASTAggregateInsertDeleteIncrementalRebuild
    List<RexNode> mvCols = new ArrayList<>(joinLeftInput.getRowType().getFieldCount());
    for (int i = 0; i < joinLeftInput.getRowType().getFieldCount(); ++i) {
        mvCols.add(rexBuilder.makeInputRef(joinLeftInput.getRowType().getFieldList().get(i).getType(), i));
    }
    mvCols.add(rexBuilder.makeLiteral(true));
    joinLeftInput = relBuilder.push(joinLeftInput).project(mvCols).build();
    // 3) Build conditions for join and start adding
    // expressions for project operator
    List<RexNode> projExprs = new ArrayList<>();
    List<RexNode> joinConjs = new ArrayList<>();
    int groupCount = agg.getGroupCount();
    int totalCount = agg.getGroupCount() + agg.getAggCallList().size();
    for (int leftPos = 0, rightPos = totalCount + 1; leftPos < groupCount; leftPos++, rightPos++) {
        RexNode leftRef = rexBuilder.makeInputRef(joinLeftInput.getRowType().getFieldList().get(leftPos).getType(), leftPos);
        RexNode rightRef = rexBuilder.makeInputRef(joinRightInput.rightInput.getRowType().getFieldList().get(leftPos).getType(), rightPos);
        projExprs.add(rightRef);
        RexNode nsEqExpr = rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_DISTINCT_FROM, ImmutableList.of(leftRef, rightRef));
        joinConjs.add(nsEqExpr);
    }
    // 4) Create join node
    RexNode joinCond = RexUtil.composeConjunction(rexBuilder, joinConjs);
    RelNode join = relBuilder.push(joinLeftInput).push(joinRightInput.rightInput).join(JoinRelType.RIGHT, joinCond).build();
    // functions
    for (int i = 0, leftPos = groupCount, rightPos = totalCount + 1 + groupCount; leftPos < totalCount; i++, leftPos++, rightPos++) {
        // case when source.s is null and mv2.s is null then null
        // case when source.s IS null then mv2.s
        // case when mv2.s IS null then source.s
        // else source.s + mv2.s end
        RexNode leftRef = rexBuilder.makeInputRef(joinLeftInput.getRowType().getFieldList().get(leftPos).getType(), leftPos);
        RexNode rightRef = rexBuilder.makeInputRef(joinRightInput.rightInput.getRowType().getFieldList().get(leftPos).getType(), rightPos);
        // Generate SQLOperator for merging the aggregations
        SqlAggFunction aggCall = agg.getAggCallList().get(i).getAggregation();
        RexNode elseReturn = createAggregateNode(aggCall, leftRef, rightRef, rexBuilder);
        // According to SQL standard (and Hive) Aggregate functions eliminates null values however operators used in
        // elseReturn expressions returns null if one of their operands is null
        // hence we need a null check of both operands.
        // Note: If both are null, we will fall into branch    WHEN leftNull THEN rightRef
        RexNode leftNull = rexBuilder.makeCall(SqlStdOperatorTable.IS_NULL, leftRef);
        RexNode rightNull = rexBuilder.makeCall(SqlStdOperatorTable.IS_NULL, rightRef);
        projExprs.add(rexBuilder.makeCall(SqlStdOperatorTable.CASE, leftNull, rightRef, rightNull, leftRef, elseReturn));
    }
    int flagIndex = joinLeftInput.getRowType().getFieldCount() - 1;
    RexNode flagNode = rexBuilder.makeInputRef(join.getRowType().getFieldList().get(flagIndex).getType(), flagIndex);
    // 6) Build plan
    RelNode newNode = relBuilder.push(join).filter(createFilterCondition(joinRightInput, flagNode, projExprs, relBuilder)).project(projExprs).build();
    call.transformTo(newNode);
}
Also used : RelBuilder(org.apache.calcite.tools.RelBuilder) RelNode(org.apache.calcite.rel.RelNode) ArrayList(java.util.ArrayList) RexBuilder(org.apache.calcite.rex.RexBuilder) SqlAggFunction(org.apache.calcite.sql.SqlAggFunction) Aggregate(org.apache.calcite.rel.core.Aggregate) Union(org.apache.calcite.rel.core.Union) RexNode(org.apache.calcite.rex.RexNode)

Example 34 with Aggregate

use of org.apache.calcite.rel.core.Aggregate in project hive by apache.

the class HiveAggregateInsertDeleteIncrementalRewritingRule method createJoinRightInput.

@Override
protected IncrementalComputePlanWithDeletedRows createJoinRightInput(RelOptRuleCall call) {
    RelBuilder relBuilder = call.builder();
    Aggregate aggregate = call.rel(2);
    RexBuilder rexBuilder = relBuilder.getRexBuilder();
    RelNode aggInput = aggregate.getInput();
    // Propagate rowIsDeleted column
    aggInput = HiveHepExtractRelNodeRule.execute(aggInput);
    aggInput = new HiveRowIsDeletedPropagator(relBuilder).propagate(aggInput);
    int rowIsDeletedIdx = aggInput.getRowType().getFieldCount() - 1;
    RexNode rowIsDeletedNode = rexBuilder.makeInputRef(aggInput.getRowType().getFieldList().get(rowIsDeletedIdx).getType(), rowIsDeletedIdx);
    // Search for count(*)
    // and create case expression to aggregate inserted/deleted values:
    // 
    // SELECT
    // sum(case when t1.ROW__IS__DELETED then -b else b end) sumb,
    // sum(case when t1.ROW__IS__DELETED then -1 else 1 end) countStar
    int countIdx = -1;
    List<RelBuilder.AggCall> newAggregateCalls = new ArrayList<>(aggregate.getAggCallList().size());
    for (int i = 0; i < aggregate.getAggCallList().size(); ++i) {
        AggregateCall aggregateCall = aggregate.getAggCallList().get(i);
        if (aggregateCall.getAggregation().getKind() == SqlKind.COUNT && aggregateCall.getArgList().size() == 0) {
            countIdx = i + aggregate.getGroupCount();
        }
        RexNode argument;
        SqlAggFunction aggFunction;
        switch(aggregateCall.getAggregation().getKind()) {
            case COUNT:
                aggFunction = SqlStdOperatorTable.SUM;
                // count(*)
                if (aggregateCall.getArgList().isEmpty()) {
                    argument = relBuilder.literal(1);
                } else {
                    // count(<column_name>)
                    argument = genArgumentForCountColumn(relBuilder, aggInput, aggregateCall.getArgList().get(0));
                }
                break;
            case SUM:
                aggFunction = SqlStdOperatorTable.SUM;
                Integer argumentIdx = aggregateCall.getArgList().get(0);
                argument = rexBuilder.makeInputRef(aggInput.getRowType().getFieldList().get(argumentIdx).getType(), argumentIdx);
                break;
            default:
                throw new AssertionError("Found an aggregation that could not be" + " recognized: " + aggregateCall);
        }
        RexNode minus = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, relBuilder.literal(-1), argument);
        RexNode newArgument = rexBuilder.makeCall(SqlStdOperatorTable.CASE, rowIsDeletedNode, minus, argument);
        newAggregateCalls.add(relBuilder.aggregateCall(aggFunction, newArgument));
    }
    if (countIdx == -1) {
        // Can not rewrite, bail out;
        return null;
    }
    return new IncrementalComputePlanWithDeletedRows(relBuilder.push(aggInput).aggregate(relBuilder.groupKey(aggregate.getGroupSet()), newAggregateCalls).build(), countIdx);
}
Also used : RelBuilder(org.apache.calcite.tools.RelBuilder) ArrayList(java.util.ArrayList) SqlAggFunction(org.apache.calcite.sql.SqlAggFunction) AggregateCall(org.apache.calcite.rel.core.AggregateCall) RelNode(org.apache.calcite.rel.RelNode) RexBuilder(org.apache.calcite.rex.RexBuilder) Aggregate(org.apache.calcite.rel.core.Aggregate) RexNode(org.apache.calcite.rex.RexNode)

Example 35 with Aggregate

use of org.apache.calcite.rel.core.Aggregate in project hive by apache.

the class HiveAggregateJoinTransposeRule method onMatch.

@Override
public void onMatch(RelOptRuleCall call) {
    try {
        final Aggregate aggregate = call.rel(0);
        final Join join = call.rel(1);
        final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
        final RelBuilder relBuilder = call.builder();
        // If any aggregate call has a filter, bail out
        for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
            if (aggregateCall.getAggregation().unwrap(SqlSplittableAggFunction.class) == null) {
                return;
            }
            if (aggregateCall.filterArg >= 0) {
                return;
            }
        }
        // aggregate operator
        if (join.getJoinType() != JoinRelType.INNER) {
            return;
        }
        if (!allowFunctions && !aggregate.getAggCallList().isEmpty()) {
            return;
        }
        boolean groupingUnique = isGroupingUnique(join, aggregate.getGroupSet());
        if (!groupingUnique && !costBased) {
            // there is no need to check further - the transformation may not happen
            return;
        }
        // Do the columns used by the join appear in the output of the aggregate?
        final ImmutableBitSet aggregateColumns = aggregate.getGroupSet();
        final RelMetadataQuery mq = call.getMetadataQuery();
        final ImmutableBitSet keyColumns = keyColumns(aggregateColumns, mq.getPulledUpPredicates(join).pulledUpPredicates);
        final ImmutableBitSet joinColumns = RelOptUtil.InputFinder.bits(join.getCondition());
        final boolean allColumnsInAggregate = keyColumns.contains(joinColumns);
        final ImmutableBitSet belowAggregateColumns = aggregateColumns.union(joinColumns);
        // Split join condition
        final List<Integer> leftKeys = Lists.newArrayList();
        final List<Integer> rightKeys = Lists.newArrayList();
        final List<Boolean> filterNulls = Lists.newArrayList();
        RexNode nonEquiConj = RelOptUtil.splitJoinCondition(join.getLeft(), join.getRight(), join.getCondition(), leftKeys, rightKeys, filterNulls);
        // If it contains non-equi join conditions, we bail out
        if (!nonEquiConj.isAlwaysTrue()) {
            return;
        }
        // Push each aggregate function down to each side that contains all of its
        // arguments. Note that COUNT(*), because it has no arguments, can go to
        // both sides.
        final Map<Integer, Integer> map = new HashMap<>();
        final List<Side> sides = new ArrayList<>();
        int uniqueCount = 0;
        int offset = 0;
        int belowOffset = 0;
        for (int s = 0; s < 2; s++) {
            final Side side = new Side();
            final RelNode joinInput = join.getInput(s);
            int fieldCount = joinInput.getRowType().getFieldCount();
            final ImmutableBitSet fieldSet = ImmutableBitSet.range(offset, offset + fieldCount);
            final ImmutableBitSet belowAggregateKeyNotShifted = belowAggregateColumns.intersect(fieldSet);
            for (Ord<Integer> c : Ord.zip(belowAggregateKeyNotShifted)) {
                map.put(c.e, belowOffset + c.i);
            }
            final ImmutableBitSet belowAggregateKey = belowAggregateKeyNotShifted.shift(-offset);
            final boolean unique;
            if (!allowFunctions) {
                assert aggregate.getAggCallList().isEmpty();
                // If there are no functions, it doesn't matter as much whether we
                // aggregate the inputs before the join, because there will not be
                // any functions experiencing a cartesian product effect.
                // 
                // But finding out whether the input is already unique requires a call
                // to areColumnsUnique that currently (until [CALCITE-1048] "Make
                // metadata more robust" is fixed) places a heavy load on
                // the metadata system.
                // 
                // So we choose to imagine the the input is already unique, which is
                // untrue but harmless.
                // 
                unique = true;
            } else {
                final Boolean unique0 = mq.areColumnsUnique(joinInput, belowAggregateKey, true);
                unique = unique0 != null && unique0;
            }
            if (unique) {
                ++uniqueCount;
                relBuilder.push(joinInput);
                relBuilder.project(belowAggregateKey.asList().stream().map(relBuilder::field).collect(Collectors.toList()));
                side.newInput = relBuilder.build();
            } else {
                List<AggregateCall> belowAggCalls = new ArrayList<>();
                final SqlSplittableAggFunction.Registry<AggregateCall> belowAggCallRegistry = registry(belowAggCalls);
                final Mappings.TargetMapping mapping = s == 0 ? Mappings.createIdentity(fieldCount) : Mappings.createShiftMapping(fieldCount + offset, 0, offset, fieldCount);
                for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) {
                    final SqlAggFunction aggregation = aggCall.e.getAggregation();
                    final SqlSplittableAggFunction splitter = Preconditions.checkNotNull(aggregation.unwrap(SqlSplittableAggFunction.class));
                    final AggregateCall call1;
                    if (fieldSet.contains(ImmutableBitSet.of(aggCall.e.getArgList()))) {
                        call1 = splitter.split(aggCall.e, mapping);
                    } else {
                        call1 = splitter.other(rexBuilder.getTypeFactory(), aggCall.e);
                    }
                    if (call1 != null) {
                        side.split.put(aggCall.i, belowAggregateKey.cardinality() + belowAggCallRegistry.register(call1));
                    }
                }
                side.newInput = relBuilder.push(joinInput).aggregate(relBuilder.groupKey(belowAggregateKey, null), belowAggCalls).build();
            }
            offset += fieldCount;
            belowOffset += side.newInput.getRowType().getFieldCount();
            sides.add(side);
        }
        if (uniqueCount == 2) {
            // invocation of this rule; if we continue we might loop forever.
            return;
        }
        // Update condition
        final Mapping mapping = (Mapping) Mappings.target(map::get, join.getRowType().getFieldCount(), belowOffset);
        final RexNode newCondition = RexUtil.apply(mapping, join.getCondition());
        // Create new join
        relBuilder.push(sides.get(0).newInput).push(sides.get(1).newInput).join(join.getJoinType(), newCondition);
        // Aggregate above to sum up the sub-totals
        final List<AggregateCall> newAggCalls = new ArrayList<>();
        final int groupIndicatorCount = aggregate.getGroupCount() + aggregate.getIndicatorCount();
        final int newLeftWidth = sides.get(0).newInput.getRowType().getFieldCount();
        final List<RexNode> projects = new ArrayList<>(rexBuilder.identityProjects(relBuilder.peek().getRowType()));
        for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) {
            final SqlAggFunction aggregation = aggCall.e.getAggregation();
            final SqlSplittableAggFunction splitter = Preconditions.checkNotNull(aggregation.unwrap(SqlSplittableAggFunction.class));
            final Integer leftSubTotal = sides.get(0).split.get(aggCall.i);
            final Integer rightSubTotal = sides.get(1).split.get(aggCall.i);
            newAggCalls.add(splitter.topSplit(rexBuilder, registry(projects), groupIndicatorCount, relBuilder.peek().getRowType(), aggCall.e, leftSubTotal == null ? -1 : leftSubTotal, rightSubTotal == null ? -1 : rightSubTotal + newLeftWidth));
        }
        relBuilder.project(projects);
        boolean aggConvertedToProjects = false;
        if (allColumnsInAggregate) {
            // let's see if we can convert aggregate into projects
            List<RexNode> projects2 = new ArrayList<>();
            for (int key : Mappings.apply(mapping, aggregate.getGroupSet())) {
                projects2.add(relBuilder.field(key));
            }
            for (AggregateCall newAggCall : newAggCalls) {
                final SqlSplittableAggFunction splitter = newAggCall.getAggregation().unwrap(SqlSplittableAggFunction.class);
                if (splitter != null) {
                    final RelDataType rowType = relBuilder.peek().getRowType();
                    projects2.add(splitter.singleton(rexBuilder, rowType, newAggCall));
                }
            }
            if (projects2.size() == aggregate.getGroupSet().cardinality() + newAggCalls.size()) {
                // We successfully converted agg calls into projects.
                relBuilder.project(projects2);
                aggConvertedToProjects = true;
            }
        }
        if (!aggConvertedToProjects) {
            relBuilder.aggregate(relBuilder.groupKey(Mappings.apply(mapping, aggregate.getGroupSet()), Mappings.apply2(mapping, aggregate.getGroupSets())), newAggCalls);
        }
        RelNode r = relBuilder.build();
        boolean transform = false;
        if (uniqueBased && aggConvertedToProjects) {
            transform = groupingUnique;
        }
        if (!transform && costBased) {
            RelOptCost afterCost = mq.getCumulativeCost(r);
            RelOptCost beforeCost = mq.getCumulativeCost(aggregate);
            transform = afterCost.isLt(beforeCost);
        }
        if (transform) {
            call.transformTo(r);
        }
    } catch (Exception e) {
        if (noColsMissingStats.get() > 0) {
            LOG.warn("Missing column stats (see previous messages), skipping aggregate-join transpose in CBO");
            noColsMissingStats.set(0);
        } else {
            throw e;
        }
    }
}
Also used : RelMetadataQuery(org.apache.calcite.rel.metadata.RelMetadataQuery) ImmutableBitSet(org.apache.calcite.util.ImmutableBitSet) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) Mapping(org.apache.calcite.util.mapping.Mapping) RelDataType(org.apache.calcite.rel.type.RelDataType) RelOptCost(org.apache.calcite.plan.RelOptCost) RexBuilder(org.apache.calcite.rex.RexBuilder) SqlSplittableAggFunction(org.apache.calcite.sql.SqlSplittableAggFunction) RelBuilder(org.apache.calcite.tools.RelBuilder) Join(org.apache.calcite.rel.core.Join) HiveJoin(org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveJoin) SqlAggFunction(org.apache.calcite.sql.SqlAggFunction) AggregateCall(org.apache.calcite.rel.core.AggregateCall) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) RelNode(org.apache.calcite.rel.RelNode) Mappings(org.apache.calcite.util.mapping.Mappings) HiveAggregate(org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate) Aggregate(org.apache.calcite.rel.core.Aggregate) RexNode(org.apache.calcite.rex.RexNode)

Aggregations

Aggregate (org.apache.calcite.rel.core.Aggregate)75 RelNode (org.apache.calcite.rel.RelNode)44 ArrayList (java.util.ArrayList)39 AggregateCall (org.apache.calcite.rel.core.AggregateCall)37 RexNode (org.apache.calcite.rex.RexNode)32 RelBuilder (org.apache.calcite.tools.RelBuilder)27 ImmutableBitSet (org.apache.calcite.util.ImmutableBitSet)27 LogicalAggregate (org.apache.calcite.rel.logical.LogicalAggregate)23 RexBuilder (org.apache.calcite.rex.RexBuilder)22 Project (org.apache.calcite.rel.core.Project)21 HashMap (java.util.HashMap)17 RexInputRef (org.apache.calcite.rex.RexInputRef)15 Join (org.apache.calcite.rel.core.Join)14 RelMetadataQuery (org.apache.calcite.rel.metadata.RelMetadataQuery)14 RelDataType (org.apache.calcite.rel.type.RelDataType)14 RelDataTypeField (org.apache.calcite.rel.type.RelDataTypeField)14 Filter (org.apache.calcite.rel.core.Filter)13 HiveAggregate (org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate)13 List (java.util.List)12 ImmutableList (com.google.common.collect.ImmutableList)11