Search in sources :

Example 16 with ProjectNode

use of io.trino.sql.planner.plan.ProjectNode in project trino by trinodb.

the class PushDownDereferencesThroughAssignUniqueId method apply.

@Override
public Result apply(ProjectNode projectNode, Captures captures, Context context) {
    AssignUniqueId assignUniqueId = captures.get(CHILD);
    // Extract dereferences from project node assignments for pushdown
    Set<SubscriptExpression> dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false, context.getSession(), typeAnalyzer, context.getSymbolAllocator().getTypes());
    if (dereferences.isEmpty()) {
        return Result.empty();
    }
    // Create new symbols for dereference expressions
    Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSession(), context.getSymbolAllocator(), typeAnalyzer);
    // Rewrite project node assignments using new symbols for dereference expressions
    Map<Expression, SymbolReference> mappings = HashBiMap.create(dereferenceAssignments.getMap()).inverse().entrySet().stream().collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference()));
    Assignments newAssignments = projectNode.getAssignments().rewrite(expression -> replaceExpression(expression, mappings));
    return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), assignUniqueId.replaceChildren(ImmutableList.of(new ProjectNode(context.getIdAllocator().getNextId(), assignUniqueId.getSource(), Assignments.builder().putIdentities(assignUniqueId.getSource().getOutputSymbols()).putAll(dereferenceAssignments).build()))), newAssignments));
}
Also used : Patterns.assignUniqueId(io.trino.sql.planner.plan.Patterns.assignUniqueId) Assignments(io.trino.sql.planner.plan.Assignments) Set(java.util.Set) Capture.newCapture(io.trino.matching.Capture.newCapture) Capture(io.trino.matching.Capture) SubscriptExpression(io.trino.sql.tree.SubscriptExpression) DereferencePushdown.extractRowSubscripts(io.trino.sql.planner.iterative.rule.DereferencePushdown.extractRowSubscripts) HashBiMap(com.google.common.collect.HashBiMap) ImmutableMap.toImmutableMap(com.google.common.collect.ImmutableMap.toImmutableMap) Pattern(io.trino.matching.Pattern) TypeAnalyzer(io.trino.sql.planner.TypeAnalyzer) ImmutableList(com.google.common.collect.ImmutableList) Patterns.source(io.trino.sql.planner.plan.Patterns.source) SymbolReference(io.trino.sql.tree.SymbolReference) Captures(io.trino.matching.Captures) Map(java.util.Map) Objects.requireNonNull(java.util.Objects.requireNonNull) ExpressionNodeInliner.replaceExpression(io.trino.sql.planner.ExpressionNodeInliner.replaceExpression) Rule(io.trino.sql.planner.iterative.Rule) Expression(io.trino.sql.tree.Expression) Patterns.project(io.trino.sql.planner.plan.Patterns.project) AssignUniqueId(io.trino.sql.planner.plan.AssignUniqueId) ProjectNode(io.trino.sql.planner.plan.ProjectNode) AssignUniqueId(io.trino.sql.planner.plan.AssignUniqueId) SubscriptExpression(io.trino.sql.tree.SubscriptExpression) ExpressionNodeInliner.replaceExpression(io.trino.sql.planner.ExpressionNodeInliner.replaceExpression) Expression(io.trino.sql.tree.Expression) SymbolReference(io.trino.sql.tree.SymbolReference) Assignments(io.trino.sql.planner.plan.Assignments) SubscriptExpression(io.trino.sql.tree.SubscriptExpression) ProjectNode(io.trino.sql.planner.plan.ProjectNode) HashBiMap(com.google.common.collect.HashBiMap) ImmutableMap.toImmutableMap(com.google.common.collect.ImmutableMap.toImmutableMap) Map(java.util.Map)

Example 17 with ProjectNode

use of io.trino.sql.planner.plan.ProjectNode in project trino by trinodb.

the class PushDownDereferencesThroughSort method apply.

