Search in sources :

Example 6 with RexInputRef

use of org.apache.calcite.rex.RexInputRef in project flink by apache.

the class FlinkRelDecorrelator method projectJoinOutputWithNullability.

/**
	 * Pulls project above the join from its RHS input. Enforces nullability
	 * for join output.
	 *
	 * @param join             Join
	 * @param project          Original project as the right-hand input of the join
	 * @param nullIndicatorPos Position of null indicator
	 * @return the subtree with the new LogicalProject at the root
	 */
private RelNode projectJoinOutputWithNullability(LogicalJoin join, LogicalProject project, int nullIndicatorPos) {
    final RelDataTypeFactory typeFactory = join.getCluster().getTypeFactory();
    final RelNode left = join.getLeft();
    final JoinRelType joinType = join.getJoinType();
    RexInputRef nullIndicator = new RexInputRef(nullIndicatorPos, typeFactory.createTypeWithNullability(join.getRowType().getFieldList().get(nullIndicatorPos).getType(), true));
    // now create the new project
    List<Pair<RexNode, String>> newProjExprs = Lists.newArrayList();
    // project everything from the LHS and then those from the original
    // projRel
    List<RelDataTypeField> leftInputFields = left.getRowType().getFieldList();
    for (int i = 0; i < leftInputFields.size(); i++) {
        newProjExprs.add(RexInputRef.of2(i, leftInputFields));
    }
    // Marked where the projected expr is coming from so that the types will
    // become nullable for the original projections which are now coming out
    // of the nullable side of the OJ.
    boolean projectPulledAboveLeftCorrelator = joinType.generatesNullsOnRight();
    for (Pair<RexNode, String> pair : project.getNamedProjects()) {
        RexNode newProjExpr = removeCorrelationExpr(pair.left, projectPulledAboveLeftCorrelator, nullIndicator);
        newProjExprs.add(Pair.of(newProjExpr, pair.right));
    }
    return RelOptUtil.createProject(join, newProjExprs, false);
}
Also used : JoinRelType(org.apache.calcite.rel.core.JoinRelType) RelDataTypeField(org.apache.calcite.rel.type.RelDataTypeField) RelNode(org.apache.calcite.rel.RelNode) RelDataTypeFactory(org.apache.calcite.rel.type.RelDataTypeFactory) RexInputRef(org.apache.calcite.rex.RexInputRef) Pair(org.apache.calcite.util.Pair) RexNode(org.apache.calcite.rex.RexNode)

Example 7 with RexInputRef

use of org.apache.calcite.rex.RexInputRef in project flink by apache.

the class FlinkAggregateExpandDistinctAggregatesRule method onMatch.

