Search in sources :

Example 46 with AggregateCall

use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.AggregateCall in project flink by apache.

the class FlinkAggregateExpandDistinctAggregatesRule method createSelectDistinct.

/**
 * Given an {@link org.apache.calcite.rel.core.Aggregate} and the ordinals of the arguments to a
 * particular call to an aggregate function, creates a 'select distinct' relational expression
 * which projects the group columns and those arguments but nothing else.
 *
 * <p>For example, given
 *
 * <blockquote>
 *
 * <pre>select f0, count(distinct f1), count(distinct f2)
 * from t group by f0</pre>
 *
 * </blockquote>
 *
 * <p>and the argument list
 *
 * <blockquote>
 *
 * {2}
 *
 * </blockquote>
 *
 * <p>returns
 *
 * <blockquote>
 *
 * <pre>select distinct f0, f2 from t</pre>
 *
 * </blockquote>
 *
 * <p>The <code>sourceOf</code> map is populated with the source of each column; in this case
 * sourceOf.get(0) = 0, and sourceOf.get(1) = 2.
 *
 * @param relBuilder Relational expression builder
 * @param aggregate Aggregate relational expression
 * @param argList Ordinals of columns to make distinct
 * @param filterArg Ordinal of column to filter on, or -1
 * @param sourceOf Out parameter, is populated with a map of where each output field came from
 * @return Aggregate relational expression which projects the required columns
 */
private RelBuilder createSelectDistinct(RelBuilder relBuilder, Aggregate aggregate, List<Integer> argList, int filterArg, Map<Integer, Integer> sourceOf) {
    relBuilder.push(aggregate.getInput());
    final List<Pair<RexNode, String>> projects = new ArrayList<>();
    final List<RelDataTypeField> childFields = relBuilder.peek().getRowType().getFieldList();
    for (int i : aggregate.getGroupSet()) {
        sourceOf.put(i, projects.size());
        projects.add(RexInputRef.of2(i, childFields));
    }
    if (filterArg >= 0) {
        sourceOf.put(filterArg, projects.size());
        projects.add(RexInputRef.of2(filterArg, childFields));
    }
    for (Integer arg : argList) {
        if (filterArg >= 0) {
            // Implement
            // agg(DISTINCT arg) FILTER $f
            // by generating
            // SELECT DISTINCT ... CASE WHEN $f THEN arg ELSE NULL END AS arg
            // and then applying
            // agg(arg)
            // as usual.
            // 
            // It works except for (rare) agg functions that need to see null
            // values.
            final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
            final RexInputRef filterRef = RexInputRef.of(filterArg, childFields);
            final Pair<RexNode, String> argRef = RexInputRef.of2(arg, childFields);
            RexNode condition = rexBuilder.makeCall(SqlStdOperatorTable.CASE, filterRef, argRef.left, rexBuilder.makeNullLiteral(argRef.left.getType()));
            sourceOf.put(arg, projects.size());
            projects.add(Pair.of(condition, "i$" + argRef.right));
            continue;
        }
        if (sourceOf.get(arg) != null) {
            continue;
        }
        sourceOf.put(arg, projects.size());
        projects.add(RexInputRef.of2(arg, childFields));
    }
    relBuilder.project(Pair.left(projects), Pair.right(projects));
    // Get the distinct values of the GROUP BY fields and the arguments
    // to the agg functions.
    relBuilder.push(aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), ImmutableBitSet.range(projects.size()), null, com.google.common.collect.ImmutableList.<AggregateCall>of()));
    return relBuilder;
}
Also used : ArrayList(java.util.ArrayList) AggregateCall(org.apache.calcite.rel.core.AggregateCall) RelDataTypeField(org.apache.calcite.rel.type.RelDataTypeField) RexBuilder(org.apache.calcite.rex.RexBuilder) RexInputRef(org.apache.calcite.rex.RexInputRef) Pair(org.apache.calcite.util.Pair) RexNode(org.apache.calcite.rex.RexNode)

Example 47 with AggregateCall

use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.AggregateCall in project flink by apache.

