use of io.trino.sql.planner.plan.AggregationNode in project trino by trinodb.
the class RemoveEmptyExceptBranches method apply.
@Override
public Result apply(ExceptNode node, Captures captures, Context context) {
if (isEmpty(node.getSources().get(0), context.getLookup())) {
return Result.ofPlanNode(new ValuesNode(node.getId(), node.getOutputSymbols(), ImmutableList.of()));
}
boolean hasEmptyBranches = node.getSources().stream().skip(// first source is the set we're excluding rows from, so ignore it
1).anyMatch(source -> isEmpty(source, context.getLookup()));
if (!hasEmptyBranches) {
return Result.empty();
}
ImmutableList.Builder<PlanNode> newSourcesBuilder = ImmutableList.builder();
ImmutableListMultimap.Builder<Symbol, Symbol> outputsToInputsBuilder = ImmutableListMultimap.builder();
for (int i = 0; i < node.getSources().size(); i++) {
PlanNode source = node.getSources().get(i);
if (i == 0 || !isEmpty(source, context.getLookup())) {
newSourcesBuilder.add(source);
for (Symbol column : node.getOutputSymbols()) {
outputsToInputsBuilder.put(column, node.getSymbolMapping().get(column).get(i));
}
}
}
List<PlanNode> newSources = newSourcesBuilder.build();
ListMultimap<Symbol, Symbol> outputsToInputs = outputsToInputsBuilder.build();
if (newSources.size() == 1) {
Assignments.Builder assignments = Assignments.builder();
outputsToInputs.entries().stream().forEach(entry -> assignments.put(entry.getKey(), entry.getValue().toSymbolReference()));
if (node.isDistinct()) {
return Result.ofPlanNode(new AggregationNode(node.getId(), new ProjectNode(context.getIdAllocator().getNextId(), newSources.get(0), assignments.build()), ImmutableMap.of(), singleGroupingSet(node.getOutputSymbols()), ImmutableList.of(), Step.SINGLE, Optional.empty(), Optional.empty()));
}
return Result.ofPlanNode(new ProjectNode(node.getId(), newSources.get(0), assignments.build()));
}
return Result.ofPlanNode(new ExceptNode(node.getId(), newSources, outputsToInputs, node.getOutputSymbols(), node.isDistinct()));
}
use of io.trino.sql.planner.plan.AggregationNode in project trino by trinodb.
the class PlanUtils method createAggregationFragment.
static PlanFragment createAggregationFragment(String name, PlanFragment sourceFragment) {
RemoteSourceNode source = new RemoteSourceNode(new PlanNodeId("source_id"), sourceFragment.getId(), ImmutableList.of(), Optional.empty(), REPARTITION, RetryPolicy.NONE);
PlanNode planNode = new AggregationNode(new PlanNodeId(name + "_id"), source, ImmutableMap.of(), singleGroupingSet(ImmutableList.of()), ImmutableList.of(), FINAL, Optional.empty(), Optional.empty());
return createFragment(planNode);
}
use of io.trino.sql.planner.plan.AggregationNode in project trino by trinodb.
the class TestUnion method assertAtMostOneAggregationBetweenRemoteExchanges.
private static void assertAtMostOneAggregationBetweenRemoteExchanges(Plan plan) {
List<PlanNode> fragments = searchFrom(plan.getRoot()).where(TestUnion::isRemoteExchange).findAll().stream().flatMap(exchangeNode -> exchangeNode.getSources().stream()).collect(toList());
for (PlanNode fragment : fragments) {
List<PlanNode> aggregations = searchFrom(fragment).where(AggregationNode.class::isInstance).recurseOnlyWhen(TestUnion::isNotRemoteExchange).findAll();
assertFalse(aggregations.size() > 1, "More than a single AggregationNode between remote exchanges");
}
}
use of io.trino.sql.planner.plan.AggregationNode in project trino by trinodb.
the class DecorrelateInnerUnnestWithGlobalAggregation method withGroupingAndMask.
private static AggregationNode withGroupingAndMask(AggregationNode aggregationNode, List<Symbol> groupingSymbols, Symbol mask, PlanNode source, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator) {
// Every original aggregation agg() will be rewritten to agg() mask(mask_symbol). If the aggregation
// already has a mask, it will be replaced with conjunction of the existing mask and mask_symbol.
ImmutableMap.Builder<Symbol, Symbol> masks = ImmutableMap.builder();
Assignments.Builder assignmentsBuilder = Assignments.builder();
for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : aggregationNode.getAggregations().entrySet()) {
AggregationNode.Aggregation aggregation = entry.getValue();
if (aggregation.getMask().isPresent()) {
Symbol newMask = symbolAllocator.newSymbol("mask", BOOLEAN);
Expression expression = and(aggregation.getMask().get().toSymbolReference(), mask.toSymbolReference());
assignmentsBuilder.put(newMask, expression);
masks.put(entry.getKey(), newMask);
} else {
masks.put(entry.getKey(), mask);
}
}
Assignments maskAssignments = assignmentsBuilder.build();
if (!maskAssignments.isEmpty()) {
source = new ProjectNode(idAllocator.getNextId(), source, Assignments.builder().putIdentities(source.getOutputSymbols()).putAll(maskAssignments).build());
}
return new AggregationNode(aggregationNode.getId(), source, rewriteWithMasks(aggregationNode.getAggregations(), masks.buildOrThrow()), singleGroupingSet(groupingSymbols), ImmutableList.of(), SINGLE, Optional.empty(), Optional.empty());
}
use of io.trino.sql.planner.plan.AggregationNode in project trino by trinodb.
the class DecorrelateLeftUnnestWithGlobalAggregation method apply.
@Override
public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Context context) {
// find global aggregation in subquery
Optional<AggregationNode> globalAggregation = PlanNodeSearcher.searchFrom(correlatedJoinNode.getSubquery(), context.getLookup()).where(DecorrelateLeftUnnestWithGlobalAggregation::isGlobalAggregation).recurseOnlyWhen(node -> node instanceof ProjectNode || isGroupedAggregation(node)).findFirst();
if (globalAggregation.isEmpty()) {
return Result.empty();
}
// find unnest in subquery
Optional<UnnestNode> subqueryUnnest = PlanNodeSearcher.searchFrom(correlatedJoinNode.getSubquery(), context.getLookup()).where(node -> isSupportedUnnest(node, correlatedJoinNode.getCorrelation(), context.getLookup())).recurseOnlyWhen(node -> node instanceof ProjectNode || isGlobalAggregation(node) || isGroupedAggregation(node)).findFirst();
if (subqueryUnnest.isEmpty()) {
return Result.empty();
}
UnnestNode unnestNode = subqueryUnnest.get();
// assign unique id to input rows to restore semantics of aggregations after rewrite
PlanNode input = new AssignUniqueId(context.getIdAllocator().getNextId(), correlatedJoinNode.getInput(), context.getSymbolAllocator().newSymbol("unique", BIGINT));
// pre-project unnest symbols if they were pre-projected in subquery
// The correlated UnnestNode either unnests correlation symbols directly, or unnests symbols produced by a projection that uses only correlation symbols.
// Here, any underlying projection that was a source of the correlated UnnestNode, is appended as a source of the rewritten UnnestNode.
// If the projection is not necessary for UnnestNode (i.e. it does not produce any unnest symbols), it should be pruned afterwards.
PlanNode unnestSource = context.getLookup().resolve(unnestNode.getSource());
if (unnestSource instanceof ProjectNode) {
ProjectNode sourceProjection = (ProjectNode) unnestSource;
input = new ProjectNode(sourceProjection.getId(), input, Assignments.builder().putIdentities(input.getOutputSymbols()).putAll(sourceProjection.getAssignments()).build());
}
// rewrite correlated join to UnnestNode
UnnestNode rewrittenUnnest = new UnnestNode(context.getIdAllocator().getNextId(), input, input.getOutputSymbols(), unnestNode.getMappings(), unnestNode.getOrdinalitySymbol(), LEFT, Optional.empty());
// restore all projections, grouped aggregations and global aggregations from the subquery
PlanNode result = rewriteNodeSequence(context.getLookup().resolve(correlatedJoinNode.getSubquery()), input.getOutputSymbols(), rewrittenUnnest, unnestNode.getId(), context.getLookup());
// restrict outputs
return Result.ofPlanNode(restrictOutputs(context.getIdAllocator(), result, ImmutableSet.copyOf(correlatedJoinNode.getOutputSymbols())).orElse(result));
}
Aggregations