//~ Methods ----------------------------------------------------------------
public void onMatch(RelOptRuleCall call) {
    final Aggregate aggregate = call.rel(0);
    if (!aggregate.containsDistinctCall()) {
        return;
    }
    // Find all of the agg expressions. We use a LinkedHashSet to ensure
    // determinism.
    int nonDistinctCount = 0;
    int distinctCount = 0;
    int filterCount = 0;
    int unsupportedAggCount = 0;
    final Set<Pair<List<Integer>, Integer>> argLists = new LinkedHashSet<>();
    for (AggregateCall aggCall : aggregate.getAggCallList()) {
        if (aggCall.filterArg >= 0) {
            ++filterCount;
        }
        if (!aggCall.isDistinct()) {
            ++nonDistinctCount;
            if (!(aggCall.getAggregation() instanceof SqlCountAggFunction || aggCall.getAggregation() instanceof SqlSumAggFunction || aggCall.getAggregation() instanceof SqlMinMaxAggFunction)) {
                ++unsupportedAggCount;
            }
            continue;
        }
        ++distinctCount;
        argLists.add(Pair.of(aggCall.getArgList(), aggCall.filterArg));
    }
    Preconditions.checkState(argLists.size() > 0, "containsDistinctCall lied");
    // arguments then we can use a more efficient form.
    if (nonDistinctCount == 0 && argLists.size() == 1) {
        final Pair<List<Integer>, Integer> pair = Iterables.getOnlyElement(argLists);
        final RelBuilder relBuilder = call.builder();
        convertMonopole(relBuilder, aggregate, pair.left, pair.right);
        call.transformTo(relBuilder.build());
        return;
    }
    if (useGroupingSets) {
        rewriteUsingGroupingSets(call, aggregate, argLists);
        return;
    }
    // we can generate multi-phase aggregates
    if (// one distinct aggregate
    distinctCount == 1 && // no filter
    filterCount == 0 && // sum/min/max/count in non-distinct aggregate
    unsupportedAggCount == 0 && nonDistinctCount > 0) {
        // one or more non-distinct aggregates
        final RelBuilder relBuilder = call.builder();
        convertSingletonDistinct(relBuilder, aggregate, argLists);
        call.transformTo(relBuilder.build());
        return;
    }
    // Create a list of the expressions which will yield the final result.
    // Initially, the expressions point to the input field.
    final List<RelDataTypeField> aggFields = aggregate.getRowType().getFieldList();
    final List<RexInputRef> refs = new ArrayList<>();
    final List<String> fieldNames = aggregate.getRowType().getFieldNames();
    final ImmutableBitSet groupSet = aggregate.getGroupSet();
    final int groupAndIndicatorCount = aggregate.getGroupCount() + aggregate.getIndicatorCount();
    for (int i : Util.range(groupAndIndicatorCount)) {
        refs.add(RexInputRef.of(i, aggFields));
    }
    // Aggregate the original relation, including any non-distinct aggregates.
    final List<AggregateCall> newAggCallList = new ArrayList<>();
    int i = -1;
    for (AggregateCall aggCall : aggregate.getAggCallList()) {
        ++i;
        if (aggCall.isDistinct()) {
            refs.add(null);
            continue;
        }
        refs.add(new RexInputRef(groupAndIndicatorCount + newAggCallList.size(), aggFields.get(groupAndIndicatorCount + i).getType()));
        newAggCallList.add(aggCall);
    }
    // In the case where there are no non-distinct aggregates (regardless of
    // whether there are group bys), there's no need to generate the
    // extra aggregate and join.
    final RelBuilder relBuilder = call.builder();
    relBuilder.push(aggregate.getInput());
    int n = 0;
    if (!newAggCallList.isEmpty()) {
        final RelBuilder.GroupKey groupKey = relBuilder.groupKey(groupSet, aggregate.indicator, aggregate.getGroupSets());
        relBuilder.aggregate(groupKey, newAggCallList);
        ++n;
    }
    // set of operands.
    for (Pair<List<Integer>, Integer> argList : argLists) {
        doRewrite(relBuilder, aggregate, n++, argList.left, argList.right, refs);
    }
    relBuilder.project(refs, fieldNames);
    call.transformTo(relBuilder.build());
}
Also used : LinkedHashSet(java.util.LinkedHashSet) RelBuilder(org.apache.calcite.tools.RelBuilder) SqlMinMaxAggFunction(org.apache.calcite.sql.fun.SqlMinMaxAggFunction) ImmutableBitSet(org.apache.calcite.util.ImmutableBitSet) ArrayList(java.util.ArrayList) SqlCountAggFunction(org.apache.calcite.sql.fun.SqlCountAggFunction) AggregateCall(org.apache.calcite.rel.core.AggregateCall) RelDataTypeField(org.apache.calcite.rel.type.RelDataTypeField) SqlSumAggFunction(org.apache.calcite.sql.fun.SqlSumAggFunction) RexInputRef(org.apache.calcite.rex.RexInputRef) ArrayList(java.util.ArrayList) ImmutableList(com.google.common.collect.ImmutableList) ImmutableIntList(org.apache.calcite.util.ImmutableIntList) List(java.util.List) Aggregate(org.apache.calcite.rel.core.Aggregate) LogicalAggregate(org.apache.calcite.rel.logical.LogicalAggregate) Pair(org.apache.calcite.util.Pair)

Example 8 with RexInputRef

use of org.apache.calcite.rex.RexInputRef in project druid by druid-io.

the class GroupByRules method applyPostAggregation.

/**
   * Applies a projection to the aggregations of a druidRel, by potentially adding post-aggregators.
   *
   * @return new rel, or null if the projection cannot be applied
   */
