Search in sources :

Example 76 with AggregateCall

use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.AggregateCall in project hive by apache.

the class HiveAggregateSplitRule method onMatch.

@Override
public void onMatch(RelOptRuleCall call) {
    final Aggregate aggregate = call.rel(0);
    final RelBuilder relBuilder = call.builder();
    // If any aggregate is distinct, bail out
    // If any aggregate is the grouping id, bail out
    // If any aggregate call has a filter, bail out
    // If any aggregate functions do not support splitting, bail out
    final ImmutableBitSet bottomAggregateGroupSet = aggregate.getGroupSet();
    final List<AggregateCall> topAggregateCalls = new ArrayList<>();
    for (int i = 0; i < aggregate.getAggCallList().size(); i++) {
        AggregateCall aggregateCall = aggregate.getAggCallList().get(i);
        if (aggregateCall.isDistinct()) {
            return;
        }
        if (aggregateCall.getAggregation().equals(HiveGroupingID.INSTANCE)) {
            return;
        }
        if (aggregateCall.filterArg >= 0) {
            return;
        }
        SqlAggFunction aggFunction = HiveRelBuilder.getRollup(aggregateCall.getAggregation());
        if (aggFunction == null) {
            return;
        }
        topAggregateCalls.add(AggregateCall.create(aggFunction, aggregateCall.isDistinct(), aggregateCall.isApproximate(), ImmutableList.of(bottomAggregateGroupSet.cardinality() + i), -1, aggregateCall.collation, aggregateCall.type, aggregateCall.name));
    }
    if (aggregate.getCluster().getMetadataQuery().areColumnsUnique(aggregate.getInput(), bottomAggregateGroupSet)) {
        // Nothing to do, probably already pushed
        return;
    }
    final ImmutableBitSet topAggregateGroupSet = ImmutableBitSet.range(0, bottomAggregateGroupSet.cardinality());
    final Map<Integer, Integer> map = new HashMap<>();
    bottomAggregateGroupSet.forEach(k -> map.put(k, map.size()));
    ImmutableList<ImmutableBitSet> topAggregateGroupSets = ImmutableBitSet.ORDERING.immutableSortedCopy(ImmutableBitSet.permute(aggregate.groupSets, map));
    relBuilder.push(aggregate.getInput()).aggregate(relBuilder.groupKey(bottomAggregateGroupSet), aggregate.getAggCallList()).aggregate(relBuilder.groupKey(topAggregateGroupSet, topAggregateGroupSets), topAggregateCalls);
    call.transformTo(relBuilder.build());
}
Also used : AggregateCall(org.apache.calcite.rel.core.AggregateCall) RelBuilder(org.apache.calcite.tools.RelBuilder) HiveRelBuilder(org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelBuilder) ImmutableBitSet(org.apache.calcite.util.ImmutableBitSet) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) SqlAggFunction(org.apache.calcite.sql.SqlAggFunction) Aggregate(org.apache.calcite.rel.core.Aggregate) HiveAggregate(org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate)

Example 77 with AggregateCall

use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.AggregateCall in project hive by apache.

the class HiveExpandDistinctAggregatesRule method rewriteAggCalls.

private static void rewriteAggCalls(List<AggregateCall> newAggCalls, List<Integer> argList, Map<Integer, Integer> sourceOf) {
    // "COUNT(DISTINCT e.sal)" becomes   "COUNT(distinct_e.sal)".
    for (int i = 0; i < newAggCalls.size(); i++) {
        final AggregateCall aggCall = newAggCalls.get(i);
        // COUNT(DISTINCT gender) or SUM(sal).
        if (!aggCall.isDistinct()) {
            continue;
        }
        if (!aggCall.getArgList().equals(argList)) {
            continue;
        }
        // Re-map arguments.
        final int argCount = aggCall.getArgList().size();
        final List<Integer> newArgs = new ArrayList<Integer>(argCount);
        for (int j = 0; j < argCount; j++) {
            final Integer arg = aggCall.getArgList().get(j);
            newArgs.add(sourceOf.get(arg));
        }
        final AggregateCall newAggCall = new AggregateCall(aggCall.getAggregation(), false, newArgs, aggCall.getType(), aggCall.getName());
        newAggCalls.set(i, newAggCall);
    }
}
Also used : AggregateCall(org.apache.calcite.rel.core.AggregateCall) ArrayList(java.util.ArrayList)

