Search in sources :

Example 31 with Assignments

use of io.trino.sql.planner.plan.Assignments in project trino by trinodb.

the class DecorrelateUnnest method apply.

@Override
public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Context context) {
    // determine shape of the subquery
    PlanNode searchRoot = correlatedJoinNode.getSubquery();
    // 1. find EnforceSingleRowNode in the subquery
    Optional<EnforceSingleRowNode> enforceSingleRow = PlanNodeSearcher.searchFrom(searchRoot, context.getLookup()).where(EnforceSingleRowNode.class::isInstance).recurseOnlyWhen(planNode -> false).findFirst();
    if (enforceSingleRow.isPresent()) {
        searchRoot = enforceSingleRow.get().getSource();
    }
    // 2. find correlated UnnestNode in the subquery
    Optional<UnnestNode> subqueryUnnest = PlanNodeSearcher.searchFrom(searchRoot, context.getLookup()).where(node -> isSupportedUnnest(node, correlatedJoinNode.getCorrelation(), context.getLookup())).recurseOnlyWhen(node -> node instanceof ProjectNode || (node instanceof LimitNode && ((LimitNode) node).getCount() > 0) || (node instanceof TopNNode && ((TopNNode) node).getCount() > 0)).findFirst();
    if (subqueryUnnest.isEmpty()) {
        return Result.empty();
    }
    UnnestNode unnestNode = subqueryUnnest.get();
    // assign unique id to input rows
    Symbol uniqueSymbol = context.getSymbolAllocator().newSymbol("unique", BIGINT);
    PlanNode input = new AssignUniqueId(context.getIdAllocator().getNextId(), correlatedJoinNode.getInput(), uniqueSymbol);
    // pre-project unnest symbols if they were pre-projected in subquery
    // The correlated UnnestNode either unnests correlation symbols directly, or unnests symbols produced by a projection that uses only correlation symbols.
    // Here, any underlying projection that was a source of the correlated UnnestNode, is appended as a source of the rewritten UnnestNode.
    // If the projection is not necessary for UnnestNode (i.e. it does not produce any unnest symbols), it should be pruned afterwards.
    PlanNode unnestSource = context.getLookup().resolve(unnestNode.getSource());
    if (unnestSource instanceof ProjectNode) {
        ProjectNode sourceProjection = (ProjectNode) unnestSource;
        input = new ProjectNode(sourceProjection.getId(), input, Assignments.builder().putIdentities(input.getOutputSymbols()).putAll(sourceProjection.getAssignments()).build());
    }
    // determine join type for rewritten UnnestNode
    Type unnestJoinType = LEFT;
    if (enforceSingleRow.isEmpty() && correlatedJoinNode.getType() == CorrelatedJoinNode.Type.INNER && unnestNode.getJoinType() == INNER) {
        unnestJoinType = INNER;
    }
    // make sure that the rewritten node is with ordinality, which might be necessary to restore inner unnest semantics after rewrite.
    Symbol ordinalitySymbol = unnestNode.getOrdinalitySymbol().orElseGet(() -> context.getSymbolAllocator().newSymbol("ordinality", BIGINT));
    // rewrite correlated join to UnnestNode.
    UnnestNode rewrittenUnnest = new UnnestNode(context.getIdAllocator().getNextId(), input, input.getOutputSymbols(), unnestNode.getMappings(), Optional.of(ordinalitySymbol), unnestJoinType, Optional.empty());
    // restore all nodes from the subquery
    PlanNode rewrittenPlan = Rewriter.rewriteNodeSequence(correlatedJoinNode.getSubquery(), input.getOutputSymbols(), ordinalitySymbol, uniqueSymbol, rewrittenUnnest, context.getSession(), metadata, context.getLookup(), context.getIdAllocator(), context.getSymbolAllocator());
    // between unnested rows and synthetic rows added by left unnest.
    if (unnestNode.getJoinType() == INNER && rewrittenUnnest.getJoinType() == LEFT) {
        Assignments.Builder assignments = Assignments.builder().putIdentities(correlatedJoinNode.getInput().getOutputSymbols());
        for (Symbol subquerySymbol : correlatedJoinNode.getSubquery().getOutputSymbols()) {
            assignments.put(subquerySymbol, new IfExpression(new IsNullPredicate(ordinalitySymbol.toSymbolReference()), new Cast(new NullLiteral(), toSqlType(context.getSymbolAllocator().getTypes().get(subquerySymbol))), subquerySymbol.toSymbolReference()));
        }
        rewrittenPlan = new ProjectNode(context.getIdAllocator().getNextId(), rewrittenPlan, assignments.build());
    }
    // restrict outputs
    return Result.ofPlanNode(restrictOutputs(context.getIdAllocator(), rewrittenPlan, ImmutableSet.copyOf(correlatedJoinNode.getOutputSymbols())).orElse(rewrittenPlan));
}
Also used : CorrelatedJoin.correlation(io.trino.sql.planner.plan.Patterns.CorrelatedJoin.correlation) IsNullPredicate(io.trino.sql.tree.IsNullPredicate) SymbolAllocator(io.trino.sql.planner.SymbolAllocator) TypeSignatureProvider.fromTypes(io.trino.sql.analyzer.TypeSignatureProvider.fromTypes) CorrelatedJoinNode(io.trino.sql.planner.plan.CorrelatedJoinNode) FilterNode(io.trino.sql.planner.plan.FilterNode) PlanNode(io.trino.sql.planner.plan.PlanNode) LEFT(io.trino.sql.planner.plan.JoinNode.Type.LEFT) Type(io.trino.sql.planner.plan.JoinNode.Type) PlanNodeSearcher(io.trino.sql.planner.optimizations.PlanNodeSearcher) GREATER_THAN(io.trino.sql.tree.ComparisonExpression.Operator.GREATER_THAN) AssignUniqueId(io.trino.sql.planner.plan.AssignUniqueId) FunctionCall(io.trino.sql.tree.FunctionCall) ImmutableSet(com.google.common.collect.ImmutableSet) ImmutableMap(com.google.common.collect.ImmutableMap) ResolvedFunction(io.trino.metadata.ResolvedFunction) EnforceSingleRowNode(io.trino.sql.planner.plan.EnforceSingleRowNode) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) TypeSignatureTranslator.toSqlType(io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType) Assignments(io.trino.sql.planner.plan.Assignments) ComparisonExpression(io.trino.sql.tree.ComparisonExpression) DEFAULT_FRAME(io.trino.sql.planner.plan.WindowNode.Frame.DEFAULT_FRAME) LESS_THAN_OR_EQUAL(io.trino.sql.tree.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL) GenericLiteral(io.trino.sql.tree.GenericLiteral) List(java.util.List) Pattern(io.trino.matching.Pattern) IfExpression(io.trino.sql.tree.IfExpression) BIGINT(io.trino.spi.type.BigintType.BIGINT) Optional(java.util.Optional) Expression(io.trino.sql.tree.Expression) WindowNode(io.trino.sql.planner.plan.WindowNode) Session(io.trino.Session) QueryCardinalityUtil.isScalar(io.trino.sql.planner.optimizations.QueryCardinalityUtil.isScalar) INNER(io.trino.sql.planner.plan.JoinNode.Type.INNER) LimitNode(io.trino.sql.planner.plan.LimitNode) BOOLEAN(io.trino.spi.type.BooleanType.BOOLEAN) Cast(io.trino.sql.tree.Cast) CorrelatedJoin.filter(io.trino.sql.planner.plan.Patterns.CorrelatedJoin.filter) VARCHAR(io.trino.spi.type.VarcharType.VARCHAR) ImmutableList(com.google.common.collect.ImmutableList) RowNumberNode(io.trino.sql.planner.plan.RowNumberNode) Objects.requireNonNull(java.util.Objects.requireNonNull) NullLiteral(io.trino.sql.tree.NullLiteral) Rule(io.trino.sql.planner.iterative.Rule) SymbolsExtractor(io.trino.sql.planner.SymbolsExtractor) Pattern.nonEmpty(io.trino.matching.Pattern.nonEmpty) ProjectNode(io.trino.sql.planner.plan.ProjectNode) Symbol(io.trino.sql.planner.Symbol) StringLiteral(io.trino.sql.tree.StringLiteral) Lookup(io.trino.sql.planner.iterative.Lookup) PlanVisitor(io.trino.sql.planner.plan.PlanVisitor) TopNNode(io.trino.sql.planner.plan.TopNNode) TRUE_LITERAL(io.trino.sql.tree.BooleanLiteral.TRUE_LITERAL) UnnestNode(io.trino.sql.planner.plan.UnnestNode) ImplementLimitWithTies.rewriteLimitWithTiesWithPartitioning(io.trino.sql.planner.iterative.rule.ImplementLimitWithTies.rewriteLimitWithTiesWithPartitioning) QualifiedName(io.trino.sql.tree.QualifiedName) Captures(io.trino.matching.Captures) Specification(io.trino.sql.planner.plan.WindowNode.Specification) Util.restrictOutputs(io.trino.sql.planner.iterative.rule.Util.restrictOutputs) Metadata(io.trino.metadata.Metadata) PlanNodeIdAllocator(io.trino.sql.planner.PlanNodeIdAllocator) Patterns.correlatedJoin(io.trino.sql.planner.plan.Patterns.correlatedJoin) Cast(io.trino.sql.tree.Cast) IfExpression(io.trino.sql.tree.IfExpression) Symbol(io.trino.sql.planner.Symbol) Assignments(io.trino.sql.planner.plan.Assignments) Type(io.trino.sql.planner.plan.JoinNode.Type) TypeSignatureTranslator.toSqlType(io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType) PlanNode(io.trino.sql.planner.plan.PlanNode) AssignUniqueId(io.trino.sql.planner.plan.AssignUniqueId) UnnestNode(io.trino.sql.planner.plan.UnnestNode) LimitNode(io.trino.sql.planner.plan.LimitNode) EnforceSingleRowNode(io.trino.sql.planner.plan.EnforceSingleRowNode) IsNullPredicate(io.trino.sql.tree.IsNullPredicate) ProjectNode(io.trino.sql.planner.plan.ProjectNode) NullLiteral(io.trino.sql.tree.NullLiteral) TopNNode(io.trino.sql.planner.plan.TopNNode)

