Search in sources :

Example 21 with PlanNode

use of io.trino.sql.planner.plan.PlanNode in project trino by trinodb.

the class TransformCorrelatedGlobalAggregationWithProjection method apply.

@Override
public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Context context) {
    checkArgument(correlatedJoinNode.getType() == INNER || correlatedJoinNode.getType() == LEFT, "unexpected correlated join type: " + correlatedJoinNode.getType());
    // if there is another aggregation below the AggregationNode, handle both
    PlanNode source = captures.get(SOURCE);
    AggregationNode distinct = null;
    if (isDistinctOperator(source)) {
        distinct = (AggregationNode) source;
        source = distinct.getSource();
    }
    // decorrelate nested plan
    PlanNodeDecorrelator decorrelator = new PlanNodeDecorrelator(plannerContext, context.getSymbolAllocator(), context.getLookup());
    Optional<PlanNodeDecorrelator.DecorrelatedNode> decorrelatedSource = decorrelator.decorrelateFilters(source, correlatedJoinNode.getCorrelation());
    if (decorrelatedSource.isEmpty()) {
        return Result.empty();
    }
    source = decorrelatedSource.get().getNode();
    // append non-null symbol on nested plan. It will be used to restore semantics of null-sensitive aggregations after LEFT join
    Symbol nonNull = context.getSymbolAllocator().newSymbol("non_null", BOOLEAN);
    source = new ProjectNode(context.getIdAllocator().getNextId(), source, Assignments.builder().putIdentities(source.getOutputSymbols()).put(nonNull, TRUE_LITERAL).build());
    // assign unique id on correlated join's input. It will be used to distinguish between original input rows after join
    PlanNode inputWithUniqueId = new AssignUniqueId(context.getIdAllocator().getNextId(), correlatedJoinNode.getInput(), context.getSymbolAllocator().newSymbol("unique", BIGINT));
    JoinNode join = new JoinNode(context.getIdAllocator().getNextId(), JoinNode.Type.LEFT, inputWithUniqueId, source, ImmutableList.of(), inputWithUniqueId.getOutputSymbols(), source.getOutputSymbols(), false, decorrelatedSource.get().getCorrelatedPredicates(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), ImmutableMap.of(), Optional.empty());
    PlanNode root = join;
    // restore distinct aggregation
    if (distinct != null) {
        root = new AggregationNode(distinct.getId(), join, distinct.getAggregations(), singleGroupingSet(ImmutableList.<Symbol>builder().addAll(join.getLeftOutputSymbols()).add(nonNull).addAll(distinct.getGroupingKeys()).build()), ImmutableList.of(), distinct.getStep(), Optional.empty(), Optional.empty());
    }
    // prepare mask symbols for aggregations
    // Every original aggregation agg() will be rewritten to agg() mask(non_null). If the aggregation
    // already has a mask, it will be replaced with conjunction of the existing mask and non_null.
    // This is necessary to restore the original aggregation result in case when:
    // - the nested lateral subquery returned empty result for some input row,
    // - aggregation is null-sensitive, which means that its result over a single null row is different
    // than result for empty input (with global grouping)
    // It applies to the following aggregate functions: count(*), checksum(), array_agg().
    AggregationNode globalAggregation = captures.get(AGGREGATION);
    ImmutableMap.Builder<Symbol, Symbol> masks = ImmutableMap.builder();
    Assignments.Builder assignmentsBuilder = Assignments.builder();
    for (Map.Entry<Symbol, Aggregation> entry : globalAggregation.getAggregations().entrySet()) {
        Aggregation aggregation = entry.getValue();
        if (aggregation.getMask().isPresent()) {
            Symbol newMask = context.getSymbolAllocator().newSymbol("mask", BOOLEAN);
            Expression expression = and(aggregation.getMask().get().toSymbolReference(), nonNull.toSymbolReference());
            assignmentsBuilder.put(newMask, expression);
            masks.put(entry.getKey(), newMask);
        } else {
            masks.put(entry.getKey(), nonNull);
        }
    }
    Assignments maskAssignments = assignmentsBuilder.build();
    if (!maskAssignments.isEmpty()) {
        root = new ProjectNode(context.getIdAllocator().getNextId(), root, Assignments.builder().putIdentities(root.getOutputSymbols()).putAll(maskAssignments).build());
    }
    // restore global aggregation
    globalAggregation = new AggregationNode(globalAggregation.getId(), root, rewriteWithMasks(globalAggregation.getAggregations(), masks.buildOrThrow()), singleGroupingSet(ImmutableList.<Symbol>builder().addAll(join.getLeftOutputSymbols()).addAll(globalAggregation.getGroupingKeys()).build()), ImmutableList.of(), globalAggregation.getStep(), Optional.empty(), Optional.empty());
    // restrict outputs and apply projection
    Set<Symbol> outputSymbols = new HashSet<>(correlatedJoinNode.getOutputSymbols());
    List<Symbol> expectedAggregationOutputs = globalAggregation.getOutputSymbols().stream().filter(outputSymbols::contains).collect(toImmutableList());
    Assignments assignments = Assignments.builder().putIdentities(expectedAggregationOutputs).putAll(captures.get(PROJECTION).getAssignments()).build();
    return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), globalAggregation, assignments));
}
Also used : Symbol(io.trino.sql.planner.Symbol) CorrelatedJoinNode(io.trino.sql.planner.plan.CorrelatedJoinNode) JoinNode(io.trino.sql.planner.plan.JoinNode) Assignments(io.trino.sql.planner.plan.Assignments) AggregationNode(io.trino.sql.planner.plan.AggregationNode) ImmutableMap(com.google.common.collect.ImmutableMap) Aggregation(io.trino.sql.planner.plan.AggregationNode.Aggregation) PlanNodeDecorrelator(io.trino.sql.planner.optimizations.PlanNodeDecorrelator) PlanNode(io.trino.sql.planner.plan.PlanNode) AssignUniqueId(io.trino.sql.planner.plan.AssignUniqueId) Expression(io.trino.sql.tree.Expression) ProjectNode(io.trino.sql.planner.plan.ProjectNode) Map(java.util.Map) ImmutableMap(com.google.common.collect.ImmutableMap) HashSet(java.util.HashSet)

