Search in sources :

Example 11 with AggregateCall

use of org.apache.calcite.rel.core.AggregateCall in project hive by apache.

the class SubstitutionVisitor method unifyAggregates.

public static MutableRel unifyAggregates(MutableAggregate query, MutableAggregate target) {
    if (query.getGroupType() != Aggregate.Group.SIMPLE || target.getGroupType() != Aggregate.Group.SIMPLE) {
        throw new AssertionError(Bug.CALCITE_461_FIXED);
    }
    MutableRel result;
    if (query.getGroupSet().equals(target.getGroupSet())) {
        // Same level of aggregation. Generate a project.
        final List<Integer> projects = Lists.newArrayList();
        final int groupCount = query.getGroupSet().cardinality();
        for (int i = 0; i < groupCount; i++) {
            projects.add(i);
        }
        for (AggregateCall aggregateCall : query.getAggCallList()) {
            int i = target.getAggCallList().indexOf(aggregateCall);
            if (i < 0) {
                return null;
            }
            projects.add(groupCount + i);
        }
        result = MutableRels.createProject(target, projects);
    } else {
        // Target is coarser level of aggregation. Generate an aggregate.
        final ImmutableBitSet.Builder groupSet = ImmutableBitSet.builder();
        final List<Integer> targetGroupList = target.getGroupSet().asList();
        for (int c : query.getGroupSet()) {
            int c2 = targetGroupList.indexOf(c);
            if (c2 < 0) {
                return null;
            }
            groupSet.set(c2);
        }
        final List<AggregateCall> aggregateCalls = Lists.newArrayList();
        for (AggregateCall aggregateCall : query.getAggCallList()) {
            if (aggregateCall.isDistinct()) {
                return null;
            }
            int i = target.getAggCallList().indexOf(aggregateCall);
            if (i < 0) {
                return null;
            }
            aggregateCalls.add(AggregateCall.create(getRollup(aggregateCall.getAggregation()), aggregateCall.isDistinct(), ImmutableList.of(target.groupSet.cardinality() + i), -1, aggregateCall.type, aggregateCall.name));
        }
        result = MutableAggregate.of(target, false, groupSet.build(), null, aggregateCalls);
    }
    return MutableRels.createCastRel(result, query.getRowType(), true);
}
Also used : AggregateCall(org.apache.calcite.rel.core.AggregateCall) ImmutableBitSet(org.apache.calcite.util.ImmutableBitSet)

Example 12 with AggregateCall

use of org.apache.calcite.rel.core.AggregateCall in project hive by apache.

the class HiveExpandDistinctAggregatesRule method rewriteAggCalls.

private static void rewriteAggCalls(List<AggregateCall> newAggCalls, List<Integer> argList, Map<Integer, Integer> sourceOf) {
    // "COUNT(DISTINCT e.sal)" becomes   "COUNT(distinct_e.sal)".
    for (int i = 0; i < newAggCalls.size(); i++) {
        final AggregateCall aggCall = newAggCalls.get(i);
        // COUNT(DISTINCT gender) or SUM(sal).
        if (!aggCall.isDistinct()) {
            continue;
        }
        if (!aggCall.getArgList().equals(argList)) {
            continue;
        }
        // Re-map arguments.
        final int argCount = aggCall.getArgList().size();
        final List<Integer> newArgs = new ArrayList<Integer>(argCount);
        for (int j = 0; j < argCount; j++) {
            final Integer arg = aggCall.getArgList().get(j);
            newArgs.add(sourceOf.get(arg));
        }
        final AggregateCall newAggCall = new AggregateCall(aggCall.getAggregation(), false, newArgs, aggCall.getType(), aggCall.getName());
        newAggCalls.set(i, newAggCall);
    }
}
Also used : AggregateCall(org.apache.calcite.rel.core.AggregateCall) ArrayList(java.util.ArrayList)

Example 13 with AggregateCall

use of org.apache.calcite.rel.core.AggregateCall in project drill by apache.

the class DrillReduceAggregatesRule method reduceStddev.

