use of io.trino.sql.planner.optimizations.SymbolMapper in project trino by trinodb.
the class PushLimitThroughProject method apply.
@Override
public Result apply(LimitNode parent, Captures captures, Context context) {
ProjectNode projectNode = captures.get(CHILD);
// Do not push down if the projection is made up of symbol references and exclusive dereferences. This prevents
// undoing of PushDownDereferencesThroughLimit. We still push limit in the case of overlapping dereferences since
// it enables PushDownDereferencesThroughLimit rule to push optimal dereferences.
Set<Expression> projections = ImmutableSet.copyOf(projectNode.getAssignments().getExpressions());
if (!extractRowSubscripts(projections, false, context.getSession(), typeAnalyzer, context.getSymbolAllocator().getTypes()).isEmpty() && exclusiveDereferences(projections, context.getSession(), typeAnalyzer, context.getSymbolAllocator().getTypes())) {
return Result.empty();
}
// for a LimitNode without ties and pre-sorted inputs, simply reorder the nodes
if (!parent.isWithTies() && !parent.requiresPreSortedInputs()) {
return Result.ofPlanNode(transpose(parent, projectNode));
}
// for a LimitNode with ties, the tiesResolvingScheme must be rewritten in terms of symbols before projection
SymbolMapper.Builder symbolMapper = SymbolMapper.builder();
Set<Symbol> symbolsForRewrite = ImmutableSet.<Symbol>builder().addAll(parent.getPreSortedInputs()).addAll(parent.getTiesResolvingScheme().map(OrderingScheme::getOrderBy).orElse(ImmutableList.of())).build();
for (Symbol symbol : symbolsForRewrite) {
Expression expression = projectNode.getAssignments().get(symbol);
// if a symbol results from some computation, the translation fails
if (!(expression instanceof SymbolReference)) {
return Result.empty();
}
symbolMapper.put(symbol, Symbol.from(expression));
}
LimitNode mappedLimitNode = symbolMapper.build().map(parent, projectNode.getSource());
return Result.ofPlanNode(projectNode.replaceChildren(ImmutableList.of(mappedLimitNode)));
}
use of io.trino.sql.planner.optimizations.SymbolMapper in project trino by trinodb.
the class PushTopNThroughProject method symbolMapper.
private Optional<SymbolMapper> symbolMapper(List<Symbol> symbols, Assignments assignments) {
SymbolMapper.Builder mapper = SymbolMapper.builder();
for (Symbol symbol : symbols) {
Expression expression = assignments.get(symbol);
if (!(expression instanceof SymbolReference)) {
return Optional.empty();
}
mapper.put(symbol, Symbol.from(expression));
}
return Optional.of(mapper.build());
}
use of io.trino.sql.planner.optimizations.SymbolMapper in project trino by trinodb.
the class PushTopNThroughUnion method apply.
@Override
public Result apply(TopNNode topNNode, Captures captures, Context context) {
UnionNode unionNode = captures.get(CHILD);
ImmutableList.Builder<PlanNode> sources = ImmutableList.builder();
for (PlanNode source : unionNode.getSources()) {
SymbolMapper.Builder symbolMapper = SymbolMapper.builder();
Set<Symbol> sourceOutputSymbols = ImmutableSet.copyOf(source.getOutputSymbols());
for (Symbol unionOutput : unionNode.getOutputSymbols()) {
Set<Symbol> inputSymbols = ImmutableSet.copyOf(unionNode.getSymbolMapping().get(unionOutput));
Symbol unionInput = getLast(intersection(inputSymbols, sourceOutputSymbols));
symbolMapper.put(unionOutput, unionInput);
}
sources.add(symbolMapper.build().map(topNNode, source, context.getIdAllocator().getNextId()));
}
return Result.ofPlanNode(new UnionNode(unionNode.getId(), sources.build(), unionNode.getSymbolMapping(), unionNode.getOutputSymbols()));
}
use of io.trino.sql.planner.optimizations.SymbolMapper in project trino by trinodb.
the class PushAggregationThroughOuterJoin method createAggregationOverNull.
private MappedAggregationInfo createAggregationOverNull(AggregationNode referenceAggregation, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator) {
// Create a values node that consists of a single row of nulls.
// Map the output symbols from the referenceAggregation's source
// to symbol references for the new values node.
ImmutableList.Builder<Symbol> nullSymbols = ImmutableList.builder();
ImmutableList.Builder<Expression> nullLiterals = ImmutableList.builder();
ImmutableMap.Builder<Symbol, Symbol> sourcesSymbolMappingBuilder = ImmutableMap.builder();
for (Symbol sourceSymbol : referenceAggregation.getSource().getOutputSymbols()) {
Type type = symbolAllocator.getTypes().get(sourceSymbol);
nullLiterals.add(new Cast(new NullLiteral(), toSqlType(type)));
Symbol nullSymbol = symbolAllocator.newSymbol("null", type);
nullSymbols.add(nullSymbol);
sourcesSymbolMappingBuilder.put(sourceSymbol, nullSymbol);
}
ValuesNode nullRow = new ValuesNode(idAllocator.getNextId(), nullSymbols.build(), ImmutableList.of(new Row(nullLiterals.build())));
// For each aggregation function in the reference node, create a corresponding aggregation function
// that points to the nullRow. Map the symbols from the aggregations in referenceAggregation to the
// symbols in these new aggregations.
ImmutableMap.Builder<Symbol, Symbol> aggregationsSymbolMappingBuilder = ImmutableMap.builder();
ImmutableMap.Builder<Symbol, AggregationNode.Aggregation> aggregationsOverNullBuilder = ImmutableMap.builder();
SymbolMapper mapper = symbolMapper(sourcesSymbolMappingBuilder.buildOrThrow());
for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : referenceAggregation.getAggregations().entrySet()) {
Symbol aggregationSymbol = entry.getKey();
Aggregation overNullAggregation = mapper.map(entry.getValue());
Symbol overNullSymbol = symbolAllocator.newSymbol(overNullAggregation.getResolvedFunction().getSignature().getName(), symbolAllocator.getTypes().get(aggregationSymbol));
aggregationsOverNullBuilder.put(overNullSymbol, overNullAggregation);
aggregationsSymbolMappingBuilder.put(aggregationSymbol, overNullSymbol);
}
Map<Symbol, Symbol> aggregationsSymbolMapping = aggregationsSymbolMappingBuilder.buildOrThrow();
// create an aggregation node whose source is the null row.
AggregationNode aggregationOverNullRow = new AggregationNode(idAllocator.getNextId(), nullRow, aggregationsOverNullBuilder.buildOrThrow(), globalAggregation(), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty());
return new MappedAggregationInfo(aggregationOverNullRow, aggregationsSymbolMapping);
}
use of io.trino.sql.planner.optimizations.SymbolMapper in project trino by trinodb.
the class PushPartialAggregationThroughExchange method pushPartial.
private PlanNode pushPartial(AggregationNode aggregation, ExchangeNode exchange, Context context) {
List<PlanNode> partials = new ArrayList<>();
for (int i = 0; i < exchange.getSources().size(); i++) {
PlanNode source = exchange.getSources().get(i);
SymbolMapper.Builder mappingsBuilder = SymbolMapper.builder();
for (int outputIndex = 0; outputIndex < exchange.getOutputSymbols().size(); outputIndex++) {
Symbol output = exchange.getOutputSymbols().get(outputIndex);
Symbol input = exchange.getInputs().get(i).get(outputIndex);
if (!output.equals(input)) {
mappingsBuilder.put(output, input);
}
}
SymbolMapper symbolMapper = mappingsBuilder.build();
AggregationNode mappedPartial = symbolMapper.map(aggregation, source, context.getIdAllocator().getNextId());
Assignments.Builder assignments = Assignments.builder();
for (Symbol output : aggregation.getOutputSymbols()) {
Symbol input = symbolMapper.map(output);
assignments.put(output, input.toSymbolReference());
}
partials.add(new ProjectNode(context.getIdAllocator().getNextId(), mappedPartial, assignments.build()));
}
for (PlanNode node : partials) {
verify(aggregation.getOutputSymbols().equals(node.getOutputSymbols()));
}
// Since this exchange source is now guaranteed to have the same symbols as the inputs to the partial
// aggregation, we don't need to rewrite symbols in the partitioning function
PartitioningScheme partitioning = new PartitioningScheme(exchange.getPartitioningScheme().getPartitioning(), aggregation.getOutputSymbols(), exchange.getPartitioningScheme().getHashColumn(), exchange.getPartitioningScheme().isReplicateNullsAndAny(), exchange.getPartitioningScheme().getBucketToPartition());
return new ExchangeNode(context.getIdAllocator().getNextId(), exchange.getType(), exchange.getScope(), partitioning, partials, ImmutableList.copyOf(Collections.nCopies(partials.size(), aggregation.getOutputSymbols())), Optional.empty());
}
Aggregations