the class FlinkAggregateExpandDistinctAggregatesRule method onMatch.

// ~ Methods ----------------------------------------------------------------
public void onMatch(RelOptRuleCall call) {
    final Aggregate aggregate = call.rel(0);
    if (!AggregateUtil.containsAccurateDistinctCall(aggregate.getAggCallList())) {
        return;
    }
    // accurate distinct call.
    if (AggregateUtil.containsApproximateDistinctCall(aggregate.getAggCallList())) {
        throw new TableException("There are both Distinct AggCall and Approximate Distinct AggCall in one sql statement, " + "it is not supported yet.\nPlease choose one of them.");
    }
    // by DecomposeGroupingSetsRule. Then this rule expands it's distinct aggregates.
    if (aggregate.getGroupSets().size() > 1) {
        return;
    }
    // Find all of the agg expressions. We use a LinkedHashSet to ensure determinism.
    // Find all aggregate calls without distinct
    int nonDistinctAggCallCount = 0;
    // Find all aggregate calls without distinct but ignore MAX, MIN, BIT_AND, BIT_OR
    int nonDistinctAggCallExcludingIgnoredCount = 0;
    int filterCount = 0;
    int unsupportedNonDistinctAggCallCount = 0;
    final Set<Pair<List<Integer>, Integer>> argLists = new LinkedHashSet<>();
    for (AggregateCall aggCall : aggregate.getAggCallList()) {
        if (aggCall.filterArg >= 0) {
            ++filterCount;
        }
        if (!aggCall.isDistinct()) {
            ++nonDistinctAggCallCount;
            final SqlKind aggCallKind = aggCall.getAggregation().getKind();
            // We only support COUNT/SUM/MIN/MAX for the "single" count distinct optimization
            switch(aggCallKind) {
                case COUNT:
                case SUM:
                case SUM0:
                case MIN:
                case MAX:
                    break;
                default:
                    ++unsupportedNonDistinctAggCallCount;
            }
            if (aggCall.getAggregation().getDistinctOptionality() == Optionality.IGNORED) {
                argLists.add(Pair.of(aggCall.getArgList(), aggCall.filterArg));
            } else {
                ++nonDistinctAggCallExcludingIgnoredCount;
            }
        } else {
            argLists.add(Pair.of(aggCall.getArgList(), aggCall.filterArg));
        }
    }
    final int distinctAggCallCount = aggregate.getAggCallList().size() - nonDistinctAggCallCount;
    Preconditions.checkState(argLists.size() > 0, "containsDistinctCall lied");
    // we can still use this promotion.
    if (nonDistinctAggCallExcludingIgnoredCount == 0 && argLists.size() == 1 && aggregate.getGroupType() == Group.SIMPLE) {
        final Pair<List<Integer>, Integer> pair = com.google.common.collect.Iterables.getOnlyElement(argLists);
        final RelBuilder relBuilder = call.builder();
        convertMonopole(relBuilder, aggregate, pair.left, pair.right);
        call.transformTo(relBuilder.build());
        return;
    }
    if (useGroupingSets) {
        rewriteUsingGroupingSets(call, aggregate);
        return;
    }
    // we can generate multi-phase aggregates
    if (// one distinct aggregate
    distinctAggCallCount == 1 && // no filter
    filterCount == 0 && unsupportedNonDistinctAggCallCount == // sum/min/max/count in non-distinct aggregate
    0 && nonDistinctAggCallCount > 0) {
        // one or more non-distinct aggregates
        final RelBuilder relBuilder = call.builder();
        convertSingletonDistinct(relBuilder, aggregate, argLists);
        call.transformTo(relBuilder.build());
        return;
    }
    // Create a list of the expressions which will yield the final result.
    // Initially, the expressions point to the input field.
    final List<RelDataTypeField> aggFields = aggregate.getRowType().getFieldList();
    final List<RexInputRef> refs = new ArrayList<>();
    final List<String> fieldNames = aggregate.getRowType().getFieldNames();
    final ImmutableBitSet groupSet = aggregate.getGroupSet();
    final int groupCount = aggregate.getGroupCount();
    for (int i : Util.range(groupCount)) {
        refs.add(RexInputRef.of(i, aggFields));
    }
    // Aggregate the original relation, including any non-distinct aggregates.
    final List<AggregateCall> newAggCallList = new ArrayList<>();
    int i = -1;
    for (AggregateCall aggCall : aggregate.getAggCallList()) {
        ++i;
        if (aggCall.isDistinct()) {
            refs.add(null);
            continue;
        }
        refs.add(new RexInputRef(groupCount + newAggCallList.size(), aggFields.get(groupCount + i).getType()));
        newAggCallList.add(aggCall);
    }
    // In the case where there are no non-distinct aggregates (regardless of
    // whether there are group bys), there's no need to generate the
    // extra aggregate and join.
    final RelBuilder relBuilder = call.builder();
    relBuilder.push(aggregate.getInput());
    int n = 0;
    if (!newAggCallList.isEmpty()) {
        final RelBuilder.GroupKey groupKey = relBuilder.groupKey(groupSet, aggregate.getGroupSets());
        relBuilder.aggregate(groupKey, newAggCallList);
        ++n;
    }
    // set of operands.
    for (Pair<List<Integer>, Integer> argList : argLists) {
        doRewrite(relBuilder, aggregate, n++, argList.left, argList.right, refs);
    }
    relBuilder.project(refs, fieldNames);
    call.transformTo(relBuilder.build());
}
Also used : LinkedHashSet(java.util.LinkedHashSet) TableException(org.apache.flink.table.api.TableException) RelBuilder(org.apache.calcite.tools.RelBuilder) ImmutableBitSet(org.apache.calcite.util.ImmutableBitSet) ArrayList(java.util.ArrayList) SqlKind(org.apache.calcite.sql.SqlKind) AggregateCall(org.apache.calcite.rel.core.AggregateCall) RelDataTypeField(org.apache.calcite.rel.type.RelDataTypeField) RexInputRef(org.apache.calcite.rex.RexInputRef) ArrayList(java.util.ArrayList) ImmutableIntList(org.apache.calcite.util.ImmutableIntList) List(java.util.List) Aggregate(org.apache.calcite.rel.core.Aggregate) LogicalAggregate(org.apache.calcite.rel.logical.LogicalAggregate) Pair(org.apache.calcite.util.Pair)