private RexNode reduceStddev(Aggregate oldAggRel, AggregateCall oldCall, boolean biased, boolean sqrt, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping, List<RexNode> inputExprs) {
    // stddev_pop(x) ==>
    //   power(
    //     (sum(x * x) - sum(x) * sum(x) / count(x))
    //     / count(x),
    //     .5)
    //
    // stddev_samp(x) ==>
    //   power(
    //     (sum(x * x) - sum(x) * sum(x) / count(x))
    //     / nullif(count(x) - 1, 0),
    //     .5)
    final PlannerSettings plannerSettings = (PlannerSettings) oldAggRel.getCluster().getPlanner().getContext();
    final boolean isInferenceEnabled = plannerSettings.isTypeInferenceEnabled();
    final int nGroups = oldAggRel.getGroupCount();
    RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory();
    final RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
    assert oldCall.getArgList().size() == 1 : oldCall.getArgList();
    final int argOrdinal = oldCall.getArgList().get(0);
    final RelDataType argType = getFieldType(oldAggRel.getInput(), argOrdinal);
    // final RexNode argRef = inputExprs.get(argOrdinal);
    RexNode argRef = rexBuilder.makeCall(CastHighOp, inputExprs.get(argOrdinal));
    inputExprs.set(argOrdinal, argRef);
    final RexNode argSquared = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, argRef, argRef);
    final int argSquaredOrdinal = lookupOrAdd(inputExprs, argSquared);
    final RelDataType sumType = typeFactory.createTypeWithNullability(argType, true);
    final AggregateCall sumArgSquaredAggCall = AggregateCall.create(new SqlSumAggFunction(sumType), oldCall.isDistinct(), ImmutableIntList.of(argSquaredOrdinal), -1, sumType, null);
    final RexNode sumArgSquared = rexBuilder.addAggCall(sumArgSquaredAggCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(argType));
    final AggregateCall sumArgAggCall = AggregateCall.create(new SqlSumAggFunction(sumType), oldCall.isDistinct(), ImmutableIntList.of(argOrdinal), -1, sumType, null);
    final RexNode sumArg = rexBuilder.addAggCall(sumArgAggCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(argType));
    final RexNode sumSquaredArg = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, sumArg, sumArg);
    final SqlCountAggFunction countAgg = (SqlCountAggFunction) SqlStdOperatorTable.COUNT;
    final RelDataType countType = countAgg.getReturnType(typeFactory);
    final AggregateCall countArgAggCall = AggregateCall.create(countAgg, oldCall.isDistinct(), oldCall.getArgList(), -1, countType, null);
    final RexNode countArg = rexBuilder.addAggCall(countArgAggCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(argType));
    final RexNode avgSumSquaredArg = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, sumSquaredArg, countArg);
    final RexNode diff = rexBuilder.makeCall(SqlStdOperatorTable.MINUS, sumArgSquared, avgSumSquaredArg);
    final RexNode denominator;
    if (biased) {
        denominator = countArg;
    } else {
        final RexLiteral one = rexBuilder.makeExactLiteral(BigDecimal.ONE);
        final RexNode nul = rexBuilder.makeNullLiteral(countArg.getType().getSqlTypeName());
        final RexNode countMinusOne = rexBuilder.makeCall(SqlStdOperatorTable.MINUS, countArg, one);
        final RexNode countEqOne = rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, countArg, one);
        denominator = rexBuilder.makeCall(SqlStdOperatorTable.CASE, countEqOne, nul, countMinusOne);
    }
    final SqlOperator divide;
    if (isInferenceEnabled) {
        divide = new DrillSqlOperator("divide", 2, true, oldCall.getType(), false);
    } else {
        divide = SqlStdOperatorTable.DIVIDE;
    }
    final RexNode div = rexBuilder.makeCall(divide, diff, denominator);
    RexNode result = div;
    if (sqrt) {
        final RexNode half = rexBuilder.makeExactLiteral(new BigDecimal("0.5"));
        result = rexBuilder.makeCall(SqlStdOperatorTable.POWER, div, half);
    }
    if (isInferenceEnabled) {
        return result;
    } else {
        /*
      * Currently calcite's strategy to infer the return type of aggregate functions
      * is wrong because it uses the first known argument to determine output type. For
      * instance if we are performing stddev on an integer column then it interprets the
      * output type to be integer which is incorrect as it should be double. So based on
      * this if we add cast after rewriting the aggregate we add an additional cast which
      * would cause wrong results. So we simply add a cast to ANY.
      */
        return rexBuilder.makeCast(typeFactory.createSqlType(SqlTypeName.ANY), result);
    }
}
Also used : RexLiteral(org.apache.calcite.rex.RexLiteral) PlannerSettings(org.apache.drill.exec.planner.physical.PlannerSettings) DrillSqlOperator(org.apache.drill.exec.planner.sql.DrillSqlOperator) SqlOperator(org.apache.calcite.sql.SqlOperator) RelDataType(org.apache.calcite.rel.type.RelDataType) SqlCountAggFunction(org.apache.calcite.sql.fun.SqlCountAggFunction) BigDecimal(java.math.BigDecimal) DrillSqlOperator(org.apache.drill.exec.planner.sql.DrillSqlOperator) AggregateCall(org.apache.calcite.rel.core.AggregateCall) RelDataTypeFactory(org.apache.calcite.rel.type.RelDataTypeFactory) SqlSumAggFunction(org.apache.calcite.sql.fun.SqlSumAggFunction) RexBuilder(org.apache.calcite.rex.RexBuilder) RexNode(org.apache.calcite.rex.RexNode)

