use of org.apache.hadoop.hive.ql.optimizer.calcite.functions.HiveSqlCountAggFunction in project hive by apache.
the class SqlFunctionConverter method buildAST.
// TODO: 1) handle Agg Func Name translation 2) is it correct to add func
// args as child of func?
public static ASTNode buildAST(SqlOperator op, List<ASTNode> children) {
HiveToken hToken = calciteToHiveToken.get(op);
ASTNode node;
if (hToken != null) {
switch(op.kind) {
case IN:
case BETWEEN:
case ROW:
case IS_NOT_NULL:
case IS_NULL:
case CASE:
case EXTRACT:
case FLOOR:
case CEIL:
case LIKE:
case OTHER_FUNCTION:
node = (ASTNode) ParseDriver.adaptor.create(HiveParser.TOK_FUNCTION, "TOK_FUNCTION");
node.addChild((ASTNode) ParseDriver.adaptor.create(hToken.type, hToken.text));
break;
default:
node = (ASTNode) ParseDriver.adaptor.create(hToken.type, hToken.text);
}
} else {
node = (ASTNode) ParseDriver.adaptor.create(HiveParser.TOK_FUNCTION, "TOK_FUNCTION");
if (op.kind != SqlKind.CAST) {
if (op.kind == SqlKind.MINUS_PREFIX) {
node = (ASTNode) ParseDriver.adaptor.create(HiveParser.MINUS, "MINUS");
} else if (op.kind == SqlKind.PLUS_PREFIX) {
node = (ASTNode) ParseDriver.adaptor.create(HiveParser.PLUS, "PLUS");
} else {
// Handle COUNT/SUM/AVG function for the case of COUNT(*) and COUNT(DISTINCT)
if (op instanceof HiveSqlCountAggFunction || op instanceof HiveSqlSumAggFunction || (op instanceof CalciteUDAF && op.getName().equalsIgnoreCase(SqlStdOperatorTable.AVG.getName()))) {
if (children.size() == 0) {
node = (ASTNode) ParseDriver.adaptor.create(HiveParser.TOK_FUNCTIONSTAR, "TOK_FUNCTIONSTAR");
} else {
CanAggregateDistinct distinctFunction = (CanAggregateDistinct) op;
if (distinctFunction.isDistinct()) {
node = (ASTNode) ParseDriver.adaptor.create(HiveParser.TOK_FUNCTIONDI, "TOK_FUNCTIONDI");
}
}
}
node.addChild((ASTNode) ParseDriver.adaptor.create(HiveParser.Identifier, op.getName()));
}
}
}
for (ASTNode c : children) {
ParseDriver.adaptor.addChild(node, c);
}
return node;
}
use of org.apache.hadoop.hive.ql.optimizer.calcite.functions.HiveSqlCountAggFunction in project hive by apache.
the class HiveAggregateReduceFunctionsRule method reduceStddev.
private RexNode reduceStddev(Aggregate oldAggRel, AggregateCall oldCall, boolean biased, boolean sqrt, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping, List<RexNode> inputExprs) {
// stddev_pop(x) ==>
// power(
// (sum(x * x) - sum(x) * sum(x) / count(x))
// / count(x),
// .5)
//
// stddev_samp(x) ==>
// power(
// (sum(x * x) - sum(x) * sum(x) / count(x))
// / nullif(count(x) - 1, 0),
// .5)
final int nGroups = oldAggRel.getGroupCount();
final RelOptCluster cluster = oldAggRel.getCluster();
final RexBuilder rexBuilder = cluster.getRexBuilder();
final RelDataTypeFactory typeFactory = cluster.getTypeFactory();
assert oldCall.getArgList().size() == 1 : oldCall.getArgList();
final int argOrdinal = oldCall.getArgList().get(0);
final RelDataType argOrdinalType = getFieldType(oldAggRel.getInput(), argOrdinal);
final RelDataType oldCallType = typeFactory.createTypeWithNullability(oldCall.getType(), true);
final RexNode argRef = rexBuilder.ensureType(oldCallType, inputExprs.get(argOrdinal), false);
final int argRefOrdinal = lookupOrAdd(inputExprs, argRef);
final RexNode argSquared = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, argRef, argRef);
final int argSquaredOrdinal = lookupOrAdd(inputExprs, argSquared);
final AggregateCall sumArgSquaredAggCall = createAggregateCallWithBinding(typeFactory, new HiveSqlSumAggFunction(oldCall.isDistinct(), oldCall.getAggregation().getReturnTypeInference(), oldCall.getAggregation().getOperandTypeInference(), // SqlStdOperatorTable.SUM,
oldCall.getAggregation().getOperandTypeChecker()), argSquared.getType(), oldAggRel, oldCall, argSquaredOrdinal);
final RexNode sumArgSquared = rexBuilder.addAggCall(sumArgSquaredAggCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(sumArgSquaredAggCall.getType()));
final AggregateCall sumArgAggCall = AggregateCall.create(new HiveSqlSumAggFunction(oldCall.isDistinct(), oldCall.getAggregation().getReturnTypeInference(), oldCall.getAggregation().getOperandTypeInference(), // SqlStdOperatorTable.SUM,
oldCall.getAggregation().getOperandTypeChecker()), oldCall.isDistinct(), oldCall.isApproximate(), ImmutableIntList.of(argRefOrdinal), oldCall.filterArg, oldAggRel.getGroupCount(), oldAggRel.getInput(), null, null);
final RexNode sumArg = rexBuilder.addAggCall(sumArgAggCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(sumArgAggCall.getType()));
final RexNode sumArgCast = rexBuilder.ensureType(oldCallType, sumArg, true);
final RexNode sumSquaredArg = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, sumArgCast, sumArgCast);
RelDataType countRetType = typeFactory.createTypeWithNullability(typeFactory.createSqlType(SqlTypeName.BIGINT), true);
final AggregateCall countArgAggCall = AggregateCall.create(new HiveSqlCountAggFunction(oldCall.isDistinct(), ReturnTypes.explicit(countRetType), oldCall.getAggregation().getOperandTypeInference(), // SqlStdOperatorTable.COUNT,
oldCall.getAggregation().getOperandTypeChecker()), oldCall.isDistinct(), oldCall.isApproximate(), oldCall.getArgList(), oldCall.filterArg, oldAggRel.getGroupCount(), oldAggRel.getInput(), countRetType, null);
final RexNode countArg = rexBuilder.addAggCall(countArgAggCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(argOrdinalType));
final RexNode avgSumSquaredArg = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, sumSquaredArg, countArg);
final RexNode diff = rexBuilder.makeCall(SqlStdOperatorTable.MINUS, sumArgSquared, avgSumSquaredArg);
final RexNode denominator;
if (biased) {
denominator = countArg;
} else {
final RexLiteral one = rexBuilder.makeExactLiteral(BigDecimal.ONE);
final RexNode nul = rexBuilder.makeCast(countArg.getType(), rexBuilder.constantNull());
final RexNode countMinusOne = rexBuilder.makeCall(SqlStdOperatorTable.MINUS, countArg, one);
final RexNode countEqOne = rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, countArg, one);
denominator = rexBuilder.makeCall(SqlStdOperatorTable.CASE, countEqOne, nul, countMinusOne);
}
final RexNode div = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, diff, denominator);
RexNode result = div;
if (sqrt) {
final RexNode half = rexBuilder.makeExactLiteral(new BigDecimal("0.5"));
result = rexBuilder.makeCall(SqlStdOperatorTable.POWER, div, half);
}
return rexBuilder.makeCast(oldCall.getType(), result);
}
use of org.apache.hadoop.hive.ql.optimizer.calcite.functions.HiveSqlCountAggFunction in project hive by apache.
the class HiveAggregateReduceFunctionsRule method reduceAvg.
private RexNode reduceAvg(Aggregate oldAggRel, AggregateCall oldCall, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping, List<RexNode> inputExprs) {
final int nGroups = oldAggRel.getGroupCount();
final RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
final RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory();
final int iAvgInput = oldCall.getArgList().get(0);
RelDataType avgInputType = typeFactory.createTypeWithNullability(getFieldType(oldAggRel.getInput(), iAvgInput), true);
final AggregateCall sumCall = AggregateCall.create(new HiveSqlSumAggFunction(oldCall.isDistinct(), oldCall.getAggregation().getReturnTypeInference(), oldCall.getAggregation().getOperandTypeInference(), // SqlStdOperatorTable.SUM,
oldCall.getAggregation().getOperandTypeChecker()), oldCall.isDistinct(), oldCall.isApproximate(), oldCall.getArgList(), oldCall.filterArg, oldAggRel.getGroupCount(), oldAggRel.getInput(), null, null);
RelDataType countRetType = typeFactory.createTypeWithNullability(typeFactory.createSqlType(SqlTypeName.BIGINT), true);
final AggregateCall countCall = AggregateCall.create(new HiveSqlCountAggFunction(oldCall.isDistinct(), ReturnTypes.explicit(countRetType), oldCall.getAggregation().getOperandTypeInference(), // SqlStdOperatorTable.COUNT,
oldCall.getAggregation().getOperandTypeChecker()), oldCall.isDistinct(), oldCall.isApproximate(), oldCall.getArgList(), oldCall.filterArg, oldAggRel.getGroupCount(), oldAggRel.getInput(), countRetType, null);
// NOTE: these references are with respect to the output
// of newAggRel
RexNode numeratorRef = rexBuilder.addAggCall(sumCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(avgInputType));
final RexNode denominatorRef = rexBuilder.addAggCall(countCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(avgInputType));
numeratorRef = rexBuilder.ensureType(oldCall.getType(), numeratorRef, true);
final RexNode divideRef = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, numeratorRef, denominatorRef);
return rexBuilder.makeCast(oldCall.getType(), divideRef);
}
Aggregations