use of org.apache.calcite.sql.SqlNumericLiteral in project drill by apache.
the class DrillAvgVarianceConvertlet method expandVariance.
private SqlNode expandVariance(final SqlNode arg, boolean biased, boolean sqrt) {
/* stddev_pop(x) ==>
* power(
* (sum(x * x) - sum(x) * sum(x) / count(x))
* / count(x),
* .5)
* stddev_samp(x) ==>
* power(
* (sum(x * x) - sum(x) * sum(x) / count(x))
* / (count(x) - 1),
* .5)
* var_pop(x) ==>
* (sum(x * x) - sum(x) * sum(x) / count(x))
* / count(x)
* var_samp(x) ==>
* (sum(x * x) - sum(x) * sum(x) / count(x))
* / (count(x) - 1)
*/
final SqlParserPos pos = SqlParserPos.ZERO;
// cast the argument to double
final SqlNode castHighArg = CastHighOp.createCall(pos, arg);
final SqlNode argSquared = SqlStdOperatorTable.MULTIPLY.createCall(pos, castHighArg, castHighArg);
final SqlNode sumArgSquared = SqlStdOperatorTable.SUM.createCall(pos, argSquared);
final SqlNode sum = SqlStdOperatorTable.SUM.createCall(pos, castHighArg);
final SqlNode sumSquared = SqlStdOperatorTable.MULTIPLY.createCall(pos, sum, sum);
final SqlNode count = SqlStdOperatorTable.COUNT.createCall(pos, castHighArg);
final SqlNode avgSumSquared = SqlStdOperatorTable.DIVIDE.createCall(pos, sumSquared, count);
final SqlNode diff = SqlStdOperatorTable.MINUS.createCall(pos, sumArgSquared, avgSumSquared);
final SqlNode denominator;
if (biased) {
denominator = count;
} else {
final SqlNumericLiteral one = SqlLiteral.createExactNumeric("1", pos);
denominator = SqlStdOperatorTable.MINUS.createCall(pos, count, one);
}
final SqlNode diffAsDouble = CastHighOp.createCall(pos, diff);
final SqlNode div = SqlStdOperatorTable.DIVIDE.createCall(pos, diffAsDouble, denominator);
SqlNode result = div;
if (sqrt) {
final SqlNumericLiteral half = SqlLiteral.createExactNumeric("0.5", pos);
result = SqlStdOperatorTable.POWER.createCall(pos, div, half);
}
return result;
}
use of org.apache.calcite.sql.SqlNumericLiteral in project calcite by apache.
the class SqlToRelConverter method convertToSingleValueSubq.
/**
* Converts the RelNode tree for a select statement to a select that
* produces a single value.
*
* @param query the query
* @param plan the original RelNode tree corresponding to the statement
* @return the converted RelNode tree
*/
public RelNode convertToSingleValueSubq(SqlNode query, RelNode plan) {
// Check whether query is guaranteed to produce a single value.
if (query instanceof SqlSelect) {
SqlSelect select = (SqlSelect) query;
SqlNodeList selectList = select.getSelectList();
SqlNodeList groupList = select.getGroup();
if ((selectList.size() == 1) && ((groupList == null) || (groupList.size() == 0))) {
SqlNode selectExpr = selectList.get(0);
if (selectExpr instanceof SqlCall) {
SqlCall selectExprCall = (SqlCall) selectExpr;
if (Util.isSingleValue(selectExprCall)) {
return plan;
}
}
// it is ensured to produce a single value
if (select.getFetch() != null && select.getFetch() instanceof SqlNumericLiteral) {
SqlNumericLiteral limitNum = (SqlNumericLiteral) select.getFetch();
if (((BigDecimal) limitNum.getValue()).intValue() < 2) {
return plan;
}
}
}
} else if (query instanceof SqlCall) {
// If the query is (values ...),
// it is necessary to look into the operands to determine
// whether SingleValueAgg is necessary
SqlCall exprCall = (SqlCall) query;
if (exprCall.getOperator() instanceof SqlValuesOperator && Util.isSingleValue(exprCall)) {
return plan;
}
}
// If not, project SingleValueAgg
return RelOptUtil.createSingleValueAggRel(cluster, plan);
}
use of org.apache.calcite.sql.SqlNumericLiteral in project calcite by apache.
the class SqlToRelConverter method convertMatchRecognize.
protected void convertMatchRecognize(Blackboard bb, SqlCall call) {
final SqlMatchRecognize matchRecognize = (SqlMatchRecognize) call;
final SqlValidatorNamespace ns = validator.getNamespace(matchRecognize);
final SqlValidatorScope scope = validator.getMatchRecognizeScope(matchRecognize);
final Blackboard matchBb = createBlackboard(scope, null, false);
final RelDataType rowType = ns.getRowType();
// convert inner query, could be a table name or a derived table
SqlNode expr = matchRecognize.getTableRef();
convertFrom(matchBb, expr);
final RelNode input = matchBb.root;
// PARTITION BY
final SqlNodeList partitionList = matchRecognize.getPartitionList();
final List<RexNode> partitionKeys = new ArrayList<>();
for (SqlNode partition : partitionList) {
RexNode e = matchBb.convertExpression(partition);
partitionKeys.add(e);
}
// ORDER BY
final SqlNodeList orderList = matchRecognize.getOrderList();
final List<RelFieldCollation> orderKeys = new ArrayList<>();
for (SqlNode order : orderList) {
final RelFieldCollation.Direction direction;
switch(order.getKind()) {
case DESCENDING:
direction = RelFieldCollation.Direction.DESCENDING;
order = ((SqlCall) order).operand(0);
break;
case NULLS_FIRST:
case NULLS_LAST:
throw new AssertionError();
default:
direction = RelFieldCollation.Direction.ASCENDING;
break;
}
final RelFieldCollation.NullDirection nullDirection = validator.getDefaultNullCollation().last(desc(direction)) ? RelFieldCollation.NullDirection.LAST : RelFieldCollation.NullDirection.FIRST;
RexNode e = matchBb.convertExpression(order);
orderKeys.add(new RelFieldCollation(((RexInputRef) e).getIndex(), direction, nullDirection));
}
final RelCollation orders = cluster.traitSet().canonize(RelCollations.of(orderKeys));
// convert pattern
final Set<String> patternVarsSet = new HashSet<>();
SqlNode pattern = matchRecognize.getPattern();
final SqlBasicVisitor<RexNode> patternVarVisitor = new SqlBasicVisitor<RexNode>() {
@Override
public RexNode visit(SqlCall call) {
List<SqlNode> operands = call.getOperandList();
List<RexNode> newOperands = Lists.newArrayList();
for (SqlNode node : operands) {
newOperands.add(node.accept(this));
}
return rexBuilder.makeCall(validator.getUnknownType(), call.getOperator(), newOperands);
}
@Override
public RexNode visit(SqlIdentifier id) {
assert id.isSimple();
patternVarsSet.add(id.getSimple());
return rexBuilder.makeLiteral(id.getSimple());
}
@Override
public RexNode visit(SqlLiteral literal) {
if (literal instanceof SqlNumericLiteral) {
return rexBuilder.makeExactLiteral(BigDecimal.valueOf(literal.intValue(true)));
} else {
return rexBuilder.makeLiteral(literal.booleanValue());
}
}
};
final RexNode patternNode = pattern.accept(patternVarVisitor);
SqlLiteral interval = matchRecognize.getInterval();
RexNode intervalNode = null;
if (interval != null) {
intervalNode = matchBb.convertLiteral(interval);
}
// convert subset
final SqlNodeList subsets = matchRecognize.getSubsetList();
final Map<String, TreeSet<String>> subsetMap = Maps.newHashMap();
for (SqlNode node : subsets) {
List<SqlNode> operands = ((SqlCall) node).getOperandList();
SqlIdentifier left = (SqlIdentifier) operands.get(0);
patternVarsSet.add(left.getSimple());
SqlNodeList rights = (SqlNodeList) operands.get(1);
final TreeSet<String> list = new TreeSet<String>();
for (SqlNode right : rights) {
assert right instanceof SqlIdentifier;
list.add(((SqlIdentifier) right).getSimple());
}
subsetMap.put(left.getSimple(), list);
}
SqlNode afterMatch = matchRecognize.getAfter();
if (afterMatch == null) {
afterMatch = SqlMatchRecognize.AfterOption.SKIP_TO_NEXT_ROW.symbol(SqlParserPos.ZERO);
}
final RexNode after;
if (afterMatch instanceof SqlCall) {
List<SqlNode> operands = ((SqlCall) afterMatch).getOperandList();
SqlOperator operator = ((SqlCall) afterMatch).getOperator();
assert operands.size() == 1;
SqlIdentifier id = (SqlIdentifier) operands.get(0);
assert patternVarsSet.contains(id.getSimple()) : id.getSimple() + " not defined in pattern";
RexNode rex = rexBuilder.makeLiteral(id.getSimple());
after = rexBuilder.makeCall(validator.getUnknownType(), operator, ImmutableList.of(rex));
} else {
after = matchBb.convertExpression(afterMatch);
}
matchBb.setPatternVarRef(true);
// convert measures
final ImmutableMap.Builder<String, RexNode> measureNodes = ImmutableMap.builder();
for (SqlNode measure : matchRecognize.getMeasureList()) {
List<SqlNode> operands = ((SqlCall) measure).getOperandList();
String alias = ((SqlIdentifier) operands.get(1)).getSimple();
RexNode rex = matchBb.convertExpression(operands.get(0));
measureNodes.put(alias, rex);
}
// convert definitions
final ImmutableMap.Builder<String, RexNode> definitionNodes = ImmutableMap.builder();
for (SqlNode def : matchRecognize.getPatternDefList()) {
List<SqlNode> operands = ((SqlCall) def).getOperandList();
String alias = ((SqlIdentifier) operands.get(1)).getSimple();
RexNode rex = matchBb.convertExpression(operands.get(0));
definitionNodes.put(alias, rex);
}
final SqlLiteral rowsPerMatch = matchRecognize.getRowsPerMatch();
final boolean allRows = rowsPerMatch != null && rowsPerMatch.getValue() == SqlMatchRecognize.RowsPerMatchOption.ALL_ROWS;
matchBb.setPatternVarRef(false);
final RelFactories.MatchFactory factory = RelFactories.DEFAULT_MATCH_FACTORY;
final RelNode rel = factory.createMatch(input, patternNode, rowType, matchRecognize.getStrictStart().booleanValue(), matchRecognize.getStrictEnd().booleanValue(), definitionNodes.build(), measureNodes.build(), after, subsetMap, allRows, partitionKeys, orders, intervalNode);
bb.setRoot(rel, false);
}
use of org.apache.calcite.sql.SqlNumericLiteral in project calcite by apache.
the class StandardConvertletTable method convertCast.
protected RexNode convertCast(SqlRexContext cx, final SqlCall call) {
RelDataTypeFactory typeFactory = cx.getTypeFactory();
assert call.getKind() == SqlKind.CAST;
final SqlNode left = call.operand(0);
final SqlNode right = call.operand(1);
if (right instanceof SqlIntervalQualifier) {
final SqlIntervalQualifier intervalQualifier = (SqlIntervalQualifier) right;
if (left instanceof SqlIntervalLiteral) {
RexLiteral sourceInterval = (RexLiteral) cx.convertExpression(left);
BigDecimal sourceValue = (BigDecimal) sourceInterval.getValue();
RexLiteral castedInterval = cx.getRexBuilder().makeIntervalLiteral(sourceValue, intervalQualifier);
return castToValidatedType(cx, call, castedInterval);
} else if (left instanceof SqlNumericLiteral) {
RexLiteral sourceInterval = (RexLiteral) cx.convertExpression(left);
BigDecimal sourceValue = (BigDecimal) sourceInterval.getValue();
final BigDecimal multiplier = intervalQualifier.getUnit().multiplier;
sourceValue = sourceValue.multiply(multiplier);
RexLiteral castedInterval = cx.getRexBuilder().makeIntervalLiteral(sourceValue, intervalQualifier);
return castToValidatedType(cx, call, castedInterval);
}
return castToValidatedType(cx, call, cx.convertExpression(left));
}
SqlDataTypeSpec dataType = (SqlDataTypeSpec) right;
if (SqlUtil.isNullLiteral(left, false)) {
return cx.convertExpression(left);
}
RexNode arg = cx.convertExpression(left);
RelDataType type = dataType.deriveType(typeFactory);
if (arg.getType().isNullable()) {
type = typeFactory.createTypeWithNullability(type, true);
}
if (null != dataType.getCollectionsTypeName()) {
final RelDataType argComponentType = arg.getType().getComponentType();
final RelDataType componentType = type.getComponentType();
if (argComponentType.isStruct() && !componentType.isStruct()) {
RelDataType tt = typeFactory.builder().add(argComponentType.getFieldList().get(0).getName(), componentType).build();
tt = typeFactory.createTypeWithNullability(tt, componentType.isNullable());
boolean isn = type.isNullable();
type = typeFactory.createMultisetType(tt, -1);
type = typeFactory.createTypeWithNullability(type, isn);
}
}
return cx.getRexBuilder().makeCast(type, arg);
}
use of org.apache.calcite.sql.SqlNumericLiteral in project drill by axbaretto.
the class DrillAvgVarianceConvertlet method expandVariance.
private SqlNode expandVariance(final SqlNode arg, boolean biased, boolean sqrt) {
/* stddev_pop(x) ==>
* power(
* (sum(x * x) - sum(x) * sum(x) / count(x))
* / count(x),
* .5)
* stddev_samp(x) ==>
* power(
* (sum(x * x) - sum(x) * sum(x) / count(x))
* / (count(x) - 1),
* .5)
* var_pop(x) ==>
* (sum(x * x) - sum(x) * sum(x) / count(x))
* / count(x)
* var_samp(x) ==>
* (sum(x * x) - sum(x) * sum(x) / count(x))
* / (count(x) - 1)
*/
final SqlParserPos pos = SqlParserPos.ZERO;
// cast the argument to double
final SqlNode castHighArg = CastHighOp.createCall(pos, arg);
final SqlNode argSquared = SqlStdOperatorTable.MULTIPLY.createCall(pos, castHighArg, castHighArg);
final SqlNode sumArgSquared = DrillCalciteSqlAggFunctionWrapper.SUM.createCall(pos, argSquared);
final SqlNode sum = DrillCalciteSqlAggFunctionWrapper.SUM.createCall(pos, castHighArg);
final SqlNode sumSquared = SqlStdOperatorTable.MULTIPLY.createCall(pos, sum, sum);
final SqlNode count = SqlStdOperatorTable.COUNT.createCall(pos, castHighArg);
final SqlNode avgSumSquared = SqlStdOperatorTable.DIVIDE.createCall(pos, sumSquared, count);
final SqlNode diff = SqlStdOperatorTable.MINUS.createCall(pos, sumArgSquared, avgSumSquared);
final SqlNode denominator;
if (biased) {
denominator = count;
} else {
final SqlNumericLiteral one = SqlLiteral.createExactNumeric("1", pos);
denominator = SqlStdOperatorTable.MINUS.createCall(pos, count, one);
}
final SqlNode diffAsDouble = CastHighOp.createCall(pos, diff);
final SqlNode div = SqlStdOperatorTable.DIVIDE.createCall(pos, diffAsDouble, denominator);
SqlNode result = div;
if (sqrt) {
final SqlNumericLiteral half = SqlLiteral.createExactNumeric("0.5", pos);
result = SqlStdOperatorTable.POWER.createCall(pos, div, half);
}
return result;
}
Aggregations