Example 22 with PlanNode

use of io.trino.sql.planner.plan.PlanNode in project trino by trinodb.

the class PushProjectionThroughJoin method inlineProjections.

private static PlanNode inlineProjections(PlannerContext plannerContext, ProjectNode parentProjection, Lookup lookup, Session session, TypeAnalyzer typeAnalyzer, TypeProvider types) {
    PlanNode child = lookup.resolve(parentProjection.getSource());
    if (!(child instanceof ProjectNode)) {
        return parentProjection;
    }
    ProjectNode childProjection = (ProjectNode) child;
    return InlineProjections.inlineProjections(plannerContext, parentProjection, childProjection, session, typeAnalyzer, types).map(node -> inlineProjections(plannerContext, node, lookup, session, typeAnalyzer, types)).orElse(parentProjection);
}
Also used : SymbolsExtractor.extractUnique(io.trino.sql.planner.SymbolsExtractor.extractUnique) INNER(io.trino.sql.planner.plan.JoinNode.Type.INNER) PlanNode(io.trino.sql.planner.plan.PlanNode) DeterminismEvaluator.isDeterministic(io.trino.sql.planner.DeterminismEvaluator.isDeterministic) Map(java.util.Map) ImmutableSet.toImmutableSet(com.google.common.collect.ImmutableSet.toImmutableSet) SymbolsExtractor(io.trino.sql.planner.SymbolsExtractor) JoinNode(io.trino.sql.planner.plan.JoinNode) ProjectNode(io.trino.sql.planner.plan.ProjectNode) Symbol(io.trino.sql.planner.Symbol) ImmutableSet(com.google.common.collect.ImmutableSet) Lookup(io.trino.sql.planner.iterative.Lookup) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) Assignments(io.trino.sql.planner.plan.Assignments) Set(java.util.Set) Streams(com.google.common.collect.Streams) Preconditions.checkState(com.google.common.base.Preconditions.checkState) List(java.util.List) TypeAnalyzer(io.trino.sql.planner.TypeAnalyzer) TypeProvider(io.trino.sql.planner.TypeProvider) Optional(java.util.Optional) Expression(io.trino.sql.tree.Expression) PlanNodeIdAllocator(io.trino.sql.planner.PlanNodeIdAllocator) Session(io.trino.Session) PlannerContext(io.trino.sql.PlannerContext) PlanNode(io.trino.sql.planner.plan.PlanNode) ProjectNode(io.trino.sql.planner.plan.ProjectNode)

