Search in sources :

Example 1 with OffsetOverFrame

use of org.apache.flink.table.runtime.operators.over.frame.OffsetOverFrame in project flink by apache.

the class BatchExecOverAggregate method createOverWindowFrames.

private List<OverWindowFrame> createOverWindowFrames(FlinkRelBuilder relBuilder, ExecNodeConfig config, RowType inputType, SortSpec sortSpec, RowType inputTypeWithConstants) {
    final List<OverWindowFrame> windowFrames = new ArrayList<>();
    for (GroupSpec group : overSpec.getGroups()) {
        OverWindowMode mode = inferGroupMode(group);
        if (mode == OverWindowMode.OFFSET) {
            for (AggregateCall aggCall : group.getAggCalls()) {
                AggregateInfoList aggInfoList = AggregateUtil.transformToBatchAggregateInfoList(inputTypeWithConstants, JavaScalaConversionUtil.toScala(Collections.singletonList(aggCall)), new boolean[] { true }, /* needRetraction = true, See LeadLagAggFunction */
                sortSpec.getFieldIndices());
                AggsHandlerCodeGenerator generator = new AggsHandlerCodeGenerator(new CodeGeneratorContext(config.getTableConfig()), relBuilder, JavaScalaConversionUtil.toScala(inputType.getChildren()), // copyInputField
                false);
                // over agg code gen must pass the constants
                GeneratedAggsHandleFunction genAggsHandler = generator.needAccumulate().needRetract().withConstants(JavaScalaConversionUtil.toScala(getConstants())).generateAggsHandler("BoundedOverAggregateHelper", aggInfoList);
                // LEAD is behind the currentRow, so we need plus offset.
                // LAG is in front of the currentRow, so we need minus offset.
                long flag = aggCall.getAggregation().kind == SqlKind.LEAD ? 1L : -1L;
                final Long offset;
                final OffsetOverFrame.CalcOffsetFunc calcOffsetFunc;
                // The second arg mean the offset arg index for leag/lag function, default is 1.
                if (aggCall.getArgList().size() >= 2) {
                    int constantIndex = aggCall.getArgList().get(1) - overSpec.getOriginalInputFields();
                    if (constantIndex < 0) {
                        offset = null;
                        int rowIndex = aggCall.getArgList().get(1);
                        switch(inputType.getTypeAt(rowIndex).getTypeRoot()) {
                            case BIGINT:
                                calcOffsetFunc = row -> row.getLong(rowIndex) * flag;
                                break;
                            case INTEGER:
                                calcOffsetFunc = row -> (long) row.getInt(rowIndex) * flag;
                                break;
                            case SMALLINT:
                                calcOffsetFunc = row -> (long) row.getShort(rowIndex) * flag;
                                break;
                            default:
                                throw new RuntimeException("The column type must be in long/int/short.");
                        }
                    } else {
                        long constantOffset = getConstants().get(constantIndex).getValueAs(Long.class);
                        offset = constantOffset * flag;
                        calcOffsetFunc = null;
                    }
                } else {
                    offset = flag;
                    calcOffsetFunc = null;
                }
                windowFrames.add(new OffsetOverFrame(genAggsHandler, offset, calcOffsetFunc));
            }
        } else {
            AggregateInfoList aggInfoList = AggregateUtil.transformToBatchAggregateInfoList(// inputSchema.relDataType
            inputTypeWithConstants, JavaScalaConversionUtil.toScala(group.getAggCalls()), // aggCallNeedRetractions
            null, sortSpec.getFieldIndices());
            AggsHandlerCodeGenerator generator = new AggsHandlerCodeGenerator(new CodeGeneratorContext(config.getTableConfig()), relBuilder, JavaScalaConversionUtil.toScala(inputType.getChildren()), // copyInputField
            false);
            // over agg code gen must pass the constants
            GeneratedAggsHandleFunction genAggsHandler = generator.needAccumulate().withConstants(JavaScalaConversionUtil.toScala(getConstants())).generateAggsHandler("BoundedOverAggregateHelper", aggInfoList);
            RowType valueType = generator.valueType();
            final OverWindowFrame frame;
            switch(mode) {
                case RANGE:
                    if (isUnboundedWindow(group)) {
                        frame = new UnboundedOverWindowFrame(genAggsHandler, valueType);
                    } else if (isUnboundedPrecedingWindow(group)) {
                        GeneratedRecordComparator genBoundComparator = createBoundComparator(relBuilder, config, sortSpec, group.getUpperBound(), false, inputType);
                        frame = new RangeUnboundedPrecedingOverFrame(genAggsHandler, genBoundComparator);
                    } else if (isUnboundedFollowingWindow(group)) {
                        GeneratedRecordComparator genBoundComparator = createBoundComparator(relBuilder, config, sortSpec, group.getLowerBound(), true, inputType);
                        frame = new RangeUnboundedFollowingOverFrame(valueType, genAggsHandler, genBoundComparator);
                    } else if (isSlidingWindow(group)) {
                        GeneratedRecordComparator genLeftBoundComparator = createBoundComparator(relBuilder, config, sortSpec, group.getLowerBound(), true, inputType);
                        GeneratedRecordComparator genRightBoundComparator = createBoundComparator(relBuilder, config, sortSpec, group.getUpperBound(), false, inputType);
                        frame = new RangeSlidingOverFrame(inputType, valueType, genAggsHandler, genLeftBoundComparator, genRightBoundComparator);
                    } else {
                        throw new TableException("This should not happen.");
                    }
                    break;
                case ROW:
                    if (isUnboundedWindow(group)) {
                        frame = new UnboundedOverWindowFrame(genAggsHandler, valueType);
                    } else if (isUnboundedPrecedingWindow(group)) {
                        frame = new RowUnboundedPrecedingOverFrame(genAggsHandler, OverAggregateUtil.getLongBoundary(overSpec, group.getUpperBound()));
                    } else if (isUnboundedFollowingWindow(group)) {
                        frame = new RowUnboundedFollowingOverFrame(valueType, genAggsHandler, OverAggregateUtil.getLongBoundary(overSpec, group.getLowerBound()));
                    } else if (isSlidingWindow(group)) {
                        frame = new RowSlidingOverFrame(inputType, valueType, genAggsHandler, OverAggregateUtil.getLongBoundary(overSpec, group.getLowerBound()), OverAggregateUtil.getLongBoundary(overSpec, group.getUpperBound()));
                    } else {
                        throw new TableException("This should not happen.");
                    }
                    break;
                case INSENSITIVE:
                    frame = new InsensitiveOverFrame(genAggsHandler);
                    break;
                default:
                    throw new TableException("This should not happen.");
            }
            windowFrames.add(frame);
        }
    }
    return windowFrames;
}
Also used : AggregateInfoList(org.apache.flink.table.planner.plan.utils.AggregateInfoList) ArrayList(java.util.ArrayList) AggsHandlerCodeGenerator(org.apache.flink.table.planner.codegen.agg.AggsHandlerCodeGenerator) RowType(org.apache.flink.table.types.logical.RowType) RowUnboundedFollowingOverFrame(org.apache.flink.table.runtime.operators.over.frame.RowUnboundedFollowingOverFrame) RangeUnboundedPrecedingOverFrame(org.apache.flink.table.runtime.operators.over.frame.RangeUnboundedPrecedingOverFrame) RowUnboundedPrecedingOverFrame(org.apache.flink.table.runtime.operators.over.frame.RowUnboundedPrecedingOverFrame) RowSlidingOverFrame(org.apache.flink.table.runtime.operators.over.frame.RowSlidingOverFrame) InsensitiveOverFrame(org.apache.flink.table.runtime.operators.over.frame.InsensitiveOverFrame) TableException(org.apache.flink.table.api.TableException) CodeGeneratorContext(org.apache.flink.table.planner.codegen.CodeGeneratorContext) OffsetOverFrame(org.apache.flink.table.runtime.operators.over.frame.OffsetOverFrame) GeneratedAggsHandleFunction(org.apache.flink.table.runtime.generated.GeneratedAggsHandleFunction) RangeSlidingOverFrame(org.apache.flink.table.runtime.operators.over.frame.RangeSlidingOverFrame) UnboundedOverWindowFrame(org.apache.flink.table.runtime.operators.over.frame.UnboundedOverWindowFrame) UnboundedOverWindowFrame(org.apache.flink.table.runtime.operators.over.frame.UnboundedOverWindowFrame) OverWindowFrame(org.apache.flink.table.runtime.operators.over.frame.OverWindowFrame) AggregateCall(org.apache.calcite.rel.core.AggregateCall) GroupSpec(org.apache.flink.table.planner.plan.nodes.exec.spec.OverSpec.GroupSpec) RangeUnboundedFollowingOverFrame(org.apache.flink.table.runtime.operators.over.frame.RangeUnboundedFollowingOverFrame) GeneratedRecordComparator(org.apache.flink.table.runtime.generated.GeneratedRecordComparator)