Example 32 with Assignments

use of io.trino.sql.planner.plan.Assignments in project trino by trinodb.

the class SubqueryPlanner method coerceIfNecessary.

private PlanAndMappings coerceIfNecessary(PlanBuilder subPlan, Symbol symbol, Expression value, Type type, Optional<? extends Type> coercion) {
    Symbol coerced = symbol;
    if (coercion.isPresent()) {
        coerced = symbolAllocator.newSymbol(value, coercion.get());
        Assignments assignments = Assignments.builder().putIdentities(subPlan.getRoot().getOutputSymbols()).put(coerced, new Cast(symbol.toSymbolReference(), toSqlType(coercion.get()), false, typeCoercion.isTypeOnlyCoercion(type, coercion.get()))).build();
        subPlan = subPlan.withNewRoot(new ProjectNode(idAllocator.getNextId(), subPlan.getRoot(), assignments));
    }
    return new PlanAndMappings(subPlan, ImmutableMap.of(NodeRef.of(value), coerced));
}
Also used : Cast(io.trino.sql.tree.Cast) Assignments(io.trino.sql.planner.plan.Assignments) ProjectNode(io.trino.sql.planner.plan.ProjectNode) PlanAndMappings(io.trino.sql.planner.QueryPlanner.PlanAndMappings)

