use of org.apache.calcite.rex.RexInputRef in project flink by apache.
the class FlinkRelDecorrelator method projectJoinOutputWithNullability.
/**
* Pulls project above the join from its RHS input. Enforces nullability
* for join output.
*
* @param join Join
* @param project Original project as the right-hand input of the join
* @param nullIndicatorPos Position of null indicator
* @return the subtree with the new LogicalProject at the root
*/
private RelNode projectJoinOutputWithNullability(LogicalJoin join, LogicalProject project, int nullIndicatorPos) {
final RelDataTypeFactory typeFactory = join.getCluster().getTypeFactory();
final RelNode left = join.getLeft();
final JoinRelType joinType = join.getJoinType();
RexInputRef nullIndicator = new RexInputRef(nullIndicatorPos, typeFactory.createTypeWithNullability(join.getRowType().getFieldList().get(nullIndicatorPos).getType(), true));
// now create the new project
List<Pair<RexNode, String>> newProjExprs = Lists.newArrayList();
// project everything from the LHS and then those from the original
// projRel
List<RelDataTypeField> leftInputFields = left.getRowType().getFieldList();
for (int i = 0; i < leftInputFields.size(); i++) {
newProjExprs.add(RexInputRef.of2(i, leftInputFields));
}
// Marked where the projected expr is coming from so that the types will
// become nullable for the original projections which are now coming out
// of the nullable side of the OJ.
boolean projectPulledAboveLeftCorrelator = joinType.generatesNullsOnRight();
for (Pair<RexNode, String> pair : project.getNamedProjects()) {
RexNode newProjExpr = removeCorrelationExpr(pair.left, projectPulledAboveLeftCorrelator, nullIndicator);
newProjExprs.add(Pair.of(newProjExpr, pair.right));
}
return RelOptUtil.createProject(join, newProjExprs, false);
}
use of org.apache.calcite.rex.RexInputRef in project flink by apache.
the class FlinkAggregateExpandDistinctAggregatesRule method onMatch.
//~ Methods ----------------------------------------------------------------
public void onMatch(RelOptRuleCall call) {
final Aggregate aggregate = call.rel(0);
if (!aggregate.containsDistinctCall()) {
return;
}
// Find all of the agg expressions. We use a LinkedHashSet to ensure
// determinism.
int nonDistinctCount = 0;
int distinctCount = 0;
int filterCount = 0;
int unsupportedAggCount = 0;
final Set<Pair<List<Integer>, Integer>> argLists = new LinkedHashSet<>();
for (AggregateCall aggCall : aggregate.getAggCallList()) {
if (aggCall.filterArg >= 0) {
++filterCount;
}
if (!aggCall.isDistinct()) {
++nonDistinctCount;
if (!(aggCall.getAggregation() instanceof SqlCountAggFunction || aggCall.getAggregation() instanceof SqlSumAggFunction || aggCall.getAggregation() instanceof SqlMinMaxAggFunction)) {
++unsupportedAggCount;
}
continue;
}
++distinctCount;
argLists.add(Pair.of(aggCall.getArgList(), aggCall.filterArg));
}
Preconditions.checkState(argLists.size() > 0, "containsDistinctCall lied");
// arguments then we can use a more efficient form.
if (nonDistinctCount == 0 && argLists.size() == 1) {
final Pair<List<Integer>, Integer> pair = Iterables.getOnlyElement(argLists);
final RelBuilder relBuilder = call.builder();
convertMonopole(relBuilder, aggregate, pair.left, pair.right);
call.transformTo(relBuilder.build());
return;
}
if (useGroupingSets) {
rewriteUsingGroupingSets(call, aggregate, argLists);
return;
}
// we can generate multi-phase aggregates
if (// one distinct aggregate
distinctCount == 1 && // no filter
filterCount == 0 && // sum/min/max/count in non-distinct aggregate
unsupportedAggCount == 0 && nonDistinctCount > 0) {
// one or more non-distinct aggregates
final RelBuilder relBuilder = call.builder();
convertSingletonDistinct(relBuilder, aggregate, argLists);
call.transformTo(relBuilder.build());
return;
}
// Create a list of the expressions which will yield the final result.
// Initially, the expressions point to the input field.
final List<RelDataTypeField> aggFields = aggregate.getRowType().getFieldList();
final List<RexInputRef> refs = new ArrayList<>();
final List<String> fieldNames = aggregate.getRowType().getFieldNames();
final ImmutableBitSet groupSet = aggregate.getGroupSet();
final int groupAndIndicatorCount = aggregate.getGroupCount() + aggregate.getIndicatorCount();
for (int i : Util.range(groupAndIndicatorCount)) {
refs.add(RexInputRef.of(i, aggFields));
}
// Aggregate the original relation, including any non-distinct aggregates.
final List<AggregateCall> newAggCallList = new ArrayList<>();
int i = -1;
for (AggregateCall aggCall : aggregate.getAggCallList()) {
++i;
if (aggCall.isDistinct()) {
refs.add(null);
continue;
}
refs.add(new RexInputRef(groupAndIndicatorCount + newAggCallList.size(), aggFields.get(groupAndIndicatorCount + i).getType()));
newAggCallList.add(aggCall);
}
// In the case where there are no non-distinct aggregates (regardless of
// whether there are group bys), there's no need to generate the
// extra aggregate and join.
final RelBuilder relBuilder = call.builder();
relBuilder.push(aggregate.getInput());
int n = 0;
if (!newAggCallList.isEmpty()) {
final RelBuilder.GroupKey groupKey = relBuilder.groupKey(groupSet, aggregate.indicator, aggregate.getGroupSets());
relBuilder.aggregate(groupKey, newAggCallList);
++n;
}
// set of operands.
for (Pair<List<Integer>, Integer> argList : argLists) {
doRewrite(relBuilder, aggregate, n++, argList.left, argList.right, refs);
}
relBuilder.project(refs, fieldNames);
call.transformTo(relBuilder.build());
}
use of org.apache.calcite.rex.RexInputRef in project druid by druid-io.
the class GroupByRules method applyPostAggregation.
/**
* Applies a projection to the aggregations of a druidRel, by potentially adding post-aggregators.
*
* @return new rel, or null if the projection cannot be applied
*/
private static DruidRel applyPostAggregation(final DruidRel druidRel, final Project postProject) {
Preconditions.checkState(canApplyPostAggregation(druidRel), "Cannot applyPostAggregation");
final List<String> rowOrder = druidRel.getQueryBuilder().getRowOrder();
final Grouping grouping = druidRel.getQueryBuilder().getGrouping();
final List<Aggregation> newAggregations = Lists.newArrayList(grouping.getAggregations());
final List<PostAggregatorFactory> finalizingPostAggregatorFactories = Lists.newArrayList();
final List<String> newRowOrder = Lists.newArrayList();
// Build list of finalizingPostAggregatorFactories.
final Map<String, Aggregation> aggregationMap = Maps.newHashMap();
for (final Aggregation aggregation : grouping.getAggregations()) {
aggregationMap.put(aggregation.getOutputName(), aggregation);
}
for (final String field : rowOrder) {
final Aggregation aggregation = aggregationMap.get(field);
finalizingPostAggregatorFactories.add(aggregation == null ? null : aggregation.getFinalizingPostAggregatorFactory());
}
// Walk through the postProject expressions.
for (final RexNode projectExpression : postProject.getChildExps()) {
if (projectExpression.isA(SqlKind.INPUT_REF)) {
final RexInputRef ref = (RexInputRef) projectExpression;
final String fieldName = rowOrder.get(ref.getIndex());
newRowOrder.add(fieldName);
finalizingPostAggregatorFactories.add(null);
} else {
// Attempt to convert to PostAggregator.
final String postAggregatorName = aggOutputName(newAggregations.size());
final PostAggregator postAggregator = Expressions.toPostAggregator(postAggregatorName, rowOrder, finalizingPostAggregatorFactories, projectExpression);
if (postAggregator != null) {
newAggregations.add(Aggregation.create(postAggregator));
newRowOrder.add(postAggregator.getName());
finalizingPostAggregatorFactories.add(null);
} else {
return null;
}
}
}
return druidRel.withQueryBuilder(druidRel.getQueryBuilder().withAdjustedGrouping(Grouping.create(grouping.getDimensions(), newAggregations), postProject.getRowType(), newRowOrder));
}
use of org.apache.calcite.rex.RexInputRef in project druid by druid-io.
the class Expressions method toMathExpression.
/**
* Translate a row-expression to a Druid math expression. One day, when possible, this could be folded into
* {@link #toRowExtraction(DruidOperatorTable, PlannerContext, List, RexNode)}.
*
* @param rowOrder order of fields in the Druid rows to be extracted from
* @param expression expression meant to be applied on top of the rows
*
* @return expression referring to fields in rowOrder, or null if not possible
*/
public static String toMathExpression(final List<String> rowOrder, final RexNode expression) {
final SqlKind kind = expression.getKind();
final SqlTypeName sqlTypeName = expression.getType().getSqlTypeName();
if (kind == SqlKind.INPUT_REF) {
// Translate field references.
final RexInputRef ref = (RexInputRef) expression;
final String columnName = rowOrder.get(ref.getIndex());
if (columnName == null) {
throw new ISE("WTF?! Expression referred to nonexistent index[%d]", ref.getIndex());
}
return String.format("\"%s\"", escape(columnName));
} else if (kind == SqlKind.CAST || kind == SqlKind.REINTERPRET) {
// Translate casts.
final RexNode operand = ((RexCall) expression).getOperands().get(0);
final String operandExpression = toMathExpression(rowOrder, operand);
if (operandExpression == null) {
return null;
}
final ExprType fromType = MATH_TYPES.get(operand.getType().getSqlTypeName());
final ExprType toType = MATH_TYPES.get(sqlTypeName);
if (fromType != toType) {
return String.format("CAST(%s, '%s')", operandExpression, toType.toString());
} else {
return operandExpression;
}
} else if (kind == SqlKind.TIMES || kind == SqlKind.DIVIDE || kind == SqlKind.PLUS || kind == SqlKind.MINUS) {
// Translate simple arithmetic.
final List<RexNode> operands = ((RexCall) expression).getOperands();
final String lhsExpression = toMathExpression(rowOrder, operands.get(0));
final String rhsExpression = toMathExpression(rowOrder, operands.get(1));
if (lhsExpression == null || rhsExpression == null) {
return null;
}
final String op = ImmutableMap.of(SqlKind.TIMES, "*", SqlKind.DIVIDE, "/", SqlKind.PLUS, "+", SqlKind.MINUS, "-").get(kind);
return String.format("(%s %s %s)", lhsExpression, op, rhsExpression);
} else if (kind == SqlKind.OTHER_FUNCTION) {
final String calciteFunction = ((RexCall) expression).getOperator().getName();
final String druidFunction = MATH_FUNCTIONS.get(calciteFunction);
final List<String> functionArgs = Lists.newArrayList();
for (final RexNode operand : ((RexCall) expression).getOperands()) {
final String operandExpression = toMathExpression(rowOrder, operand);
if (operandExpression == null) {
return null;
}
functionArgs.add(operandExpression);
}
if ("MOD".equals(calciteFunction)) {
// Special handling for MOD, which is a function in Calcite but a binary operator in Druid.
Preconditions.checkState(functionArgs.size() == 2, "WTF?! Expected 2 args for MOD.");
return String.format("(%s %s %s)", functionArgs.get(0), "%", functionArgs.get(1));
}
if (druidFunction == null) {
return null;
}
return String.format("%s(%s)", druidFunction, Joiner.on(", ").join(functionArgs));
} else if (kind == SqlKind.LITERAL) {
// Translate literal.
if (SqlTypeName.NUMERIC_TYPES.contains(sqlTypeName)) {
// Include literal numbers as-is.
return String.valueOf(RexLiteral.value(expression));
} else if (SqlTypeName.STRING_TYPES.contains(sqlTypeName)) {
// Quote literal strings.
return "\'" + escape(RexLiteral.stringValue(expression)) + "\'";
} else {
// Can't translate other literals.
return null;
}
} else {
// Can't translate other kinds of expressions.
return null;
}
}
use of org.apache.calcite.rex.RexInputRef in project druid by druid-io.
the class Expressions method toPostAggregator.
/**
* Translate a Calcite row-expression to a Druid PostAggregator. One day, when possible, this could be folded
* into {@link #toRowExtraction(DruidOperatorTable, PlannerContext, List, RexNode)} .
*
* @param name name of the PostAggregator
* @param rowOrder order of fields in the Druid rows to be extracted from
* @param finalizingPostAggregatorFactories post-aggregators that should be used for specific entries in rowOrder.
* May be empty, and individual values may be null. Missing or null values
* will lead to creation of {@link FieldAccessPostAggregator}.
* @param expression expression meant to be applied on top of the rows
*
* @return PostAggregator or null if not possible
*/
public static PostAggregator toPostAggregator(final String name, final List<String> rowOrder, final List<PostAggregatorFactory> finalizingPostAggregatorFactories, final RexNode expression) {
final PostAggregator retVal;
if (expression.getKind() == SqlKind.INPUT_REF) {
final RexInputRef ref = (RexInputRef) expression;
final PostAggregatorFactory finalizingPostAggregatorFactory = finalizingPostAggregatorFactories.get(ref.getIndex());
retVal = finalizingPostAggregatorFactory != null ? finalizingPostAggregatorFactory.factorize(name) : new FieldAccessPostAggregator(name, rowOrder.get(ref.getIndex()));
} else if (expression.getKind() == SqlKind.CAST) {
// Ignore CAST when translating to PostAggregators and hope for the best. They are really loosey-goosey with
// types internally and there isn't much we can do to respect
// TODO(gianm): Probably not a good idea to ignore CAST like this.
final RexNode operand = ((RexCall) expression).getOperands().get(0);
retVal = toPostAggregator(name, rowOrder, finalizingPostAggregatorFactories, operand);
} else if (expression.getKind() == SqlKind.LITERAL && SqlTypeName.NUMERIC_TYPES.contains(expression.getType().getSqlTypeName())) {
retVal = new ConstantPostAggregator(name, (Number) RexLiteral.value(expression));
} else if (expression.getKind() == SqlKind.TIMES || expression.getKind() == SqlKind.DIVIDE || expression.getKind() == SqlKind.PLUS || expression.getKind() == SqlKind.MINUS) {
final String fnName = ImmutableMap.<SqlKind, String>builder().put(SqlKind.TIMES, "*").put(SqlKind.DIVIDE, "quotient").put(SqlKind.PLUS, "+").put(SqlKind.MINUS, "-").build().get(expression.getKind());
final List<PostAggregator> operands = Lists.newArrayList();
for (RexNode operand : ((RexCall) expression).getOperands()) {
final PostAggregator translatedOperand = toPostAggregator(null, rowOrder, finalizingPostAggregatorFactories, operand);
if (translatedOperand == null) {
return null;
}
operands.add(translatedOperand);
}
retVal = new ArithmeticPostAggregator(name, fnName, operands);
} else {
// Try converting to a math expression.
final String mathExpression = Expressions.toMathExpression(rowOrder, expression);
if (mathExpression == null) {
retVal = null;
} else {
retVal = new ExpressionPostAggregator(name, mathExpression);
}
}
if (retVal != null && name != null && !name.equals(retVal.getName())) {
throw new ISE("WTF?! Was about to return a PostAggregator with bad name, [%s] != [%s]", name, retVal.getName());
}
return retVal;
}
Aggregations