Search in sources :

Example 1 with Captures

use of io.trino.matching.Captures in project trino by trinodb.

the class InlineProjectIntoFilter method apply.

@Override
public Result apply(FilterNode node, Captures captures, Context context) {
    ProjectNode projectNode = captures.get(PROJECTION);
    List<Expression> filterConjuncts = extractConjuncts(node.getPredicate());
    Map<Boolean, List<Expression>> conjuncts = filterConjuncts.stream().collect(partitioningBy(SymbolReference.class::isInstance));
    List<Expression> simpleConjuncts = conjuncts.get(true);
    List<Expression> complexConjuncts = conjuncts.get(false);
    // Do not inline expression if the symbol is used multiple times in simple conjuncts.
    Set<Expression> simpleUniqueConjuncts = simpleConjuncts.stream().collect(Collectors.groupingBy(identity(), counting())).entrySet().stream().filter(entry -> entry.getValue() == 1).map(Map.Entry::getKey).collect(toImmutableSet());
    Set<Expression> complexConjunctSymbols = SymbolsExtractor.extractUnique(complexConjuncts).stream().map(Symbol::toSymbolReference).collect(toImmutableSet());
    // Do not inline expression if the symbol is used in complex conjuncts.
    Set<Expression> simpleConjunctsToInline = Sets.difference(simpleUniqueConjuncts, complexConjunctSymbols);
    if (simpleConjunctsToInline.isEmpty()) {
        return Result.empty();
    }
    ImmutableList.Builder<Expression> newConjuncts = ImmutableList.builder();
    Assignments.Builder newAssignments = Assignments.builder();
    Assignments.Builder postFilterAssignmentsBuilder = Assignments.builder();
    for (Expression conjunct : filterConjuncts) {
        if (simpleConjunctsToInline.contains(conjunct)) {
            Expression expression = projectNode.getAssignments().get(Symbol.from(conjunct));
            if (expression == null || expression instanceof SymbolReference) {
                // expression == null -> The symbol is not produced by the underlying projection (i.e. it is a correlation symbol).
                // expression instanceof SymbolReference -> Do not inline trivial projections.
                newConjuncts.add(conjunct);
            } else {
                newConjuncts.add(expression);
                newAssignments.putIdentities(SymbolsExtractor.extractUnique(expression));
                postFilterAssignmentsBuilder.put(Symbol.from(conjunct), TRUE_LITERAL);
            }
        } else {
            newConjuncts.add(conjunct);
        }
    }
    Assignments postFilterAssignments = postFilterAssignmentsBuilder.build();
    if (postFilterAssignments.isEmpty()) {
        return Result.empty();
    }
    Set<Symbol> postFilterSymbols = postFilterAssignments.getSymbols();
    // Remove inlined expressions from the underlying projection.
    newAssignments.putAll(projectNode.getAssignments().filter(symbol -> !postFilterSymbols.contains(symbol)));
    Map<Symbol, Expression> outputAssignments = new HashMap<>();
    outputAssignments.putAll(Assignments.identity(node.getOutputSymbols()).getMap());
    // Restore inlined symbols.
    outputAssignments.putAll(postFilterAssignments.getMap());
    return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), new FilterNode(node.getId(), new ProjectNode(projectNode.getId(), projectNode.getSource(), newAssignments.build()), combineConjuncts(metadata, newConjuncts.build())), Assignments.builder().putAll(outputAssignments).build()));
}
Also used : Collectors.partitioningBy(java.util.stream.Collectors.partitioningBy) Patterns.filter(io.trino.sql.planner.plan.Patterns.filter) Collectors.counting(java.util.stream.Collectors.counting) HashMap(java.util.HashMap) Capture.newCapture(io.trino.matching.Capture.newCapture) FilterNode(io.trino.sql.planner.plan.FilterNode) ImmutableList(com.google.common.collect.ImmutableList) Map(java.util.Map) ImmutableSet.toImmutableSet(com.google.common.collect.ImmutableSet.toImmutableSet) Rule(io.trino.sql.planner.iterative.Rule) SymbolsExtractor(io.trino.sql.planner.SymbolsExtractor) ProjectNode(io.trino.sql.planner.plan.ProjectNode) Symbol(io.trino.sql.planner.Symbol) Assignments(io.trino.sql.planner.plan.Assignments) Set(java.util.Set) TRUE_LITERAL(io.trino.sql.tree.BooleanLiteral.TRUE_LITERAL) Collectors(java.util.stream.Collectors) Sets(com.google.common.collect.Sets) Capture(io.trino.matching.Capture) List(java.util.List) Pattern(io.trino.matching.Pattern) Patterns.source(io.trino.sql.planner.plan.Patterns.source) SymbolReference(io.trino.sql.tree.SymbolReference) Captures(io.trino.matching.Captures) Function.identity(java.util.function.Function.identity) Metadata(io.trino.metadata.Metadata) Expression(io.trino.sql.tree.Expression) ExpressionUtils.extractConjuncts(io.trino.sql.ExpressionUtils.extractConjuncts) Patterns.project(io.trino.sql.planner.plan.Patterns.project) ExpressionUtils.combineConjuncts(io.trino.sql.ExpressionUtils.combineConjuncts) HashMap(java.util.HashMap) ImmutableList(com.google.common.collect.ImmutableList) SymbolReference(io.trino.sql.tree.SymbolReference) Symbol(io.trino.sql.planner.Symbol) FilterNode(io.trino.sql.planner.plan.FilterNode) Assignments(io.trino.sql.planner.plan.Assignments) Expression(io.trino.sql.tree.Expression) ProjectNode(io.trino.sql.planner.plan.ProjectNode) ImmutableList(com.google.common.collect.ImmutableList) List(java.util.List) HashMap(java.util.HashMap) Map(java.util.Map)