Example 78 with AggregateCall

use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.AggregateCall in project hive by apache.

the class HiveExpandDistinctAggregatesRule method createGroupingSets.

/**
 * @param aggregate: the original aggregate
 * @param argList: the original argList in aggregate
 * @param cleanArgList: the new argList without duplicates
 * @param map: the mapping from the original argList to the new argList
 * @param groupSet: new group set
 * @return
 */
private Aggregate createGroupingSets(Aggregate aggregate, List<List<Integer>> argList, List<List<Integer>> cleanArgList, Map<Integer, Integer> map, ImmutableBitSet groupSet) {
    final List<ImmutableBitSet> origGroupSets = new ArrayList<>();
    for (int i = 0; i < argList.size(); i++) {
        List<Integer> list = argList.get(i);
        ImmutableBitSet bitSet = aggregate.getGroupSet().union(ImmutableBitSet.of(list));
        int prev = origGroupSets.indexOf(bitSet);
        if (prev == -1) {
            origGroupSets.add(bitSet);
            cleanArgList.add(list);
        } else {
            map.put(i, prev);
        }
    }
    // Calcite expects the grouping sets sorted and without duplicates
    origGroupSets.sort(ImmutableBitSet.COMPARATOR);
    List<AggregateCall> aggregateCalls = new ArrayList<AggregateCall>();
    // Create GroupingID column
    AggregateCall aggCall = AggregateCall.create(HiveGroupingID.INSTANCE, false, new ImmutableList.Builder<Integer>().build(), -1, this.cluster.getTypeFactory().createSqlType(SqlTypeName.BIGINT), HiveGroupingID.INSTANCE.getName());
    aggregateCalls.add(aggCall);
    return new HiveAggregate(cluster, cluster.traitSetOf(HiveRelNode.CONVENTION), aggregate.getInput(), groupSet, origGroupSets, aggregateCalls);
}
Also used : AggregateCall(org.apache.calcite.rel.core.AggregateCall) HiveAggregate(org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate) ImmutableBitSet(org.apache.calcite.util.ImmutableBitSet) RexBuilder(org.apache.calcite.rex.RexBuilder) ArrayList(java.util.ArrayList)

Example 79 with AggregateCall

use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.AggregateCall in project hive by apache.

the class HiveExpandDistinctAggregatesRule method convertMonopole.

/**
 * Converts an aggregate relational expression that contains just one
 * distinct aggregate function (or perhaps several over the same arguments)
 * and no non-distinct aggregate functions.
 */
private RelNode convertMonopole(Aggregate aggregate, List<Integer> argList) {
    // For example,
    // SELECT deptno, COUNT(DISTINCT sal), SUM(DISTINCT sal)
    // FROM emp
    // GROUP BY deptno
    // 
    // becomes
    // 
    // SELECT deptno, COUNT(distinct_sal), SUM(distinct_sal)
    // FROM (
    // SELECT DISTINCT deptno, sal AS distinct_sal
    // FROM EMP GROUP BY deptno)
    // GROUP BY deptno
    // Project the columns of the GROUP BY plus the arguments
    // to the agg function.
    Map<Integer, Integer> sourceOf = new HashMap<Integer, Integer>();
    final Aggregate distinct = createSelectDistinct(aggregate, argList, sourceOf);
    // Create an aggregate on top, with the new aggregate list.
    final List<AggregateCall> newAggCalls = Lists.newArrayList(aggregate.getAggCallList());
    rewriteAggCalls(newAggCalls, argList, sourceOf);
    final int cardinality = aggregate.getGroupSet().cardinality();
    return aggregate.copy(aggregate.getTraitSet(), distinct, aggregate.indicator, ImmutableBitSet.range(cardinality), null, newAggCalls);
}
Also used : AggregateCall(org.apache.calcite.rel.core.AggregateCall) HashMap(java.util.HashMap) HiveAggregate(org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate) Aggregate(org.apache.calcite.rel.core.Aggregate)

