Search in sources :

Example 11 with Symbol

use of io.trino.sql.planner.Symbol in project trino by trinodb.

the class DecorrelateInnerUnnestWithGlobalAggregation method apply.

@Override
public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Context context) {
    // find global aggregation in subquery
    List<PlanNode> globalAggregations = PlanNodeSearcher.searchFrom(correlatedJoinNode.getSubquery(), context.getLookup()).where(DecorrelateInnerUnnestWithGlobalAggregation::isGlobalAggregation).recurseOnlyWhen(node -> node instanceof ProjectNode || isGlobalAggregation(node)).findAll();
    if (globalAggregations.isEmpty()) {
        return Result.empty();
    }
    // if there are multiple global aggregations, the one that is closest to the source is the "reducing" aggregation, because it reduces multiple input rows to single output row
    AggregationNode reducingAggregation = (AggregationNode) globalAggregations.get(globalAggregations.size() - 1);
    // find unnest in subquery
    Optional<UnnestNode> subqueryUnnest = PlanNodeSearcher.searchFrom(reducingAggregation.getSource(), context.getLookup()).where(node -> isSupportedUnnest(node, correlatedJoinNode.getCorrelation(), context.getLookup())).recurseOnlyWhen(node -> node instanceof ProjectNode || isGroupedAggregation(node)).findFirst();
    if (subqueryUnnest.isEmpty()) {
        return Result.empty();
    }
    UnnestNode unnestNode = subqueryUnnest.get();
    // assign unique id to input rows to restore semantics of aggregations after rewrite
    PlanNode input = new AssignUniqueId(context.getIdAllocator().getNextId(), correlatedJoinNode.getInput(), context.getSymbolAllocator().newSymbol("unique", BIGINT));
    // pre-project unnest symbols if they were pre-projected in subquery
    // The correlated UnnestNode either unnests correlation symbols directly, or unnests symbols produced by a projection that uses only correlation symbols.
    // Here, any underlying projection that was a source of the correlated UnnestNode, is appended as a source of the rewritten UnnestNode.
    // If the projection is not necessary for UnnestNode (i.e. it does not produce any unnest symbols), it should be pruned afterwards.
    PlanNode unnestSource = context.getLookup().resolve(unnestNode.getSource());
    if (unnestSource instanceof ProjectNode) {
        ProjectNode sourceProjection = (ProjectNode) unnestSource;
        input = new ProjectNode(sourceProjection.getId(), input, Assignments.builder().putIdentities(input.getOutputSymbols()).putAll(sourceProjection.getAssignments()).build());
    }
    // rewrite correlated join to UnnestNode
    Symbol ordinalitySymbol = unnestNode.getOrdinalitySymbol().orElseGet(() -> context.getSymbolAllocator().newSymbol("ordinality", BIGINT));
    UnnestNode rewrittenUnnest = new UnnestNode(context.getIdAllocator().getNextId(), input, input.getOutputSymbols(), unnestNode.getMappings(), Optional.of(ordinalitySymbol), LEFT, Optional.empty());
    // append mask symbol based on ordinality to distinguish between the unnested rows and synthetic null rows
    Symbol mask = context.getSymbolAllocator().newSymbol("mask", BOOLEAN);
    ProjectNode sourceWithMask = new ProjectNode(context.getIdAllocator().getNextId(), rewrittenUnnest, Assignments.builder().putIdentities(rewrittenUnnest.getOutputSymbols()).put(mask, new IsNotNullPredicate(ordinalitySymbol.toSymbolReference())).build());
    // restore all projections, grouped aggregations and global aggregations from the subquery
    PlanNode result = rewriteNodeSequence(context.getLookup().resolve(correlatedJoinNode.getSubquery()), input.getOutputSymbols(), mask, sourceWithMask, reducingAggregation.getId(), unnestNode.getId(), context.getSymbolAllocator(), context.getIdAllocator(), context.getLookup());
    // restrict outputs
    return Result.ofPlanNode(restrictOutputs(context.getIdAllocator(), result, ImmutableSet.copyOf(correlatedJoinNode.getOutputSymbols())).orElse(result));
}
Also used : CorrelatedJoin.correlation(io.trino.sql.planner.plan.Patterns.CorrelatedJoin.correlation) QueryCardinalityUtil.isScalar(io.trino.sql.planner.optimizations.QueryCardinalityUtil.isScalar) INNER(io.trino.sql.planner.plan.JoinNode.Type.INNER) SymbolAllocator(io.trino.sql.planner.SymbolAllocator) BOOLEAN(io.trino.spi.type.BooleanType.BOOLEAN) SINGLE(io.trino.sql.planner.plan.AggregationNode.Step.SINGLE) CorrelatedJoinNode(io.trino.sql.planner.plan.CorrelatedJoinNode) PlanNode(io.trino.sql.planner.plan.PlanNode) AggregationDecorrelation.rewriteWithMasks(io.trino.sql.planner.iterative.rule.AggregationDecorrelation.rewriteWithMasks) CorrelatedJoin.filter(io.trino.sql.planner.plan.Patterns.CorrelatedJoin.filter) LEFT(io.trino.sql.planner.plan.JoinNode.Type.LEFT) ImmutableList(com.google.common.collect.ImmutableList) PlanNodeSearcher(io.trino.sql.planner.optimizations.PlanNodeSearcher) PlanNodeId(io.trino.sql.planner.plan.PlanNodeId) IsNotNullPredicate(io.trino.sql.tree.IsNotNullPredicate) Map(java.util.Map) AggregationNode(io.trino.sql.planner.plan.AggregationNode) Rule(io.trino.sql.planner.iterative.Rule) SymbolsExtractor(io.trino.sql.planner.SymbolsExtractor) Pattern.nonEmpty(io.trino.matching.Pattern.nonEmpty) AssignUniqueId(io.trino.sql.planner.plan.AssignUniqueId) ProjectNode(io.trino.sql.planner.plan.ProjectNode) Symbol(io.trino.sql.planner.Symbol) ImmutableSet(com.google.common.collect.ImmutableSet) ImmutableMap(com.google.common.collect.ImmutableMap) Lookup(io.trino.sql.planner.iterative.Lookup) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) Assignments(io.trino.sql.planner.plan.Assignments) Set(java.util.Set) Iterables.getOnlyElement(com.google.common.collect.Iterables.getOnlyElement) TRUE_LITERAL(io.trino.sql.tree.BooleanLiteral.TRUE_LITERAL) Streams(com.google.common.collect.Streams) UnnestNode(io.trino.sql.planner.plan.UnnestNode) AggregationNode.singleGroupingSet(io.trino.sql.planner.plan.AggregationNode.singleGroupingSet) List(java.util.List) Pattern(io.trino.matching.Pattern) BIGINT(io.trino.spi.type.BigintType.BIGINT) ExpressionUtils.and(io.trino.sql.ExpressionUtils.and) Captures(io.trino.matching.Captures) Util.restrictOutputs(io.trino.sql.planner.iterative.rule.Util.restrictOutputs) Mapping(io.trino.sql.planner.plan.UnnestNode.Mapping) Optional(java.util.Optional) Expression(io.trino.sql.tree.Expression) PlanNodeIdAllocator(io.trino.sql.planner.PlanNodeIdAllocator) Patterns.correlatedJoin(io.trino.sql.planner.plan.Patterns.correlatedJoin) PlanNode(io.trino.sql.planner.plan.PlanNode) AssignUniqueId(io.trino.sql.planner.plan.AssignUniqueId) UnnestNode(io.trino.sql.planner.plan.UnnestNode) Symbol(io.trino.sql.planner.Symbol) ProjectNode(io.trino.sql.planner.plan.ProjectNode) AggregationNode(io.trino.sql.planner.plan.AggregationNode) IsNotNullPredicate(io.trino.sql.tree.IsNotNullPredicate)