Example 2 with Captures

use of io.trino.matching.Captures in project trino by trinodb.

the class DecorrelateInnerUnnestWithGlobalAggregation method apply.

@Override
public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Context context) {
    // find global aggregation in subquery
    List<PlanNode> globalAggregations = PlanNodeSearcher.searchFrom(correlatedJoinNode.getSubquery(), context.getLookup()).where(DecorrelateInnerUnnestWithGlobalAggregation::isGlobalAggregation).recurseOnlyWhen(node -> node instanceof ProjectNode || isGlobalAggregation(node)).findAll();
    if (globalAggregations.isEmpty()) {
        return Result.empty();
    }
    // if there are multiple global aggregations, the one that is closest to the source is the "reducing" aggregation, because it reduces multiple input rows to single output row
    AggregationNode reducingAggregation = (AggregationNode) globalAggregations.get(globalAggregations.size() - 1);
    // find unnest in subquery
    Optional<UnnestNode> subqueryUnnest = PlanNodeSearcher.searchFrom(reducingAggregation.getSource(), context.getLookup()).where(node -> isSupportedUnnest(node, correlatedJoinNode.getCorrelation(), context.getLookup())).recurseOnlyWhen(node -> node instanceof ProjectNode || isGroupedAggregation(node)).findFirst();
    if (subqueryUnnest.isEmpty()) {
        return Result.empty();
    }
    UnnestNode unnestNode = subqueryUnnest.get();
    // assign unique id to input rows to restore semantics of aggregations after rewrite
    PlanNode input = new AssignUniqueId(context.getIdAllocator().getNextId(), correlatedJoinNode.getInput(), context.getSymbolAllocator().newSymbol("unique", BIGINT));
    // pre-project unnest symbols if they were pre-projected in subquery
    // The correlated UnnestNode either unnests correlation symbols directly, or unnests symbols produced by a projection that uses only correlation symbols.
    // Here, any underlying projection that was a source of the correlated UnnestNode, is appended as a source of the rewritten UnnestNode.
    // If the projection is not necessary for UnnestNode (i.e. it does not produce any unnest symbols), it should be pruned afterwards.
    PlanNode unnestSource = context.getLookup().resolve(unnestNode.getSource());
    if (unnestSource instanceof ProjectNode) {
        ProjectNode sourceProjection = (ProjectNode) unnestSource;
        input = new ProjectNode(sourceProjection.getId(), input, Assignments.builder().putIdentities(input.getOutputSymbols()).putAll(sourceProjection.getAssignments()).build());
    }
    // rewrite correlated join to UnnestNode
    Symbol ordinalitySymbol = unnestNode.getOrdinalitySymbol().orElseGet(() -> context.getSymbolAllocator().newSymbol("ordinality", BIGINT));
    UnnestNode rewrittenUnnest = new UnnestNode(context.getIdAllocator().getNextId(), input, input.getOutputSymbols(), unnestNode.getMappings(), Optional.of(ordinalitySymbol), LEFT, Optional.empty());
    // append mask symbol based on ordinality to distinguish between the unnested rows and synthetic null rows
    Symbol mask = context.getSymbolAllocator().newSymbol("mask", BOOLEAN);
    ProjectNode sourceWithMask = new ProjectNode(context.getIdAllocator().getNextId(), rewrittenUnnest, Assignments.builder().putIdentities(rewrittenUnnest.getOutputSymbols()).put(mask, new IsNotNullPredicate(ordinalitySymbol.toSymbolReference())).build());
    // restore all projections, grouped aggregations and global aggregations from the subquery
    PlanNode result = rewriteNodeSequence(context.getLookup().resolve(correlatedJoinNode.getSubquery()), input.getOutputSymbols(), mask, sourceWithMask, reducingAggregation.getId(), unnestNode.getId(), context.getSymbolAllocator(), context.getIdAllocator(), context.getLookup());
    // restrict outputs
    return Result.ofPlanNode(restrictOutputs(context.getIdAllocator(), result, ImmutableSet.copyOf(correlatedJoinNode.getOutputSymbols())).orElse(result));
}
Also used : CorrelatedJoin.correlation(io.trino.sql.planner.plan.Patterns.CorrelatedJoin.correlation) QueryCardinalityUtil.isScalar(io.trino.sql.planner.optimizations.QueryCardinalityUtil.isScalar) INNER(io.trino.sql.planner.plan.JoinNode.Type.INNER) SymbolAllocator(io.trino.sql.planner.SymbolAllocator) BOOLEAN(io.trino.spi.type.BooleanType.BOOLEAN) SINGLE(io.trino.sql.planner.plan.AggregationNode.Step.SINGLE) CorrelatedJoinNode(io.trino.sql.planner.plan.CorrelatedJoinNode) PlanNode(io.trino.sql.planner.plan.PlanNode) AggregationDecorrelation.rewriteWithMasks(io.trino.sql.planner.iterative.rule.AggregationDecorrelation.rewriteWithMasks) CorrelatedJoin.filter(io.trino.sql.planner.plan.Patterns.CorrelatedJoin.filter) LEFT(io.trino.sql.planner.plan.JoinNode.Type.LEFT) ImmutableList(com.google.common.collect.ImmutableList) PlanNodeSearcher(io.trino.sql.planner.optimizations.PlanNodeSearcher) PlanNodeId(io.trino.sql.planner.plan.PlanNodeId) IsNotNullPredicate(io.trino.sql.tree.IsNotNullPredicate) Map(java.util.Map) AggregationNode(io.trino.sql.planner.plan.AggregationNode) Rule(io.trino.sql.planner.iterative.Rule) SymbolsExtractor(io.trino.sql.planner.SymbolsExtractor) Pattern.nonEmpty(io.trino.matching.Pattern.nonEmpty) AssignUniqueId(io.trino.sql.planner.plan.AssignUniqueId) ProjectNode(io.trino.sql.planner.plan.ProjectNode) Symbol(io.trino.sql.planner.Symbol) ImmutableSet(com.google.common.collect.ImmutableSet) ImmutableMap(com.google.common.collect.ImmutableMap) Lookup(io.trino.sql.planner.iterative.Lookup) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) Assignments(io.trino.sql.planner.plan.Assignments) Set(java.util.Set) Iterables.getOnlyElement(com.google.common.collect.Iterables.getOnlyElement) TRUE_LITERAL(io.trino.sql.tree.BooleanLiteral.TRUE_LITERAL) Streams(com.google.common.collect.Streams) UnnestNode(io.trino.sql.planner.plan.UnnestNode) AggregationNode.singleGroupingSet(io.trino.sql.planner.plan.AggregationNode.singleGroupingSet) List(java.util.List) Pattern(io.trino.matching.Pattern) BIGINT(io.trino.spi.type.BigintType.BIGINT) ExpressionUtils.and(io.trino.sql.ExpressionUtils.and) Captures(io.trino.matching.Captures) Util.restrictOutputs(io.trino.sql.planner.iterative.rule.Util.restrictOutputs) Mapping(io.trino.sql.planner.plan.UnnestNode.Mapping) Optional(java.util.Optional) Expression(io.trino.sql.tree.Expression) PlanNodeIdAllocator(io.trino.sql.planner.PlanNodeIdAllocator) Patterns.correlatedJoin(io.trino.sql.planner.plan.Patterns.correlatedJoin) PlanNode(io.trino.sql.planner.plan.PlanNode) AssignUniqueId(io.trino.sql.planner.plan.AssignUniqueId) UnnestNode(io.trino.sql.planner.plan.UnnestNode) Symbol(io.trino.sql.planner.Symbol) ProjectNode(io.trino.sql.planner.plan.ProjectNode) AggregationNode(io.trino.sql.planner.plan.AggregationNode) IsNotNullPredicate(io.trino.sql.tree.IsNotNullPredicate)