private static DruidRel applyPostAggregation(final DruidRel druidRel, final Project postProject) {
    Preconditions.checkState(canApplyPostAggregation(druidRel), "Cannot applyPostAggregation");
    final List<String> rowOrder = druidRel.getQueryBuilder().getRowOrder();
    final Grouping grouping = druidRel.getQueryBuilder().getGrouping();
    final List<Aggregation> newAggregations = Lists.newArrayList(grouping.getAggregations());
    final List<PostAggregatorFactory> finalizingPostAggregatorFactories = Lists.newArrayList();
    final List<String> newRowOrder = Lists.newArrayList();
    // Build list of finalizingPostAggregatorFactories.
    final Map<String, Aggregation> aggregationMap = Maps.newHashMap();
    for (final Aggregation aggregation : grouping.getAggregations()) {
        aggregationMap.put(aggregation.getOutputName(), aggregation);
    }
    for (final String field : rowOrder) {
        final Aggregation aggregation = aggregationMap.get(field);
        finalizingPostAggregatorFactories.add(aggregation == null ? null : aggregation.getFinalizingPostAggregatorFactory());
    }
    // Walk through the postProject expressions.
    for (final RexNode projectExpression : postProject.getChildExps()) {
        if (projectExpression.isA(SqlKind.INPUT_REF)) {
            final RexInputRef ref = (RexInputRef) projectExpression;
            final String fieldName = rowOrder.get(ref.getIndex());
            newRowOrder.add(fieldName);
            finalizingPostAggregatorFactories.add(null);
        } else {
            // Attempt to convert to PostAggregator.
            final String postAggregatorName = aggOutputName(newAggregations.size());
            final PostAggregator postAggregator = Expressions.toPostAggregator(postAggregatorName, rowOrder, finalizingPostAggregatorFactories, projectExpression);
            if (postAggregator != null) {
                newAggregations.add(Aggregation.create(postAggregator));
                newRowOrder.add(postAggregator.getName());
                finalizingPostAggregatorFactories.add(null);
            } else {
                return null;
            }
        }
    }
    return druidRel.withQueryBuilder(druidRel.getQueryBuilder().withAdjustedGrouping(Grouping.create(grouping.getDimensions(), newAggregations), postProject.getRowType(), newRowOrder));
}
Also used : Aggregation(io.druid.sql.calcite.aggregation.Aggregation) PostAggregator(io.druid.query.aggregation.PostAggregator) FieldAccessPostAggregator(io.druid.query.aggregation.post.FieldAccessPostAggregator) ArithmeticPostAggregator(io.druid.query.aggregation.post.ArithmeticPostAggregator) RexInputRef(org.apache.calcite.rex.RexInputRef) Grouping(io.druid.sql.calcite.rel.Grouping) PostAggregatorFactory(io.druid.sql.calcite.aggregation.PostAggregatorFactory) RexNode(org.apache.calcite.rex.RexNode)

Example 9 with RexInputRef

use of org.apache.calcite.rex.RexInputRef in project druid by druid-io.

the class Expressions method toMathExpression.

/**
   * Translate a row-expression to a Druid math expression. One day, when possible, this could be folded into
   * {@link #toRowExtraction(DruidOperatorTable, PlannerContext, List, RexNode)}.
   *
   * @param rowOrder   order of fields in the Druid rows to be extracted from
   * @param expression expression meant to be applied on top of the rows
   *
   * @return expression referring to fields in rowOrder, or null if not possible
   */
