Search in sources :

Example 71 with AggregateCall

use of org.apache.calcite.rel.core.AggregateCall in project hive by apache.

the class HiveRelFieldTrimmer method trimFields.

@Override
public TrimResult trimFields(Aggregate aggregate, ImmutableBitSet fieldsUsed, Set<RelDataTypeField> extraFields) {
    // Fields:
    // 
    // | sys fields | group fields | indicator fields | agg functions |
    // 
    // Two kinds of trimming:
    // 
    // 1. If agg rel has system fields but none of these are used, create an
    // agg rel with no system fields.
    // 
    // 2. If aggregate functions are not used, remove them.
    // 
    // But group and indicator fields stay, even if they are not used.
    // Compute which input fields are used.
    // agg functions
    // agg functions are added first (before group sets) because rewriteGBConstantsKeys
    // needs it
    final ImmutableBitSet.Builder aggCallFieldsUsedBuilder = ImmutableBitSet.builder();
    for (AggregateCall aggCall : aggregate.getAggCallList()) {
        for (int i : aggCall.getArgList()) {
            aggCallFieldsUsedBuilder.set(i);
        }
        if (aggCall.filterArg >= 0) {
            aggCallFieldsUsedBuilder.set(aggCall.filterArg);
        }
    }
    // transform if group by contain constant keys
    ImmutableBitSet aggCallFieldsUsed = aggCallFieldsUsedBuilder.build();
    aggregate = rewriteGBConstantKeys(aggregate, fieldsUsed, aggCallFieldsUsed);
    // add group fields
    final ImmutableBitSet.Builder inputFieldsUsed = aggregate.getGroupSet().rebuild();
    inputFieldsUsed.addAll(aggCallFieldsUsed);
    final RelDataType rowType = aggregate.getRowType();
    // Create input with trimmed columns.
    final RelNode input = aggregate.getInput();
    final Set<RelDataTypeField> inputExtraFields = Collections.emptySet();
    final TrimResult trimResult = trimChild(aggregate, input, inputFieldsUsed.build(), inputExtraFields);
    final RelNode newInput = trimResult.left;
    final Mapping inputMapping = trimResult.right;
    ImmutableBitSet originalGroupSet = aggregate.getGroupSet();
    ImmutableBitSet updatedGroupSet = generateNewGroupset(aggregate, fieldsUsed);
    ImmutableBitSet gbKeysDeleted = originalGroupSet.except(updatedGroupSet);
    ImmutableBitSet updatedGroupFields = ImmutableBitSet.range(originalGroupSet.cardinality());
    final int updatedGroupCount = updatedGroupSet.cardinality();
    // we need to clear the bits corresponding to deleted gb keys
    int setIdx = 0;
    while (setIdx != -1) {
        setIdx = gbKeysDeleted.nextSetBit(setIdx);
        if (setIdx != -1) {
            updatedGroupFields = updatedGroupFields.clear(setIdx);
            setIdx++;
        }
    }
    fieldsUsed = fieldsUsed.union(updatedGroupFields);
    // there's nothing to do.
    if (input == newInput && fieldsUsed.equals(ImmutableBitSet.range(rowType.getFieldCount()))) {
        return result(aggregate, Mappings.createIdentity(rowType.getFieldCount()));
    }
    // update the group by keys based on inputMapping
    ImmutableBitSet newGroupSet = Mappings.apply(inputMapping, updatedGroupSet);
    // Which agg calls are used by our consumer?
    int originalGroupCount = aggregate.getGroupSet().cardinality();
    int j = originalGroupCount;
    int usedAggCallCount = 0;
    for (int i = 0; i < aggregate.getAggCallList().size(); i++) {
        if (fieldsUsed.get(j++)) {
            ++usedAggCallCount;
        }
    }
    // Offset due to the number of system fields having changed.
    Mapping mapping = Mappings.create(MappingType.INVERSE_SURJECTION, rowType.getFieldCount(), updatedGroupCount + usedAggCallCount);
    // if group keys were reduced, it means we didn't have grouping therefore
    // we don't need to transform group sets
    ImmutableList<ImmutableBitSet> newGroupSets = null;
    if (!updatedGroupSet.equals(aggregate.getGroupSet())) {
        newGroupSets = ImmutableList.of(newGroupSet);
    } else {
        newGroupSets = ImmutableList.copyOf(Iterables.transform(aggregate.getGroupSets(), input1 -> Mappings.apply(inputMapping, input1)));
    }
    // Populate mapping of where to find the fields. System, group key and
    // indicator fields first.
    int gbKeyIdx = 0;
    for (j = 0; j < originalGroupCount; j++) {
        if (fieldsUsed.get(j)) {
            mapping.set(j, gbKeyIdx);
            gbKeyIdx++;
        }
    }
    // Now create new agg calls, and populate mapping for them.
    final RelBuilder relBuilder = REL_BUILDER.get();
    relBuilder.push(newInput);
    final List<RelBuilder.AggCall> newAggCallList = new ArrayList<>();
    // because lookup in fieldsUsed is done using original group count
    j = originalGroupCount;
    for (AggregateCall aggCall : aggregate.getAggCallList()) {
        if (fieldsUsed.get(j)) {
            final ImmutableList<RexNode> args = relBuilder.fields(Mappings.apply2(inputMapping, aggCall.getArgList()));
            final RexNode filterArg = aggCall.filterArg < 0 ? null : relBuilder.field(Mappings.apply(inputMapping, aggCall.filterArg));
            RelBuilder.AggCall newAggCall = relBuilder.aggregateCall(aggCall.getAggregation(), aggCall.isDistinct(), aggCall.isApproximate(), filterArg, aggCall.name, args);
            mapping.set(j, updatedGroupCount + newAggCallList.size());
            newAggCallList.add(newAggCall);
        }
        ++j;
    }
    final RelBuilder.GroupKey groupKey = relBuilder.groupKey(newGroupSet, newGroupSets);
    relBuilder.aggregate(groupKey, newAggCallList);
    return result(relBuilder.build(), mapping);
}
Also used : RelBuilder(org.apache.calcite.tools.RelBuilder) ImmutableBitSet(org.apache.calcite.util.ImmutableBitSet) ArrayList(java.util.ArrayList) RelDataType(org.apache.calcite.rel.type.RelDataType) Mapping(org.apache.calcite.util.mapping.Mapping) AggregateCall(org.apache.calcite.rel.core.AggregateCall) RelDataTypeField(org.apache.calcite.rel.type.RelDataTypeField) RelNode(org.apache.calcite.rel.RelNode) RexNode(org.apache.calcite.rex.RexNode)

