Search in sources :

Example 1 with GroupSpec

use of org.apache.flink.table.planner.plan.nodes.exec.spec.OverSpec.GroupSpec in project flink by apache.

the class BatchExecOverAggregate method createOverWindowFrames.

private List<OverWindowFrame> createOverWindowFrames(FlinkRelBuilder relBuilder, ExecNodeConfig config, RowType inputType, SortSpec sortSpec, RowType inputTypeWithConstants) {
    final List<OverWindowFrame> windowFrames = new ArrayList<>();
    for (GroupSpec group : overSpec.getGroups()) {
        OverWindowMode mode = inferGroupMode(group);
        if (mode == OverWindowMode.OFFSET) {
            for (AggregateCall aggCall : group.getAggCalls()) {
                AggregateInfoList aggInfoList = AggregateUtil.transformToBatchAggregateInfoList(inputTypeWithConstants, JavaScalaConversionUtil.toScala(Collections.singletonList(aggCall)), new boolean[] { true }, /* needRetraction = true, See LeadLagAggFunction */
                sortSpec.getFieldIndices());
                AggsHandlerCodeGenerator generator = new AggsHandlerCodeGenerator(new CodeGeneratorContext(config.getTableConfig()), relBuilder, JavaScalaConversionUtil.toScala(inputType.getChildren()), // copyInputField
                false);
                // over agg code gen must pass the constants
                GeneratedAggsHandleFunction genAggsHandler = generator.needAccumulate().needRetract().withConstants(JavaScalaConversionUtil.toScala(getConstants())).generateAggsHandler("BoundedOverAggregateHelper", aggInfoList);
                // LEAD is behind the currentRow, so we need plus offset.
                // LAG is in front of the currentRow, so we need minus offset.
                long flag = aggCall.getAggregation().kind == SqlKind.LEAD ? 1L : -1L;
                final Long offset;
                final OffsetOverFrame.CalcOffsetFunc calcOffsetFunc;
                // The second arg mean the offset arg index for leag/lag function, default is 1.
                if (aggCall.getArgList().size() >= 2) {
                    int constantIndex = aggCall.getArgList().get(1) - overSpec.getOriginalInputFields();
                    if (constantIndex < 0) {
                        offset = null;
                        int rowIndex = aggCall.getArgList().get(1);
                        switch(inputType.getTypeAt(rowIndex).getTypeRoot()) {
                            case BIGINT:
                                calcOffsetFunc = row -> row.getLong(rowIndex) * flag;
                                break;
                            case INTEGER:
                                calcOffsetFunc = row -> (long) row.getInt(rowIndex) * flag;
                                break;
                            case SMALLINT:
                                calcOffsetFunc = row -> (long) row.getShort(rowIndex) * flag;
                                break;
                            default:
                                throw new RuntimeException("The column type must be in long/int/short.");
                        }
                    } else {
                        long constantOffset = getConstants().get(constantIndex).getValueAs(Long.class);
                        offset = constantOffset * flag;
                        calcOffsetFunc = null;
                    }
                } else {
                    offset = flag;
                    calcOffsetFunc = null;
                }
                windowFrames.add(new OffsetOverFrame(genAggsHandler, offset, calcOffsetFunc));
            }
        } else {
            AggregateInfoList aggInfoList = AggregateUtil.transformToBatchAggregateInfoList(// inputSchema.relDataType
            inputTypeWithConstants, JavaScalaConversionUtil.toScala(group.getAggCalls()), // aggCallNeedRetractions
            null, sortSpec.getFieldIndices());
            AggsHandlerCodeGenerator generator = new AggsHandlerCodeGenerator(new CodeGeneratorContext(config.getTableConfig()), relBuilder, JavaScalaConversionUtil.toScala(inputType.getChildren()), // copyInputField
            false);
            // over agg code gen must pass the constants
            GeneratedAggsHandleFunction genAggsHandler = generator.needAccumulate().withConstants(JavaScalaConversionUtil.toScala(getConstants())).generateAggsHandler("BoundedOverAggregateHelper", aggInfoList);
            RowType valueType = generator.valueType();
            final OverWindowFrame frame;
            switch(mode) {
                case RANGE:
                    if (isUnboundedWindow(group)) {
                        frame = new UnboundedOverWindowFrame(genAggsHandler, valueType);
                    } else if (isUnboundedPrecedingWindow(group)) {
                        GeneratedRecordComparator genBoundComparator = createBoundComparator(relBuilder, config, sortSpec, group.getUpperBound(), false, inputType);
                        frame = new RangeUnboundedPrecedingOverFrame(genAggsHandler, genBoundComparator);
                    } else if (isUnboundedFollowingWindow(group)) {
                        GeneratedRecordComparator genBoundComparator = createBoundComparator(relBuilder, config, sortSpec, group.getLowerBound(), true, inputType);
                        frame = new RangeUnboundedFollowingOverFrame(valueType, genAggsHandler, genBoundComparator);
                    } else if (isSlidingWindow(group)) {
                        GeneratedRecordComparator genLeftBoundComparator = createBoundComparator(relBuilder, config, sortSpec, group.getLowerBound(), true, inputType);
                        GeneratedRecordComparator genRightBoundComparator = createBoundComparator(relBuilder, config, sortSpec, group.getUpperBound(), false, inputType);
                        frame = new RangeSlidingOverFrame(inputType, valueType, genAggsHandler, genLeftBoundComparator, genRightBoundComparator);
                    } else {
                        throw new TableException("This should not happen.");
                    }
                    break;
                case ROW:
                    if (isUnboundedWindow(group)) {
                        frame = new UnboundedOverWindowFrame(genAggsHandler, valueType);
                    } else if (isUnboundedPrecedingWindow(group)) {
                        frame = new RowUnboundedPrecedingOverFrame(genAggsHandler, OverAggregateUtil.getLongBoundary(overSpec, group.getUpperBound()));
                    } else if (isUnboundedFollowingWindow(group)) {
                        frame = new RowUnboundedFollowingOverFrame(valueType, genAggsHandler, OverAggregateUtil.getLongBoundary(overSpec, group.getLowerBound()));
                    } else if (isSlidingWindow(group)) {
                        frame = new RowSlidingOverFrame(inputType, valueType, genAggsHandler, OverAggregateUtil.getLongBoundary(overSpec, group.getLowerBound()), OverAggregateUtil.getLongBoundary(overSpec, group.getUpperBound()));
                    } else {
                        throw new TableException("This should not happen.");
                    }
                    break;
                case INSENSITIVE:
                    frame = new InsensitiveOverFrame(genAggsHandler);
                    break;
                default:
                    throw new TableException("This should not happen.");
            }
            windowFrames.add(frame);
        }
    }
    return windowFrames;
}
Also used : AggregateInfoList(org.apache.flink.table.planner.plan.utils.AggregateInfoList) ArrayList(java.util.ArrayList) AggsHandlerCodeGenerator(org.apache.flink.table.planner.codegen.agg.AggsHandlerCodeGenerator) RowType(org.apache.flink.table.types.logical.RowType) RowUnboundedFollowingOverFrame(org.apache.flink.table.runtime.operators.over.frame.RowUnboundedFollowingOverFrame) RangeUnboundedPrecedingOverFrame(org.apache.flink.table.runtime.operators.over.frame.RangeUnboundedPrecedingOverFrame) RowUnboundedPrecedingOverFrame(org.apache.flink.table.runtime.operators.over.frame.RowUnboundedPrecedingOverFrame) RowSlidingOverFrame(org.apache.flink.table.runtime.operators.over.frame.RowSlidingOverFrame) InsensitiveOverFrame(org.apache.flink.table.runtime.operators.over.frame.InsensitiveOverFrame) TableException(org.apache.flink.table.api.TableException) CodeGeneratorContext(org.apache.flink.table.planner.codegen.CodeGeneratorContext) OffsetOverFrame(org.apache.flink.table.runtime.operators.over.frame.OffsetOverFrame) GeneratedAggsHandleFunction(org.apache.flink.table.runtime.generated.GeneratedAggsHandleFunction) RangeSlidingOverFrame(org.apache.flink.table.runtime.operators.over.frame.RangeSlidingOverFrame) UnboundedOverWindowFrame(org.apache.flink.table.runtime.operators.over.frame.UnboundedOverWindowFrame) UnboundedOverWindowFrame(org.apache.flink.table.runtime.operators.over.frame.UnboundedOverWindowFrame) OverWindowFrame(org.apache.flink.table.runtime.operators.over.frame.OverWindowFrame) AggregateCall(org.apache.calcite.rel.core.AggregateCall) GroupSpec(org.apache.flink.table.planner.plan.nodes.exec.spec.OverSpec.GroupSpec) RangeUnboundedFollowingOverFrame(org.apache.flink.table.runtime.operators.over.frame.RangeUnboundedFollowingOverFrame) GeneratedRecordComparator(org.apache.flink.table.runtime.generated.GeneratedRecordComparator)

