Search in sources :

Example 1 with HiveSqlCountAggFunction

use of org.apache.hadoop.hive.ql.optimizer.calcite.functions.HiveSqlCountAggFunction in project hive by apache.

the class SqlFunctionConverter method buildAST.

// TODO: 1) handle Agg Func Name translation 2) is it correct to add func
// args as child of func?
public static ASTNode buildAST(SqlOperator op, List<ASTNode> children) {
    HiveToken hToken = calciteToHiveToken.get(op);
    ASTNode node;
    if (hToken != null) {
        switch(op.kind) {
            case IN:
            case BETWEEN:
            case ROW:
            case IS_NOT_NULL:
            case IS_NULL:
            case CASE:
            case EXTRACT:
            case FLOOR:
            case CEIL:
            case LIKE:
            case OTHER_FUNCTION:
                node = (ASTNode) ParseDriver.adaptor.create(HiveParser.TOK_FUNCTION, "TOK_FUNCTION");
                node.addChild((ASTNode) ParseDriver.adaptor.create(hToken.type, hToken.text));
                break;
            default:
                node = (ASTNode) ParseDriver.adaptor.create(hToken.type, hToken.text);
        }
    } else {
        node = (ASTNode) ParseDriver.adaptor.create(HiveParser.TOK_FUNCTION, "TOK_FUNCTION");
        if (op.kind != SqlKind.CAST) {
            if (op.kind == SqlKind.MINUS_PREFIX) {
                node = (ASTNode) ParseDriver.adaptor.create(HiveParser.MINUS, "MINUS");
            } else if (op.kind == SqlKind.PLUS_PREFIX) {
                node = (ASTNode) ParseDriver.adaptor.create(HiveParser.PLUS, "PLUS");
            } else {
                // Handle COUNT/SUM/AVG function for the case of COUNT(*) and COUNT(DISTINCT)
                if (op instanceof HiveSqlCountAggFunction || op instanceof HiveSqlSumAggFunction || (op instanceof CalciteUDAF && op.getName().equalsIgnoreCase(SqlStdOperatorTable.AVG.getName()))) {
                    if (children.size() == 0) {
                        node = (ASTNode) ParseDriver.adaptor.create(HiveParser.TOK_FUNCTIONSTAR, "TOK_FUNCTIONSTAR");
                    } else {
                        CanAggregateDistinct distinctFunction = (CanAggregateDistinct) op;
                        if (distinctFunction.isDistinct()) {
                            node = (ASTNode) ParseDriver.adaptor.create(HiveParser.TOK_FUNCTIONDI, "TOK_FUNCTIONDI");
                        }
                    }
                }
                node.addChild((ASTNode) ParseDriver.adaptor.create(HiveParser.Identifier, op.getName()));
            }
        }
    }
    for (ASTNode c : children) {
        ParseDriver.adaptor.addChild(node, c);
    }
    return node;
}
Also used : HiveSqlCountAggFunction(org.apache.hadoop.hive.ql.optimizer.calcite.functions.HiveSqlCountAggFunction) CanAggregateDistinct(org.apache.hadoop.hive.ql.optimizer.calcite.functions.CanAggregateDistinct) ASTNode(org.apache.hadoop.hive.ql.parse.ASTNode) HiveSqlSumAggFunction(org.apache.hadoop.hive.ql.optimizer.calcite.functions.HiveSqlSumAggFunction)

Example 2 with HiveSqlCountAggFunction

use of org.apache.hadoop.hive.ql.optimizer.calcite.functions.HiveSqlCountAggFunction in project hive by apache.

the class HiveAggregateReduceFunctionsRule method reduceStddev.