Example 80 with AggregateCall

use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.AggregateCall in project hive by apache.

the class HiveAggregateReduceFunctionsRule method reduceStddev.

private RexNode reduceStddev(Aggregate oldAggRel, AggregateCall oldCall, boolean biased, boolean sqrt, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping, List<RexNode> inputExprs) {
    // stddev_pop(x) ==>
    // power(
    // (sum(x * x) - sum(x) * sum(x) / count(x))
    // / count(x),
    // .5)
    // 
    // stddev_samp(x) ==>
    // power(
    // (sum(x * x) - sum(x) * sum(x) / count(x))
    // / nullif(count(x) - 1, 0),
    // .5)
    final int nGroups = oldAggRel.getGroupCount();
    final RelOptCluster cluster = oldAggRel.getCluster();
    final RexBuilder rexBuilder = cluster.getRexBuilder();
    final RelDataTypeFactory typeFactory = cluster.getTypeFactory();
    assert oldCall.getArgList().size() == 1 : oldCall.getArgList();
    final int argOrdinal = oldCall.getArgList().get(0);
    final RelDataType argOrdinalType = getFieldType(oldAggRel.getInput(), argOrdinal);
    final RelDataType oldCallType = typeFactory.createTypeWithNullability(oldCall.getType(), true);
    final RexNode argRef = rexBuilder.ensureType(oldCallType, inputExprs.get(argOrdinal), false);
    final int argRefOrdinal = lookupOrAdd(inputExprs, argRef);
    final RelDataType sumReturnType = getSumReturnType(rexBuilder.getTypeFactory(), argRef.getType());
    final RexNode argSquared = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, argRef, argRef);
    final int argSquaredOrdinal = lookupOrAdd(inputExprs, argSquared);
    final RelDataType sumSquaredReturnType = getSumReturnType(rexBuilder.getTypeFactory(), argSquared.getType());
    final AggregateCall sumArgSquaredAggCall = createAggregateCallWithBinding(typeFactory, new HiveSqlSumAggFunction(oldCall.isDistinct(), ReturnTypes.explicit(sumSquaredReturnType), InferTypes.explicit(Collections.singletonList(argSquared.getType())), // SqlStdOperatorTable.SUM,
    oldCall.getAggregation().getOperandTypeChecker()), argSquared.getType(), oldAggRel, oldCall, argSquaredOrdinal);
    final RexNode sumArgSquared = rexBuilder.addAggCall(sumArgSquaredAggCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(sumArgSquaredAggCall.getType()));
    final AggregateCall sumArgAggCall = AggregateCall.create(new HiveSqlSumAggFunction(oldCall.isDistinct(), ReturnTypes.explicit(sumReturnType), InferTypes.explicit(Collections.singletonList(argOrdinalType)), // SqlStdOperatorTable.SUM,
    oldCall.getAggregation().getOperandTypeChecker()), oldCall.isDistinct(), oldCall.isApproximate(), ImmutableIntList.of(argRefOrdinal), oldCall.filterArg, oldAggRel.getGroupCount(), oldAggRel.getInput(), null, null);
    final RexNode sumArg = rexBuilder.addAggCall(sumArgAggCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(sumArgAggCall.getType()));
    final RexNode sumArgCast = rexBuilder.ensureType(oldCallType, sumArg, true);
    final RexNode sumSquaredArg = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, sumArgCast, sumArgCast);
    RelDataType countRetType = typeFactory.createTypeWithNullability(typeFactory.createSqlType(SqlTypeName.BIGINT), true);
    final AggregateCall countArgAggCall = AggregateCall.create(new HiveSqlCountAggFunction(oldCall.isDistinct(), ReturnTypes.explicit(countRetType), oldCall.getAggregation().getOperandTypeInference(), // SqlStdOperatorTable.COUNT,
    oldCall.getAggregation().getOperandTypeChecker()), oldCall.isDistinct(), oldCall.isApproximate(), oldCall.getArgList(), oldCall.filterArg, oldAggRel.getGroupCount(), oldAggRel.getInput(), countRetType, null);
    final RexNode countArg = rexBuilder.addAggCall(countArgAggCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(argOrdinalType));
    final RexNode avgSumSquaredArg = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, sumSquaredArg, countArg);
    final RexNode diff = rexBuilder.makeCall(SqlStdOperatorTable.MINUS, sumArgSquared, avgSumSquaredArg);
    final RexNode denominator;
    if (biased) {
        denominator = countArg;
    } else {
        final RexLiteral one = rexBuilder.makeExactLiteral(BigDecimal.ONE);
        final RexNode nul = rexBuilder.makeCast(countArg.getType(), rexBuilder.constantNull());
        final RexNode countMinusOne = rexBuilder.makeCall(SqlStdOperatorTable.MINUS, countArg, one);
        final RexNode countEqOne = rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, countArg, one);
        denominator = rexBuilder.makeCall(SqlStdOperatorTable.CASE, countEqOne, nul, countMinusOne);
    }
    final RexNode div = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, diff, denominator);
    RexNode result = div;
    if (sqrt) {
        final RexNode half = rexBuilder.makeExactLiteral(new BigDecimal("0.5"));
        result = rexBuilder.makeCall(SqlStdOperatorTable.POWER, div, half);
    }
    return rexBuilder.makeCast(oldCall.getType(), result);
}
Also used : RelOptCluster(org.apache.calcite.plan.RelOptCluster) AggregateCall(org.apache.calcite.rel.core.AggregateCall) RexLiteral(org.apache.calcite.rex.RexLiteral) HiveSqlCountAggFunction(org.apache.hadoop.hive.ql.optimizer.calcite.functions.HiveSqlCountAggFunction) RelDataTypeFactory(org.apache.calcite.rel.type.RelDataTypeFactory) RexBuilder(org.apache.calcite.rex.RexBuilder) RelDataType(org.apache.calcite.rel.type.RelDataType) HiveSqlSumAggFunction(org.apache.hadoop.hive.ql.optimizer.calcite.functions.HiveSqlSumAggFunction) BigDecimal(java.math.BigDecimal) RexNode(org.apache.calcite.rex.RexNode)