Example 72 with AggregateCall

use of org.apache.calcite.rel.core.AggregateCall in project hive by apache.

the class HiveIntersectRewriteRule method onMatch.

// ~ Methods ----------------------------------------------------------------
public void onMatch(RelOptRuleCall call) {
    final HiveIntersect hiveIntersect = call.rel(0);
    final RelOptCluster cluster = hiveIntersect.getCluster();
    final RexBuilder rexBuilder = cluster.getRexBuilder();
    int numOfBranch = hiveIntersect.getInputs().size();
    Builder<RelNode> bldr = new ImmutableList.Builder<RelNode>();
    // 1st level GB: create a GB (col0, col1, count(1) as c) for each branch
    for (int index = 0; index < numOfBranch; index++) {
        RelNode input = hiveIntersect.getInputs().get(index);
        final List<RexNode> gbChildProjLst = Lists.newArrayList();
        final List<Integer> groupSetPositions = Lists.newArrayList();
        for (int cInd = 0; cInd < input.getRowType().getFieldList().size(); cInd++) {
            gbChildProjLst.add(rexBuilder.makeInputRef(input, cInd));
            groupSetPositions.add(cInd);
        }
        gbChildProjLst.add(rexBuilder.makeBigintLiteral(new BigDecimal(1)));
        // create the project before GB because we need a new project with extra column '1'.
        RelNode gbInputRel = null;
        try {
            gbInputRel = HiveProject.create(input, gbChildProjLst, null);
        } catch (CalciteSemanticException e) {
            LOG.debug(e.toString());
            throw new RuntimeException(e);
        }
        // groupSetPosition includes all the positions
        final ImmutableBitSet groupSet = ImmutableBitSet.of(groupSetPositions);
        List<AggregateCall> aggregateCalls = Lists.newArrayList();
        RelDataType aggFnRetType = TypeConverter.convert(TypeInfoFactory.longTypeInfo, cluster.getTypeFactory());
        // count(1), 1's position is input.getRowType().getFieldList().size()
        AggregateCall aggregateCall = HiveCalciteUtil.createSingleArgAggCall("count", cluster, TypeInfoFactory.longTypeInfo, input.getRowType().getFieldList().size(), aggFnRetType);
        aggregateCalls.add(aggregateCall);
        HiveRelNode aggregateRel = new HiveAggregate(cluster, cluster.traitSetOf(HiveRelNode.CONVENTION), gbInputRel, groupSet, null, aggregateCalls);
        bldr.add(aggregateRel);
    }
    // create a union above all the branches
    HiveRelNode union = new HiveUnion(cluster, TraitsUtil.getDefaultTraitSet(cluster), bldr.build());
    // 2nd level GB: create a GB (col0, col1, count(c)) for each branch
    final List<Integer> groupSetPositions = Lists.newArrayList();
    // the index of c
    int cInd = union.getRowType().getFieldList().size() - 1;
    for (int index = 0; index < union.getRowType().getFieldList().size(); index++) {
        if (index != cInd) {
            groupSetPositions.add(index);
        }
    }
    List<AggregateCall> aggregateCalls = Lists.newArrayList();
    RelDataType aggFnRetType = TypeConverter.convert(TypeInfoFactory.longTypeInfo, cluster.getTypeFactory());
    AggregateCall aggregateCall = HiveCalciteUtil.createSingleArgAggCall("count", cluster, TypeInfoFactory.longTypeInfo, cInd, aggFnRetType);
    aggregateCalls.add(aggregateCall);
    if (hiveIntersect.all) {
        aggregateCall = HiveCalciteUtil.createSingleArgAggCall("min", cluster, TypeInfoFactory.longTypeInfo, cInd, aggFnRetType);
        aggregateCalls.add(aggregateCall);
    }
    final ImmutableBitSet groupSet = ImmutableBitSet.of(groupSetPositions);
    HiveRelNode aggregateRel = new HiveAggregate(cluster, cluster.traitSetOf(HiveRelNode.CONVENTION), union, groupSet, null, aggregateCalls);
    // add a filter count(c) = #branches
    int countInd = cInd;
    List<RexNode> childRexNodeLst = new ArrayList<RexNode>();
    RexInputRef ref = rexBuilder.makeInputRef(aggregateRel, countInd);
    RexLiteral literal = rexBuilder.makeBigintLiteral(new BigDecimal(numOfBranch));
    childRexNodeLst.add(ref);
    childRexNodeLst.add(literal);
    ImmutableList.Builder<RelDataType> calciteArgTypesBldr = new ImmutableList.Builder<RelDataType>();
    calciteArgTypesBldr.add(TypeConverter.convert(TypeInfoFactory.longTypeInfo, cluster.getTypeFactory()));
    calciteArgTypesBldr.add(TypeConverter.convert(TypeInfoFactory.longTypeInfo, cluster.getTypeFactory()));
    RexNode factoredFilterExpr = null;
    try {
        factoredFilterExpr = rexBuilder.makeCall(SqlFunctionConverter.getCalciteFn("=", calciteArgTypesBldr.build(), TypeConverter.convert(TypeInfoFactory.longTypeInfo, cluster.getTypeFactory()), true, false), childRexNodeLst);
    } catch (CalciteSemanticException e) {
        LOG.debug(e.toString());
        throw new RuntimeException(e);
    }
    RelNode filterRel = new HiveFilter(cluster, cluster.traitSetOf(HiveRelNode.CONVENTION), aggregateRel, factoredFilterExpr);
    if (!hiveIntersect.all) {
        // the schema for intersect distinct is like this
        // R3 on all attributes + count(c) as cnt
        // finally add a project to project out the last column
        Set<Integer> projectOutColumnPositions = new HashSet<>();
        projectOutColumnPositions.add(filterRel.getRowType().getFieldList().size() - 1);
        try {
            call.transformTo(HiveCalciteUtil.createProjectWithoutColumn(filterRel, projectOutColumnPositions));
        } catch (CalciteSemanticException e) {
            LOG.debug(e.toString());
            throw new RuntimeException(e);
        }
    } else {
        // the schema for intersect all is like this
        // R3 + count(c) as cnt + min(c) as m
        // we create a input project for udtf whose schema is like this
        // min(c) as m + R3
        List<RexNode> originalInputRefs = Lists.transform(filterRel.getRowType().getFieldList(), new Function<RelDataTypeField, RexNode>() {

            @Override
            public RexNode apply(RelDataTypeField input) {
                return new RexInputRef(input.getIndex(), input.getType());
            }
        });
        List<RexNode> copyInputRefs = new ArrayList<>();
        copyInputRefs.add(originalInputRefs.get(originalInputRefs.size() - 1));
        for (int i = 0; i < originalInputRefs.size() - 2; i++) {
            copyInputRefs.add(originalInputRefs.get(i));
        }
        RelNode srcRel = null;
        try {
            srcRel = HiveProject.create(filterRel, copyInputRefs, null);
            HiveTableFunctionScan udtf = HiveCalciteUtil.createUDTFForSetOp(cluster, srcRel);
            // finally add a project to project out the 1st column
            Set<Integer> projectOutColumnPositions = new HashSet<>();
            projectOutColumnPositions.add(0);
            call.transformTo(HiveCalciteUtil.createProjectWithoutColumn(udtf, projectOutColumnPositions));
        } catch (SemanticException e) {
            LOG.debug(e.toString());
            throw new RuntimeException(e);
        }
    }
}
Also used : RelOptCluster(org.apache.calcite.plan.RelOptCluster) RexLiteral(org.apache.calcite.rex.RexLiteral) ImmutableBitSet(org.apache.calcite.util.ImmutableBitSet) ImmutableList(com.google.common.collect.ImmutableList) RexBuilder(org.apache.calcite.rex.RexBuilder) Builder(com.google.common.collect.ImmutableList.Builder) HiveTableFunctionScan(org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveTableFunctionScan) HiveRelNode(org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveRelNode) ArrayList(java.util.ArrayList) RelDataType(org.apache.calcite.rel.type.RelDataType) RexBuilder(org.apache.calcite.rex.RexBuilder) CalciteSemanticException(org.apache.hadoop.hive.ql.optimizer.calcite.CalciteSemanticException) HashSet(java.util.HashSet) SemanticException(org.apache.hadoop.hive.ql.parse.SemanticException) CalciteSemanticException(org.apache.hadoop.hive.ql.optimizer.calcite.CalciteSemanticException) HiveIntersect(org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveIntersect) HiveUnion(org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveUnion) BigDecimal(java.math.BigDecimal) HiveFilter(org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveFilter) AggregateCall(org.apache.calcite.rel.core.AggregateCall) HiveAggregate(org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate) RelDataTypeField(org.apache.calcite.rel.type.RelDataTypeField) HiveRelNode(org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveRelNode) RelNode(org.apache.calcite.rel.RelNode) RexInputRef(org.apache.calcite.rex.RexInputRef) RexNode(org.apache.calcite.rex.RexNode)