Example 48 with AggregateCall

use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.AggregateCall in project flink by apache.

the class FlinkAggregateRemoveRule method matches.

@Override
public boolean matches(RelOptRuleCall call) {
    final Aggregate aggregate = call.rel(0);
    final RelNode input = call.rel(1);
    if (aggregate.getGroupCount() == 0 || aggregate.indicator || aggregate.getGroupType() != Aggregate.Group.SIMPLE) {
        return false;
    }
    for (AggregateCall aggCall : aggregate.getAggCallList()) {
        SqlKind aggCallKind = aggCall.getAggregation().getKind();
        // TODO supports more AggregateCalls
        boolean isAllowAggCall = aggCallKind == SqlKind.SUM || aggCallKind == SqlKind.MIN || aggCallKind == SqlKind.MAX || aggCall.getAggregation() instanceof SqlAuxiliaryGroupAggFunction;
        if (!isAllowAggCall || aggCall.filterArg >= 0 || aggCall.getArgList().size() != 1) {
            return false;
        }
    }
    final RelMetadataQuery mq = call.getMetadataQuery();
    return SqlFunctions.isTrue(mq.areColumnsUnique(input, aggregate.getGroupSet()));
}
Also used : AggregateCall(org.apache.calcite.rel.core.AggregateCall) RelMetadataQuery(org.apache.calcite.rel.metadata.RelMetadataQuery) RelNode(org.apache.calcite.rel.RelNode) SqlAuxiliaryGroupAggFunction(org.apache.flink.table.planner.functions.sql.internal.SqlAuxiliaryGroupAggFunction) Aggregate(org.apache.calcite.rel.core.Aggregate) LogicalAggregate(org.apache.calcite.rel.logical.LogicalAggregate) SqlKind(org.apache.calcite.sql.SqlKind)