Example 3 with Captures

use of io.trino.matching.Captures in project trino by trinodb.

the class PushDownDereferencesThroughWindow method apply.

@Override
public Result apply(ProjectNode projectNode, Captures captures, Context context) {
    WindowNode windowNode = captures.get(CHILD);
    // Extract dereferences for pushdown
    Set<SubscriptExpression> dereferences = extractRowSubscripts(ImmutableList.<Expression>builder().addAll(projectNode.getAssignments().getExpressions()).addAll(windowNode.getWindowFunctions().values().stream().flatMap(function -> function.getArguments().stream()).collect(toImmutableList())).build(), false, context.getSession(), typeAnalyzer, context.getSymbolAllocator().getTypes());
    WindowNode.Specification specification = windowNode.getSpecification();
    dereferences = dereferences.stream().filter(expression -> {
        Symbol symbol = getBase(expression);
        // Exclude partitionBy, orderBy and synthesized symbols
        return !specification.getPartitionBy().contains(symbol) && !specification.getOrderingScheme().map(OrderingScheme::getOrderBy).orElse(ImmutableList.of()).contains(symbol) && !windowNode.getCreatedSymbols().contains(symbol);
    }).collect(toImmutableSet());
    if (dereferences.isEmpty()) {
        return Result.empty();
    }
    // Create new symbols for dereference expressions
    Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSession(), context.getSymbolAllocator(), typeAnalyzer);
    // Rewrite project node assignments using new symbols for dereference expressions
    Map<Expression, SymbolReference> mappings = HashBiMap.create(dereferenceAssignments.getMap()).inverse().entrySet().stream().collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference()));
    Assignments newAssignments = projectNode.getAssignments().rewrite(expression -> replaceExpression(expression, mappings));
    return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), new WindowNode(windowNode.getId(), new ProjectNode(context.getIdAllocator().getNextId(), windowNode.getSource(), Assignments.builder().putIdentities(windowNode.getSource().getOutputSymbols()).putAll(dereferenceAssignments).build()), windowNode.getSpecification(), // Replace dereference expressions in functions
    windowNode.getWindowFunctions().entrySet().stream().collect(toImmutableMap(Map.Entry::getKey, entry -> {
        WindowNode.Function oldFunction = entry.getValue();
        return new WindowNode.Function(oldFunction.getResolvedFunction(), oldFunction.getArguments().stream().map(expression -> replaceExpression(expression, mappings)).collect(toImmutableList()), oldFunction.getFrame(), oldFunction.isIgnoreNulls());
    })), windowNode.getHashSymbol(), windowNode.getPrePartitionedInputs(), windowNode.getPreSortedOrderPrefix()), newAssignments));
}
Also used : Capture.newCapture(io.trino.matching.Capture.newCapture) DereferencePushdown.getBase(io.trino.sql.planner.iterative.rule.DereferencePushdown.getBase) SubscriptExpression(io.trino.sql.tree.SubscriptExpression) ImmutableList(com.google.common.collect.ImmutableList) Map(java.util.Map) Objects.requireNonNull(java.util.Objects.requireNonNull) ImmutableSet.toImmutableSet(com.google.common.collect.ImmutableSet.toImmutableSet) Rule(io.trino.sql.planner.iterative.Rule) ProjectNode(io.trino.sql.planner.plan.ProjectNode) Symbol(io.trino.sql.planner.Symbol) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) Assignments(io.trino.sql.planner.plan.Assignments) Set(java.util.Set) OrderingScheme(io.trino.sql.planner.OrderingScheme) Capture(io.trino.matching.Capture) DereferencePushdown.extractRowSubscripts(io.trino.sql.planner.iterative.rule.DereferencePushdown.extractRowSubscripts) HashBiMap(com.google.common.collect.HashBiMap) ImmutableMap.toImmutableMap(com.google.common.collect.ImmutableMap.toImmutableMap) Pattern(io.trino.matching.Pattern) TypeAnalyzer(io.trino.sql.planner.TypeAnalyzer) Patterns.window(io.trino.sql.planner.plan.Patterns.window) Patterns.source(io.trino.sql.planner.plan.Patterns.source) SymbolReference(io.trino.sql.tree.SymbolReference) Captures(io.trino.matching.Captures) ExpressionNodeInliner.replaceExpression(io.trino.sql.planner.ExpressionNodeInliner.replaceExpression) Expression(io.trino.sql.tree.Expression) Patterns.project(io.trino.sql.planner.plan.Patterns.project) WindowNode(io.trino.sql.planner.plan.WindowNode) WindowNode(io.trino.sql.planner.plan.WindowNode) Symbol(io.trino.sql.planner.Symbol) SymbolReference(io.trino.sql.tree.SymbolReference) Assignments(io.trino.sql.planner.plan.Assignments) SubscriptExpression(io.trino.sql.tree.SubscriptExpression) ExpressionNodeInliner.replaceExpression(io.trino.sql.planner.ExpressionNodeInliner.replaceExpression) Expression(io.trino.sql.tree.Expression) SubscriptExpression(io.trino.sql.tree.SubscriptExpression) ProjectNode(io.trino.sql.planner.plan.ProjectNode) Map(java.util.Map) HashBiMap(com.google.common.collect.HashBiMap) ImmutableMap.toImmutableMap(com.google.common.collect.ImmutableMap.toImmutableMap)