public static String toMathExpression(final List<String> rowOrder, final RexNode expression) {
    final SqlKind kind = expression.getKind();
    final SqlTypeName sqlTypeName = expression.getType().getSqlTypeName();
    if (kind == SqlKind.INPUT_REF) {
        // Translate field references.
        final RexInputRef ref = (RexInputRef) expression;
        final String columnName = rowOrder.get(ref.getIndex());
        if (columnName == null) {
            throw new ISE("WTF?! Expression referred to nonexistent index[%d]", ref.getIndex());
        }
        return String.format("\"%s\"", escape(columnName));
    } else if (kind == SqlKind.CAST || kind == SqlKind.REINTERPRET) {
        // Translate casts.
        final RexNode operand = ((RexCall) expression).getOperands().get(0);
        final String operandExpression = toMathExpression(rowOrder, operand);
        if (operandExpression == null) {
            return null;
        }
        final ExprType fromType = MATH_TYPES.get(operand.getType().getSqlTypeName());
        final ExprType toType = MATH_TYPES.get(sqlTypeName);
        if (fromType != toType) {
            return String.format("CAST(%s, '%s')", operandExpression, toType.toString());
        } else {
            return operandExpression;
        }
    } else if (kind == SqlKind.TIMES || kind == SqlKind.DIVIDE || kind == SqlKind.PLUS || kind == SqlKind.MINUS) {
        // Translate simple arithmetic.
        final List<RexNode> operands = ((RexCall) expression).getOperands();
        final String lhsExpression = toMathExpression(rowOrder, operands.get(0));
        final String rhsExpression = toMathExpression(rowOrder, operands.get(1));
        if (lhsExpression == null || rhsExpression == null) {
            return null;
        }
        final String op = ImmutableMap.of(SqlKind.TIMES, "*", SqlKind.DIVIDE, "/", SqlKind.PLUS, "+", SqlKind.MINUS, "-").get(kind);
        return String.format("(%s %s %s)", lhsExpression, op, rhsExpression);
    } else if (kind == SqlKind.OTHER_FUNCTION) {
        final String calciteFunction = ((RexCall) expression).getOperator().getName();
        final String druidFunction = MATH_FUNCTIONS.get(calciteFunction);
        final List<String> functionArgs = Lists.newArrayList();
        for (final RexNode operand : ((RexCall) expression).getOperands()) {
            final String operandExpression = toMathExpression(rowOrder, operand);
            if (operandExpression == null) {
                return null;
            }
            functionArgs.add(operandExpression);
        }
        if ("MOD".equals(calciteFunction)) {
            // Special handling for MOD, which is a function in Calcite but a binary operator in Druid.
            Preconditions.checkState(functionArgs.size() == 2, "WTF?! Expected 2 args for MOD.");
            return String.format("(%s %s %s)", functionArgs.get(0), "%", functionArgs.get(1));
        }
        if (druidFunction == null) {
            return null;
        }
        return String.format("%s(%s)", druidFunction, Joiner.on(", ").join(functionArgs));
    } else if (kind == SqlKind.LITERAL) {
        // Translate literal.
        if (SqlTypeName.NUMERIC_TYPES.contains(sqlTypeName)) {
            // Include literal numbers as-is.
            return String.valueOf(RexLiteral.value(expression));
        } else if (SqlTypeName.STRING_TYPES.contains(sqlTypeName)) {
            // Quote literal strings.
            return "\'" + escape(RexLiteral.stringValue(expression)) + "\'";
        } else {
            // Can't translate other literals.
            return null;
        }
    } else {
        // Can't translate other kinds of expressions.
        return null;
    }
}
Also used : RexCall(org.apache.calcite.rex.RexCall) SqlTypeName(org.apache.calcite.sql.type.SqlTypeName) ExprType(io.druid.math.expr.ExprType) RexInputRef(org.apache.calcite.rex.RexInputRef) ISE(io.druid.java.util.common.ISE) List(java.util.List) SqlKind(org.apache.calcite.sql.SqlKind) RexNode(org.apache.calcite.rex.RexNode)

Example 10 with RexInputRef

use of org.apache.calcite.rex.RexInputRef in project druid by druid-io.

the class Expressions method toPostAggregator.

/**
   * Translate a Calcite row-expression to a Druid PostAggregator. One day, when possible, this could be folded
   * into {@link #toRowExtraction(DruidOperatorTable, PlannerContext, List, RexNode)} .
   *
   * @param name                              name of the PostAggregator
   * @param rowOrder                          order of fields in the Druid rows to be extracted from
   * @param finalizingPostAggregatorFactories post-aggregators that should be used for specific entries in rowOrder.
   *                                          May be empty, and individual values may be null. Missing or null values
   *                                          will lead to creation of {@link FieldAccessPostAggregator}.
   * @param expression                        expression meant to be applied on top of the rows
   *
   * @return PostAggregator or null if not possible
   */