Example 73 with AggregateCall

use of org.apache.calcite.rel.core.AggregateCall in project hive by apache.

the class JDBCAggregationPushDownRule method matches.

@Override
public boolean matches(RelOptRuleCall call) {
    final HiveAggregate agg = call.rel(0);
    final HiveJdbcConverter converter = call.rel(1);
    if (agg.getGroupType() != Group.SIMPLE) {
        // TODO: Grouping sets not supported yet
        return false;
    }
    for (AggregateCall relOptRuleOperand : agg.getAggCallList()) {
        SqlAggFunction f = relOptRuleOperand.getAggregation();
        if (f instanceof HiveSqlCountAggFunction) {
            // count distinct with more that one argument is not supported
            HiveSqlCountAggFunction countAgg = (HiveSqlCountAggFunction) f;
            if (countAgg.isDistinct() && 1 < relOptRuleOperand.getArgList().size()) {
                return false;
            }
        }
        SqlKind kind = f.getKind();
        if (!converter.getJdbcDialect().supportsAggregateFunction(kind)) {
            return false;
        }
    }
    return true;
}
Also used : AggregateCall(org.apache.calcite.rel.core.AggregateCall) HiveAggregate(org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate) HiveSqlCountAggFunction(org.apache.hadoop.hive.ql.optimizer.calcite.functions.HiveSqlCountAggFunction) HiveJdbcConverter(org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.jdbc.HiveJdbcConverter) SqlAggFunction(org.apache.calcite.sql.SqlAggFunction) SqlKind(org.apache.calcite.sql.SqlKind)