Example 12 with Symbol

use of io.trino.sql.planner.Symbol in project trino by trinodb.

the class DecorrelateLeftUnnestWithGlobalAggregation method isSupportedUnnest.

/**
 * This rule supports decorrelation of UnnestNode meeting certain conditions:
 * - the UnnestNode should be based on correlation symbols, that is: either unnest correlation symbols directly,
 * or unnest symbols produced by a projection that uses only correlation symbols.
 * - the UnnestNode should not have any replicate symbols,
 * - the UnnestNode should be of type LEFT,
 * - the UnnestNode should not have a filter.
 */
private static boolean isSupportedUnnest(PlanNode node, List<Symbol> correlation, Lookup lookup) {
    if (!(node instanceof UnnestNode)) {
        return false;
    }
    UnnestNode unnestNode = (UnnestNode) node;
    List<Symbol> unnestSymbols = unnestNode.getMappings().stream().map(UnnestNode.Mapping::getInput).collect(toImmutableList());
    PlanNode unnestSource = lookup.resolve(unnestNode.getSource());
    boolean basedOnCorrelation = ImmutableSet.copyOf(correlation).containsAll(unnestSymbols) || unnestSource instanceof ProjectNode && ImmutableSet.copyOf(correlation).containsAll(SymbolsExtractor.extractUnique(((ProjectNode) unnestSource).getAssignments().getExpressions()));
    return isScalar(unnestNode.getSource(), lookup) && unnestNode.getReplicateSymbols().isEmpty() && basedOnCorrelation && unnestNode.getJoinType() == LEFT && (unnestNode.getFilter().isEmpty() || unnestNode.getFilter().get().equals(TRUE_LITERAL));
}
Also used : PlanNode(io.trino.sql.planner.plan.PlanNode) UnnestNode(io.trino.sql.planner.plan.UnnestNode) Symbol(io.trino.sql.planner.Symbol) ProjectNode(io.trino.sql.planner.plan.ProjectNode)