Example 33 with Assignments

use of io.trino.sql.planner.plan.Assignments in project trino by trinodb.

the class PushAggregationIntoTableScan method pushAggregationIntoTableScan.

public static Optional<PlanNode> pushAggregationIntoTableScan(PlannerContext plannerContext, Context context, PlanNode aggregationNode, TableScanNode tableScan, Map<Symbol, AggregationNode.Aggregation> aggregations, List<Symbol> groupingKeys) {
    Map<String, ColumnHandle> assignments = tableScan.getAssignments().entrySet().stream().collect(toImmutableMap(entry -> entry.getKey().getName(), Entry::getValue));
    List<Entry<Symbol, AggregationNode.Aggregation>> aggregationsList = aggregations.entrySet().stream().collect(toImmutableList());
    List<AggregateFunction> aggregateFunctions = aggregationsList.stream().map(Entry::getValue).map(aggregation -> toAggregateFunction(plannerContext.getMetadata(), context, aggregation)).collect(toImmutableList());
    List<Symbol> aggregationOutputSymbols = aggregationsList.stream().map(Entry::getKey).collect(toImmutableList());
    List<ColumnHandle> groupByColumns = groupingKeys.stream().map(groupByColumn -> assignments.get(groupByColumn.getName())).collect(toImmutableList());
    Optional<AggregationApplicationResult<TableHandle>> aggregationPushdownResult = plannerContext.getMetadata().applyAggregation(context.getSession(), tableScan.getTable(), aggregateFunctions, assignments, ImmutableList.of(groupByColumns));
    if (aggregationPushdownResult.isEmpty()) {
        return Optional.empty();
    }
    AggregationApplicationResult<TableHandle> result = aggregationPushdownResult.get();
    // The new scan outputs should be the symbols associated with grouping columns plus the symbols associated with aggregations.
    ImmutableList.Builder<Symbol> newScanOutputs = ImmutableList.builder();
    newScanOutputs.addAll(tableScan.getOutputSymbols());
    ImmutableBiMap.Builder<Symbol, ColumnHandle> newScanAssignments = ImmutableBiMap.builder();
    newScanAssignments.putAll(tableScan.getAssignments());
    Map<String, Symbol> variableMappings = new HashMap<>();
    for (Assignment assignment : result.getAssignments()) {
        Symbol symbol = context.getSymbolAllocator().newSymbol(assignment.getVariable(), assignment.getType());
        newScanOutputs.add(symbol);
        newScanAssignments.put(symbol, assignment.getColumn());
        variableMappings.put(assignment.getVariable(), symbol);
    }
    List<Expression> newProjections = result.getProjections().stream().map(expression -> ConnectorExpressionTranslator.translate(context.getSession(), expression, plannerContext, variableMappings, new LiteralEncoder(plannerContext))).collect(toImmutableList());
    verify(aggregationOutputSymbols.size() == newProjections.size());
    Assignments.Builder assignmentBuilder = Assignments.builder();
    IntStream.range(0, aggregationOutputSymbols.size()).forEach(index -> assignmentBuilder.put(aggregationOutputSymbols.get(index), newProjections.get(index)));
    ImmutableBiMap<Symbol, ColumnHandle> scanAssignments = newScanAssignments.build();
    ImmutableBiMap<ColumnHandle, Symbol> columnHandleToSymbol = scanAssignments.inverse();
    // projections assignmentBuilder should have both agg and group by so we add all the group bys as symbol references
    groupingKeys.forEach(groupBySymbol -> {
        // if the connector returned a new mapping from oldColumnHandle to newColumnHandle, groupBy needs to point to
        // new columnHandle's symbol reference, otherwise it will continue pointing at oldColumnHandle.
        ColumnHandle originalColumnHandle = assignments.get(groupBySymbol.getName());
        ColumnHandle groupByColumnHandle = result.getGroupingColumnMapping().getOrDefault(originalColumnHandle, originalColumnHandle);
        assignmentBuilder.put(groupBySymbol, columnHandleToSymbol.get(groupByColumnHandle).toSymbolReference());
    });
    return Optional.of(new ProjectNode(context.getIdAllocator().getNextId(), new TableScanNode(context.getIdAllocator().getNextId(), result.getHandle(), newScanOutputs.build(), scanAssignments, TupleDomain.all(), deriveTableStatisticsForPushdown(context.getStatsProvider(), context.getSession(), result.isPrecalculateStatistics(), aggregationNode), tableScan.isUpdateTarget(), // table scan partitioning might have changed with new table handle
    Optional.empty()), assignmentBuilder.build()));
}
Also used : IntStream(java.util.stream.IntStream) SortItem(io.trino.spi.connector.SortItem) LiteralEncoder(io.trino.sql.planner.LiteralEncoder) Aggregation.step(io.trino.sql.planner.plan.Patterns.Aggregation.step) AggregateFunction(io.trino.spi.connector.AggregateFunction) HashMap(java.util.HashMap) Variable(io.trino.spi.expression.Variable) Capture.newCapture(io.trino.matching.Capture.newCapture) PlanNode(io.trino.sql.planner.plan.PlanNode) ImmutableBiMap(com.google.common.collect.ImmutableBiMap) Patterns.aggregation(io.trino.sql.planner.plan.Patterns.aggregation) ImmutableList(com.google.common.collect.ImmutableList) Verify.verify(com.google.common.base.Verify.verify) Map(java.util.Map) Objects.requireNonNull(java.util.Objects.requireNonNull) ColumnHandle(io.trino.spi.connector.ColumnHandle) AggregationNode(io.trino.sql.planner.plan.AggregationNode) Rule(io.trino.sql.planner.iterative.Rule) ProjectNode(io.trino.sql.planner.plan.ProjectNode) TableScanNode(io.trino.sql.planner.plan.TableScanNode) Symbol(io.trino.sql.planner.Symbol) Rules.deriveTableStatisticsForPushdown(io.trino.sql.planner.iterative.rule.Rules.deriveTableStatisticsForPushdown) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) ConnectorExpressionTranslator(io.trino.sql.planner.ConnectorExpressionTranslator) Assignments(io.trino.sql.planner.plan.Assignments) TupleDomain(io.trino.spi.predicate.TupleDomain) OrderingScheme(io.trino.sql.planner.OrderingScheme) Patterns.tableScan(io.trino.sql.planner.plan.Patterns.tableScan) Capture(io.trino.matching.Capture) AggregationApplicationResult(io.trino.spi.connector.AggregationApplicationResult) List(java.util.List) ImmutableMap.toImmutableMap(com.google.common.collect.ImmutableMap.toImmutableMap) Pattern(io.trino.matching.Pattern) TableHandle(io.trino.metadata.TableHandle) ConnectorExpression(io.trino.spi.expression.ConnectorExpression) SystemSessionProperties.isAllowPushdownIntoConnectors(io.trino.SystemSessionProperties.isAllowPushdownIntoConnectors) Patterns.source(io.trino.sql.planner.plan.Patterns.source) SymbolReference(io.trino.sql.tree.SymbolReference) Captures(io.trino.matching.Captures) BoundSignature(io.trino.metadata.BoundSignature) Assignment(io.trino.spi.connector.Assignment) Entry(java.util.Map.Entry) Metadata(io.trino.metadata.Metadata) Optional(java.util.Optional) Expression(io.trino.sql.tree.Expression) Session(io.trino.Session) PlannerContext(io.trino.sql.PlannerContext) HashMap(java.util.HashMap) Symbol(io.trino.sql.planner.Symbol) ImmutableList(com.google.common.collect.ImmutableList) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) Assignments(io.trino.sql.planner.plan.Assignments) Assignment(io.trino.spi.connector.Assignment) Entry(java.util.Map.Entry) LiteralEncoder(io.trino.sql.planner.LiteralEncoder) ColumnHandle(io.trino.spi.connector.ColumnHandle) AggregationNode(io.trino.sql.planner.plan.AggregationNode) ImmutableBiMap(com.google.common.collect.ImmutableBiMap) TableScanNode(io.trino.sql.planner.plan.TableScanNode) ConnectorExpression(io.trino.spi.expression.ConnectorExpression) Expression(io.trino.sql.tree.Expression) AggregateFunction(io.trino.spi.connector.AggregateFunction) TableHandle(io.trino.metadata.TableHandle) ProjectNode(io.trino.sql.planner.plan.ProjectNode) AggregationApplicationResult(io.trino.spi.connector.AggregationApplicationResult)