Example 23 with PlanNode

use of io.trino.sql.planner.plan.PlanNode in project trino by trinodb.

the class PushProjectionThroughJoin method pushProjectionThroughJoin.

public static Optional<PlanNode> pushProjectionThroughJoin(PlannerContext plannerContext, ProjectNode projectNode, Lookup lookup, PlanNodeIdAllocator planNodeIdAllocator, Session session, TypeAnalyzer typeAnalyzer, TypeProvider types) {
    if (!projectNode.getAssignments().getExpressions().stream().allMatch(expression -> isDeterministic(expression, plannerContext.getMetadata()))) {
        return Optional.empty();
    }
    PlanNode child = lookup.resolve(projectNode.getSource());
    if (!(child instanceof JoinNode)) {
        return Optional.empty();
    }
    JoinNode joinNode = (JoinNode) child;
    PlanNode leftChild = joinNode.getLeft();
    PlanNode rightChild = joinNode.getRight();
    if (joinNode.getType() != INNER) {
        return Optional.empty();
    }
    Assignments.Builder leftAssignmentsBuilder = Assignments.builder();
    Assignments.Builder rightAssignmentsBuilder = Assignments.builder();
    for (Map.Entry<Symbol, Expression> assignment : projectNode.getAssignments().entrySet()) {
        Expression expression = assignment.getValue();
        Set<Symbol> symbols = extractUnique(expression);
        if (leftChild.getOutputSymbols().containsAll(symbols)) {
            // expression is satisfied with left child symbols
            leftAssignmentsBuilder.put(assignment.getKey(), expression);
        } else if (rightChild.getOutputSymbols().containsAll(symbols)) {
            // expression is satisfied with right child symbols
            rightAssignmentsBuilder.put(assignment.getKey(), expression);
        } else {
            // expression is using symbols from both join sides
            return Optional.empty();
        }
    }
    // add projections for symbols required by the join itself
    Set<Symbol> joinRequiredSymbols = getJoinRequiredSymbols(joinNode);
    for (Symbol requiredSymbol : joinRequiredSymbols) {
        if (leftChild.getOutputSymbols().contains(requiredSymbol)) {
            leftAssignmentsBuilder.putIdentity(requiredSymbol);
        } else {
            checkState(rightChild.getOutputSymbols().contains(requiredSymbol));
            rightAssignmentsBuilder.putIdentity(requiredSymbol);
        }
    }
    Assignments leftAssignments = leftAssignmentsBuilder.build();
    Assignments rightAssignments = rightAssignmentsBuilder.build();
    List<Symbol> leftOutputSymbols = leftAssignments.getOutputs().stream().filter(ImmutableSet.copyOf(projectNode.getOutputSymbols())::contains).collect(toImmutableList());
    List<Symbol> rightOutputSymbols = rightAssignments.getOutputs().stream().filter(ImmutableSet.copyOf(projectNode.getOutputSymbols())::contains).collect(toImmutableList());
    return Optional.of(new JoinNode(joinNode.getId(), joinNode.getType(), inlineProjections(plannerContext, new ProjectNode(planNodeIdAllocator.getNextId(), leftChild, leftAssignments), lookup, session, typeAnalyzer, types), inlineProjections(plannerContext, new ProjectNode(planNodeIdAllocator.getNextId(), rightChild, rightAssignments), lookup, session, typeAnalyzer, types), joinNode.getCriteria(), leftOutputSymbols, rightOutputSymbols, joinNode.isMaySkipOutputDuplicates(), joinNode.getFilter(), joinNode.getLeftHashSymbol(), joinNode.getRightHashSymbol(), joinNode.getDistributionType(), joinNode.isSpillable(), joinNode.getDynamicFilters(), joinNode.getReorderJoinStatsAndCost()));
}
Also used : SymbolsExtractor.extractUnique(io.trino.sql.planner.SymbolsExtractor.extractUnique) INNER(io.trino.sql.planner.plan.JoinNode.Type.INNER) PlanNode(io.trino.sql.planner.plan.PlanNode) DeterminismEvaluator.isDeterministic(io.trino.sql.planner.DeterminismEvaluator.isDeterministic) Map(java.util.Map) ImmutableSet.toImmutableSet(com.google.common.collect.ImmutableSet.toImmutableSet) SymbolsExtractor(io.trino.sql.planner.SymbolsExtractor) JoinNode(io.trino.sql.planner.plan.JoinNode) ProjectNode(io.trino.sql.planner.plan.ProjectNode) Symbol(io.trino.sql.planner.Symbol) ImmutableSet(com.google.common.collect.ImmutableSet) Lookup(io.trino.sql.planner.iterative.Lookup) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) Assignments(io.trino.sql.planner.plan.Assignments) Set(java.util.Set) Streams(com.google.common.collect.Streams) Preconditions.checkState(com.google.common.base.Preconditions.checkState) List(java.util.List) TypeAnalyzer(io.trino.sql.planner.TypeAnalyzer) TypeProvider(io.trino.sql.planner.TypeProvider) Optional(java.util.Optional) Expression(io.trino.sql.tree.Expression) PlanNodeIdAllocator(io.trino.sql.planner.PlanNodeIdAllocator) Session(io.trino.Session) PlannerContext(io.trino.sql.PlannerContext) PlanNode(io.trino.sql.planner.plan.PlanNode) Expression(io.trino.sql.tree.Expression) JoinNode(io.trino.sql.planner.plan.JoinNode) Symbol(io.trino.sql.planner.Symbol) Assignments(io.trino.sql.planner.plan.Assignments) ProjectNode(io.trino.sql.planner.plan.ProjectNode) Map(java.util.Map)