@Override
public Result apply(ProjectNode projectNode, Captures captures, Context context) {
    SortNode sortNode = captures.get(CHILD);
    // Extract dereferences from project node assignments for pushdown
    Set<SubscriptExpression> dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false, context.getSession(), typeAnalyzer, context.getSymbolAllocator().getTypes());
    // Exclude dereferences on symbols used in ordering scheme to avoid replication of data
    dereferences = dereferences.stream().filter(expression -> !sortNode.getOrderingScheme().getOrderBy().contains(getBase(expression))).collect(toImmutableSet());
    if (dereferences.isEmpty()) {
        return Result.empty();
    }
    // Create new symbols for dereference expressions
    Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSession(), context.getSymbolAllocator(), typeAnalyzer);
    // Rewrite project node assignments using new symbols for dereference expressions
    Map<Expression, SymbolReference> mappings = HashBiMap.create(dereferenceAssignments.getMap()).inverse().entrySet().stream().collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference()));
    Assignments newAssignments = projectNode.getAssignments().rewrite(expression -> replaceExpression(expression, mappings));
    return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), sortNode.replaceChildren(ImmutableList.of(new ProjectNode(context.getIdAllocator().getNextId(), sortNode.getSource(), Assignments.builder().putIdentities(sortNode.getSource().getOutputSymbols()).putAll(dereferenceAssignments).build()))), newAssignments));
}
Also used : Patterns.sort(io.trino.sql.planner.plan.Patterns.sort) SortNode(io.trino.sql.planner.plan.SortNode) Capture.newCapture(io.trino.matching.Capture.newCapture) DereferencePushdown.getBase(io.trino.sql.planner.iterative.rule.DereferencePushdown.getBase) SubscriptExpression(io.trino.sql.tree.SubscriptExpression) ImmutableList(com.google.common.collect.ImmutableList) Map(java.util.Map) Objects.requireNonNull(java.util.Objects.requireNonNull) ImmutableSet.toImmutableSet(com.google.common.collect.ImmutableSet.toImmutableSet) Rule(io.trino.sql.planner.iterative.Rule) ProjectNode(io.trino.sql.planner.plan.ProjectNode) Assignments(io.trino.sql.planner.plan.Assignments) Set(java.util.Set) Capture(io.trino.matching.Capture) DereferencePushdown.extractRowSubscripts(io.trino.sql.planner.iterative.rule.DereferencePushdown.extractRowSubscripts) HashBiMap(com.google.common.collect.HashBiMap) ImmutableMap.toImmutableMap(com.google.common.collect.ImmutableMap.toImmutableMap) Pattern(io.trino.matching.Pattern) TypeAnalyzer(io.trino.sql.planner.TypeAnalyzer) Patterns.source(io.trino.sql.planner.plan.Patterns.source) SymbolReference(io.trino.sql.tree.SymbolReference) Captures(io.trino.matching.Captures) ExpressionNodeInliner.replaceExpression(io.trino.sql.planner.ExpressionNodeInliner.replaceExpression) Expression(io.trino.sql.tree.Expression) Patterns.project(io.trino.sql.planner.plan.Patterns.project) SortNode(io.trino.sql.planner.plan.SortNode) SubscriptExpression(io.trino.sql.tree.SubscriptExpression) ExpressionNodeInliner.replaceExpression(io.trino.sql.planner.ExpressionNodeInliner.replaceExpression) Expression(io.trino.sql.tree.Expression) SymbolReference(io.trino.sql.tree.SymbolReference) Assignments(io.trino.sql.planner.plan.Assignments) SubscriptExpression(io.trino.sql.tree.SubscriptExpression) ProjectNode(io.trino.sql.planner.plan.ProjectNode) Map(java.util.Map) HashBiMap(com.google.common.collect.HashBiMap) ImmutableMap.toImmutableMap(com.google.common.collect.ImmutableMap.toImmutableMap)

Example 18 with ProjectNode

use of io.trino.sql.planner.plan.ProjectNode in project trino by trinodb.

the class PushDownDereferenceThroughFilter method apply.

