use of io.trino.sql.planner.plan.ExchangeNode in project trino by trinodb.
the class AddIntermediateAggregations method recurseToPartial.
/**
* Recurse through a series of preceding ExchangeNodes and ProjectNodes to find the preceding PARTIAL aggregation
*/
private Optional<PlanNode> recurseToPartial(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator) {
if (node instanceof AggregationNode && ((AggregationNode) node).getStep() == AggregationNode.Step.PARTIAL) {
return Optional.of(addGatheringIntermediate((AggregationNode) node, idAllocator));
}
if (!(node instanceof ExchangeNode) && !(node instanceof ProjectNode)) {
return Optional.empty();
}
ImmutableList.Builder<PlanNode> builder = ImmutableList.builder();
for (PlanNode source : node.getSources()) {
Optional<PlanNode> planNode = recurseToPartial(lookup.resolve(source), lookup, idAllocator);
if (planNode.isEmpty()) {
return Optional.empty();
}
builder.add(planNode.get());
}
return Optional.of(node.replaceChildren(builder.build()));
}
use of io.trino.sql.planner.plan.ExchangeNode in project trino by trinodb.
the class AddIntermediateAggregations method addGatheringIntermediate.
private PlanNode addGatheringIntermediate(AggregationNode aggregation, PlanNodeIdAllocator idAllocator) {
verify(aggregation.getGroupingKeys().isEmpty(), "Should be an un-grouped aggregation");
ExchangeNode gatheringExchange = ExchangeNode.gatheringExchange(idAllocator.getNextId(), ExchangeNode.Scope.LOCAL, aggregation);
return new AggregationNode(idAllocator.getNextId(), gatheringExchange, outputsAsInputs(aggregation.getAggregations()), aggregation.getGroupingSets(), aggregation.getPreGroupedSymbols(), AggregationNode.Step.INTERMEDIATE, aggregation.getHashSymbol(), aggregation.getGroupIdSymbol());
}
use of io.trino.sql.planner.plan.ExchangeNode in project trino by trinodb.
the class TestUnion method testUnionUnderTopN.
@Test
public void testUnionUnderTopN() {
Plan plan = plan("SELECT * FROM (" + " SELECT regionkey FROM nation " + " UNION ALL " + " SELECT nationkey FROM nation" + ") t(a) " + "ORDER BY a LIMIT 1", OPTIMIZED_AND_VALIDATED, false);
List<PlanNode> remotes = searchFrom(plan.getRoot()).where(TestUnion::isRemoteExchange).findAll();
assertEquals(remotes.size(), 1, "There should be exactly one RemoteExchange");
assertEquals(((ExchangeNode) Iterables.getOnlyElement(remotes)).getType(), GATHER);
int numberOfpartialTopN = searchFrom(plan.getRoot()).where(planNode -> planNode instanceof TopNNode && ((TopNNode) planNode).getStep() == TopNNode.Step.PARTIAL).count();
assertEquals(numberOfpartialTopN, 2, "There should be exactly two partial TopN nodes");
assertPlanIsFullyDistributed(plan);
}
use of io.trino.sql.planner.plan.ExchangeNode in project trino by trinodb.
the class PruneExchangeColumns method pushDownProjectOff.
@Override
protected Optional<PlanNode> pushDownProjectOff(Context context, ExchangeNode exchangeNode, Set<Symbol> referencedOutputs) {
// Extract output symbols referenced by parent node or used for partitioning, ordering or as a hash symbol of the Exchange
ImmutableSet.Builder<Symbol> builder = ImmutableSet.builder();
builder.addAll(referencedOutputs);
builder.addAll(exchangeNode.getPartitioningScheme().getPartitioning().getColumns());
exchangeNode.getPartitioningScheme().getHashColumn().ifPresent(builder::add);
exchangeNode.getOrderingScheme().ifPresent(orderingScheme -> builder.addAll(orderingScheme.getOrderBy()));
Set<Symbol> outputsToRetain = builder.build();
if (outputsToRetain.size() == exchangeNode.getOutputSymbols().size()) {
return Optional.empty();
}
ImmutableList.Builder<Symbol> newOutputs = ImmutableList.builder();
List<List<Symbol>> newInputs = new ArrayList<>(exchangeNode.getInputs().size());
for (int i = 0; i < exchangeNode.getInputs().size(); i++) {
newInputs.add(new ArrayList<>());
}
// Retain used symbols from output list and corresponding symbols from all input lists
for (int i = 0; i < exchangeNode.getOutputSymbols().size(); i++) {
Symbol output = exchangeNode.getOutputSymbols().get(i);
if (outputsToRetain.contains(output)) {
newOutputs.add(output);
for (int source = 0; source < exchangeNode.getInputs().size(); source++) {
newInputs.get(source).add(exchangeNode.getInputs().get(source).get(i));
}
}
}
// newOutputs contains all partition, sort and hash symbols so simply swap the output layout
PartitioningScheme newPartitioningScheme = new PartitioningScheme(exchangeNode.getPartitioningScheme().getPartitioning(), newOutputs.build(), exchangeNode.getPartitioningScheme().getHashColumn(), exchangeNode.getPartitioningScheme().isReplicateNullsAndAny(), exchangeNode.getPartitioningScheme().getBucketToPartition());
return Optional.of(new ExchangeNode(exchangeNode.getId(), exchangeNode.getType(), exchangeNode.getScope(), newPartitioningScheme, exchangeNode.getSources(), newInputs, exchangeNode.getOrderingScheme()));
}
use of io.trino.sql.planner.plan.ExchangeNode in project trino by trinodb.
the class PushProjectionThroughExchange method apply.
@Override
public Result apply(ProjectNode project, Captures captures, Context context) {
ExchangeNode exchange = captures.get(CHILD);
Set<Symbol> partitioningColumns = exchange.getPartitioningScheme().getPartitioning().getColumns();
ImmutableList.Builder<PlanNode> newSourceBuilder = ImmutableList.builder();
ImmutableList.Builder<List<Symbol>> inputsBuilder = ImmutableList.builder();
for (int i = 0; i < exchange.getSources().size(); i++) {
Map<Symbol, Symbol> outputToInputMap = mapExchangeOutputToInput(exchange, i);
Assignments.Builder projections = Assignments.builder();
ImmutableList.Builder<Symbol> inputs = ImmutableList.builder();
// Need to retain the partition keys for the exchange
partitioningColumns.stream().map(outputToInputMap::get).forEach(inputSymbol -> {
projections.put(inputSymbol, inputSymbol.toSymbolReference());
inputs.add(inputSymbol);
});
// Need to retain the hash symbol for the exchange
exchange.getPartitioningScheme().getHashColumn().map(outputToInputMap::get).ifPresent(inputSymbol -> {
projections.put(inputSymbol, inputSymbol.toSymbolReference());
inputs.add(inputSymbol);
});
if (exchange.getOrderingScheme().isPresent()) {
// Need to retain ordering columns for the exchange
exchange.getOrderingScheme().get().getOrderBy().stream().filter(symbol -> !partitioningColumns.contains(symbol)).map(outputToInputMap::get).forEach(inputSymbol -> {
projections.put(inputSymbol, inputSymbol.toSymbolReference());
inputs.add(inputSymbol);
});
}
ImmutableSet.Builder<Symbol> outputBuilder = ImmutableSet.builder();
partitioningColumns.forEach(outputBuilder::add);
exchange.getPartitioningScheme().getHashColumn().ifPresent(outputBuilder::add);
exchange.getOrderingScheme().ifPresent(orderingScheme -> outputBuilder.addAll(orderingScheme.getOrderBy()));
Set<Symbol> partitioningHashAndOrderingOutputs = outputBuilder.build();
Map<Symbol, Expression> translationMap = outputToInputMap.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference()));
for (Map.Entry<Symbol, Expression> projection : project.getAssignments().entrySet()) {
// Skip identity projection if symbol is in outputs already
if (partitioningHashAndOrderingOutputs.contains(projection.getKey())) {
continue;
}
Expression translatedExpression = inlineSymbols(translationMap, projection.getValue());
Type type = context.getSymbolAllocator().getTypes().get(projection.getKey());
Symbol symbol = context.getSymbolAllocator().newSymbol(translatedExpression, type);
projections.put(symbol, translatedExpression);
inputs.add(symbol);
}
newSourceBuilder.add(new ProjectNode(context.getIdAllocator().getNextId(), exchange.getSources().get(i), projections.build()));
inputsBuilder.add(inputs.build());
}
// Construct the output symbols in the same order as the sources
ImmutableList.Builder<Symbol> outputBuilder = ImmutableList.builder();
partitioningColumns.forEach(outputBuilder::add);
exchange.getPartitioningScheme().getHashColumn().ifPresent(outputBuilder::add);
if (exchange.getOrderingScheme().isPresent()) {
exchange.getOrderingScheme().get().getOrderBy().stream().filter(symbol -> !partitioningColumns.contains(symbol)).forEach(outputBuilder::add);
}
Set<Symbol> partitioningHashAndOrderingOutputs = ImmutableSet.copyOf(outputBuilder.build());
for (Map.Entry<Symbol, Expression> projection : project.getAssignments().entrySet()) {
// Do not add output for identity projection if symbol is in outputs already
if (partitioningHashAndOrderingOutputs.contains(projection.getKey())) {
continue;
}
outputBuilder.add(projection.getKey());
}
// outputBuilder contains all partition and hash symbols so simply swap the output layout
PartitioningScheme partitioningScheme = new PartitioningScheme(exchange.getPartitioningScheme().getPartitioning(), outputBuilder.build(), exchange.getPartitioningScheme().getHashColumn(), exchange.getPartitioningScheme().isReplicateNullsAndAny(), exchange.getPartitioningScheme().getBucketToPartition());
PlanNode result = new ExchangeNode(exchange.getId(), exchange.getType(), exchange.getScope(), partitioningScheme, newSourceBuilder.build(), inputsBuilder.build(), exchange.getOrderingScheme());
// we need to strip unnecessary symbols (hash, partitioning columns).
return Result.ofPlanNode(restrictOutputs(context.getIdAllocator(), result, ImmutableSet.copyOf(project.getOutputSymbols())).orElse(result));
}
Aggregations