Example 34 with Assignments

use of io.trino.sql.planner.plan.Assignments in project trino by trinodb.

the class PushDownDereferenceThroughJoin method apply.

@Override
public Result apply(ProjectNode projectNode, Captures captures, Context context) {
    JoinNode joinNode = captures.get(CHILD);
    // Consider dereferences in projections and join filter for pushdown
    ImmutableList.Builder<Expression> expressionsBuilder = ImmutableList.builder();
    expressionsBuilder.addAll(projectNode.getAssignments().getExpressions());
    joinNode.getFilter().ifPresent(expressionsBuilder::add);
    Set<SubscriptExpression> dereferences = extractRowSubscripts(expressionsBuilder.build(), false, context.getSession(), typeAnalyzer, context.getSymbolAllocator().getTypes());
    // Exclude criteria symbols
    ImmutableSet.Builder<Symbol> criteriaSymbolsBuilder = ImmutableSet.builder();
    joinNode.getCriteria().forEach(criteria -> {
        criteriaSymbolsBuilder.add(criteria.getLeft());
        criteriaSymbolsBuilder.add(criteria.getRight());
    });
    Set<Symbol> excludeSymbols = criteriaSymbolsBuilder.build();
    dereferences = dereferences.stream().filter(expression -> !excludeSymbols.contains(getBase(expression))).collect(toImmutableSet());
    if (dereferences.isEmpty()) {
        return Result.empty();
    }
    // Create new symbols for dereference expressions
    Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSession(), context.getSymbolAllocator(), typeAnalyzer);
    // Rewrite project node assignments using new symbols for dereference expressions
    Map<Expression, SymbolReference> mappings = HashBiMap.create(dereferenceAssignments.getMap()).inverse().entrySet().stream().collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference()));
    Assignments newAssignments = projectNode.getAssignments().rewrite(expression -> replaceExpression(expression, mappings));
    Assignments.Builder leftAssignmentsBuilder = Assignments.builder();
    Assignments.Builder rightAssignmentsBuilder = Assignments.builder();
    // Separate dereferences coming from left and right nodes
    dereferenceAssignments.entrySet().stream().forEach(entry -> {
        Symbol baseSymbol = getOnlyElement(extractAll(entry.getValue()));
        if (joinNode.getLeft().getOutputSymbols().contains(baseSymbol)) {
            leftAssignmentsBuilder.put(entry.getKey(), entry.getValue());
        } else if (joinNode.getRight().getOutputSymbols().contains(baseSymbol)) {
            rightAssignmentsBuilder.put(entry.getKey(), entry.getValue());
        } else {
            throw new IllegalArgumentException(format("Unexpected symbol %s in projectNode", baseSymbol));
        }
    });
    Assignments leftAssignments = leftAssignmentsBuilder.build();
    Assignments rightAssignments = rightAssignmentsBuilder.build();
    PlanNode leftNode = createProjectNodeIfRequired(joinNode.getLeft(), leftAssignments, context.getIdAllocator());
    PlanNode rightNode = createProjectNodeIfRequired(joinNode.getRight(), rightAssignments, context.getIdAllocator());
    // Prepare new output symbols for join node
    List<Symbol> referredSymbolsInAssignments = newAssignments.getExpressions().stream().flatMap(expression -> extractAll(expression).stream()).collect(toList());
    List<Symbol> newLeftOutputSymbols = referredSymbolsInAssignments.stream().filter(symbol -> leftNode.getOutputSymbols().contains(symbol)).collect(toList());
    List<Symbol> newRightOutputSymbols = referredSymbolsInAssignments.stream().filter(symbol -> rightNode.getOutputSymbols().contains(symbol)).collect(toList());
    JoinNode newJoinNode = new JoinNode(context.getIdAllocator().getNextId(), joinNode.getType(), leftNode, rightNode, joinNode.getCriteria(), newLeftOutputSymbols, newRightOutputSymbols, joinNode.isMaySkipOutputDuplicates(), // Use newly created symbols in filter
    joinNode.getFilter().map(expression -> replaceExpression(expression, mappings)), joinNode.getLeftHashSymbol(), joinNode.getRightHashSymbol(), joinNode.getDistributionType(), joinNode.isSpillable(), joinNode.getDynamicFilters(), joinNode.getReorderJoinStatsAndCost());
    return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), newJoinNode, newAssignments));
}
Also used : Capture.newCapture(io.trino.matching.Capture.newCapture) DereferencePushdown.getBase(io.trino.sql.planner.iterative.rule.DereferencePushdown.getBase) PlanNode(io.trino.sql.planner.plan.PlanNode) SubscriptExpression(io.trino.sql.tree.SubscriptExpression) ImmutableList(com.google.common.collect.ImmutableList) Map(java.util.Map) Objects.requireNonNull(java.util.Objects.requireNonNull) ImmutableSet.toImmutableSet(com.google.common.collect.ImmutableSet.toImmutableSet) Rule(io.trino.sql.planner.iterative.Rule) JoinNode(io.trino.sql.planner.plan.JoinNode) Patterns.join(io.trino.sql.planner.plan.Patterns.join) ProjectNode(io.trino.sql.planner.plan.ProjectNode) Symbol(io.trino.sql.planner.Symbol) ImmutableSet(com.google.common.collect.ImmutableSet) Assignments(io.trino.sql.planner.plan.Assignments) SymbolsExtractor.extractAll(io.trino.sql.planner.SymbolsExtractor.extractAll) Set(java.util.Set) Iterables.getOnlyElement(com.google.common.collect.Iterables.getOnlyElement) Capture(io.trino.matching.Capture) String.format(java.lang.String.format) DereferencePushdown.extractRowSubscripts(io.trino.sql.planner.iterative.rule.DereferencePushdown.extractRowSubscripts) HashBiMap(com.google.common.collect.HashBiMap) List(java.util.List) ImmutableMap.toImmutableMap(com.google.common.collect.ImmutableMap.toImmutableMap) Collectors.toList(java.util.stream.Collectors.toList) Pattern(io.trino.matching.Pattern) TypeAnalyzer(io.trino.sql.planner.TypeAnalyzer) Patterns.source(io.trino.sql.planner.plan.Patterns.source) SymbolReference(io.trino.sql.tree.SymbolReference) Captures(io.trino.matching.Captures) ExpressionNodeInliner.replaceExpression(io.trino.sql.planner.ExpressionNodeInliner.replaceExpression) Expression(io.trino.sql.tree.Expression) Patterns.project(io.trino.sql.planner.plan.Patterns.project) PlanNodeIdAllocator(io.trino.sql.planner.PlanNodeIdAllocator) JoinNode(io.trino.sql.planner.plan.JoinNode) ImmutableList(com.google.common.collect.ImmutableList) Symbol(io.trino.sql.planner.Symbol) SymbolReference(io.trino.sql.tree.SymbolReference) Assignments(io.trino.sql.planner.plan.Assignments) PlanNode(io.trino.sql.planner.plan.PlanNode) ImmutableSet.toImmutableSet(com.google.common.collect.ImmutableSet.toImmutableSet) ImmutableSet(com.google.common.collect.ImmutableSet) SubscriptExpression(io.trino.sql.tree.SubscriptExpression) ExpressionNodeInliner.replaceExpression(io.trino.sql.planner.ExpressionNodeInliner.replaceExpression) Expression(io.trino.sql.tree.Expression) SubscriptExpression(io.trino.sql.tree.SubscriptExpression) ProjectNode(io.trino.sql.planner.plan.ProjectNode) Map(java.util.Map) HashBiMap(com.google.common.collect.HashBiMap) ImmutableMap.toImmutableMap(com.google.common.collect.ImmutableMap.toImmutableMap)