@Override
public Result apply(ProjectNode node, Captures captures, Rule.Context context) {
    FilterNode filterNode = captures.get(CHILD);
    // Pushdown superset of dereference expressions from projections and filtering predicate
    List<Expression> expressions = ImmutableList.<Expression>builder().addAll(node.getAssignments().getExpressions()).add(filterNode.getPredicate()).build();
    // Extract dereferences from project node assignments for pushdown
    Set<SubscriptExpression> dereferences = extractRowSubscripts(expressions, false, context.getSession(), typeAnalyzer, context.getSymbolAllocator().getTypes());
    if (dereferences.isEmpty()) {
        return Result.empty();
    }
    // Create new symbols for dereference expressions
    Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSession(), context.getSymbolAllocator(), typeAnalyzer);
    // Rewrite project node assignments using new symbols for dereference expressions
    Map<Expression, SymbolReference> mappings = HashBiMap.create(dereferenceAssignments.getMap()).inverse().entrySet().stream().collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference()));
    Assignments assignments = node.getAssignments().rewrite(expression -> replaceExpression(expression, mappings));
    PlanNode source = filterNode.getSource();
    return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), new FilterNode(context.getIdAllocator().getNextId(), new ProjectNode(context.getIdAllocator().getNextId(), source, Assignments.builder().putIdentities(source.getOutputSymbols()).putAll(dereferenceAssignments).build()), replaceExpression(filterNode.getPredicate(), mappings)), assignments));
}
Also used : Patterns.filter(io.trino.sql.planner.plan.Patterns.filter) Capture.newCapture(io.trino.matching.Capture.newCapture) FilterNode(io.trino.sql.planner.plan.FilterNode) PlanNode(io.trino.sql.planner.plan.PlanNode) SubscriptExpression(io.trino.sql.tree.SubscriptExpression) ImmutableList(com.google.common.collect.ImmutableList) Map(java.util.Map) Objects.requireNonNull(java.util.Objects.requireNonNull) Rule(io.trino.sql.planner.iterative.Rule) ProjectNode(io.trino.sql.planner.plan.ProjectNode) Assignments(io.trino.sql.planner.plan.Assignments) Set(java.util.Set) Capture(io.trino.matching.Capture) DereferencePushdown.extractRowSubscripts(io.trino.sql.planner.iterative.rule.DereferencePushdown.extractRowSubscripts) HashBiMap(com.google.common.collect.HashBiMap) List(java.util.List) ImmutableMap.toImmutableMap(com.google.common.collect.ImmutableMap.toImmutableMap) Pattern(io.trino.matching.Pattern) TypeAnalyzer(io.trino.sql.planner.TypeAnalyzer) Patterns.source(io.trino.sql.planner.plan.Patterns.source) SymbolReference(io.trino.sql.tree.SymbolReference) Captures(io.trino.matching.Captures) ExpressionNodeInliner.replaceExpression(io.trino.sql.planner.ExpressionNodeInliner.replaceExpression) Expression(io.trino.sql.tree.Expression) Patterns.project(io.trino.sql.planner.plan.Patterns.project) PlanNode(io.trino.sql.planner.plan.PlanNode) SubscriptExpression(io.trino.sql.tree.SubscriptExpression) ExpressionNodeInliner.replaceExpression(io.trino.sql.planner.ExpressionNodeInliner.replaceExpression) Expression(io.trino.sql.tree.Expression) SymbolReference(io.trino.sql.tree.SymbolReference) FilterNode(io.trino.sql.planner.plan.FilterNode) Assignments(io.trino.sql.planner.plan.Assignments) SubscriptExpression(io.trino.sql.tree.SubscriptExpression) ProjectNode(io.trino.sql.planner.plan.ProjectNode) Map(java.util.Map) HashBiMap(com.google.common.collect.HashBiMap) ImmutableMap.toImmutableMap(com.google.common.collect.ImmutableMap.toImmutableMap)

Example 19 with ProjectNode

use of io.trino.sql.planner.plan.ProjectNode in project trino by trinodb.

the class TransformCorrelatedDistinctAggregationWithProjection method apply.