public static PostAggregator toPostAggregator(final String name, final List<String> rowOrder, final List<PostAggregatorFactory> finalizingPostAggregatorFactories, final RexNode expression) {
    final PostAggregator retVal;
    if (expression.getKind() == SqlKind.INPUT_REF) {
        final RexInputRef ref = (RexInputRef) expression;
        final PostAggregatorFactory finalizingPostAggregatorFactory = finalizingPostAggregatorFactories.get(ref.getIndex());
        retVal = finalizingPostAggregatorFactory != null ? finalizingPostAggregatorFactory.factorize(name) : new FieldAccessPostAggregator(name, rowOrder.get(ref.getIndex()));
    } else if (expression.getKind() == SqlKind.CAST) {
        // Ignore CAST when translating to PostAggregators and hope for the best. They are really loosey-goosey with
        // types internally and there isn't much we can do to respect
        // TODO(gianm): Probably not a good idea to ignore CAST like this.
        final RexNode operand = ((RexCall) expression).getOperands().get(0);
        retVal = toPostAggregator(name, rowOrder, finalizingPostAggregatorFactories, operand);
    } else if (expression.getKind() == SqlKind.LITERAL && SqlTypeName.NUMERIC_TYPES.contains(expression.getType().getSqlTypeName())) {
        retVal = new ConstantPostAggregator(name, (Number) RexLiteral.value(expression));
    } else if (expression.getKind() == SqlKind.TIMES || expression.getKind() == SqlKind.DIVIDE || expression.getKind() == SqlKind.PLUS || expression.getKind() == SqlKind.MINUS) {
        final String fnName = ImmutableMap.<SqlKind, String>builder().put(SqlKind.TIMES, "*").put(SqlKind.DIVIDE, "quotient").put(SqlKind.PLUS, "+").put(SqlKind.MINUS, "-").build().get(expression.getKind());
        final List<PostAggregator> operands = Lists.newArrayList();
        for (RexNode operand : ((RexCall) expression).getOperands()) {
            final PostAggregator translatedOperand = toPostAggregator(null, rowOrder, finalizingPostAggregatorFactories, operand);
            if (translatedOperand == null) {
                return null;
            }
            operands.add(translatedOperand);
        }
        retVal = new ArithmeticPostAggregator(name, fnName, operands);
    } else {
        // Try converting to a math expression.
        final String mathExpression = Expressions.toMathExpression(rowOrder, expression);
        if (mathExpression == null) {
            retVal = null;
        } else {
            retVal = new ExpressionPostAggregator(name, mathExpression);
        }
    }
    if (retVal != null && name != null && !name.equals(retVal.getName())) {
        throw new ISE("WTF?! Was about to return a PostAggregator with bad name, [%s] != [%s]", name, retVal.getName());
    }
    return retVal;
}
Also used : ArithmeticPostAggregator(io.druid.query.aggregation.post.ArithmeticPostAggregator) FieldAccessPostAggregator(io.druid.query.aggregation.post.FieldAccessPostAggregator) ExpressionPostAggregator(io.druid.query.aggregation.post.ExpressionPostAggregator) ConstantPostAggregator(io.druid.query.aggregation.post.ConstantPostAggregator) PostAggregator(io.druid.query.aggregation.PostAggregator) FieldAccessPostAggregator(io.druid.query.aggregation.post.FieldAccessPostAggregator) ArithmeticPostAggregator(io.druid.query.aggregation.post.ArithmeticPostAggregator) ConstantPostAggregator(io.druid.query.aggregation.post.ConstantPostAggregator) PostAggregatorFactory(io.druid.sql.calcite.aggregation.PostAggregatorFactory) RexCall(org.apache.calcite.rex.RexCall) ExpressionPostAggregator(io.druid.query.aggregation.post.ExpressionPostAggregator) RexInputRef(org.apache.calcite.rex.RexInputRef) List(java.util.List) ISE(io.druid.java.util.common.ISE) RexNode(org.apache.calcite.rex.RexNode)

Aggregations

RexInputRef (org.apache.calcite.rex.RexInputRef)241 RexNode (org.apache.calcite.rex.RexNode)200 ArrayList (java.util.ArrayList)103 RelNode (org.apache.calcite.rel.RelNode)85 RelDataTypeField (org.apache.calcite.rel.type.RelDataTypeField)80 RexCall (org.apache.calcite.rex.RexCall)67 RelDataType (org.apache.calcite.rel.type.RelDataType)63 RexBuilder (org.apache.calcite.rex.RexBuilder)54 ImmutableBitSet (org.apache.calcite.util.ImmutableBitSet)52 HashMap (java.util.HashMap)47 AggregateCall (org.apache.calcite.rel.core.AggregateCall)36 List (java.util.List)35 HashSet (java.util.HashSet)32 Pair (org.apache.calcite.util.Pair)32 RexLiteral (org.apache.calcite.rex.RexLiteral)29 Map (java.util.Map)24 RelOptUtil (org.apache.calcite.plan.RelOptUtil)24 Set (java.util.Set)20 ImmutableList (com.google.common.collect.ImmutableList)19 LinkedHashMap (java.util.LinkedHashMap)19