Example 24 with PlanNode

use of io.trino.sql.planner.plan.PlanNode in project trino by trinodb.

the class PushProjectionThroughUnion method apply.

@Override
public Result apply(ProjectNode parent, Captures captures, Context context) {
    UnionNode source = captures.get(CHILD);
    // OutputLayout of the resultant Union, will be same as the layout of the Project
    List<Symbol> outputLayout = parent.getOutputSymbols();
    // Mapping from the output symbol to ordered list of symbols from each of the sources
    ImmutableListMultimap.Builder<Symbol, Symbol> mappings = ImmutableListMultimap.builder();
    // sources for the resultant UnionNode
    ImmutableList.Builder<PlanNode> outputSources = ImmutableList.builder();
    for (int i = 0; i < source.getSources().size(); i++) {
        // Map: output of union -> input of this source to the union
        Map<Symbol, SymbolReference> outputToInput = source.sourceSymbolMap(i);
        // assignments for the new ProjectNode
        Assignments.Builder assignments = Assignments.builder();
        // mapping from current ProjectNode to new ProjectNode, used to identify the output layout
        Map<Symbol, Symbol> projectSymbolMapping = new HashMap<>();
        // Translate the assignments in the ProjectNode using symbols of the source of the UnionNode
        for (Map.Entry<Symbol, Expression> entry : parent.getAssignments().entrySet()) {
            Expression translatedExpression = inlineSymbols(outputToInput, entry.getValue());
            Type type = context.getSymbolAllocator().getTypes().get(entry.getKey());
            Symbol symbol = context.getSymbolAllocator().newSymbol(translatedExpression, type);
            assignments.put(symbol, translatedExpression);
            projectSymbolMapping.put(entry.getKey(), symbol);
        }
        outputSources.add(new ProjectNode(context.getIdAllocator().getNextId(), source.getSources().get(i), assignments.build()));
        outputLayout.forEach(symbol -> mappings.put(symbol, projectSymbolMapping.get(symbol)));
    }
    return Result.ofPlanNode(new UnionNode(parent.getId(), outputSources.build(), mappings.build(), ImmutableList.copyOf(mappings.build().keySet())));
}
Also used : HashMap(java.util.HashMap) Symbol(io.trino.sql.planner.Symbol) ImmutableList(com.google.common.collect.ImmutableList) SymbolReference(io.trino.sql.tree.SymbolReference) Assignments(io.trino.sql.planner.plan.Assignments) Type(io.trino.spi.type.Type) PlanNode(io.trino.sql.planner.plan.PlanNode) UnionNode(io.trino.sql.planner.plan.UnionNode) Expression(io.trino.sql.tree.Expression) ImmutableListMultimap(com.google.common.collect.ImmutableListMultimap) ProjectNode(io.trino.sql.planner.plan.ProjectNode) HashMap(java.util.HashMap) Map(java.util.Map)

