use of io.trino.sql.planner.plan.Assignments in project trino by trinodb.
the class PushDownDereferencesThroughMarkDistinct method apply.
@Override
public Result apply(ProjectNode projectNode, Captures captures, Context context) {
MarkDistinctNode markDistinctNode = captures.get(CHILD);
// Extract dereferences from project node assignments for pushdown
Set<SubscriptExpression> dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false, context.getSession(), typeAnalyzer, context.getSymbolAllocator().getTypes());
// Exclude dereferences on distinct symbols being used in markDistinctNode. We do not need to filter
// dereferences on markerSymbol since it is supposed to be of boolean type.
dereferences = dereferences.stream().filter(expression -> !markDistinctNode.getDistinctSymbols().contains(getBase(expression))).collect(toImmutableSet());
if (dereferences.isEmpty()) {
return Result.empty();
}
// Create new symbols for dereference expressions
Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSession(), context.getSymbolAllocator(), typeAnalyzer);
// Rewrite project node assignments using new symbols for dereference expressions
Map<Expression, SymbolReference> mappings = HashBiMap.create(dereferenceAssignments.getMap()).inverse().entrySet().stream().collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference()));
Assignments newAssignments = projectNode.getAssignments().rewrite(expression -> replaceExpression(expression, mappings));
return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), markDistinctNode.replaceChildren(ImmutableList.of(new ProjectNode(context.getIdAllocator().getNextId(), markDistinctNode.getSource(), Assignments.builder().putIdentities(markDistinctNode.getSource().getOutputSymbols()).putAll(dereferenceAssignments).build()))), newAssignments));
}
use of io.trino.sql.planner.plan.Assignments in project trino by trinodb.
the class PushDownDereferencesThroughRowNumber method apply.
@Override
public Result apply(ProjectNode projectNode, Captures captures, Context context) {
RowNumberNode rowNumberNode = captures.get(CHILD);
// Extract dereferences from project node assignments for pushdown
Set<SubscriptExpression> dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false, context.getSession(), typeAnalyzer, context.getSymbolAllocator().getTypes());
// Exclude dereferences on symbols being used in partitionBy
dereferences = dereferences.stream().filter(expression -> !rowNumberNode.getPartitionBy().contains(getBase(expression))).collect(toImmutableSet());
if (dereferences.isEmpty()) {
return Result.empty();
}
// Create new symbols for dereference expressions
Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSession(), context.getSymbolAllocator(), typeAnalyzer);
// Rewrite project node assignments using new symbols for dereference expressions
Map<Expression, SymbolReference> mappings = HashBiMap.create(dereferenceAssignments.getMap()).inverse().entrySet().stream().collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference()));
Assignments newAssignments = projectNode.getAssignments().rewrite(expression -> replaceExpression(expression, mappings));
return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), rowNumberNode.replaceChildren(ImmutableList.of(new ProjectNode(context.getIdAllocator().getNextId(), rowNumberNode.getSource(), Assignments.builder().putIdentities(rowNumberNode.getSource().getOutputSymbols()).putAll(dereferenceAssignments).build()))), newAssignments));
}
use of io.trino.sql.planner.plan.Assignments in project trino by trinodb.
the class PushDownDereferencesThroughTopN method apply.
@Override
public Result apply(ProjectNode projectNode, Captures captures, Context context) {
TopNNode topNNode = captures.get(CHILD);
// Extract dereferences from project node assignments for pushdown
Set<SubscriptExpression> dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false, context.getSession(), typeAnalyzer, context.getSymbolAllocator().getTypes());
// Exclude dereferences on symbols being used in orderBy
dereferences = dereferences.stream().filter(expression -> !topNNode.getOrderingScheme().getOrderBy().contains(getBase(expression))).collect(toImmutableSet());
if (dereferences.isEmpty()) {
return Result.empty();
}
// Create new symbols for dereference expressions
Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSession(), context.getSymbolAllocator(), typeAnalyzer);
// Rewrite project node assignments using new symbols for dereference expressions
Map<Expression, SymbolReference> mappings = HashBiMap.create(dereferenceAssignments.getMap()).inverse().entrySet().stream().collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference()));
Assignments newAssignments = projectNode.getAssignments().rewrite(expression -> replaceExpression(expression, mappings));
return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), topNNode.replaceChildren(ImmutableList.of(new ProjectNode(context.getIdAllocator().getNextId(), topNNode.getSource(), Assignments.builder().putIdentities(topNNode.getSource().getOutputSymbols()).putAll(dereferenceAssignments).build()))), newAssignments));
}
use of io.trino.sql.planner.plan.Assignments in project trino by trinodb.
the class PushDownDereferencesThroughTopNRanking method apply.
@Override
public Result apply(ProjectNode projectNode, Captures captures, Context context) {
TopNRankingNode topNRankingNode = captures.get(CHILD);
// Extract dereferences from project node assignments for pushdown
Set<SubscriptExpression> dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false, context.getSession(), typeAnalyzer, context.getSymbolAllocator().getTypes());
// Exclude dereferences on symbols being used in partitionBy and orderBy
WindowNode.Specification specification = topNRankingNode.getSpecification();
dereferences = dereferences.stream().filter(expression -> {
Symbol symbol = getBase(expression);
return !specification.getPartitionBy().contains(symbol) && !specification.getOrderingScheme().map(OrderingScheme::getOrderBy).orElse(ImmutableList.of()).contains(symbol);
}).collect(toImmutableSet());
if (dereferences.isEmpty()) {
return Result.empty();
}
// Create new symbols for dereference expressions
Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSession(), context.getSymbolAllocator(), typeAnalyzer);
// Rewrite project node assignments using new symbols for dereference expressions
Map<Expression, SymbolReference> mappings = HashBiMap.create(dereferenceAssignments.getMap()).inverse().entrySet().stream().collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference()));
Assignments newAssignments = projectNode.getAssignments().rewrite(expression -> replaceExpression(expression, mappings));
return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), topNRankingNode.replaceChildren(ImmutableList.of(new ProjectNode(context.getIdAllocator().getNextId(), topNRankingNode.getSource(), Assignments.builder().putIdentities(topNRankingNode.getSource().getOutputSymbols()).putAll(dereferenceAssignments).build()))), newAssignments));
}
use of io.trino.sql.planner.plan.Assignments in project trino by trinodb.
the class PushPartialAggregationThroughExchange method pushPartial.
private PlanNode pushPartial(AggregationNode aggregation, ExchangeNode exchange, Context context) {
List<PlanNode> partials = new ArrayList<>();
for (int i = 0; i < exchange.getSources().size(); i++) {
PlanNode source = exchange.getSources().get(i);
SymbolMapper.Builder mappingsBuilder = SymbolMapper.builder();
for (int outputIndex = 0; outputIndex < exchange.getOutputSymbols().size(); outputIndex++) {
Symbol output = exchange.getOutputSymbols().get(outputIndex);
Symbol input = exchange.getInputs().get(i).get(outputIndex);
if (!output.equals(input)) {
mappingsBuilder.put(output, input);
}
}
SymbolMapper symbolMapper = mappingsBuilder.build();
AggregationNode mappedPartial = symbolMapper.map(aggregation, source, context.getIdAllocator().getNextId());
Assignments.Builder assignments = Assignments.builder();
for (Symbol output : aggregation.getOutputSymbols()) {
Symbol input = symbolMapper.map(output);
assignments.put(output, input.toSymbolReference());
}
partials.add(new ProjectNode(context.getIdAllocator().getNextId(), mappedPartial, assignments.build()));
}
for (PlanNode node : partials) {
verify(aggregation.getOutputSymbols().equals(node.getOutputSymbols()));
}
// Since this exchange source is now guaranteed to have the same symbols as the inputs to the partial
// aggregation, we don't need to rewrite symbols in the partitioning function
PartitioningScheme partitioning = new PartitioningScheme(exchange.getPartitioningScheme().getPartitioning(), aggregation.getOutputSymbols(), exchange.getPartitioningScheme().getHashColumn(), exchange.getPartitioningScheme().isReplicateNullsAndAny(), exchange.getPartitioningScheme().getBucketToPartition());
return new ExchangeNode(context.getIdAllocator().getNextId(), exchange.getType(), exchange.getScope(), partitioning, partials, ImmutableList.copyOf(Collections.nCopies(partials.size(), aggregation.getOutputSymbols())), Optional.empty());
}
Aggregations