Example 49 with AggregateCall

use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.AggregateCall in project flink by apache.

the class FlinkAggregateRemoveRule method onMatch.

// ~ Methods ----------------------------------------------------------------
public void onMatch(RelOptRuleCall call) {
    final Aggregate aggregate = call.rel(0);
    final RelNode input = call.rel(1);
    // Distinct is "GROUP BY c1, c2" (where c1, c2 are a set of columns on
    // which the input is unique, i.e. contain a key) and has no aggregate
    // functions or the functions we enumerated. It can be removed.
    final RelNode newInput = convert(input, aggregate.getTraitSet().simplify());
    // If aggregate was projecting a subset of columns, add a project for the
    // same effect.
    final RelBuilder relBuilder = call.builder();
    relBuilder.push(newInput);
    List<Integer> projectIndices = new ArrayList<>(aggregate.getGroupSet().asList());
    for (AggregateCall aggCall : aggregate.getAggCallList()) {
        projectIndices.addAll(aggCall.getArgList());
    }
    relBuilder.project(relBuilder.fields(projectIndices));
    // Create a project if some of the columns have become
    // NOT NULL due to aggregate functions are removed
    relBuilder.convert(aggregate.getRowType(), true);
    call.transformTo(relBuilder.build());
}
Also used : AggregateCall(org.apache.calcite.rel.core.AggregateCall) RelBuilder(org.apache.calcite.tools.RelBuilder) RelNode(org.apache.calcite.rel.RelNode) ArrayList(java.util.ArrayList) Aggregate(org.apache.calcite.rel.core.Aggregate) LogicalAggregate(org.apache.calcite.rel.logical.LogicalAggregate)

Example 50 with AggregateCall

use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.AggregateCall in project flink by apache.

the class BatchExecOverAggregate method createOverWindowFrames.

