use of io.trino.sql.planner.Symbol in project trino by trinodb.
the class DecorrelateInnerUnnestWithGlobalAggregation method apply.
@Override
public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Context context) {
// find global aggregation in subquery
List<PlanNode> globalAggregations = PlanNodeSearcher.searchFrom(correlatedJoinNode.getSubquery(), context.getLookup()).where(DecorrelateInnerUnnestWithGlobalAggregation::isGlobalAggregation).recurseOnlyWhen(node -> node instanceof ProjectNode || isGlobalAggregation(node)).findAll();
if (globalAggregations.isEmpty()) {
return Result.empty();
}
// if there are multiple global aggregations, the one that is closest to the source is the "reducing" aggregation, because it reduces multiple input rows to single output row
AggregationNode reducingAggregation = (AggregationNode) globalAggregations.get(globalAggregations.size() - 1);
// find unnest in subquery
Optional<UnnestNode> subqueryUnnest = PlanNodeSearcher.searchFrom(reducingAggregation.getSource(), context.getLookup()).where(node -> isSupportedUnnest(node, correlatedJoinNode.getCorrelation(), context.getLookup())).recurseOnlyWhen(node -> node instanceof ProjectNode || isGroupedAggregation(node)).findFirst();
if (subqueryUnnest.isEmpty()) {
return Result.empty();
}
UnnestNode unnestNode = subqueryUnnest.get();
// assign unique id to input rows to restore semantics of aggregations after rewrite
PlanNode input = new AssignUniqueId(context.getIdAllocator().getNextId(), correlatedJoinNode.getInput(), context.getSymbolAllocator().newSymbol("unique", BIGINT));
// pre-project unnest symbols if they were pre-projected in subquery
// The correlated UnnestNode either unnests correlation symbols directly, or unnests symbols produced by a projection that uses only correlation symbols.
// Here, any underlying projection that was a source of the correlated UnnestNode, is appended as a source of the rewritten UnnestNode.
// If the projection is not necessary for UnnestNode (i.e. it does not produce any unnest symbols), it should be pruned afterwards.
PlanNode unnestSource = context.getLookup().resolve(unnestNode.getSource());
if (unnestSource instanceof ProjectNode) {
ProjectNode sourceProjection = (ProjectNode) unnestSource;
input = new ProjectNode(sourceProjection.getId(), input, Assignments.builder().putIdentities(input.getOutputSymbols()).putAll(sourceProjection.getAssignments()).build());
}
// rewrite correlated join to UnnestNode
Symbol ordinalitySymbol = unnestNode.getOrdinalitySymbol().orElseGet(() -> context.getSymbolAllocator().newSymbol("ordinality", BIGINT));
UnnestNode rewrittenUnnest = new UnnestNode(context.getIdAllocator().getNextId(), input, input.getOutputSymbols(), unnestNode.getMappings(), Optional.of(ordinalitySymbol), LEFT, Optional.empty());
// append mask symbol based on ordinality to distinguish between the unnested rows and synthetic null rows
Symbol mask = context.getSymbolAllocator().newSymbol("mask", BOOLEAN);
ProjectNode sourceWithMask = new ProjectNode(context.getIdAllocator().getNextId(), rewrittenUnnest, Assignments.builder().putIdentities(rewrittenUnnest.getOutputSymbols()).put(mask, new IsNotNullPredicate(ordinalitySymbol.toSymbolReference())).build());
// restore all projections, grouped aggregations and global aggregations from the subquery
PlanNode result = rewriteNodeSequence(context.getLookup().resolve(correlatedJoinNode.getSubquery()), input.getOutputSymbols(), mask, sourceWithMask, reducingAggregation.getId(), unnestNode.getId(), context.getSymbolAllocator(), context.getIdAllocator(), context.getLookup());
// restrict outputs
return Result.ofPlanNode(restrictOutputs(context.getIdAllocator(), result, ImmutableSet.copyOf(correlatedJoinNode.getOutputSymbols())).orElse(result));
}
use of io.trino.sql.planner.Symbol in project trino by trinodb.
the class DecorrelateLeftUnnestWithGlobalAggregation method isSupportedUnnest.
/**
* This rule supports decorrelation of UnnestNode meeting certain conditions:
* - the UnnestNode should be based on correlation symbols, that is: either unnest correlation symbols directly,
* or unnest symbols produced by a projection that uses only correlation symbols.
* - the UnnestNode should not have any replicate symbols,
* - the UnnestNode should be of type LEFT,
* - the UnnestNode should not have a filter.
*/
private static boolean isSupportedUnnest(PlanNode node, List<Symbol> correlation, Lookup lookup) {
if (!(node instanceof UnnestNode)) {
return false;
}
UnnestNode unnestNode = (UnnestNode) node;
List<Symbol> unnestSymbols = unnestNode.getMappings().stream().map(UnnestNode.Mapping::getInput).collect(toImmutableList());
PlanNode unnestSource = lookup.resolve(unnestNode.getSource());
boolean basedOnCorrelation = ImmutableSet.copyOf(correlation).containsAll(unnestSymbols) || unnestSource instanceof ProjectNode && ImmutableSet.copyOf(correlation).containsAll(SymbolsExtractor.extractUnique(((ProjectNode) unnestSource).getAssignments().getExpressions()));
return isScalar(unnestNode.getSource(), lookup) && unnestNode.getReplicateSymbols().isEmpty() && basedOnCorrelation && unnestNode.getJoinType() == LEFT && (unnestNode.getFilter().isEmpty() || unnestNode.getFilter().get().equals(TRUE_LITERAL));
}
use of io.trino.sql.planner.Symbol in project trino by trinodb.
the class DecorrelateUnnest method isSupportedUnnest.
/**
* This rule supports decorrelation of UnnestNode meeting certain conditions:
* - the UnnestNode should be based on correlation symbols, that is: either unnest correlation symbols directly,
* or unnest symbols produced by a projection that uses only correlation symbols.
* - the UnnestNode should not have any replicate symbols,
* - the UnnestNode should be of type INNER or LEFT,
* - the UnnestNode should not have a filter.
*/
private static boolean isSupportedUnnest(PlanNode node, List<Symbol> correlation, Lookup lookup) {
if (!(node instanceof UnnestNode)) {
return false;
}
UnnestNode unnestNode = (UnnestNode) node;
List<Symbol> unnestSymbols = unnestNode.getMappings().stream().map(UnnestNode.Mapping::getInput).collect(toImmutableList());
PlanNode unnestSource = lookup.resolve(unnestNode.getSource());
boolean basedOnCorrelation = ImmutableSet.copyOf(correlation).containsAll(unnestSymbols) || unnestSource instanceof ProjectNode && ImmutableSet.copyOf(correlation).containsAll(SymbolsExtractor.extractUnique(((ProjectNode) unnestSource).getAssignments().getExpressions()));
return isScalar(unnestNode.getSource(), lookup) && unnestNode.getReplicateSymbols().isEmpty() && basedOnCorrelation && (unnestNode.getJoinType() == INNER || unnestNode.getJoinType() == LEFT) && (unnestNode.getFilter().isEmpty() || unnestNode.getFilter().get().equals(TRUE_LITERAL));
}
use of io.trino.sql.planner.Symbol in project trino by trinodb.
the class AddIntermediateAggregations method outputsAsInputs.
/**
* Rewrite assignments so that inputs are in terms of the output symbols.
* <p>
* Example:
* 'a' := sum('b') => 'a' := sum('a')
* 'a' := count(*) => 'a' := count('a')
*/
private static Map<Symbol, AggregationNode.Aggregation> outputsAsInputs(Map<Symbol, AggregationNode.Aggregation> assignments) {
ImmutableMap.Builder<Symbol, AggregationNode.Aggregation> builder = ImmutableMap.builder();
for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : assignments.entrySet()) {
Symbol output = entry.getKey();
AggregationNode.Aggregation aggregation = entry.getValue();
checkState(aggregation.getOrderingScheme().isEmpty(), "Intermediate aggregation does not support ORDER BY");
builder.put(output, new AggregationNode.Aggregation(aggregation.getResolvedFunction(), ImmutableList.of(output.toSymbolReference()), false, Optional.empty(), Optional.empty(), // No mask for INTERMEDIATE
Optional.empty()));
}
return builder.buildOrThrow();
}
use of io.trino.sql.planner.Symbol in project trino by trinodb.
the class AddIntermediateAggregations method inputsAsOutputs.
/**
* Rewrite assignments so that outputs are in terms of the input symbols.
* This operation only reliably applies to aggregation steps that take partial inputs (e.g. INTERMEDIATE and split FINALs),
* which are guaranteed to have exactly one input and one output.
* <p>
* Example:
* 'a' := sum('b') => 'b' := sum('b')
*/
private static Map<Symbol, AggregationNode.Aggregation> inputsAsOutputs(Map<Symbol, AggregationNode.Aggregation> assignments) {
ImmutableMap.Builder<Symbol, AggregationNode.Aggregation> builder = ImmutableMap.builder();
for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : assignments.entrySet()) {
// Should only have one input symbol
Symbol input = getOnlyElement(SymbolsExtractor.extractAll(entry.getValue()));
builder.put(input, entry.getValue());
}
return builder.buildOrThrow();
}
Aggregations