Example 74 with AggregateCall

use of org.apache.calcite.rel.core.AggregateCall in project hive by apache.

the class HiveAggregateInsertDeleteIncrementalRewritingRule method createJoinRightInput.

@Override
protected IncrementalComputePlanWithDeletedRows createJoinRightInput(RelOptRuleCall call) {
    RelBuilder relBuilder = call.builder();
    Aggregate aggregate = call.rel(2);
    RexBuilder rexBuilder = relBuilder.getRexBuilder();
    RelNode aggInput = aggregate.getInput();
    // Propagate rowIsDeleted column
    aggInput = HiveHepExtractRelNodeRule.execute(aggInput);
    aggInput = new HiveRowIsDeletedPropagator(relBuilder).propagate(aggInput);
    int rowIsDeletedIdx = aggInput.getRowType().getFieldCount() - 1;
    RexNode rowIsDeletedNode = rexBuilder.makeInputRef(aggInput.getRowType().getFieldList().get(rowIsDeletedIdx).getType(), rowIsDeletedIdx);
    // Search for count(*)
    // and create case expression to aggregate inserted/deleted values:
    // 
    // SELECT
    // sum(case when t1.ROW__IS__DELETED then -b else b end) sumb,
    // sum(case when t1.ROW__IS__DELETED then -1 else 1 end) countStar
    int countIdx = -1;
    List<RelBuilder.AggCall> newAggregateCalls = new ArrayList<>(aggregate.getAggCallList().size());
    for (int i = 0; i < aggregate.getAggCallList().size(); ++i) {
        AggregateCall aggregateCall = aggregate.getAggCallList().get(i);
        if (aggregateCall.getAggregation().getKind() == SqlKind.COUNT && aggregateCall.getArgList().size() == 0) {
            countIdx = i + aggregate.getGroupCount();
        }
        RexNode argument;
        SqlAggFunction aggFunction;
        switch(aggregateCall.getAggregation().getKind()) {
            case COUNT:
                aggFunction = SqlStdOperatorTable.SUM;
                // count(*)
                if (aggregateCall.getArgList().isEmpty()) {
                    argument = relBuilder.literal(1);
                } else {
                    // count(<column_name>)
                    argument = genArgumentForCountColumn(relBuilder, aggInput, aggregateCall.getArgList().get(0));
                }
                break;
            case SUM:
                aggFunction = SqlStdOperatorTable.SUM;
                Integer argumentIdx = aggregateCall.getArgList().get(0);
                argument = rexBuilder.makeInputRef(aggInput.getRowType().getFieldList().get(argumentIdx).getType(), argumentIdx);
                break;
            default:
                throw new AssertionError("Found an aggregation that could not be" + " recognized: " + aggregateCall);
        }
        RexNode minus = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, relBuilder.literal(-1), argument);
        RexNode newArgument = rexBuilder.makeCall(SqlStdOperatorTable.CASE, rowIsDeletedNode, minus, argument);
        newAggregateCalls.add(relBuilder.aggregateCall(aggFunction, newArgument));
    }
    if (countIdx == -1) {
        // Can not rewrite, bail out;
        return null;
    }
    return new IncrementalComputePlanWithDeletedRows(relBuilder.push(aggInput).aggregate(relBuilder.groupKey(aggregate.getGroupSet()), newAggregateCalls).build(), countIdx);
}
Also used : RelBuilder(org.apache.calcite.tools.RelBuilder) ArrayList(java.util.ArrayList) SqlAggFunction(org.apache.calcite.sql.SqlAggFunction) AggregateCall(org.apache.calcite.rel.core.AggregateCall) RelNode(org.apache.calcite.rel.RelNode) RexBuilder(org.apache.calcite.rex.RexBuilder) Aggregate(org.apache.calcite.rel.core.Aggregate) RexNode(org.apache.calcite.rex.RexNode)