Example 13 with Symbol

use of io.trino.sql.planner.Symbol in project trino by trinodb.

the class DecorrelateUnnest method isSupportedUnnest.

/**
 * This rule supports decorrelation of UnnestNode meeting certain conditions:
 * - the UnnestNode should be based on correlation symbols, that is: either unnest correlation symbols directly,
 * or unnest symbols produced by a projection that uses only correlation symbols.
 * - the UnnestNode should not have any replicate symbols,
 * - the UnnestNode should be of type INNER or LEFT,
 * - the UnnestNode should not have a filter.
 */
private static boolean isSupportedUnnest(PlanNode node, List<Symbol> correlation, Lookup lookup) {
    if (!(node instanceof UnnestNode)) {
        return false;
    }
    UnnestNode unnestNode = (UnnestNode) node;
    List<Symbol> unnestSymbols = unnestNode.getMappings().stream().map(UnnestNode.Mapping::getInput).collect(toImmutableList());
    PlanNode unnestSource = lookup.resolve(unnestNode.getSource());
    boolean basedOnCorrelation = ImmutableSet.copyOf(correlation).containsAll(unnestSymbols) || unnestSource instanceof ProjectNode && ImmutableSet.copyOf(correlation).containsAll(SymbolsExtractor.extractUnique(((ProjectNode) unnestSource).getAssignments().getExpressions()));
    return isScalar(unnestNode.getSource(), lookup) && unnestNode.getReplicateSymbols().isEmpty() && basedOnCorrelation && (unnestNode.getJoinType() == INNER || unnestNode.getJoinType() == LEFT) && (unnestNode.getFilter().isEmpty() || unnestNode.getFilter().get().equals(TRUE_LITERAL));
}
Also used : PlanNode(io.trino.sql.planner.plan.PlanNode) UnnestNode(io.trino.sql.planner.plan.UnnestNode) Symbol(io.trino.sql.planner.Symbol) ProjectNode(io.trino.sql.planner.plan.ProjectNode)

Example 14 with Symbol