private RexNode reduceStddev(Aggregate oldAggRel, AggregateCall oldCall, boolean biased, boolean sqrt, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping, List<RexNode> inputExprs) {
    // stddev_pop(x) ==>
    // power(
    // (sum(x * x) - sum(x) * sum(x) / count(x))
    // / count(x),
    // .5)
    // 
    // stddev_samp(x) ==>
    // power(
    // (sum(x * x) - sum(x) * sum(x) / count(x))
    // / nullif(count(x) - 1, 0),
    // .5)
    final int nGroups = oldAggRel.getGroupCount();
    final RelOptCluster cluster = oldAggRel.getCluster();
    final RexBuilder rexBuilder = cluster.getRexBuilder();
    final RelDataTypeFactory typeFactory = cluster.getTypeFactory();
    assert oldCall.getArgList().size() == 1 : oldCall.getArgList();
    final int argOrdinal = oldCall.getArgList().get(0);
    final RelDataType argOrdinalType = getFieldType(oldAggRel.getInput(), argOrdinal);
    final RelDataType oldCallType = typeFactory.createTypeWithNullability(oldCall.getType(), true);
    final RexNode argRef = rexBuilder.ensureType(oldCallType, inputExprs.get(argOrdinal), false);
    final int argRefOrdinal = lookupOrAdd(inputExprs, argRef);
    final RexNode argSquared = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, argRef, argRef);
    final int argSquaredOrdinal = lookupOrAdd(inputExprs, argSquared);
    final AggregateCall sumArgSquaredAggCall = createAggregateCallWithBinding(typeFactory, new HiveSqlSumAggFunction(oldCall.isDistinct(), oldCall.getAggregation().getReturnTypeInference(), oldCall.getAggregation().getOperandTypeInference(), // SqlStdOperatorTable.SUM,
    oldCall.getAggregation().getOperandTypeChecker()), argSquared.getType(), oldAggRel, oldCall, argSquaredOrdinal);
    final RexNode sumArgSquared = rexBuilder.addAggCall(sumArgSquaredAggCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(sumArgSquaredAggCall.getType()));
    final AggregateCall sumArgAggCall = AggregateCall.create(new HiveSqlSumAggFunction(oldCall.isDistinct(), oldCall.getAggregation().getReturnTypeInference(), oldCall.getAggregation().getOperandTypeInference(), // SqlStdOperatorTable.SUM,
    oldCall.getAggregation().getOperandTypeChecker()), oldCall.isDistinct(), oldCall.isApproximate(), ImmutableIntList.of(argRefOrdinal), oldCall.filterArg, oldAggRel.getGroupCount(), oldAggRel.getInput(), null, null);
    final RexNode sumArg = rexBuilder.addAggCall(sumArgAggCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(sumArgAggCall.getType()));
    final RexNode sumArgCast = rexBuilder.ensureType(oldCallType, sumArg, true);
    final RexNode sumSquaredArg = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, sumArgCast, sumArgCast);
    RelDataType countRetType = typeFactory.createTypeWithNullability(typeFactory.createSqlType(SqlTypeName.BIGINT), true);
    final AggregateCall countArgAggCall = AggregateCall.create(new HiveSqlCountAggFunction(oldCall.isDistinct(), ReturnTypes.explicit(countRetType), oldCall.getAggregation().getOperandTypeInference(), // SqlStdOperatorTable.COUNT,
    oldCall.getAggregation().getOperandTypeChecker()), oldCall.isDistinct(), oldCall.isApproximate(), oldCall.getArgList(), oldCall.filterArg, oldAggRel.getGroupCount(), oldAggRel.getInput(), countRetType, null);
    final RexNode countArg = rexBuilder.addAggCall(countArgAggCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(argOrdinalType));
    final RexNode avgSumSquaredArg = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, sumSquaredArg, countArg);
    final RexNode diff = rexBuilder.makeCall(SqlStdOperatorTable.MINUS, sumArgSquared, avgSumSquaredArg);
    final RexNode denominator;
    if (biased) {
        denominator = countArg;
    } else {
        final RexLiteral one = rexBuilder.makeExactLiteral(BigDecimal.ONE);
        final RexNode nul = rexBuilder.makeCast(countArg.getType(), rexBuilder.constantNull());
        final RexNode countMinusOne = rexBuilder.makeCall(SqlStdOperatorTable.MINUS, countArg, one);
        final RexNode countEqOne = rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, countArg, one);
        denominator = rexBuilder.makeCall(SqlStdOperatorTable.CASE, countEqOne, nul, countMinusOne);
    }
    final RexNode div = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, diff, denominator);
    RexNode result = div;
    if (sqrt) {
        final RexNode half = rexBuilder.makeExactLiteral(new BigDecimal("0.5"));
        result = rexBuilder.makeCall(SqlStdOperatorTable.POWER, div, half);
    }
    return rexBuilder.makeCast(oldCall.getType(), result);
}
Also used : RelOptCluster(org.apache.calcite.plan.RelOptCluster) AggregateCall(org.apache.calcite.rel.core.AggregateCall) RexLiteral(org.apache.calcite.rex.RexLiteral) HiveSqlCountAggFunction(org.apache.hadoop.hive.ql.optimizer.calcite.functions.HiveSqlCountAggFunction) RelDataTypeFactory(org.apache.calcite.rel.type.RelDataTypeFactory) RexBuilder(org.apache.calcite.rex.RexBuilder) RelDataType(org.apache.calcite.rel.type.RelDataType) HiveSqlSumAggFunction(org.apache.hadoop.hive.ql.optimizer.calcite.functions.HiveSqlSumAggFunction) BigDecimal(java.math.BigDecimal) RexNode(org.apache.calcite.rex.RexNode)