Example 14 with AggregateCall

use of org.apache.calcite.rel.core.AggregateCall in project drill by apache.

the class DrillReduceAggregatesRule method reduceSum.

private RexNode reduceSum(Aggregate oldAggRel, AggregateCall oldCall, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping) {
    final PlannerSettings plannerSettings = (PlannerSettings) oldAggRel.getCluster().getPlanner().getContext();
    final boolean isInferenceEnabled = plannerSettings.isTypeInferenceEnabled();
    final int nGroups = oldAggRel.getGroupCount();
    RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory();
    RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
    int arg = oldCall.getArgList().get(0);
    RelDataType argType = getFieldType(oldAggRel.getInput(), arg);
    final RelDataType sumType;
    final SqlAggFunction sumZeroAgg;
    if (isInferenceEnabled) {
        sumType = oldCall.getType();
        sumZeroAgg = new DrillCalciteSqlAggFunctionWrapper(new SqlSumEmptyIsZeroAggFunction(), sumType);
    } else {
        sumType = typeFactory.createTypeWithNullability(argType, argType.isNullable());
        sumZeroAgg = new SqlSumEmptyIsZeroAggFunction();
    }
    AggregateCall sumZeroCall = AggregateCall.create(sumZeroAgg, oldCall.isDistinct(), oldCall.getArgList(), -1, sumType, null);
    final SqlCountAggFunction countAgg = (SqlCountAggFunction) SqlStdOperatorTable.COUNT;
    final RelDataType countType = countAgg.getReturnType(typeFactory);
    AggregateCall countCall = AggregateCall.create(countAgg, oldCall.isDistinct(), oldCall.getArgList(), -1, countType, null);
    // NOTE:  these references are with respect to the output
    // of newAggRel
    RexNode sumZeroRef = rexBuilder.addAggCall(sumZeroCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(argType));
    if (!oldCall.getType().isNullable()) {
        // null). Therefore we translate to SUM0(x).
        return sumZeroRef;
    }
    RexNode countRef = rexBuilder.addAggCall(countCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(argType));
    return rexBuilder.makeCall(SqlStdOperatorTable.CASE, rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, countRef, rexBuilder.makeExactLiteral(BigDecimal.ZERO)), rexBuilder.constantNull(), sumZeroRef);
}
Also used : SqlSumEmptyIsZeroAggFunction(org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction) AggregateCall(org.apache.calcite.rel.core.AggregateCall) PlannerSettings(org.apache.drill.exec.planner.physical.PlannerSettings) RelDataTypeFactory(org.apache.calcite.rel.type.RelDataTypeFactory) RexBuilder(org.apache.calcite.rex.RexBuilder) RelDataType(org.apache.calcite.rel.type.RelDataType) SqlAggFunction(org.apache.calcite.sql.SqlAggFunction) SqlCountAggFunction(org.apache.calcite.sql.fun.SqlCountAggFunction) DrillCalciteSqlAggFunctionWrapper(org.apache.drill.exec.planner.sql.DrillCalciteSqlAggFunctionWrapper) RexNode(org.apache.calcite.rex.RexNode)