Example 25 with PlanNode

use of io.trino.sql.planner.plan.PlanNode in project trino by trinodb.

the class ReplaceRedundantJoinWithSource method apply.

@Override
public Result apply(JoinNode node, Captures captures, Context context) {
    if (isAtMost(node.getLeft(), context.getLookup(), 0) || isAtMost(node.getRight(), context.getLookup(), 0)) {
        return Result.empty();
    }
    boolean leftSourceScalarWithNoOutputs = node.getLeft().getOutputSymbols().isEmpty() && isScalar(node.getLeft(), context.getLookup());
    boolean rightSourceScalarWithNoOutputs = node.getRight().getOutputSymbols().isEmpty() && isScalar(node.getRight(), context.getLookup());
    switch(node.getType()) {
        case INNER:
            PlanNode source;
            List<Symbol> sourceOutputs;
            if (leftSourceScalarWithNoOutputs) {
                source = node.getRight();
                sourceOutputs = node.getRightOutputSymbols();
            } else if (rightSourceScalarWithNoOutputs) {
                source = node.getLeft();
                sourceOutputs = node.getLeftOutputSymbols();
            } else {
                return Result.empty();
            }
            if (node.getFilter().isPresent()) {
                source = new FilterNode(context.getIdAllocator().getNextId(), source, node.getFilter().get());
            }
            return Result.ofPlanNode(restrictOutputs(context.getIdAllocator(), source, ImmutableSet.copyOf(sourceOutputs)).orElse(source));
        case LEFT:
            if (rightSourceScalarWithNoOutputs) {
                return Result.ofPlanNode(restrictOutputs(context.getIdAllocator(), node.getLeft(), ImmutableSet.copyOf(node.getLeftOutputSymbols())).orElse(node.getLeft()));
            }
            break;
        case RIGHT:
            if (leftSourceScalarWithNoOutputs) {
                return Result.ofPlanNode(restrictOutputs(context.getIdAllocator(), node.getRight(), ImmutableSet.copyOf(node.getRightOutputSymbols())).orElse(node.getRight()));
            }
            break;
        case FULL:
            if (leftSourceScalarWithNoOutputs && isAtLeastScalar(node.getRight(), context.getLookup())) {
                return Result.ofPlanNode(restrictOutputs(context.getIdAllocator(), node.getRight(), ImmutableSet.copyOf(node.getRightOutputSymbols())).orElse(node.getRight()));
            }
            if (rightSourceScalarWithNoOutputs && isAtLeastScalar(node.getLeft(), context.getLookup())) {
                return Result.ofPlanNode(restrictOutputs(context.getIdAllocator(), node.getLeft(), ImmutableSet.copyOf(node.getLeftOutputSymbols())).orElse(node.getLeft()));
            }
    }
    return Result.empty();
}
Also used : PlanNode(io.trino.sql.planner.plan.PlanNode) Symbol(io.trino.sql.planner.Symbol) FilterNode(io.trino.sql.planner.plan.FilterNode)

Aggregations

PlanNode (io.trino.sql.planner.plan.PlanNode)235 Test (org.testng.annotations.Test)90 Symbol (io.trino.sql.planner.Symbol)83 Expression (io.trino.sql.tree.Expression)62 ImmutableList (com.google.common.collect.ImmutableList)52 JoinNode (io.trino.sql.planner.plan.JoinNode)52 ProjectNode (io.trino.sql.planner.plan.ProjectNode)51 List (java.util.List)43 Session (io.trino.Session)41 ImmutableList.toImmutableList (com.google.common.collect.ImmutableList.toImmutableList)39 Map (java.util.Map)37 Optional (java.util.Optional)37 AggregationNode (io.trino.sql.planner.plan.AggregationNode)35 BasePlanTest (io.trino.sql.planner.assertions.BasePlanTest)34 Assignments (io.trino.sql.planner.plan.Assignments)33 ComparisonExpression (io.trino.sql.tree.ComparisonExpression)32 Objects.requireNonNull (java.util.Objects.requireNonNull)31 ImmutableMap (com.google.common.collect.ImmutableMap)29 ImmutableSet (com.google.common.collect.ImmutableSet)29 Set (java.util.Set)29