Example 4 with Captures

use of io.trino.matching.Captures in project trino by trinodb.

the class PushJoinIntoTableScan method apply.

@Override
public Result apply(JoinNode joinNode, Captures captures, Context context) {
    if (joinNode.isCrossJoin()) {
        return Result.empty();
    }
    TableScanNode left = captures.get(LEFT_TABLE_SCAN);
    TableScanNode right = captures.get(RIGHT_TABLE_SCAN);
    verify(!left.isUpdateTarget() && !right.isUpdateTarget(), "Unexpected Join over for-update table scan");
    Expression effectiveFilter = getEffectiveFilter(joinNode);
    FilterSplitResult filterSplitResult = splitFilter(effectiveFilter, left.getOutputSymbols(), right.getOutputSymbols(), context);
    if (!filterSplitResult.getRemainingFilter().equals(BooleanLiteral.TRUE_LITERAL)) {
        // TODO add extra filter node above join
        return Result.empty();
    }
    if (left.getEnforcedConstraint().isNone() || right.getEnforcedConstraint().isNone()) {
        // enforced constraint harder below.
        return Result.empty();
    }
    Map<String, ColumnHandle> leftAssignments = left.getAssignments().entrySet().stream().collect(toImmutableMap(entry -> entry.getKey().getName(), Map.Entry::getValue));
    Map<String, ColumnHandle> rightAssignments = right.getAssignments().entrySet().stream().collect(toImmutableMap(entry -> entry.getKey().getName(), Map.Entry::getValue));
    /*
         * We are (lazily) computing estimated statistics for join node and left and right table
         * and passing those to connector via applyJoin.
         *
         * There are a couple reasons for this approach:
         * - the engine knows how to estimate join and connector may not
         * - the engine may have cached stats for the table scans (within context.getStatsProvider()), so can be able to provide information more inexpensively
         * - in the future, the engine may be able to provide stats for table scan even in case when connector no longer can (see https://github.com/trinodb/trino/issues/6998)
         * - the pushdown feasibility assessment logic may be different (or configured differently) for different connectors/catalogs.
         */
    JoinStatistics joinStatistics = getJoinStatistics(joinNode, left, right, context);
    Optional<JoinApplicationResult<TableHandle>> joinApplicationResult = metadata.applyJoin(context.getSession(), getJoinType(joinNode), left.getTable(), right.getTable(), filterSplitResult.getPushableConditions(), // TODO we could pass only subset of assignments here, those which are needed to resolve filterSplitResult.getPushableConditions
    leftAssignments, rightAssignments, joinStatistics);
    if (joinApplicationResult.isEmpty()) {
        return Result.empty();
    }
    TableHandle handle = joinApplicationResult.get().getTableHandle();
    Map<ColumnHandle, ColumnHandle> leftColumnHandlesMapping = joinApplicationResult.get().getLeftColumnHandles();
    Map<ColumnHandle, ColumnHandle> rightColumnHandlesMapping = joinApplicationResult.get().getRightColumnHandles();
    ImmutableMap.Builder<Symbol, ColumnHandle> assignmentsBuilder = ImmutableMap.builder();
    assignmentsBuilder.putAll(left.getAssignments().entrySet().stream().collect(toImmutableMap(Map.Entry::getKey, entry -> leftColumnHandlesMapping.get(entry.getValue()))));
    assignmentsBuilder.putAll(right.getAssignments().entrySet().stream().collect(toImmutableMap(Map.Entry::getKey, entry -> rightColumnHandlesMapping.get(entry.getValue()))));
    Map<Symbol, ColumnHandle> assignments = assignmentsBuilder.buildOrThrow();
    // convert enforced constraint
    JoinNode.Type joinType = joinNode.getType();
    TupleDomain<ColumnHandle> leftConstraint = deriveConstraint(left.getEnforcedConstraint(), leftColumnHandlesMapping, joinType == RIGHT || joinType == FULL);
    TupleDomain<ColumnHandle> rightConstraint = deriveConstraint(right.getEnforcedConstraint(), rightColumnHandlesMapping, joinType == LEFT || joinType == FULL);
    TupleDomain<ColumnHandle> newEnforcedConstraint = TupleDomain.withColumnDomains(ImmutableMap.<ColumnHandle, Domain>builder().putAll(leftConstraint.getDomains().orElseThrow()).putAll(rightConstraint.getDomains().orElseThrow()).buildOrThrow());
    return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), new TableScanNode(joinNode.getId(), handle, ImmutableList.copyOf(assignments.keySet()), assignments, newEnforcedConstraint, deriveTableStatisticsForPushdown(context.getStatsProvider(), context.getSession(), joinApplicationResult.get().isPrecalculateStatistics(), joinNode), false, Optional.empty()), Assignments.identity(joinNode.getOutputSymbols())));
}
Also used : PlanNode(io.trino.sql.planner.plan.PlanNode) LEFT(io.trino.sql.planner.plan.JoinNode.Type.LEFT) BooleanLiteral(io.trino.sql.tree.BooleanLiteral) Map(java.util.Map) JoinNode(io.trino.sql.planner.plan.JoinNode) TableScanNode(io.trino.sql.planner.plan.TableScanNode) PlanNodeStatsEstimate(io.trino.cost.PlanNodeStatsEstimate) Rules.deriveTableStatisticsForPushdown(io.trino.sql.planner.iterative.rule.Rules.deriveTableStatisticsForPushdown) ImmutableSet(com.google.common.collect.ImmutableSet) ImmutableMap(com.google.common.collect.ImmutableMap) Domain(io.trino.spi.predicate.Domain) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) Assignments(io.trino.sql.planner.plan.Assignments) Set(java.util.Set) BasicRelationStatistics(io.trino.spi.connector.BasicRelationStatistics) Patterns.tableScan(io.trino.sql.planner.plan.Patterns.tableScan) ComparisonExpression(io.trino.sql.tree.ComparisonExpression) List(java.util.List) ImmutableMap.toImmutableMap(com.google.common.collect.ImmutableMap.toImmutableMap) Pattern(io.trino.matching.Pattern) SystemSessionProperties.isAllowPushdownIntoConnectors(io.trino.SystemSessionProperties.isAllowPushdownIntoConnectors) RIGHT(io.trino.sql.planner.plan.JoinNode.Type.RIGHT) SymbolReference(io.trino.sql.tree.SymbolReference) Optional(java.util.Optional) Expression(io.trino.sql.tree.Expression) ExpressionUtils.extractConjuncts(io.trino.sql.ExpressionUtils.extractConjuncts) Session(io.trino.Session) JoinCondition(io.trino.spi.connector.JoinCondition) Patterns(io.trino.sql.planner.plan.Patterns) Variable(io.trino.spi.expression.Variable) Capture.newCapture(io.trino.matching.Capture.newCapture) JoinStatistics(io.trino.spi.connector.JoinStatistics) ImmutableList(com.google.common.collect.ImmutableList) JoinType(io.trino.spi.connector.JoinType) Verify.verify(com.google.common.base.Verify.verify) Objects.requireNonNull(java.util.Objects.requireNonNull) ColumnHandle(io.trino.spi.connector.ColumnHandle) Rule(io.trino.sql.planner.iterative.Rule) Join.left(io.trino.sql.planner.plan.Patterns.Join.left) ProjectNode(io.trino.sql.planner.plan.ProjectNode) ExpressionUtils(io.trino.sql.ExpressionUtils) Symbol(io.trino.sql.planner.Symbol) FULL(io.trino.sql.planner.plan.JoinNode.Type.FULL) TupleDomain(io.trino.spi.predicate.TupleDomain) Capture(io.trino.matching.Capture) JoinApplicationResult(io.trino.spi.connector.JoinApplicationResult) TableHandle(io.trino.metadata.TableHandle) Join.right(io.trino.sql.planner.plan.Patterns.Join.right) ExpressionUtils.and(io.trino.sql.ExpressionUtils.and) Double.isNaN(java.lang.Double.isNaN) Captures(io.trino.matching.Captures) Metadata(io.trino.metadata.Metadata) TypeProvider(io.trino.sql.planner.TypeProvider) Domain.onlyNull(io.trino.spi.predicate.Domain.onlyNull) ColumnHandle(io.trino.spi.connector.ColumnHandle) Symbol(io.trino.sql.planner.Symbol) JoinNode(io.trino.sql.planner.plan.JoinNode) JoinApplicationResult(io.trino.spi.connector.JoinApplicationResult) ImmutableMap(com.google.common.collect.ImmutableMap) ImmutableMap.toImmutableMap(com.google.common.collect.ImmutableMap.toImmutableMap) JoinStatistics(io.trino.spi.connector.JoinStatistics) TableScanNode(io.trino.sql.planner.plan.TableScanNode) ComparisonExpression(io.trino.sql.tree.ComparisonExpression) Expression(io.trino.sql.tree.Expression) TableHandle(io.trino.metadata.TableHandle) ProjectNode(io.trino.sql.planner.plan.ProjectNode) Domain(io.trino.spi.predicate.Domain) TupleDomain(io.trino.spi.predicate.TupleDomain) Map(java.util.Map) ImmutableMap(com.google.common.collect.ImmutableMap) ImmutableMap.toImmutableMap(com.google.common.collect.ImmutableMap.toImmutableMap)