Aggregations

AggregateCall (org.apache.calcite.rel.core.AggregateCall)158 ArrayList (java.util.ArrayList)82 RexNode (org.apache.calcite.rex.RexNode)78 ImmutableBitSet (org.apache.calcite.util.ImmutableBitSet)57 RelNode (org.apache.calcite.rel.RelNode)54 RexBuilder (org.apache.calcite.rex.RexBuilder)52 RelDataType (org.apache.calcite.rel.type.RelDataType)42 Aggregate (org.apache.calcite.rel.core.Aggregate)37 RelDataTypeField (org.apache.calcite.rel.type.RelDataTypeField)36 RexInputRef (org.apache.calcite.rex.RexInputRef)33 RelBuilder (org.apache.calcite.tools.RelBuilder)29 HashMap (java.util.HashMap)28 SqlAggFunction (org.apache.calcite.sql.SqlAggFunction)28 List (java.util.List)27 RexLiteral (org.apache.calcite.rex.RexLiteral)23 Pair (org.apache.calcite.util.Pair)20 ImmutableList (com.google.common.collect.ImmutableList)19 Project (org.apache.calcite.rel.core.Project)17 RelDataTypeFactory (org.apache.calcite.rel.type.RelDataTypeFactory)17 LogicalAggregate (org.apache.calcite.rel.logical.LogicalAggregate)16