use of io.trino.sql.planner.Symbol in project trino by trinodb.

the class AddIntermediateAggregations method outputsAsInputs.

/**
 * Rewrite assignments so that inputs are in terms of the output symbols.
 * <p>
 * Example:
 * 'a' := sum('b') => 'a' := sum('a')
 * 'a' := count(*) => 'a' := count('a')
 */
private static Map<Symbol, AggregationNode.Aggregation> outputsAsInputs(Map<Symbol, AggregationNode.Aggregation> assignments) {
    ImmutableMap.Builder<Symbol, AggregationNode.Aggregation> builder = ImmutableMap.builder();
    for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : assignments.entrySet()) {
        Symbol output = entry.getKey();
        AggregationNode.Aggregation aggregation = entry.getValue();
        checkState(aggregation.getOrderingScheme().isEmpty(), "Intermediate aggregation does not support ORDER BY");
        builder.put(output, new AggregationNode.Aggregation(aggregation.getResolvedFunction(), ImmutableList.of(output.toSymbolReference()), false, Optional.empty(), Optional.empty(), // No mask for INTERMEDIATE
        Optional.empty()));
    }
    return builder.buildOrThrow();
}
Also used : Symbol(io.trino.sql.planner.Symbol) AggregationNode(io.trino.sql.planner.plan.AggregationNode) Map(java.util.Map) ImmutableMap(com.google.common.collect.ImmutableMap) ImmutableMap(com.google.common.collect.ImmutableMap)

Example 15 with Symbol

use of io.trino.sql.planner.Symbol in project trino by trinodb.

the class AddIntermediateAggregations method inputsAsOutputs.

/**
 * Rewrite assignments so that outputs are in terms of the input symbols.
 * This operation only reliably applies to aggregation steps that take partial inputs (e.g. INTERMEDIATE and split FINALs),
 * which are guaranteed to have exactly one input and one output.
 * <p>
 * Example:
 * 'a' := sum('b') => 'b' := sum('b')
 */
private static Map<Symbol, AggregationNode.Aggregation> inputsAsOutputs(Map<Symbol, AggregationNode.Aggregation> assignments) {
    ImmutableMap.Builder<Symbol, AggregationNode.Aggregation> builder = ImmutableMap.builder();
    for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : assignments.entrySet()) {
        // Should only have one input symbol
        Symbol input = getOnlyElement(SymbolsExtractor.extractAll(entry.getValue()));
        builder.put(input, entry.getValue());
    }
    return builder.buildOrThrow();
}
Also used : Symbol(io.trino.sql.planner.Symbol) Map(java.util.Map) ImmutableMap(com.google.common.collect.ImmutableMap) ImmutableMap(com.google.common.collect.ImmutableMap)

Aggregations

Symbol (io.trino.sql.planner.Symbol)366 Test (org.testng.annotations.Test)171 ImmutableList (com.google.common.collect.ImmutableList)124 ImmutableMap (com.google.common.collect.ImmutableMap)106 Optional (java.util.Optional)96 PlanNode (io.trino.sql.planner.plan.PlanNode)85 Expression (io.trino.sql.tree.Expression)85 Map (java.util.Map)66 Assignments (io.trino.sql.planner.plan.Assignments)65 JoinNode (io.trino.sql.planner.plan.JoinNode)64 ProjectNode (io.trino.sql.planner.plan.ProjectNode)61 PlanMatchPattern.values (io.trino.sql.planner.assertions.PlanMatchPattern.values)56 SymbolReference (io.trino.sql.tree.SymbolReference)55 BIGINT (io.trino.spi.type.BigintType.BIGINT)52 Type (io.trino.spi.type.Type)51 Session (io.trino.Session)46 List (java.util.List)46 RuleTester (io.trino.sql.planner.iterative.rule.test.RuleTester)43 RuleTester.defaultRuleTester (io.trino.sql.planner.iterative.rule.test.RuleTester.defaultRuleTester)43 PlanNodeId (io.trino.sql.planner.plan.PlanNodeId)43