Example 5 with Captures

use of io.trino.matching.Captures in project trino by trinodb.

the class PushDownDereferencesThroughAssignUniqueId method apply.

@Override
public Result apply(ProjectNode projectNode, Captures captures, Context context) {
    AssignUniqueId assignUniqueId = captures.get(CHILD);
    // Extract dereferences from project node assignments for pushdown
    Set<SubscriptExpression> dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false, context.getSession(), typeAnalyzer, context.getSymbolAllocator().getTypes());
    if (dereferences.isEmpty()) {
        return Result.empty();
    }
    // Create new symbols for dereference expressions
    Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSession(), context.getSymbolAllocator(), typeAnalyzer);
    // Rewrite project node assignments using new symbols for dereference expressions
    Map<Expression, SymbolReference> mappings = HashBiMap.create(dereferenceAssignments.getMap()).inverse().entrySet().stream().collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference()));
    Assignments newAssignments = projectNode.getAssignments().rewrite(expression -> replaceExpression(expression, mappings));
    return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), assignUniqueId.replaceChildren(ImmutableList.of(new ProjectNode(context.getIdAllocator().getNextId(), assignUniqueId.getSource(), Assignments.builder().putIdentities(assignUniqueId.getSource().getOutputSymbols()).putAll(dereferenceAssignments).build()))), newAssignments));
}
Also used : Patterns.assignUniqueId(io.trino.sql.planner.plan.Patterns.assignUniqueId) Assignments(io.trino.sql.planner.plan.Assignments) Set(java.util.Set) Capture.newCapture(io.trino.matching.Capture.newCapture) Capture(io.trino.matching.Capture) SubscriptExpression(io.trino.sql.tree.SubscriptExpression) DereferencePushdown.extractRowSubscripts(io.trino.sql.planner.iterative.rule.DereferencePushdown.extractRowSubscripts) HashBiMap(com.google.common.collect.HashBiMap) ImmutableMap.toImmutableMap(com.google.common.collect.ImmutableMap.toImmutableMap) Pattern(io.trino.matching.Pattern) TypeAnalyzer(io.trino.sql.planner.TypeAnalyzer) ImmutableList(com.google.common.collect.ImmutableList) Patterns.source(io.trino.sql.planner.plan.Patterns.source) SymbolReference(io.trino.sql.tree.SymbolReference) Captures(io.trino.matching.Captures) Map(java.util.Map) Objects.requireNonNull(java.util.Objects.requireNonNull) ExpressionNodeInliner.replaceExpression(io.trino.sql.planner.ExpressionNodeInliner.replaceExpression) Rule(io.trino.sql.planner.iterative.Rule) Expression(io.trino.sql.tree.Expression) Patterns.project(io.trino.sql.planner.plan.Patterns.project) AssignUniqueId(io.trino.sql.planner.plan.AssignUniqueId) ProjectNode(io.trino.sql.planner.plan.ProjectNode) AssignUniqueId(io.trino.sql.planner.plan.AssignUniqueId) SubscriptExpression(io.trino.sql.tree.SubscriptExpression) ExpressionNodeInliner.replaceExpression(io.trino.sql.planner.ExpressionNodeInliner.replaceExpression) Expression(io.trino.sql.tree.Expression) SymbolReference(io.trino.sql.tree.SymbolReference) Assignments(io.trino.sql.planner.plan.Assignments) SubscriptExpression(io.trino.sql.tree.SubscriptExpression) ProjectNode(io.trino.sql.planner.plan.ProjectNode) HashBiMap(com.google.common.collect.HashBiMap) ImmutableMap.toImmutableMap(com.google.common.collect.ImmutableMap.toImmutableMap) Map(java.util.Map)

Aggregations

Captures (io.trino.matching.Captures)28 Pattern (io.trino.matching.Pattern)28 Rule (io.trino.sql.planner.iterative.Rule)28 ImmutableList (com.google.common.collect.ImmutableList)25 Assignments (io.trino.sql.planner.plan.Assignments)25 ProjectNode (io.trino.sql.planner.plan.ProjectNode)25 Expression (io.trino.sql.tree.Expression)23 Map (java.util.Map)22 Capture (io.trino.matching.Capture)20 Capture.newCapture (io.trino.matching.Capture.newCapture)20 Patterns.source (io.trino.sql.planner.plan.Patterns.source)19 Set (java.util.Set)19 Objects.requireNonNull (java.util.Objects.requireNonNull)18 ImmutableMap.toImmutableMap (com.google.common.collect.ImmutableMap.toImmutableMap)17 SymbolReference (io.trino.sql.tree.SymbolReference)17 Symbol (io.trino.sql.planner.Symbol)16 TypeAnalyzer (io.trino.sql.planner.TypeAnalyzer)16 Patterns.project (io.trino.sql.planner.plan.Patterns.project)16 SubscriptExpression (io.trino.sql.tree.SubscriptExpression)15 List (java.util.List)15