Aggregations

ArrayList (java.util.ArrayList)1 AggregateCall (org.apache.calcite.rel.core.AggregateCall)1 TableException (org.apache.flink.table.api.TableException)1 CodeGeneratorContext (org.apache.flink.table.planner.codegen.CodeGeneratorContext)1 AggsHandlerCodeGenerator (org.apache.flink.table.planner.codegen.agg.AggsHandlerCodeGenerator)1 GroupSpec (org.apache.flink.table.planner.plan.nodes.exec.spec.OverSpec.GroupSpec)1 AggregateInfoList (org.apache.flink.table.planner.plan.utils.AggregateInfoList)1 GeneratedAggsHandleFunction (org.apache.flink.table.runtime.generated.GeneratedAggsHandleFunction)1 GeneratedRecordComparator (org.apache.flink.table.runtime.generated.GeneratedRecordComparator)1 InsensitiveOverFrame (org.apache.flink.table.runtime.operators.over.frame.InsensitiveOverFrame)1 OffsetOverFrame (org.apache.flink.table.runtime.operators.over.frame.OffsetOverFrame)1 OverWindowFrame (org.apache.flink.table.runtime.operators.over.frame.OverWindowFrame)1 RangeSlidingOverFrame (org.apache.flink.table.runtime.operators.over.frame.RangeSlidingOverFrame)1 RangeUnboundedFollowingOverFrame (org.apache.flink.table.runtime.operators.over.frame.RangeUnboundedFollowingOverFrame)1 RangeUnboundedPrecedingOverFrame (org.apache.flink.table.runtime.operators.over.frame.RangeUnboundedPrecedingOverFrame)1 RowSlidingOverFrame (org.apache.flink.table.runtime.operators.over.frame.RowSlidingOverFrame)1 RowUnboundedFollowingOverFrame (org.apache.flink.table.runtime.operators.over.frame.RowUnboundedFollowingOverFrame)1 RowUnboundedPrecedingOverFrame (org.apache.flink.table.runtime.operators.over.frame.RowUnboundedPrecedingOverFrame)1 UnboundedOverWindowFrame (org.apache.flink.table.runtime.operators.over.frame.UnboundedOverWindowFrame)1 RowType (org.apache.flink.table.types.logical.RowType)1