Example 2 with GroupSpec

use of org.apache.flink.table.planner.plan.nodes.exec.spec.OverSpec.GroupSpec in project flink by apache.

the class BatchExecOverAggregate method translateToPlanInternal.

@SuppressWarnings("unchecked")
@Override
protected Transformation<RowData> translateToPlanInternal(PlannerBase planner, ExecNodeConfig config) {
    final ExecEdge inputEdge = getInputEdges().get(0);
    final Transformation<RowData> inputTransform = (Transformation<RowData>) inputEdge.translateToPlan(planner);
    final RowType inputType = (RowType) inputEdge.getOutputType();
    // The generated sort is used for generating the comparator among partitions.
    // So here not care the ASC or DESC for the grouping fields.
    // TODO just replace comparator to equaliser
    final int[] partitionFields = overSpec.getPartition().getFieldIndices();
    final GeneratedRecordComparator genComparator = ComparatorCodeGenerator.gen(config.getTableConfig(), "SortComparator", inputType, SortUtil.getAscendingSortSpec(partitionFields));
    // use aggInputType which considers constants as input instead of inputType
    final RowType inputTypeWithConstants = getInputTypeWithConstants();
    // Over operator could support different order-by keys with collation satisfied.
    // Currently, this operator requires all order keys (combined with partition keys) are
    // the same, but order-by keys may be different. Consider the following sql:
    // select *, sum(b) over partition by a order by a, count(c) over partition by a from T
    // So we can use any one from the groups. To keep the behavior with the rule, we use the
    // last one.
    final SortSpec sortSpec = overSpec.getGroups().get(overSpec.getGroups().size() - 1).getSort();
    final TableStreamOperator<RowData> operator;
    final long managedMemory;
    if (!needBufferData()) {
        // operator needn't cache data
        final int numOfGroup = overSpec.getGroups().size();
        final GeneratedAggsHandleFunction[] aggsHandlers = new GeneratedAggsHandleFunction[numOfGroup];
        final boolean[] resetAccumulators = new boolean[numOfGroup];
        for (int i = 0; i < numOfGroup; ++i) {
            GroupSpec group = overSpec.getGroups().get(i);
            AggregateInfoList aggInfoList = AggregateUtil.transformToBatchAggregateInfoList(inputTypeWithConstants, JavaScalaConversionUtil.toScala(group.getAggCalls()), // aggCallNeedRetractions
            null, sortSpec.getFieldIndices());
            AggsHandlerCodeGenerator generator = new AggsHandlerCodeGenerator(new CodeGeneratorContext(config.getTableConfig()), planner.getRelBuilder(), JavaScalaConversionUtil.toScala(inputType.getChildren()), // copyInputField
            false);
            // over agg code gen must pass the constants
            aggsHandlers[i] = generator.needAccumulate().withConstants(JavaScalaConversionUtil.toScala(getConstants())).generateAggsHandler("BoundedOverAggregateHelper", aggInfoList);
            OverWindowMode mode = inferGroupMode(group);
            resetAccumulators[i] = mode == OverWindowMode.ROW && group.getLowerBound().isCurrentRow() && group.getUpperBound().isCurrentRow();
        }
        operator = new NonBufferOverWindowOperator(aggsHandlers, genComparator, resetAccumulators);
        managedMemory = 0L;
    } else {
        List<OverWindowFrame> windowFrames = createOverWindowFrames(planner.getRelBuilder(), config, inputType, sortSpec, inputTypeWithConstants);
        operator = new BufferDataOverWindowOperator(windowFrames.toArray(new OverWindowFrame[0]), genComparator, inputType.getChildren().stream().allMatch(BinaryRowData::isInFixedLengthPart));
        managedMemory = config.get(ExecutionConfigOptions.TABLE_EXEC_RESOURCE_EXTERNAL_BUFFER_MEMORY).getBytes();
    }
    return ExecNodeUtil.createOneInputTransformation(inputTransform, createTransformationName(config), createTransformationDescription(config), SimpleOperatorFactory.of(operator), InternalTypeInfo.of(getOutputType()), inputTransform.getParallelism(), managedMemory);
}
Also used : Transformation(org.apache.flink.api.dag.Transformation) AggregateInfoList(org.apache.flink.table.planner.plan.utils.AggregateInfoList) ExecEdge(org.apache.flink.table.planner.plan.nodes.exec.ExecEdge) CodeGeneratorContext(org.apache.flink.table.planner.codegen.CodeGeneratorContext) RowType(org.apache.flink.table.types.logical.RowType) AggsHandlerCodeGenerator(org.apache.flink.table.planner.codegen.agg.AggsHandlerCodeGenerator) GeneratedAggsHandleFunction(org.apache.flink.table.runtime.generated.GeneratedAggsHandleFunction) BufferDataOverWindowOperator(org.apache.flink.table.runtime.operators.over.BufferDataOverWindowOperator) NonBufferOverWindowOperator(org.apache.flink.table.runtime.operators.over.NonBufferOverWindowOperator) UnboundedOverWindowFrame(org.apache.flink.table.runtime.operators.over.frame.UnboundedOverWindowFrame) OverWindowFrame(org.apache.flink.table.runtime.operators.over.frame.OverWindowFrame) RowData(org.apache.flink.table.data.RowData) BinaryRowData(org.apache.flink.table.data.binary.BinaryRowData) GroupSpec(org.apache.flink.table.planner.plan.nodes.exec.spec.OverSpec.GroupSpec) BinaryRowData(org.apache.flink.table.data.binary.BinaryRowData) GeneratedRecordComparator(org.apache.flink.table.runtime.generated.GeneratedRecordComparator) SortSpec(org.apache.flink.table.planner.plan.nodes.exec.spec.SortSpec)