Example 15 with AggregateCall

use of org.apache.calcite.rel.core.AggregateCall in project drill by apache.

the class AggPrelBase method createKeysAndExprs.

protected void createKeysAndExprs() {
    final List<String> childFields = getInput().getRowType().getFieldNames();
    final List<String> fields = getRowType().getFieldNames();
    for (int group : BitSets.toIter(groupSet)) {
        FieldReference fr = FieldReference.getWithQuotedRef(childFields.get(group));
        keys.add(new NamedExpression(fr, fr));
    }
    for (Ord<AggregateCall> aggCall : Ord.zip(aggCalls)) {
        int aggExprOrdinal = groupSet.cardinality() + aggCall.i;
        FieldReference ref = FieldReference.getWithQuotedRef(fields.get(aggExprOrdinal));
        LogicalExpression expr = toDrill(aggCall.e, childFields);
        NamedExpression ne = new NamedExpression(expr, ref);
        aggExprs.add(ne);
        if (getOperatorPhase() == OperatorPhase.PHASE_1of2) {
            if (aggCall.e.getAggregation().getName().equals("COUNT")) {
                // If we are doing a COUNT aggregate in Phase1of2, then in Phase2of2 we should SUM the COUNTs,
                SqlAggFunction sumAggFun = new SqlSumCountAggFunction(aggCall.e.getType());
                AggregateCall newAggCall = new AggregateCall(sumAggFun, aggCall.e.isDistinct(), Collections.singletonList(aggExprOrdinal), aggCall.e.getType(), aggCall.e.getName());
                phase2AggCallList.add(newAggCall);
            } else {
                AggregateCall newAggCall = new AggregateCall(aggCall.e.getAggregation(), aggCall.e.isDistinct(), Collections.singletonList(aggExprOrdinal), aggCall.e.getType(), aggCall.e.getName());
                phase2AggCallList.add(newAggCall);
            }
        }
    }
}
Also used : AggregateCall(org.apache.calcite.rel.core.AggregateCall) LogicalExpression(org.apache.drill.common.expression.LogicalExpression) FieldReference(org.apache.drill.common.expression.FieldReference) NamedExpression(org.apache.drill.common.logical.data.NamedExpression) SqlAggFunction(org.apache.calcite.sql.SqlAggFunction)

Aggregations

AggregateCall (org.apache.calcite.rel.core.AggregateCall)40 RexNode (org.apache.calcite.rex.RexNode)21 ArrayList (java.util.ArrayList)20 ImmutableBitSet (org.apache.calcite.util.ImmutableBitSet)17 RelNode (org.apache.calcite.rel.RelNode)16 RelDataType (org.apache.calcite.rel.type.RelDataType)13 RexBuilder (org.apache.calcite.rex.RexBuilder)13 RelDataTypeField (org.apache.calcite.rel.type.RelDataTypeField)12 HashMap (java.util.HashMap)11 RexInputRef (org.apache.calcite.rex.RexInputRef)11 HiveAggregate (org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate)10 Aggregate (org.apache.calcite.rel.core.Aggregate)9 ImmutableList (com.google.common.collect.ImmutableList)8 SqlAggFunction (org.apache.calcite.sql.SqlAggFunction)8 RelDataTypeFactory (org.apache.calcite.rel.type.RelDataTypeFactory)6 Pair (org.apache.calcite.util.Pair)6 BigDecimal (java.math.BigDecimal)5 List (java.util.List)5 RexLiteral (org.apache.calcite.rex.RexLiteral)5 SqlCountAggFunction (org.apache.calcite.sql.fun.SqlCountAggFunction)5