Example 3 with HiveSqlCountAggFunction

use of org.apache.hadoop.hive.ql.optimizer.calcite.functions.HiveSqlCountAggFunction in project hive by apache.

the class HiveAggregateReduceFunctionsRule method reduceAvg.

private RexNode reduceAvg(Aggregate oldAggRel, AggregateCall oldCall, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping, List<RexNode> inputExprs) {
    final int nGroups = oldAggRel.getGroupCount();
    final RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
    final RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory();
    final int iAvgInput = oldCall.getArgList().get(0);
    RelDataType avgInputType = typeFactory.createTypeWithNullability(getFieldType(oldAggRel.getInput(), iAvgInput), true);
    final AggregateCall sumCall = AggregateCall.create(new HiveSqlSumAggFunction(oldCall.isDistinct(), oldCall.getAggregation().getReturnTypeInference(), oldCall.getAggregation().getOperandTypeInference(), // SqlStdOperatorTable.SUM,
    oldCall.getAggregation().getOperandTypeChecker()), oldCall.isDistinct(), oldCall.isApproximate(), oldCall.getArgList(), oldCall.filterArg, oldAggRel.getGroupCount(), oldAggRel.getInput(), null, null);
    RelDataType countRetType = typeFactory.createTypeWithNullability(typeFactory.createSqlType(SqlTypeName.BIGINT), true);
    final AggregateCall countCall = AggregateCall.create(new HiveSqlCountAggFunction(oldCall.isDistinct(), ReturnTypes.explicit(countRetType), oldCall.getAggregation().getOperandTypeInference(), // SqlStdOperatorTable.COUNT,
    oldCall.getAggregation().getOperandTypeChecker()), oldCall.isDistinct(), oldCall.isApproximate(), oldCall.getArgList(), oldCall.filterArg, oldAggRel.getGroupCount(), oldAggRel.getInput(), countRetType, null);
    // NOTE:  these references are with respect to the output
    // of newAggRel
    RexNode numeratorRef = rexBuilder.addAggCall(sumCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(avgInputType));
    final RexNode denominatorRef = rexBuilder.addAggCall(countCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(avgInputType));
    numeratorRef = rexBuilder.ensureType(oldCall.getType(), numeratorRef, true);
    final RexNode divideRef = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, numeratorRef, denominatorRef);
    return rexBuilder.makeCast(oldCall.getType(), divideRef);
}
Also used : AggregateCall(org.apache.calcite.rel.core.AggregateCall) HiveSqlCountAggFunction(org.apache.hadoop.hive.ql.optimizer.calcite.functions.HiveSqlCountAggFunction) RelDataTypeFactory(org.apache.calcite.rel.type.RelDataTypeFactory) RexBuilder(org.apache.calcite.rex.RexBuilder) RelDataType(org.apache.calcite.rel.type.RelDataType) HiveSqlSumAggFunction(org.apache.hadoop.hive.ql.optimizer.calcite.functions.HiveSqlSumAggFunction) RexNode(org.apache.calcite.rex.RexNode)

Aggregations

HiveSqlCountAggFunction (org.apache.hadoop.hive.ql.optimizer.calcite.functions.HiveSqlCountAggFunction)3 HiveSqlSumAggFunction (org.apache.hadoop.hive.ql.optimizer.calcite.functions.HiveSqlSumAggFunction)3 AggregateCall (org.apache.calcite.rel.core.AggregateCall)2 RelDataType (org.apache.calcite.rel.type.RelDataType)2 RelDataTypeFactory (org.apache.calcite.rel.type.RelDataTypeFactory)2 RexBuilder (org.apache.calcite.rex.RexBuilder)2 RexNode (org.apache.calcite.rex.RexNode)2 BigDecimal (java.math.BigDecimal)1 RelOptCluster (org.apache.calcite.plan.RelOptCluster)1 RexLiteral (org.apache.calcite.rex.RexLiteral)1 CanAggregateDistinct (org.apache.hadoop.hive.ql.optimizer.calcite.functions.CanAggregateDistinct)1 ASTNode (org.apache.hadoop.hive.ql.parse.ASTNode)1