@Override
public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Context context) {
    // decorrelate nested plan
    PlanNodeDecorrelator decorrelator = new PlanNodeDecorrelator(plannerContext, context.getSymbolAllocator(), context.getLookup());
    Optional<PlanNodeDecorrelator.DecorrelatedNode> decorrelatedSource = decorrelator.decorrelateFilters(captures.get(AGGREGATION).getSource(), correlatedJoinNode.getCorrelation());
    if (decorrelatedSource.isEmpty()) {
        return Result.empty();
    }
    PlanNode source = decorrelatedSource.get().getNode();
    // assign unique id on correlated join's input. It will be used to distinguish between original input rows after join
    PlanNode inputWithUniqueId = new AssignUniqueId(context.getIdAllocator().getNextId(), correlatedJoinNode.getInput(), context.getSymbolAllocator().newSymbol("unique", BIGINT));
    JoinNode join = new JoinNode(context.getIdAllocator().getNextId(), JoinNode.Type.LEFT, inputWithUniqueId, source, ImmutableList.of(), inputWithUniqueId.getOutputSymbols(), source.getOutputSymbols(), false, decorrelatedSource.get().getCorrelatedPredicates(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), ImmutableMap.of(), Optional.empty());
    // restore aggregation
    AggregationNode aggregation = captures.get(AGGREGATION);
    aggregation = new AggregationNode(aggregation.getId(), join, aggregation.getAggregations(), singleGroupingSet(ImmutableList.<Symbol>builder().addAll(join.getLeftOutputSymbols()).addAll(aggregation.getGroupingKeys()).build()), ImmutableList.of(), aggregation.getStep(), Optional.empty(), Optional.empty());
    // restrict outputs and apply projection
    Set<Symbol> outputSymbols = new HashSet<>(correlatedJoinNode.getOutputSymbols());
    List<Symbol> expectedAggregationOutputs = aggregation.getOutputSymbols().stream().filter(outputSymbols::contains).collect(toImmutableList());
    Assignments assignments = Assignments.builder().putIdentities(expectedAggregationOutputs).putAll(captures.get(PROJECTION).getAssignments()).build();
    return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), aggregation, assignments));
}
Also used : CorrelatedJoinNode(io.trino.sql.planner.plan.CorrelatedJoinNode) JoinNode(io.trino.sql.planner.plan.JoinNode) Symbol(io.trino.sql.planner.Symbol) Assignments(io.trino.sql.planner.plan.Assignments) AggregationNode(io.trino.sql.planner.plan.AggregationNode) PlanNodeDecorrelator(io.trino.sql.planner.optimizations.PlanNodeDecorrelator) PlanNode(io.trino.sql.planner.plan.PlanNode) AssignUniqueId(io.trino.sql.planner.plan.AssignUniqueId) ProjectNode(io.trino.sql.planner.plan.ProjectNode) HashSet(java.util.HashSet)

Example 20 with ProjectNode

use of io.trino.sql.planner.plan.ProjectNode in project trino by trinodb.

the class TransformCorrelatedGlobalAggregationWithProjection method apply.

