use of io.trino.sql.planner.plan.AggregationNode in project trino by trinodb.
the class ImplementFilteredAggregations method apply.
@Override
public Result apply(AggregationNode aggregationNode, Captures captures, Context context) {
Assignments.Builder newAssignments = Assignments.builder();
ImmutableMap.Builder<Symbol, Aggregation> aggregations = ImmutableMap.builder();
ImmutableList.Builder<Expression> maskSymbols = ImmutableList.builder();
boolean aggregateWithoutFilterOrMaskPresent = false;
for (Map.Entry<Symbol, Aggregation> entry : aggregationNode.getAggregations().entrySet()) {
Symbol output = entry.getKey();
// strip the filters
Aggregation aggregation = entry.getValue();
Optional<Symbol> mask = aggregation.getMask();
if (aggregation.getFilter().isPresent()) {
Symbol filter = aggregation.getFilter().get();
if (mask.isPresent()) {
Symbol newMask = context.getSymbolAllocator().newSymbol("mask", BOOLEAN);
Expression expression = and(mask.get().toSymbolReference(), filter.toSymbolReference());
newAssignments.put(newMask, expression);
mask = Optional.of(newMask);
maskSymbols.add(newMask.toSymbolReference());
} else {
mask = Optional.of(filter);
maskSymbols.add(filter.toSymbolReference());
}
} else if (mask.isPresent()) {
maskSymbols.add(mask.get().toSymbolReference());
} else {
aggregateWithoutFilterOrMaskPresent = true;
}
aggregations.put(output, new Aggregation(aggregation.getResolvedFunction(), aggregation.getArguments(), aggregation.isDistinct(), Optional.empty(), aggregation.getOrderingScheme(), mask));
}
Expression predicate = TRUE_LITERAL;
if (!aggregationNode.hasNonEmptyGroupingSet() && !aggregateWithoutFilterOrMaskPresent) {
predicate = combineDisjunctsWithDefault(metadata, maskSymbols.build(), TRUE_LITERAL);
}
// identity projection for all existing inputs
newAssignments.putIdentities(aggregationNode.getSource().getOutputSymbols());
return Result.ofPlanNode(new AggregationNode(context.getIdAllocator().getNextId(), new FilterNode(context.getIdAllocator().getNextId(), new ProjectNode(context.getIdAllocator().getNextId(), aggregationNode.getSource(), newAssignments.build()), predicate), aggregations.buildOrThrow(), aggregationNode.getGroupingSets(), ImmutableList.of(), aggregationNode.getStep(), aggregationNode.getHashSymbol(), aggregationNode.getGroupIdSymbol()));
}
use of io.trino.sql.planner.plan.AggregationNode in project trino by trinodb.
the class DecorrelateInnerUnnestWithGlobalAggregation method apply.
@Override
public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Context context) {
// find global aggregation in subquery
List<PlanNode> globalAggregations = PlanNodeSearcher.searchFrom(correlatedJoinNode.getSubquery(), context.getLookup()).where(DecorrelateInnerUnnestWithGlobalAggregation::isGlobalAggregation).recurseOnlyWhen(node -> node instanceof ProjectNode || isGlobalAggregation(node)).findAll();
if (globalAggregations.isEmpty()) {
return Result.empty();
}
// if there are multiple global aggregations, the one that is closest to the source is the "reducing" aggregation, because it reduces multiple input rows to single output row
AggregationNode reducingAggregation = (AggregationNode) globalAggregations.get(globalAggregations.size() - 1);
// find unnest in subquery
Optional<UnnestNode> subqueryUnnest = PlanNodeSearcher.searchFrom(reducingAggregation.getSource(), context.getLookup()).where(node -> isSupportedUnnest(node, correlatedJoinNode.getCorrelation(), context.getLookup())).recurseOnlyWhen(node -> node instanceof ProjectNode || isGroupedAggregation(node)).findFirst();
if (subqueryUnnest.isEmpty()) {
return Result.empty();
}
UnnestNode unnestNode = subqueryUnnest.get();
// assign unique id to input rows to restore semantics of aggregations after rewrite
PlanNode input = new AssignUniqueId(context.getIdAllocator().getNextId(), correlatedJoinNode.getInput(), context.getSymbolAllocator().newSymbol("unique", BIGINT));
// pre-project unnest symbols if they were pre-projected in subquery
// The correlated UnnestNode either unnests correlation symbols directly, or unnests symbols produced by a projection that uses only correlation symbols.
// Here, any underlying projection that was a source of the correlated UnnestNode, is appended as a source of the rewritten UnnestNode.
// If the projection is not necessary for UnnestNode (i.e. it does not produce any unnest symbols), it should be pruned afterwards.
PlanNode unnestSource = context.getLookup().resolve(unnestNode.getSource());
if (unnestSource instanceof ProjectNode) {
ProjectNode sourceProjection = (ProjectNode) unnestSource;
input = new ProjectNode(sourceProjection.getId(), input, Assignments.builder().putIdentities(input.getOutputSymbols()).putAll(sourceProjection.getAssignments()).build());
}
// rewrite correlated join to UnnestNode
Symbol ordinalitySymbol = unnestNode.getOrdinalitySymbol().orElseGet(() -> context.getSymbolAllocator().newSymbol("ordinality", BIGINT));
UnnestNode rewrittenUnnest = new UnnestNode(context.getIdAllocator().getNextId(), input, input.getOutputSymbols(), unnestNode.getMappings(), Optional.of(ordinalitySymbol), LEFT, Optional.empty());
// append mask symbol based on ordinality to distinguish between the unnested rows and synthetic null rows
Symbol mask = context.getSymbolAllocator().newSymbol("mask", BOOLEAN);
ProjectNode sourceWithMask = new ProjectNode(context.getIdAllocator().getNextId(), rewrittenUnnest, Assignments.builder().putIdentities(rewrittenUnnest.getOutputSymbols()).put(mask, new IsNotNullPredicate(ordinalitySymbol.toSymbolReference())).build());
// restore all projections, grouped aggregations and global aggregations from the subquery
PlanNode result = rewriteNodeSequence(context.getLookup().resolve(correlatedJoinNode.getSubquery()), input.getOutputSymbols(), mask, sourceWithMask, reducingAggregation.getId(), unnestNode.getId(), context.getSymbolAllocator(), context.getIdAllocator(), context.getLookup());
// restrict outputs
return Result.ofPlanNode(restrictOutputs(context.getIdAllocator(), result, ImmutableSet.copyOf(correlatedJoinNode.getOutputSymbols())).orElse(result));
}
use of io.trino.sql.planner.plan.AggregationNode in project trino by trinodb.
the class AddIntermediateAggregations method apply.
@Override
public Result apply(AggregationNode aggregation, Captures captures, Context context) {
Lookup lookup = context.getLookup();
PlanNodeIdAllocator idAllocator = context.getIdAllocator();
Session session = context.getSession();
Optional<PlanNode> rewrittenSource = recurseToPartial(lookup.resolve(aggregation.getSource()), lookup, idAllocator);
if (rewrittenSource.isEmpty()) {
return Result.empty();
}
PlanNode source = rewrittenSource.get();
if (getTaskConcurrency(session) > 1) {
source = ExchangeNode.partitionedExchange(idAllocator.getNextId(), ExchangeNode.Scope.LOCAL, source, new PartitioningScheme(Partitioning.create(FIXED_ARBITRARY_DISTRIBUTION, ImmutableList.of()), source.getOutputSymbols()));
source = new AggregationNode(idAllocator.getNextId(), source, inputsAsOutputs(aggregation.getAggregations()), aggregation.getGroupingSets(), aggregation.getPreGroupedSymbols(), AggregationNode.Step.INTERMEDIATE, aggregation.getHashSymbol(), aggregation.getGroupIdSymbol());
source = ExchangeNode.gatheringExchange(idAllocator.getNextId(), ExchangeNode.Scope.LOCAL, source);
}
return Result.ofPlanNode(aggregation.replaceChildren(ImmutableList.of(source)));
}
use of io.trino.sql.planner.plan.AggregationNode in project trino by trinodb.
the class MultipleDistinctAggregationToMarkDistinct method apply.
@Override
public Result apply(AggregationNode parent, Captures captures, Context context) {
if (!SystemSessionProperties.useMarkDistinct(context.getSession())) {
return Result.empty();
}
// the distinct marker for the given set of input columns
Map<Set<Symbol>, Symbol> markers = new HashMap<>();
Map<Symbol, Aggregation> newAggregations = new HashMap<>();
PlanNode subPlan = parent.getSource();
for (Map.Entry<Symbol, Aggregation> entry : parent.getAggregations().entrySet()) {
Aggregation aggregation = entry.getValue();
if (aggregation.isDistinct() && aggregation.getFilter().isEmpty() && aggregation.getMask().isEmpty()) {
Set<Symbol> inputs = aggregation.getArguments().stream().map(Symbol::from).collect(toSet());
Symbol marker = markers.get(inputs);
if (marker == null) {
marker = context.getSymbolAllocator().newSymbol(Iterables.getLast(inputs).getName(), BOOLEAN, "distinct");
markers.put(inputs, marker);
ImmutableSet.Builder<Symbol> distinctSymbols = ImmutableSet.<Symbol>builder().addAll(parent.getGroupingKeys()).addAll(inputs);
parent.getGroupIdSymbol().ifPresent(distinctSymbols::add);
subPlan = new MarkDistinctNode(context.getIdAllocator().getNextId(), subPlan, marker, ImmutableList.copyOf(distinctSymbols.build()), Optional.empty());
}
// remove the distinct flag and set the distinct marker
newAggregations.put(entry.getKey(), new Aggregation(aggregation.getResolvedFunction(), aggregation.getArguments(), false, aggregation.getFilter(), aggregation.getOrderingScheme(), Optional.of(marker)));
} else {
newAggregations.put(entry.getKey(), aggregation);
}
}
return Result.ofPlanNode(new AggregationNode(parent.getId(), subPlan, newAggregations, parent.getGroupingSets(), ImmutableList.of(), parent.getStep(), parent.getHashSymbol(), parent.getGroupIdSymbol()));
}
use of io.trino.sql.planner.plan.AggregationNode in project trino by trinodb.
the class PushPartialAggregationThroughExchange method split.
private PlanNode split(AggregationNode node, Context context) {
// otherwise, add a partial and final with an exchange in between
Map<Symbol, AggregationNode.Aggregation> intermediateAggregation = new HashMap<>();
Map<Symbol, AggregationNode.Aggregation> finalAggregation = new HashMap<>();
for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : node.getAggregations().entrySet()) {
AggregationNode.Aggregation originalAggregation = entry.getValue();
ResolvedFunction resolvedFunction = originalAggregation.getResolvedFunction();
AggregationFunctionMetadata functionMetadata = plannerContext.getMetadata().getAggregationFunctionMetadata(resolvedFunction);
List<Type> intermediateTypes = functionMetadata.getIntermediateTypes().stream().map(plannerContext.getTypeManager()::getType).collect(toImmutableList());
Type intermediateType = intermediateTypes.size() == 1 ? intermediateTypes.get(0) : RowType.anonymous(intermediateTypes);
Symbol intermediateSymbol = context.getSymbolAllocator().newSymbol(resolvedFunction.getSignature().getName(), intermediateType);
checkState(originalAggregation.getOrderingScheme().isEmpty(), "Aggregate with ORDER BY does not support partial aggregation");
intermediateAggregation.put(intermediateSymbol, new AggregationNode.Aggregation(resolvedFunction, originalAggregation.getArguments(), originalAggregation.isDistinct(), originalAggregation.getFilter(), originalAggregation.getOrderingScheme(), originalAggregation.getMask()));
// rewrite final aggregation in terms of intermediate function
finalAggregation.put(entry.getKey(), new AggregationNode.Aggregation(resolvedFunction, ImmutableList.<Expression>builder().add(intermediateSymbol.toSymbolReference()).addAll(originalAggregation.getArguments().stream().filter(LambdaExpression.class::isInstance).collect(toImmutableList())).build(), false, Optional.empty(), Optional.empty(), Optional.empty()));
}
PlanNode partial = new AggregationNode(context.getIdAllocator().getNextId(), node.getSource(), intermediateAggregation, node.getGroupingSets(), // through the exchange may or may not preserve these properties. Hence, it is safest to drop preGroupedSymbols here.
ImmutableList.of(), PARTIAL, node.getHashSymbol(), node.getGroupIdSymbol());
return new AggregationNode(node.getId(), partial, finalAggregation, node.getGroupingSets(), // through the exchange may or may not preserve these properties. Hence, it is safest to drop preGroupedSymbols here.
ImmutableList.of(), FINAL, node.getHashSymbol(), node.getGroupIdSymbol());
}
Aggregations