Example 75 with AggregateCall

use of org.apache.calcite.rel.core.AggregateCall in project hive by apache.

the class HiveAggregateJoinTransposeRule method onMatch.

@Override
public void onMatch(RelOptRuleCall call) {
    try {
        final Aggregate aggregate = call.rel(0);
        final Join join = call.rel(1);
        final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
        final RelBuilder relBuilder = call.builder();
        // If any aggregate call has a filter, bail out
        for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
            if (aggregateCall.getAggregation().unwrap(SqlSplittableAggFunction.class) == null) {
                return;
            }
            if (aggregateCall.filterArg >= 0) {
                return;
            }
        }
        // aggregate operator
        if (join.getJoinType() != JoinRelType.INNER) {
            return;
        }
        if (!allowFunctions && !aggregate.getAggCallList().isEmpty()) {
            return;
        }
        boolean groupingUnique = isGroupingUnique(join, aggregate.getGroupSet());
        if (!groupingUnique && !costBased) {
            // there is no need to check further - the transformation may not happen
            return;
        }
        // Do the columns used by the join appear in the output of the aggregate?
        final ImmutableBitSet aggregateColumns = aggregate.getGroupSet();
        final RelMetadataQuery mq = call.getMetadataQuery();
        final ImmutableBitSet keyColumns = keyColumns(aggregateColumns, mq.getPulledUpPredicates(join).pulledUpPredicates);
        final ImmutableBitSet joinColumns = RelOptUtil.InputFinder.bits(join.getCondition());
        final boolean allColumnsInAggregate = keyColumns.contains(joinColumns);
        final ImmutableBitSet belowAggregateColumns = aggregateColumns.union(joinColumns);
        // Split join condition
        final List<Integer> leftKeys = Lists.newArrayList();
        final List<Integer> rightKeys = Lists.newArrayList();
        final List<Boolean> filterNulls = Lists.newArrayList();
        RexNode nonEquiConj = RelOptUtil.splitJoinCondition(join.getLeft(), join.getRight(), join.getCondition(), leftKeys, rightKeys, filterNulls);
        // If it contains non-equi join conditions, we bail out
        if (!nonEquiConj.isAlwaysTrue()) {
            return;
        }
        // Push each aggregate function down to each side that contains all of its
        // arguments. Note that COUNT(*), because it has no arguments, can go to
        // both sides.
        final Map<Integer, Integer> map = new HashMap<>();
        final List<Side> sides = new ArrayList<>();
        int uniqueCount = 0;
        int offset = 0;
        int belowOffset = 0;
        for (int s = 0; s < 2; s++) {
            final Side side = new Side();
            final RelNode joinInput = join.getInput(s);
            int fieldCount = joinInput.getRowType().getFieldCount();
            final ImmutableBitSet fieldSet = ImmutableBitSet.range(offset, offset + fieldCount);
            final ImmutableBitSet belowAggregateKeyNotShifted = belowAggregateColumns.intersect(fieldSet);
            for (Ord<Integer> c : Ord.zip(belowAggregateKeyNotShifted)) {
                map.put(c.e, belowOffset + c.i);
            }
            final ImmutableBitSet belowAggregateKey = belowAggregateKeyNotShifted.shift(-offset);
            final boolean unique;
            if (!allowFunctions) {
                assert aggregate.getAggCallList().isEmpty();
                // If there are no functions, it doesn't matter as much whether we
                // aggregate the inputs before the join, because there will not be
                // any functions experiencing a cartesian product effect.
                // 
                // But finding out whether the input is already unique requires a call
                // to areColumnsUnique that currently (until [CALCITE-1048] "Make
                // metadata more robust" is fixed) places a heavy load on
                // the metadata system.
                // 
                // So we choose to imagine the the input is already unique, which is
                // untrue but harmless.
                // 
                unique = true;
            } else {
                final Boolean unique0 = mq.areColumnsUnique(joinInput, belowAggregateKey, true);
                unique = unique0 != null && unique0;
            }
            if (unique) {
                ++uniqueCount;
                relBuilder.push(joinInput);
                relBuilder.project(belowAggregateKey.asList().stream().map(relBuilder::field).collect(Collectors.toList()));
                side.newInput = relBuilder.build();
            } else {
                List<AggregateCall> belowAggCalls = new ArrayList<>();
                final SqlSplittableAggFunction.Registry<AggregateCall> belowAggCallRegistry = registry(belowAggCalls);
                final Mappings.TargetMapping mapping = s == 0 ? Mappings.createIdentity(fieldCount) : Mappings.createShiftMapping(fieldCount + offset, 0, offset, fieldCount);
                for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) {
                    final SqlAggFunction aggregation = aggCall.e.getAggregation();
                    final SqlSplittableAggFunction splitter = Preconditions.checkNotNull(aggregation.unwrap(SqlSplittableAggFunction.class));
                    final AggregateCall call1;
                    if (fieldSet.contains(ImmutableBitSet.of(aggCall.e.getArgList()))) {
                        call1 = splitter.split(aggCall.e, mapping);
                    } else {
                        call1 = splitter.other(rexBuilder.getTypeFactory(), aggCall.e);
                    }
                    if (call1 != null) {
                        side.split.put(aggCall.i, belowAggregateKey.cardinality() + belowAggCallRegistry.register(call1));
                    }
                }
                side.newInput = relBuilder.push(joinInput).aggregate(relBuilder.groupKey(belowAggregateKey, null), belowAggCalls).build();
            }
            offset += fieldCount;
            belowOffset += side.newInput.getRowType().getFieldCount();
            sides.add(side);
        }
        if (uniqueCount == 2) {
            // invocation of this rule; if we continue we might loop forever.
            return;
        }
        // Update condition
        final Mapping mapping = (Mapping) Mappings.target(map::get, join.getRowType().getFieldCount(), belowOffset);
        final RexNode newCondition = RexUtil.apply(mapping, join.getCondition());
        // Create new join
        relBuilder.push(sides.get(0).newInput).push(sides.get(1).newInput).join(join.getJoinType(), newCondition);
        // Aggregate above to sum up the sub-totals
        final List<AggregateCall> newAggCalls = new ArrayList<>();
        final int groupIndicatorCount = aggregate.getGroupCount() + aggregate.getIndicatorCount();
        final int newLeftWidth = sides.get(0).newInput.getRowType().getFieldCount();
        final List<RexNode> projects = new ArrayList<>(rexBuilder.identityProjects(relBuilder.peek().getRowType()));
        for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) {
            final SqlAggFunction aggregation = aggCall.e.getAggregation();
            final SqlSplittableAggFunction splitter = Preconditions.checkNotNull(aggregation.unwrap(SqlSplittableAggFunction.class));
            final Integer leftSubTotal = sides.get(0).split.get(aggCall.i);
            final Integer rightSubTotal = sides.get(1).split.get(aggCall.i);
            newAggCalls.add(splitter.topSplit(rexBuilder, registry(projects), groupIndicatorCount, relBuilder.peek().getRowType(), aggCall.e, leftSubTotal == null ? -1 : leftSubTotal, rightSubTotal == null ? -1 : rightSubTotal + newLeftWidth));
        }
        relBuilder.project(projects);
        boolean aggConvertedToProjects = false;
        if (allColumnsInAggregate) {
            // let's see if we can convert aggregate into projects
            List<RexNode> projects2 = new ArrayList<>();
            for (int key : Mappings.apply(mapping, aggregate.getGroupSet())) {
                projects2.add(relBuilder.field(key));
            }
            for (AggregateCall newAggCall : newAggCalls) {
                final SqlSplittableAggFunction splitter = newAggCall.getAggregation().unwrap(SqlSplittableAggFunction.class);
                if (splitter != null) {
                    final RelDataType rowType = relBuilder.peek().getRowType();
                    projects2.add(splitter.singleton(rexBuilder, rowType, newAggCall));
                }
            }
            if (projects2.size() == aggregate.getGroupSet().cardinality() + newAggCalls.size()) {
                // We successfully converted agg calls into projects.
                relBuilder.project(projects2);
                aggConvertedToProjects = true;
            }
        }
        if (!aggConvertedToProjects) {
            relBuilder.aggregate(relBuilder.groupKey(Mappings.apply(mapping, aggregate.getGroupSet()), Mappings.apply2(mapping, aggregate.getGroupSets())), newAggCalls);
        }
        RelNode r = relBuilder.build();
        boolean transform = false;
        if (uniqueBased && aggConvertedToProjects) {
            transform = groupingUnique;
        }
        if (!transform && costBased) {
            RelOptCost afterCost = mq.getCumulativeCost(r);
            RelOptCost beforeCost = mq.getCumulativeCost(aggregate);
            transform = afterCost.isLt(beforeCost);
        }
        if (transform) {
            call.transformTo(r);
        }
    } catch (Exception e) {
        if (noColsMissingStats.get() > 0) {
            LOG.warn("Missing column stats (see previous messages), skipping aggregate-join transpose in CBO");
            noColsMissingStats.set(0);
        } else {
            throw e;
        }
    }
}
Also used : RelMetadataQuery(org.apache.calcite.rel.metadata.RelMetadataQuery) ImmutableBitSet(org.apache.calcite.util.ImmutableBitSet) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) Mapping(org.apache.calcite.util.mapping.Mapping) RelDataType(org.apache.calcite.rel.type.RelDataType) RelOptCost(org.apache.calcite.plan.RelOptCost) RexBuilder(org.apache.calcite.rex.RexBuilder) SqlSplittableAggFunction(org.apache.calcite.sql.SqlSplittableAggFunction) RelBuilder(org.apache.calcite.tools.RelBuilder) Join(org.apache.calcite.rel.core.Join) HiveJoin(org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveJoin) SqlAggFunction(org.apache.calcite.sql.SqlAggFunction) AggregateCall(org.apache.calcite.rel.core.AggregateCall) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) RelNode(org.apache.calcite.rel.RelNode) Mappings(org.apache.calcite.util.mapping.Mappings) HiveAggregate(org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate) Aggregate(org.apache.calcite.rel.core.Aggregate) RexNode(org.apache.calcite.rex.RexNode)