@Override
public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Context context) {
    checkArgument(correlatedJoinNode.getType() == INNER || correlatedJoinNode.getType() == LEFT, "unexpected correlated join type: " + correlatedJoinNode.getType());
    // if there is another aggregation below the AggregationNode, handle both
    PlanNode source = captures.get(SOURCE);
    AggregationNode distinct = null;
    if (isDistinctOperator(source)) {
        distinct = (AggregationNode) source;
        source = distinct.getSource();
    }
    // decorrelate nested plan
    PlanNodeDecorrelator decorrelator = new PlanNodeDecorrelator(plannerContext, context.getSymbolAllocator(), context.getLookup());
    Optional<PlanNodeDecorrelator.DecorrelatedNode> decorrelatedSource = decorrelator.decorrelateFilters(source, correlatedJoinNode.getCorrelation());
    if (decorrelatedSource.isEmpty()) {
        return Result.empty();
    }
    source = decorrelatedSource.get().getNode();
    // append non-null symbol on nested plan. It will be used to restore semantics of null-sensitive aggregations after LEFT join
    Symbol nonNull = context.getSymbolAllocator().newSymbol("non_null", BOOLEAN);
    source = new ProjectNode(context.getIdAllocator().getNextId(), source, Assignments.builder().putIdentities(source.getOutputSymbols()).put(nonNull, TRUE_LITERAL).build());
    // assign unique id on correlated join's input. It will be used to distinguish between original input rows after join
    PlanNode inputWithUniqueId = new AssignUniqueId(context.getIdAllocator().getNextId(), correlatedJoinNode.getInput(), context.getSymbolAllocator().newSymbol("unique", BIGINT));
    JoinNode join = new JoinNode(context.getIdAllocator().getNextId(), JoinNode.Type.LEFT, inputWithUniqueId, source, ImmutableList.of(), inputWithUniqueId.getOutputSymbols(), source.getOutputSymbols(), false, decorrelatedSource.get().getCorrelatedPredicates(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), ImmutableMap.of(), Optional.empty());
    PlanNode root = join;
    // restore distinct aggregation
    if (distinct != null) {
        root = new AggregationNode(distinct.getId(), join, distinct.getAggregations(), singleGroupingSet(ImmutableList.<Symbol>builder().addAll(join.getLeftOutputSymbols()).add(nonNull).addAll(distinct.getGroupingKeys()).build()), ImmutableList.of(), distinct.getStep(), Optional.empty(), Optional.empty());
    }
    // prepare mask symbols for aggregations
    // Every original aggregation agg() will be rewritten to agg() mask(non_null). If the aggregation
    // already has a mask, it will be replaced with conjunction of the existing mask and non_null.
    // This is necessary to restore the original aggregation result in case when:
    // - the nested lateral subquery returned empty result for some input row,
    // - aggregation is null-sensitive, which means that its result over a single null row is different
    // than result for empty input (with global grouping)
    // It applies to the following aggregate functions: count(*), checksum(), array_agg().
    AggregationNode globalAggregation = captures.get(AGGREGATION);
    ImmutableMap.Builder<Symbol, Symbol> masks = ImmutableMap.builder();
    Assignments.Builder assignmentsBuilder = Assignments.builder();
    for (Map.Entry<Symbol, Aggregation> entry : globalAggregation.getAggregations().entrySet()) {
        Aggregation aggregation = entry.getValue();
        if (aggregation.getMask().isPresent()) {
            Symbol newMask = context.getSymbolAllocator().newSymbol("mask", BOOLEAN);
            Expression expression = and(aggregation.getMask().get().toSymbolReference(), nonNull.toSymbolReference());
            assignmentsBuilder.put(newMask, expression);
            masks.put(entry.getKey(), newMask);
        } else {
            masks.put(entry.getKey(), nonNull);
        }
    }
    Assignments maskAssignments = assignmentsBuilder.build();
    if (!maskAssignments.isEmpty()) {
        root = new ProjectNode(context.getIdAllocator().getNextId(), root, Assignments.builder().putIdentities(root.getOutputSymbols()).putAll(maskAssignments).build());
    }
    // restore global aggregation
    globalAggregation = new AggregationNode(globalAggregation.getId(), root, rewriteWithMasks(globalAggregation.getAggregations(), masks.buildOrThrow()), singleGroupingSet(ImmutableList.<Symbol>builder().addAll(join.getLeftOutputSymbols()).addAll(globalAggregation.getGroupingKeys()).build()), ImmutableList.of(), globalAggregation.getStep(), Optional.empty(), Optional.empty());
    // restrict outputs and apply projection
    Set<Symbol> outputSymbols = new HashSet<>(correlatedJoinNode.getOutputSymbols());
    List<Symbol> expectedAggregationOutputs = globalAggregation.getOutputSymbols().stream().filter(outputSymbols::contains).collect(toImmutableList());
    Assignments assignments = Assignments.builder().putIdentities(expectedAggregationOutputs).putAll(captures.get(PROJECTION).getAssignments()).build();
    return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), globalAggregation, assignments));
}
Also used : Symbol(io.trino.sql.planner.Symbol) CorrelatedJoinNode(io.trino.sql.planner.plan.CorrelatedJoinNode) JoinNode(io.trino.sql.planner.plan.JoinNode) Assignments(io.trino.sql.planner.plan.Assignments) AggregationNode(io.trino.sql.planner.plan.AggregationNode) ImmutableMap(com.google.common.collect.ImmutableMap) Aggregation(io.trino.sql.planner.plan.AggregationNode.Aggregation) PlanNodeDecorrelator(io.trino.sql.planner.optimizations.PlanNodeDecorrelator) PlanNode(io.trino.sql.planner.plan.PlanNode) AssignUniqueId(io.trino.sql.planner.plan.AssignUniqueId) Expression(io.trino.sql.tree.Expression) ProjectNode(io.trino.sql.planner.plan.ProjectNode) Map(java.util.Map) ImmutableMap(com.google.common.collect.ImmutableMap) HashSet(java.util.HashSet)

Aggregations

ProjectNode (io.trino.sql.planner.plan.ProjectNode)101 Assignments (io.trino.sql.planner.plan.Assignments)64 Expression (io.trino.sql.tree.Expression)58 Symbol (io.trino.sql.planner.Symbol)57 PlanNode (io.trino.sql.planner.plan.PlanNode)50 ImmutableList (com.google.common.collect.ImmutableList)41 Map (java.util.Map)35 Captures (io.trino.matching.Captures)29 Pattern (io.trino.matching.Pattern)29 Rule (io.trino.sql.planner.iterative.Rule)29 Set (java.util.Set)25 Capture (io.trino.matching.Capture)23 Capture.newCapture (io.trino.matching.Capture.newCapture)23 SymbolReference (io.trino.sql.tree.SymbolReference)23 Objects.requireNonNull (java.util.Objects.requireNonNull)23 Patterns.source (io.trino.sql.planner.plan.Patterns.source)22 TypeAnalyzer (io.trino.sql.planner.TypeAnalyzer)20 AggregationNode (io.trino.sql.planner.plan.AggregationNode)20 FilterNode (io.trino.sql.planner.plan.FilterNode)20 ImmutableMap.toImmutableMap (com.google.common.collect.ImmutableMap.toImmutableMap)19