use of io.trino.sql.planner.OrderingScheme in project trino by trinodb.
the class PushDownDereferencesThroughLimit method apply.
@Override
public Result apply(ProjectNode projectNode, Captures captures, Context context) {
LimitNode limitNode = captures.get(CHILD);
// Extract dereferences from project node assignments for pushdown
Set<SubscriptExpression> dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false, context.getSession(), typeAnalyzer, context.getSymbolAllocator().getTypes());
// Exclude dereferences on symbols being used in tiesResolvingScheme and requiresPreSortedInputs
Set<Symbol> excludedSymbols = ImmutableSet.<Symbol>builder().addAll(limitNode.getTiesResolvingScheme().map(OrderingScheme::getOrderBy).orElse(ImmutableList.of())).addAll(limitNode.getPreSortedInputs()).build();
dereferences = dereferences.stream().filter(expression -> !excludedSymbols.contains(getBase(expression))).collect(toImmutableSet());
if (dereferences.isEmpty()) {
return Result.empty();
}
// Create new symbols for dereference expressions
Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSession(), context.getSymbolAllocator(), typeAnalyzer);
// Rewrite project node assignments using new symbols for dereference expressions
Map<Expression, SymbolReference> mappings = HashBiMap.create(dereferenceAssignments.getMap()).inverse().entrySet().stream().collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference()));
Assignments newAssignments = projectNode.getAssignments().rewrite(expression -> replaceExpression(expression, mappings));
return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), limitNode.replaceChildren(ImmutableList.of(new ProjectNode(context.getIdAllocator().getNextId(), limitNode.getSource(), Assignments.builder().putIdentities(limitNode.getSource().getOutputSymbols()).putAll(dereferenceAssignments).build()))), newAssignments));
}
use of io.trino.sql.planner.OrderingScheme in project trino by trinodb.
the class PushDownDereferencesThroughTopNRanking method apply.
@Override
public Result apply(ProjectNode projectNode, Captures captures, Context context) {
TopNRankingNode topNRankingNode = captures.get(CHILD);
// Extract dereferences from project node assignments for pushdown
Set<SubscriptExpression> dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false, context.getSession(), typeAnalyzer, context.getSymbolAllocator().getTypes());
// Exclude dereferences on symbols being used in partitionBy and orderBy
WindowNode.Specification specification = topNRankingNode.getSpecification();
dereferences = dereferences.stream().filter(expression -> {
Symbol symbol = getBase(expression);
return !specification.getPartitionBy().contains(symbol) && !specification.getOrderingScheme().map(OrderingScheme::getOrderBy).orElse(ImmutableList.of()).contains(symbol);
}).collect(toImmutableSet());
if (dereferences.isEmpty()) {
return Result.empty();
}
// Create new symbols for dereference expressions
Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSession(), context.getSymbolAllocator(), typeAnalyzer);
// Rewrite project node assignments using new symbols for dereference expressions
Map<Expression, SymbolReference> mappings = HashBiMap.create(dereferenceAssignments.getMap()).inverse().entrySet().stream().collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference()));
Assignments newAssignments = projectNode.getAssignments().rewrite(expression -> replaceExpression(expression, mappings));
return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), topNRankingNode.replaceChildren(ImmutableList.of(new ProjectNode(context.getIdAllocator().getNextId(), topNRankingNode.getSource(), Assignments.builder().putIdentities(topNRankingNode.getSource().getOutputSymbols()).putAll(dereferenceAssignments).build()))), newAssignments));
}
use of io.trino.sql.planner.OrderingScheme in project trino by trinodb.
the class TestMergePatternRecognitionNodes method testMergeWithoutProject.
@Test
public void testMergeWithoutProject() {
ResolvedFunction lag = createTestMetadataManager().resolveFunction(tester().getSession(), QualifiedName.of("lag"), fromTypes(BIGINT));
tester().assertThat(new MergePatternRecognitionNodesWithoutProject()).on(p -> p.patternRecognition(parentBuilder -> parentBuilder.partitionBy(ImmutableList.of(p.symbol("c"))).orderBy(new OrderingScheme(ImmutableList.of(p.symbol("d")), ImmutableMap.of(p.symbol("d"), ASC_NULLS_LAST))).addMeasure(p.symbol("parent_measure"), "LAST(X.b)", BIGINT).addWindowFunction(p.symbol("parent_function"), new WindowNode.Function(lag, ImmutableList.of(p.symbol("a").toSymbolReference()), DEFAULT_FRAME, false)).rowsPerMatch(WINDOW).frame(new WindowNode.Frame(ROWS, CURRENT_ROW, Optional.empty(), Optional.empty(), UNBOUNDED_FOLLOWING, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty())).skipTo(LAST, new IrLabel("X")).seek().addSubset(new IrLabel("U"), ImmutableSet.of(new IrLabel("X"))).pattern(new IrLabel("X")).addVariableDefinition(new IrLabel("X"), "true").source(p.patternRecognition(childBuilder -> childBuilder.partitionBy(ImmutableList.of(p.symbol("c"))).orderBy(new OrderingScheme(ImmutableList.of(p.symbol("d")), ImmutableMap.of(p.symbol("d"), ASC_NULLS_LAST))).addMeasure(p.symbol("child_measure"), "FIRST(X.a)", BIGINT).addWindowFunction(p.symbol("child_function"), new WindowNode.Function(lag, ImmutableList.of(p.symbol("b").toSymbolReference()), DEFAULT_FRAME, false)).rowsPerMatch(WINDOW).frame(new WindowNode.Frame(ROWS, CURRENT_ROW, Optional.empty(), Optional.empty(), UNBOUNDED_FOLLOWING, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty())).skipTo(LAST, new IrLabel("X")).seek().addSubset(new IrLabel("U"), ImmutableSet.of(new IrLabel("X"))).pattern(new IrLabel("X")).addVariableDefinition(new IrLabel("X"), "true").source(p.values(p.symbol("a"), p.symbol("b"), p.symbol("c"), p.symbol("d"))))))).matches(patternRecognition(builder -> builder.specification(specification(ImmutableList.of("c"), ImmutableList.of("d"), ImmutableMap.of("d", ASC_NULLS_LAST))).addMeasure("parent_measure", "LAST(X.b)", BIGINT).addMeasure("child_measure", "FIRST(X.a)", BIGINT).addFunction("parent_function", functionCall("lag", ImmutableList.of("a"))).addFunction("child_function", functionCall("lag", ImmutableList.of("b"))).rowsPerMatch(WINDOW).frame(windowFrame(ROWS, CURRENT_ROW, Optional.empty(), UNBOUNDED_FOLLOWING, Optional.empty(), Optional.empty())).skipTo(LAST, new IrLabel("X")).seek().addSubset(new IrLabel("U"), ImmutableSet.of(new IrLabel("X"))).pattern(new IrLabel("X")).addVariableDefinition(new IrLabel("X"), "true"), values("a", "b", "c", "d")));
}
use of io.trino.sql.planner.OrderingScheme in project trino by trinodb.
the class TestWindowNode method testSerializationRoundtrip.
@Test
public void testSerializationRoundtrip() throws Exception {
Symbol windowSymbol = symbolAllocator.newSymbol("sum", BIGINT);
ResolvedFunction resolvedFunction = functionResolution.resolveFunction(QualifiedName.of("sum"), fromTypes(BIGINT));
WindowNode.Frame frame = new WindowNode.Frame(WindowFrame.Type.RANGE, FrameBound.Type.UNBOUNDED_PRECEDING, Optional.empty(), Optional.empty(), FrameBound.Type.UNBOUNDED_FOLLOWING, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());
PlanNodeId id = newId();
WindowNode.Specification specification = new WindowNode.Specification(ImmutableList.of(columnA), Optional.of(new OrderingScheme(ImmutableList.of(columnB), ImmutableMap.of(columnB, SortOrder.ASC_NULLS_FIRST))));
Map<Symbol, WindowNode.Function> functions = ImmutableMap.of(windowSymbol, new WindowNode.Function(resolvedFunction, ImmutableList.of(columnC.toSymbolReference()), frame, false));
Optional<Symbol> hashSymbol = Optional.of(columnB);
Set<Symbol> prePartitionedInputs = ImmutableSet.of(columnA);
WindowNode windowNode = new WindowNode(id, sourceNode, specification, functions, hashSymbol, prePartitionedInputs, 0);
String json = objectMapper.writeValueAsString(windowNode);
WindowNode actualNode = objectMapper.readValue(json, WindowNode.class);
assertEquals(actualNode.getId(), windowNode.getId());
assertEquals(actualNode.getSpecification(), windowNode.getSpecification());
assertEquals(actualNode.getWindowFunctions(), windowNode.getWindowFunctions());
assertEquals(actualNode.getFrames(), windowNode.getFrames());
assertEquals(actualNode.getHashSymbol(), windowNode.getHashSymbol());
assertEquals(actualNode.getPrePartitionedInputs(), windowNode.getPrePartitionedInputs());
assertEquals(actualNode.getPreSortedOrderPrefix(), windowNode.getPreSortedOrderPrefix());
}
use of io.trino.sql.planner.OrderingScheme in project trino by trinodb.
the class TestPushdownFilterIntoWindow method assertKeepFilter.
private void assertKeepFilter(String rankingFunctionName) {
ResolvedFunction ranking = tester().getMetadata().resolveFunction(tester().getSession(), QualifiedName.of(rankingFunctionName), fromTypes());
tester().assertThat(new PushdownFilterIntoWindow(tester().getPlannerContext())).on(p -> {
Symbol rowNumberSymbol = p.symbol("row_number_1");
Symbol a = p.symbol("a", BIGINT);
OrderingScheme orderingScheme = new OrderingScheme(ImmutableList.of(a), ImmutableMap.of(a, SortOrder.ASC_NULLS_FIRST));
return p.filter(expression("cast(3 as bigint) < row_number_1 and row_number_1 < cast(100 as bigint)"), p.window(new WindowNode.Specification(ImmutableList.of(a), Optional.of(orderingScheme)), ImmutableMap.of(rowNumberSymbol, newWindowNodeFunction(ranking, a)), p.values(p.symbol("a"))));
}).matches(filter("cast(3 as bigint) < row_number_1 and row_number_1 < cast(100 as bigint)", topNRanking(pattern -> pattern.partial(false).maxRankingPerPartition(99).specification(ImmutableList.of("a"), ImmutableList.of("a"), ImmutableMap.of("a", SortOrder.ASC_NULLS_FIRST)), values("a")).withAlias("row_number_1", new TopNRankingSymbolMatcher())));
tester().assertThat(new PushdownFilterIntoWindow(tester().getPlannerContext())).on(p -> {
Symbol rowNumberSymbol = p.symbol("row_number_1");
Symbol a = p.symbol("a", BIGINT);
OrderingScheme orderingScheme = new OrderingScheme(ImmutableList.of(a), ImmutableMap.of(a, SortOrder.ASC_NULLS_FIRST));
return p.filter(expression("row_number_1 < cast(100 as bigint) and a = 1"), p.window(new WindowNode.Specification(ImmutableList.of(a), Optional.of(orderingScheme)), ImmutableMap.of(rowNumberSymbol, newWindowNodeFunction(ranking, a)), p.values(p.symbol("a"))));
}).matches(filter("a = 1", topNRanking(pattern -> pattern.partial(false).maxRankingPerPartition(99).specification(ImmutableList.of("a"), ImmutableList.of("a"), ImmutableMap.of("a", SortOrder.ASC_NULLS_FIRST)), values("a")).withAlias("row_number_1", new TopNRankingSymbolMatcher())));
}
Aggregations