Example 35 with Assignments

use of io.trino.sql.planner.plan.Assignments in project trino by trinodb.

the class InlineProjections method inlineProjections.

static Optional<ProjectNode> inlineProjections(PlannerContext plannerContext, ProjectNode parent, ProjectNode child, Session session, TypeAnalyzer typeAnalyzer, TypeProvider types) {
    // squash identity projections
    if (parent.isIdentity() && child.isIdentity()) {
        return Optional.of((ProjectNode) parent.replaceChildren(ImmutableList.of(child.getSource())));
    }
    Set<Symbol> targets = extractInliningTargets(plannerContext, parent, child, session, typeAnalyzer, types);
    if (targets.isEmpty()) {
        return Optional.empty();
    }
    // inline the expressions
    Assignments assignments = child.getAssignments().filter(targets::contains);
    Map<Symbol, Expression> parentAssignments = parent.getAssignments().entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> inlineReferences(entry.getValue(), assignments)));
    // Synthesize identity assignments for the inputs of expressions that were inlined
    // to place in the child projection.
    Set<Symbol> inputs = child.getAssignments().entrySet().stream().filter(entry -> targets.contains(entry.getKey())).map(Map.Entry::getValue).flatMap(entry -> SymbolsExtractor.extractAll(entry).stream()).collect(toSet());
    Assignments.Builder newChildAssignmentsBuilder = Assignments.builder();
    for (Map.Entry<Symbol, Expression> assignment : child.getAssignments().entrySet()) {
        if (!targets.contains(assignment.getKey())) {
            newChildAssignmentsBuilder.put(assignment);
        }
    }
    for (Symbol input : inputs) {
        newChildAssignmentsBuilder.putIdentity(input);
    }
    Assignments newChildAssignments = newChildAssignmentsBuilder.build();
    PlanNode newChild;
    if (newChildAssignments.isIdentity()) {
        newChild = child.getSource();
    } else {
        newChild = new ProjectNode(child.getId(), child.getSource(), newChildAssignments);
    }
    return Optional.of(new ProjectNode(parent.getId(), newChild, Assignments.copyOf(parentAssignments)));
}
Also used : Function(java.util.function.Function) Capture.newCapture(io.trino.matching.Capture.newCapture) PlanNode(io.trino.sql.planner.plan.PlanNode) SubscriptExpression(io.trino.sql.tree.SubscriptExpression) ExpressionSymbolInliner.inlineSymbols(io.trino.sql.planner.ExpressionSymbolInliner.inlineSymbols) ImmutableList(com.google.common.collect.ImmutableList) Map(java.util.Map) Objects.requireNonNull(java.util.Objects.requireNonNull) Rule(io.trino.sql.planner.iterative.Rule) SymbolsExtractor(io.trino.sql.planner.SymbolsExtractor) TryExpression(io.trino.sql.tree.TryExpression) ProjectNode(io.trino.sql.planner.plan.ProjectNode) Collectors.toSet(java.util.stream.Collectors.toSet) Symbol(io.trino.sql.planner.Symbol) RowType(io.trino.spi.type.RowType) ImmutableSet(com.google.common.collect.ImmutableSet) Assignments(io.trino.sql.planner.plan.Assignments) Set(java.util.Set) Collectors(java.util.stream.Collectors) Sets(com.google.common.collect.Sets) Capture(io.trino.matching.Capture) ExpressionUtils.isEffectivelyLiteral(io.trino.sql.ExpressionUtils.isEffectivelyLiteral) Pattern(io.trino.matching.Pattern) TypeAnalyzer(io.trino.sql.planner.TypeAnalyzer) AstUtils(io.trino.sql.util.AstUtils) Patterns.source(io.trino.sql.planner.plan.Patterns.source) SymbolReference(io.trino.sql.tree.SymbolReference) Captures(io.trino.matching.Captures) TypeProvider(io.trino.sql.planner.TypeProvider) Optional(java.util.Optional) Expression(io.trino.sql.tree.Expression) Patterns.project(io.trino.sql.planner.plan.Patterns.project) Session(io.trino.Session) PlannerContext(io.trino.sql.PlannerContext) Symbol(io.trino.sql.planner.Symbol) Assignments(io.trino.sql.planner.plan.Assignments) PlanNode(io.trino.sql.planner.plan.PlanNode) SubscriptExpression(io.trino.sql.tree.SubscriptExpression) TryExpression(io.trino.sql.tree.TryExpression) Expression(io.trino.sql.tree.Expression) ProjectNode(io.trino.sql.planner.plan.ProjectNode) Map(java.util.Map)

Aggregations

Assignments (io.trino.sql.planner.plan.Assignments)61 ProjectNode (io.trino.sql.planner.plan.ProjectNode)50 Symbol (io.trino.sql.planner.Symbol)39 Expression (io.trino.sql.tree.Expression)39 ImmutableList (com.google.common.collect.ImmutableList)36 Map (java.util.Map)33 PlanNode (io.trino.sql.planner.plan.PlanNode)26 SymbolReference (io.trino.sql.tree.SymbolReference)26 Captures (io.trino.matching.Captures)23 Pattern (io.trino.matching.Pattern)23 Rule (io.trino.sql.planner.iterative.Rule)23 ImmutableMap.toImmutableMap (com.google.common.collect.ImmutableMap.toImmutableMap)21 Objects.requireNonNull (java.util.Objects.requireNonNull)21 Set (java.util.Set)21 Capture (io.trino.matching.Capture)20 Capture.newCapture (io.trino.matching.Capture.newCapture)20 TypeAnalyzer (io.trino.sql.planner.TypeAnalyzer)19 Patterns.source (io.trino.sql.planner.plan.Patterns.source)19 SubscriptExpression (io.trino.sql.tree.SubscriptExpression)18 ImmutableMap (com.google.common.collect.ImmutableMap)17