Aggregations

CodeGeneratorContext (org.apache.flink.table.planner.codegen.CodeGeneratorContext)2 AggsHandlerCodeGenerator (org.apache.flink.table.planner.codegen.agg.AggsHandlerCodeGenerator)2 GroupSpec (org.apache.flink.table.planner.plan.nodes.exec.spec.OverSpec.GroupSpec)2 AggregateInfoList (org.apache.flink.table.planner.plan.utils.AggregateInfoList)2 GeneratedAggsHandleFunction (org.apache.flink.table.runtime.generated.GeneratedAggsHandleFunction)2 GeneratedRecordComparator (org.apache.flink.table.runtime.generated.GeneratedRecordComparator)2 OverWindowFrame (org.apache.flink.table.runtime.operators.over.frame.OverWindowFrame)2 UnboundedOverWindowFrame (org.apache.flink.table.runtime.operators.over.frame.UnboundedOverWindowFrame)2 RowType (org.apache.flink.table.types.logical.RowType)2 ArrayList (java.util.ArrayList)1 AggregateCall (org.apache.calcite.rel.core.AggregateCall)1 Transformation (org.apache.flink.api.dag.Transformation)1 TableException (org.apache.flink.table.api.TableException)1 RowData (org.apache.flink.table.data.RowData)1 BinaryRowData (org.apache.flink.table.data.binary.BinaryRowData)1 ExecEdge (org.apache.flink.table.planner.plan.nodes.exec.ExecEdge)1 SortSpec (org.apache.flink.table.planner.plan.nodes.exec.spec.SortSpec)1 BufferDataOverWindowOperator (org.apache.flink.table.runtime.operators.over.BufferDataOverWindowOperator)1 NonBufferOverWindowOperator (org.apache.flink.table.runtime.operators.over.NonBufferOverWindowOperator)1 InsensitiveOverFrame (org.apache.flink.table.runtime.operators.over.frame.InsensitiveOverFrame)1