Aggregations

AggregateCall (org.apache.calcite.rel.core.AggregateCall)158 ArrayList (java.util.ArrayList)82 RexNode (org.apache.calcite.rex.RexNode)78 ImmutableBitSet (org.apache.calcite.util.ImmutableBitSet)57 RelNode (org.apache.calcite.rel.RelNode)54 RexBuilder (org.apache.calcite.rex.RexBuilder)52 RelDataType (org.apache.calcite.rel.type.RelDataType)42 Aggregate (org.apache.calcite.rel.core.Aggregate)37 RelDataTypeField (org.apache.calcite.rel.type.RelDataTypeField)36 RexInputRef (org.apache.calcite.rex.RexInputRef)33 RelBuilder (org.apache.calcite.tools.RelBuilder)29 HashMap (java.util.HashMap)28 SqlAggFunction (org.apache.calcite.sql.SqlAggFunction)28 List (java.util.List)27 RexLiteral (org.apache.calcite.rex.RexLiteral)23 Pair (org.apache.calcite.util.Pair)20 ImmutableList (com.google.common.collect.ImmutableList)19 Project (org.apache.calcite.rel.core.Project)17 RelDataTypeFactory (org.apache.calcite.rel.type.RelDataTypeFactory)17 LogicalAggregate (org.apache.calcite.rel.logical.LogicalAggregate)16