private List<OverWindowFrame> createOverWindowFrames(FlinkRelBuilder relBuilder, ExecNodeConfig config, RowType inputType, SortSpec sortSpec, RowType inputTypeWithConstants) {
    final List<OverWindowFrame> windowFrames = new ArrayList<>();
    for (GroupSpec group : overSpec.getGroups()) {
        OverWindowMode mode = inferGroupMode(group);
        if (mode == OverWindowMode.OFFSET) {
            for (AggregateCall aggCall : group.getAggCalls()) {
                AggregateInfoList aggInfoList = AggregateUtil.transformToBatchAggregateInfoList(inputTypeWithConstants, JavaScalaConversionUtil.toScala(Collections.singletonList(aggCall)), new boolean[] { true }, /* needRetraction = true, See LeadLagAggFunction */
                sortSpec.getFieldIndices());
                AggsHandlerCodeGenerator generator = new AggsHandlerCodeGenerator(new CodeGeneratorContext(config.getTableConfig()), relBuilder, JavaScalaConversionUtil.toScala(inputType.getChildren()), // copyInputField
                false);
                // over agg code gen must pass the constants
                GeneratedAggsHandleFunction genAggsHandler = generator.needAccumulate().needRetract().withConstants(JavaScalaConversionUtil.toScala(getConstants())).generateAggsHandler("BoundedOverAggregateHelper", aggInfoList);
                // LEAD is behind the currentRow, so we need plus offset.
                // LAG is in front of the currentRow, so we need minus offset.
                long flag = aggCall.getAggregation().kind == SqlKind.LEAD ? 1L : -1L;
                final Long offset;
                final OffsetOverFrame.CalcOffsetFunc calcOffsetFunc;
                // The second arg mean the offset arg index for leag/lag function, default is 1.
                if (aggCall.getArgList().size() >= 2) {
                    int constantIndex = aggCall.getArgList().get(1) - overSpec.getOriginalInputFields();
                    if (constantIndex < 0) {
                        offset = null;
                        int rowIndex = aggCall.getArgList().get(1);
                        switch(inputType.getTypeAt(rowIndex).getTypeRoot()) {
                            case BIGINT:
                                calcOffsetFunc = row -> row.getLong(rowIndex) * flag;
                                break;
                            case INTEGER:
                                calcOffsetFunc = row -> (long) row.getInt(rowIndex) * flag;
                                break;
                            case SMALLINT:
                                calcOffsetFunc = row -> (long) row.getShort(rowIndex) * flag;
                                break;
                            default:
                                throw new RuntimeException("The column type must be in long/int/short.");
                        }
                    } else {
                        long constantOffset = getConstants().get(constantIndex).getValueAs(Long.class);
                        offset = constantOffset * flag;
                        calcOffsetFunc = null;
                    }
                } else {
                    offset = flag;
                    calcOffsetFunc = null;
                }
                windowFrames.add(new OffsetOverFrame(genAggsHandler, offset, calcOffsetFunc));
            }
        } else {
            AggregateInfoList aggInfoList = AggregateUtil.transformToBatchAggregateInfoList(// inputSchema.relDataType
            inputTypeWithConstants, JavaScalaConversionUtil.toScala(group.getAggCalls()), // aggCallNeedRetractions
            null, sortSpec.getFieldIndices());
            AggsHandlerCodeGenerator generator = new AggsHandlerCodeGenerator(new CodeGeneratorContext(config.getTableConfig()), relBuilder, JavaScalaConversionUtil.toScala(inputType.getChildren()), // copyInputField
            false);
            // over agg code gen must pass the constants
            GeneratedAggsHandleFunction genAggsHandler = generator.needAccumulate().withConstants(JavaScalaConversionUtil.toScala(getConstants())).generateAggsHandler("BoundedOverAggregateHelper", aggInfoList);
            RowType valueType = generator.valueType();
            final OverWindowFrame frame;
            switch(mode) {
                case RANGE:
                    if (isUnboundedWindow(group)) {
                        frame = new UnboundedOverWindowFrame(genAggsHandler, valueType);
                    } else if (isUnboundedPrecedingWindow(group)) {
                        GeneratedRecordComparator genBoundComparator = createBoundComparator(relBuilder, config, sortSpec, group.getUpperBound(), false, inputType);
                        frame = new RangeUnboundedPrecedingOverFrame(genAggsHandler, genBoundComparator);
                    } else if (isUnboundedFollowingWindow(group)) {
                        GeneratedRecordComparator genBoundComparator = createBoundComparator(relBuilder, config, sortSpec, group.getLowerBound(), true, inputType);
                        frame = new RangeUnboundedFollowingOverFrame(valueType, genAggsHandler, genBoundComparator);
                    } else if (isSlidingWindow(group)) {
                        GeneratedRecordComparator genLeftBoundComparator = createBoundComparator(relBuilder, config, sortSpec, group.getLowerBound(), true, inputType);
                        GeneratedRecordComparator genRightBoundComparator = createBoundComparator(relBuilder, config, sortSpec, group.getUpperBound(), false, inputType);
                        frame = new RangeSlidingOverFrame(inputType, valueType, genAggsHandler, genLeftBoundComparator, genRightBoundComparator);
                    } else {
                        throw new TableException("This should not happen.");
                    }
                    break;
                case ROW:
                    if (isUnboundedWindow(group)) {
                        frame = new UnboundedOverWindowFrame(genAggsHandler, valueType);
                    } else if (isUnboundedPrecedingWindow(group)) {
                        frame = new RowUnboundedPrecedingOverFrame(genAggsHandler, OverAggregateUtil.getLongBoundary(overSpec, group.getUpperBound()));
                    } else if (isUnboundedFollowingWindow(group)) {
                        frame = new RowUnboundedFollowingOverFrame(valueType, genAggsHandler, OverAggregateUtil.getLongBoundary(overSpec, group.getLowerBound()));
                    } else if (isSlidingWindow(group)) {
                        frame = new RowSlidingOverFrame(inputType, valueType, genAggsHandler, OverAggregateUtil.getLongBoundary(overSpec, group.getLowerBound()), OverAggregateUtil.getLongBoundary(overSpec, group.getUpperBound()));
                    } else {
                        throw new TableException("This should not happen.");
                    }
                    break;
                case INSENSITIVE:
                    frame = new InsensitiveOverFrame(genAggsHandler);
                    break;
                default:
                    throw new TableException("This should not happen.");
            }
            windowFrames.add(frame);
        }
    }
    return windowFrames;
}
Also used : AggregateInfoList(org.apache.flink.table.planner.plan.utils.AggregateInfoList) ArrayList(java.util.ArrayList) AggsHandlerCodeGenerator(org.apache.flink.table.planner.codegen.agg.AggsHandlerCodeGenerator) RowType(org.apache.flink.table.types.logical.RowType) RowUnboundedFollowingOverFrame(org.apache.flink.table.runtime.operators.over.frame.RowUnboundedFollowingOverFrame) RangeUnboundedPrecedingOverFrame(org.apache.flink.table.runtime.operators.over.frame.RangeUnboundedPrecedingOverFrame) RowUnboundedPrecedingOverFrame(org.apache.flink.table.runtime.operators.over.frame.RowUnboundedPrecedingOverFrame) RowSlidingOverFrame(org.apache.flink.table.runtime.operators.over.frame.RowSlidingOverFrame) InsensitiveOverFrame(org.apache.flink.table.runtime.operators.over.frame.InsensitiveOverFrame) TableException(org.apache.flink.table.api.TableException) CodeGeneratorContext(org.apache.flink.table.planner.codegen.CodeGeneratorContext) OffsetOverFrame(org.apache.flink.table.runtime.operators.over.frame.OffsetOverFrame) GeneratedAggsHandleFunction(org.apache.flink.table.runtime.generated.GeneratedAggsHandleFunction) RangeSlidingOverFrame(org.apache.flink.table.runtime.operators.over.frame.RangeSlidingOverFrame) UnboundedOverWindowFrame(org.apache.flink.table.runtime.operators.over.frame.UnboundedOverWindowFrame) UnboundedOverWindowFrame(org.apache.flink.table.runtime.operators.over.frame.UnboundedOverWindowFrame) OverWindowFrame(org.apache.flink.table.runtime.operators.over.frame.OverWindowFrame) AggregateCall(org.apache.calcite.rel.core.AggregateCall) GroupSpec(org.apache.flink.table.planner.plan.nodes.exec.spec.OverSpec.GroupSpec) RangeUnboundedFollowingOverFrame(org.apache.flink.table.runtime.operators.over.frame.RangeUnboundedFollowingOverFrame) GeneratedRecordComparator(org.apache.flink.table.runtime.generated.GeneratedRecordComparator)

