use of io.trino.sql.planner.DomainTranslator in project trino by trinodb.
the class PushPredicateThroughProjectIntoRowNumber method apply.
@Override
public Result apply(FilterNode filter, Captures captures, Context context) {
ProjectNode project = captures.get(PROJECT);
RowNumberNode rowNumber = captures.get(ROW_NUMBER);
Symbol rowNumberSymbol = rowNumber.getRowNumberSymbol();
if (!project.getAssignments().getSymbols().contains(rowNumberSymbol)) {
return Result.empty();
}
ExtractionResult extractionResult = DomainTranslator.getExtractionResult(plannerContext, context.getSession(), filter.getPredicate(), context.getSymbolAllocator().getTypes());
TupleDomain<Symbol> tupleDomain = extractionResult.getTupleDomain();
OptionalInt upperBound = extractUpperBound(tupleDomain, rowNumberSymbol);
if (upperBound.isEmpty()) {
return Result.empty();
}
if (upperBound.getAsInt() <= 0) {
return Result.ofPlanNode(new ValuesNode(filter.getId(), filter.getOutputSymbols(), ImmutableList.of()));
}
boolean updatedMaxRowCountPerPartition = false;
if (rowNumber.getMaxRowCountPerPartition().isEmpty() || rowNumber.getMaxRowCountPerPartition().get() > upperBound.getAsInt()) {
rowNumber = new RowNumberNode(rowNumber.getId(), rowNumber.getSource(), rowNumber.getPartitionBy(), rowNumber.isOrderSensitive(), rowNumber.getRowNumberSymbol(), Optional.of(upperBound.getAsInt()), rowNumber.getHashSymbol());
project = (ProjectNode) project.replaceChildren(ImmutableList.of(rowNumber));
updatedMaxRowCountPerPartition = true;
}
if (!allRowNumberValuesInDomain(tupleDomain, rowNumberSymbol, rowNumber.getMaxRowCountPerPartition().get())) {
if (updatedMaxRowCountPerPartition) {
return Result.ofPlanNode(filter.replaceChildren(ImmutableList.of(project)));
}
return Result.empty();
}
// Remove the row number domain because it is absorbed into the node
TupleDomain<Symbol> newTupleDomain = tupleDomain.filter((symbol, domain) -> !symbol.equals(rowNumberSymbol));
Expression newPredicate = ExpressionUtils.combineConjuncts(plannerContext.getMetadata(), extractionResult.getRemainingExpression(), new DomainTranslator(plannerContext).toPredicate(context.getSession(), newTupleDomain));
if (newPredicate.equals(TRUE_LITERAL)) {
return Result.ofPlanNode(project);
}
return Result.ofPlanNode(new FilterNode(filter.getId(), project, newPredicate));
}
use of io.trino.sql.planner.DomainTranslator in project trino by trinodb.
the class PushPredicateThroughProjectIntoWindow method apply.
@Override
public Result apply(FilterNode filter, Captures captures, Context context) {
ProjectNode project = captures.get(PROJECT);
WindowNode window = captures.get(WINDOW);
Symbol rankingSymbol = getOnlyElement(window.getWindowFunctions().keySet());
if (!project.getAssignments().getSymbols().contains(rankingSymbol)) {
return Result.empty();
}
DomainTranslator.ExtractionResult extractionResult = DomainTranslator.getExtractionResult(plannerContext, context.getSession(), filter.getPredicate(), context.getSymbolAllocator().getTypes());
TupleDomain<Symbol> tupleDomain = extractionResult.getTupleDomain();
OptionalInt upperBound = extractUpperBound(tupleDomain, rankingSymbol);
if (upperBound.isEmpty()) {
return Result.empty();
}
if (upperBound.getAsInt() <= 0) {
return Result.ofPlanNode(new ValuesNode(filter.getId(), filter.getOutputSymbols(), ImmutableList.of()));
}
RankingType rankingType = toTopNRankingType(window).orElseThrow();
project = (ProjectNode) project.replaceChildren(ImmutableList.of(new TopNRankingNode(window.getId(), window.getSource(), window.getSpecification(), rankingType, rankingSymbol, upperBound.getAsInt(), false, Optional.empty())));
if (!allRankingValuesInDomain(tupleDomain, rankingSymbol, upperBound.getAsInt())) {
return Result.ofPlanNode(filter.replaceChildren(ImmutableList.of(project)));
}
// Remove the ranking domain because it is absorbed into the node
TupleDomain<Symbol> newTupleDomain = tupleDomain.filter((symbol, domain) -> !symbol.equals(rankingSymbol));
Expression newPredicate = ExpressionUtils.combineConjuncts(plannerContext.getMetadata(), extractionResult.getRemainingExpression(), new DomainTranslator(plannerContext).toPredicate(context.getSession(), newTupleDomain));
if (newPredicate.equals(TRUE_LITERAL)) {
return Result.ofPlanNode(project);
}
return Result.ofPlanNode(new FilterNode(filter.getId(), project, newPredicate));
}
use of io.trino.sql.planner.DomainTranslator in project trino by trinodb.
the class PushdownFilterIntoRowNumber method apply.
@Override
public Result apply(FilterNode node, Captures captures, Context context) {
Session session = context.getSession();
TypeProvider types = context.getSymbolAllocator().getTypes();
DomainTranslator.ExtractionResult extractionResult = DomainTranslator.getExtractionResult(plannerContext, session, node.getPredicate(), types);
TupleDomain<Symbol> tupleDomain = extractionResult.getTupleDomain();
RowNumberNode source = captures.get(CHILD);
Symbol rowNumberSymbol = source.getRowNumberSymbol();
OptionalInt upperBound = extractUpperBound(tupleDomain, rowNumberSymbol);
if (upperBound.isEmpty()) {
return Result.empty();
}
if (upperBound.getAsInt() <= 0) {
return Result.ofPlanNode(new ValuesNode(node.getId(), node.getOutputSymbols(), ImmutableList.of()));
}
RowNumberNode merged = mergeLimit(source, upperBound.getAsInt());
boolean needRewriteSource = !merged.getMaxRowCountPerPartition().equals(source.getMaxRowCountPerPartition());
if (needRewriteSource) {
source = merged;
}
if (!allRowNumberValuesInDomain(tupleDomain, rowNumberSymbol, source.getMaxRowCountPerPartition().get())) {
if (needRewriteSource) {
return Result.ofPlanNode(new FilterNode(node.getId(), source, node.getPredicate()));
} else {
return Result.empty();
}
}
TupleDomain<Symbol> newTupleDomain = tupleDomain.filter((symbol, domain) -> !symbol.equals(rowNumberSymbol));
Expression newPredicate = ExpressionUtils.combineConjuncts(plannerContext.getMetadata(), extractionResult.getRemainingExpression(), new DomainTranslator(plannerContext).toPredicate(session, newTupleDomain));
if (newPredicate.equals(BooleanLiteral.TRUE_LITERAL)) {
return Result.ofPlanNode(source);
}
if (!newPredicate.equals(node.getPredicate())) {
return Result.ofPlanNode(new FilterNode(node.getId(), source, newPredicate));
}
return Result.empty();
}
Aggregations