use of io.trino.sql.planner.plan.TopNRankingNode in project trino by trinodb.
the class PushdownFilterIntoWindow method apply.
@Override
public Result apply(FilterNode node, Captures captures, Context context) {
Session session = context.getSession();
TypeProvider types = context.getSymbolAllocator().getTypes();
WindowNode windowNode = captures.get(childCapture);
DomainTranslator.ExtractionResult extractionResult = DomainTranslator.getExtractionResult(plannerContext, session, node.getPredicate(), types);
TupleDomain<Symbol> tupleDomain = extractionResult.getTupleDomain();
Optional<RankingType> rankingType = toTopNRankingType(windowNode);
Symbol rankingSymbol = getOnlyElement(windowNode.getWindowFunctions().keySet());
OptionalInt upperBound = extractUpperBound(tupleDomain, rankingSymbol);
if (upperBound.isEmpty()) {
return Result.empty();
}
if (upperBound.getAsInt() <= 0) {
return Result.ofPlanNode(new ValuesNode(node.getId(), node.getOutputSymbols(), ImmutableList.of()));
}
TopNRankingNode newSource = new TopNRankingNode(windowNode.getId(), windowNode.getSource(), windowNode.getSpecification(), rankingType.get(), rankingSymbol, upperBound.getAsInt(), false, Optional.empty());
if (!allRowNumberValuesInDomain(tupleDomain, rankingSymbol, upperBound.getAsInt())) {
return Result.ofPlanNode(new FilterNode(node.getId(), newSource, node.getPredicate()));
}
// Remove the row number domain because it is absorbed into the node
TupleDomain<Symbol> newTupleDomain = tupleDomain.filter((symbol, domain) -> !symbol.equals(rankingSymbol));
Expression newPredicate = ExpressionUtils.combineConjuncts(plannerContext.getMetadata(), extractionResult.getRemainingExpression(), new DomainTranslator(plannerContext).toPredicate(session, newTupleDomain));
if (newPredicate.equals(BooleanLiteral.TRUE_LITERAL)) {
return Result.ofPlanNode(newSource);
}
return Result.ofPlanNode(new FilterNode(node.getId(), newSource, newPredicate));
}
use of io.trino.sql.planner.plan.TopNRankingNode in project trino by trinodb.
the class TopNRankingMatcher method detailMatches.
@Override
public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session session, Metadata metadata, SymbolAliases symbolAliases) {
checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName());
TopNRankingNode topNRankingNode = (TopNRankingNode) node;
if (specification.isPresent()) {
WindowNode.Specification expected = specification.get().getExpectedValue(symbolAliases);
if (!expected.equals(topNRankingNode.getSpecification())) {
return NO_MATCH;
}
}
if (rankingSymbol.isPresent()) {
Symbol expected = rankingSymbol.get().toSymbol(symbolAliases);
if (!expected.equals(topNRankingNode.getRankingSymbol())) {
return NO_MATCH;
}
}
if (rankingType.isPresent()) {
if (!rankingType.get().equals(topNRankingNode.getRankingType())) {
return NO_MATCH;
}
}
if (maxRankingPerPartition.isPresent()) {
if (!maxRankingPerPartition.get().equals(topNRankingNode.getMaxRankingPerPartition())) {
return NO_MATCH;
}
}
if (partial.isPresent()) {
if (!partial.get().equals(topNRankingNode.isPartial())) {
return NO_MATCH;
}
}
if (hashSymbol.isPresent()) {
Optional<Symbol> expected = hashSymbol.get().map(alias -> alias.toSymbol(symbolAliases));
if (!expected.equals(topNRankingNode.getHashSymbol())) {
return NO_MATCH;
}
}
return match();
}
use of io.trino.sql.planner.plan.TopNRankingNode in project trino by trinodb.
the class PushPredicateThroughProjectIntoWindow method apply.
@Override
public Result apply(FilterNode filter, Captures captures, Context context) {
ProjectNode project = captures.get(PROJECT);
WindowNode window = captures.get(WINDOW);
Symbol rankingSymbol = getOnlyElement(window.getWindowFunctions().keySet());
if (!project.getAssignments().getSymbols().contains(rankingSymbol)) {
return Result.empty();
}
DomainTranslator.ExtractionResult extractionResult = DomainTranslator.getExtractionResult(plannerContext, context.getSession(), filter.getPredicate(), context.getSymbolAllocator().getTypes());
TupleDomain<Symbol> tupleDomain = extractionResult.getTupleDomain();
OptionalInt upperBound = extractUpperBound(tupleDomain, rankingSymbol);
if (upperBound.isEmpty()) {
return Result.empty();
}
if (upperBound.getAsInt() <= 0) {
return Result.ofPlanNode(new ValuesNode(filter.getId(), filter.getOutputSymbols(), ImmutableList.of()));
}
RankingType rankingType = toTopNRankingType(window).orElseThrow();
project = (ProjectNode) project.replaceChildren(ImmutableList.of(new TopNRankingNode(window.getId(), window.getSource(), window.getSpecification(), rankingType, rankingSymbol, upperBound.getAsInt(), false, Optional.empty())));
if (!allRankingValuesInDomain(tupleDomain, rankingSymbol, upperBound.getAsInt())) {
return Result.ofPlanNode(filter.replaceChildren(ImmutableList.of(project)));
}
// Remove the ranking domain because it is absorbed into the node
TupleDomain<Symbol> newTupleDomain = tupleDomain.filter((symbol, domain) -> !symbol.equals(rankingSymbol));
Expression newPredicate = ExpressionUtils.combineConjuncts(plannerContext.getMetadata(), extractionResult.getRemainingExpression(), new DomainTranslator(plannerContext).toPredicate(context.getSession(), newTupleDomain));
if (newPredicate.equals(TRUE_LITERAL)) {
return Result.ofPlanNode(project);
}
return Result.ofPlanNode(new FilterNode(filter.getId(), project, newPredicate));
}
use of io.trino.sql.planner.plan.TopNRankingNode in project trino by trinodb.
the class PushDownDereferencesThroughTopNRanking method apply.
@Override
public Result apply(ProjectNode projectNode, Captures captures, Context context) {
TopNRankingNode topNRankingNode = captures.get(CHILD);
// Extract dereferences from project node assignments for pushdown
Set<SubscriptExpression> dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false, context.getSession(), typeAnalyzer, context.getSymbolAllocator().getTypes());
// Exclude dereferences on symbols being used in partitionBy and orderBy
WindowNode.Specification specification = topNRankingNode.getSpecification();
dereferences = dereferences.stream().filter(expression -> {
Symbol symbol = getBase(expression);
return !specification.getPartitionBy().contains(symbol) && !specification.getOrderingScheme().map(OrderingScheme::getOrderBy).orElse(ImmutableList.of()).contains(symbol);
}).collect(toImmutableSet());
if (dereferences.isEmpty()) {
return Result.empty();
}
// Create new symbols for dereference expressions
Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSession(), context.getSymbolAllocator(), typeAnalyzer);
// Rewrite project node assignments using new symbols for dereference expressions
Map<Expression, SymbolReference> mappings = HashBiMap.create(dereferenceAssignments.getMap()).inverse().entrySet().stream().collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference()));
Assignments newAssignments = projectNode.getAssignments().rewrite(expression -> replaceExpression(expression, mappings));
return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), topNRankingNode.replaceChildren(ImmutableList.of(new ProjectNode(context.getIdAllocator().getNextId(), topNRankingNode.getSource(), Assignments.builder().putIdentities(topNRankingNode.getSource().getOutputSymbols()).putAll(dereferenceAssignments).build()))), newAssignments));
}
use of io.trino.sql.planner.plan.TopNRankingNode in project trino by trinodb.
the class PushdownLimitIntoWindow method apply.
@Override
public Result apply(LimitNode node, Captures captures, Context context) {
WindowNode source = captures.get(childCapture);
Optional<RankingType> rankingType = toTopNRankingType(source);
int limit = toIntExact(node.getCount());
TopNRankingNode topNRowNumberNode = new TopNRankingNode(source.getId(), source.getSource(), source.getSpecification(), rankingType.get(), getOnlyElement(source.getWindowFunctions().keySet()), limit, false, Optional.empty());
if (rankingType.get() == ROW_NUMBER && source.getPartitionBy().isEmpty()) {
return Result.ofPlanNode(topNRowNumberNode);
}
return Result.ofPlanNode(replaceChildren(node, ImmutableList.of(topNRowNumberNode)));
}
Aggregations