Aggregations

AggregateCall (org.apache.calcite.rel.core.AggregateCall)158 ArrayList (java.util.ArrayList)82 RexNode (org.apache.calcite.rex.RexNode)78 ImmutableBitSet (org.apache.calcite.util.ImmutableBitSet)57 RelNode (org.apache.calcite.rel.RelNode)54 RexBuilder (org.apache.calcite.rex.RexBuilder)52 RelDataType (org.apache.calcite.rel.type.RelDataType)42 Aggregate (org.apache.calcite.rel.core.Aggregate)37 RelDataTypeField (org.apache.calcite.rel.type.RelDataTypeField)36 RexInputRef (org.apache.calcite.rex.RexInputRef)33 RelBuilder (org.apache.calcite.tools.RelBuilder)29 HashMap (java.util.HashMap)28 SqlAggFunction (org.apache.calcite.sql.SqlAggFunction)28 List (java.util.List)27 RexLiteral (org.apache.calcite.rex.RexLiteral)23 Pair (org.apache.calcite.util.Pair)20 ImmutableList (com.google.common.collect.ImmutableList)19 Project (org.apache.calcite.rel.core.Project)17 RelDataTypeFactory (org.apache.calcite.rel.type.RelDataTypeFactory)17 LogicalAggregate (org.apache.calcite.rel.logical.LogicalAggregate)16