Search in sources :

Example 1 with TopNNode

use of io.trino.sql.planner.plan.TopNNode in project trino by trinodb.

the class MergeLimitOverProjectWithSort method apply.

@Override
public Result apply(LimitNode parent, Captures captures, Context context) {
    ProjectNode project = captures.get(PROJECT);
    SortNode sort = captures.get(SORT);
    return Result.ofPlanNode(project.replaceChildren(ImmutableList.of(new TopNNode(parent.getId(), sort.getSource(), parent.getCount(), sort.getOrderingScheme(), parent.isPartial() ? PARTIAL : SINGLE))));
}
Also used : SortNode(io.trino.sql.planner.plan.SortNode) ProjectNode(io.trino.sql.planner.plan.ProjectNode) TopNNode(io.trino.sql.planner.plan.TopNNode)

Example 2 with TopNNode

use of io.trino.sql.planner.plan.TopNNode in project trino by trinodb.

the class DecorrelateUnnest method apply.

@Override
public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Context context) {
    // determine shape of the subquery
    PlanNode searchRoot = correlatedJoinNode.getSubquery();
    // 1. find EnforceSingleRowNode in the subquery
    Optional<EnforceSingleRowNode> enforceSingleRow = PlanNodeSearcher.searchFrom(searchRoot, context.getLookup()).where(EnforceSingleRowNode.class::isInstance).recurseOnlyWhen(planNode -> false).findFirst();
    if (enforceSingleRow.isPresent()) {
        searchRoot = enforceSingleRow.get().getSource();
    }
    // 2. find correlated UnnestNode in the subquery
    Optional<UnnestNode> subqueryUnnest = PlanNodeSearcher.searchFrom(searchRoot, context.getLookup()).where(node -> isSupportedUnnest(node, correlatedJoinNode.getCorrelation(), context.getLookup())).recurseOnlyWhen(node -> node instanceof ProjectNode || (node instanceof LimitNode && ((LimitNode) node).getCount() > 0) || (node instanceof TopNNode && ((TopNNode) node).getCount() > 0)).findFirst();
    if (subqueryUnnest.isEmpty()) {
        return Result.empty();
    }
    UnnestNode unnestNode = subqueryUnnest.get();
    // assign unique id to input rows
    Symbol uniqueSymbol = context.getSymbolAllocator().newSymbol("unique", BIGINT);
    PlanNode input = new AssignUniqueId(context.getIdAllocator().getNextId(), correlatedJoinNode.getInput(), uniqueSymbol);
    // pre-project unnest symbols if they were pre-projected in subquery
    // The correlated UnnestNode either unnests correlation symbols directly, or unnests symbols produced by a projection that uses only correlation symbols.
    // Here, any underlying projection that was a source of the correlated UnnestNode, is appended as a source of the rewritten UnnestNode.
    // If the projection is not necessary for UnnestNode (i.e. it does not produce any unnest symbols), it should be pruned afterwards.
    PlanNode unnestSource = context.getLookup().resolve(unnestNode.getSource());
    if (unnestSource instanceof ProjectNode) {
        ProjectNode sourceProjection = (ProjectNode) unnestSource;
        input = new ProjectNode(sourceProjection.getId(), input, Assignments.builder().putIdentities(input.getOutputSymbols()).putAll(sourceProjection.getAssignments()).build());
    }
    // determine join type for rewritten UnnestNode
    Type unnestJoinType = LEFT;
    if (enforceSingleRow.isEmpty() && correlatedJoinNode.getType() == CorrelatedJoinNode.Type.INNER && unnestNode.getJoinType() == INNER) {
        unnestJoinType = INNER;
    }
    // make sure that the rewritten node is with ordinality, which might be necessary to restore inner unnest semantics after rewrite.
    Symbol ordinalitySymbol = unnestNode.getOrdinalitySymbol().orElseGet(() -> context.getSymbolAllocator().newSymbol("ordinality", BIGINT));
    // rewrite correlated join to UnnestNode.
    UnnestNode rewrittenUnnest = new UnnestNode(context.getIdAllocator().getNextId(), input, input.getOutputSymbols(), unnestNode.getMappings(), Optional.of(ordinalitySymbol), unnestJoinType, Optional.empty());
    // restore all nodes from the subquery
    PlanNode rewrittenPlan = Rewriter.rewriteNodeSequence(correlatedJoinNode.getSubquery(), input.getOutputSymbols(), ordinalitySymbol, uniqueSymbol, rewrittenUnnest, context.getSession(), metadata, context.getLookup(), context.getIdAllocator(), context.getSymbolAllocator());
    // between unnested rows and synthetic rows added by left unnest.
    if (unnestNode.getJoinType() == INNER && rewrittenUnnest.getJoinType() == LEFT) {
        Assignments.Builder assignments = Assignments.builder().putIdentities(correlatedJoinNode.getInput().getOutputSymbols());
        for (Symbol subquerySymbol : correlatedJoinNode.getSubquery().getOutputSymbols()) {
            assignments.put(subquerySymbol, new IfExpression(new IsNullPredicate(ordinalitySymbol.toSymbolReference()), new Cast(new NullLiteral(), toSqlType(context.getSymbolAllocator().getTypes().get(subquerySymbol))), subquerySymbol.toSymbolReference()));
        }
        rewrittenPlan = new ProjectNode(context.getIdAllocator().getNextId(), rewrittenPlan, assignments.build());
    }
    // restrict outputs
    return Result.ofPlanNode(restrictOutputs(context.getIdAllocator(), rewrittenPlan, ImmutableSet.copyOf(correlatedJoinNode.getOutputSymbols())).orElse(rewrittenPlan));
}
Also used : CorrelatedJoin.correlation(io.trino.sql.planner.plan.Patterns.CorrelatedJoin.correlation) IsNullPredicate(io.trino.sql.tree.IsNullPredicate) SymbolAllocator(io.trino.sql.planner.SymbolAllocator) TypeSignatureProvider.fromTypes(io.trino.sql.analyzer.TypeSignatureProvider.fromTypes) CorrelatedJoinNode(io.trino.sql.planner.plan.CorrelatedJoinNode) FilterNode(io.trino.sql.planner.plan.FilterNode) PlanNode(io.trino.sql.planner.plan.PlanNode) LEFT(io.trino.sql.planner.plan.JoinNode.Type.LEFT) Type(io.trino.sql.planner.plan.JoinNode.Type) PlanNodeSearcher(io.trino.sql.planner.optimizations.PlanNodeSearcher) GREATER_THAN(io.trino.sql.tree.ComparisonExpression.Operator.GREATER_THAN) AssignUniqueId(io.trino.sql.planner.plan.AssignUniqueId) FunctionCall(io.trino.sql.tree.FunctionCall) ImmutableSet(com.google.common.collect.ImmutableSet) ImmutableMap(com.google.common.collect.ImmutableMap) ResolvedFunction(io.trino.metadata.ResolvedFunction) EnforceSingleRowNode(io.trino.sql.planner.plan.EnforceSingleRowNode) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) TypeSignatureTranslator.toSqlType(io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType) Assignments(io.trino.sql.planner.plan.Assignments) ComparisonExpression(io.trino.sql.tree.ComparisonExpression) DEFAULT_FRAME(io.trino.sql.planner.plan.WindowNode.Frame.DEFAULT_FRAME) LESS_THAN_OR_EQUAL(io.trino.sql.tree.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL) GenericLiteral(io.trino.sql.tree.GenericLiteral) List(java.util.List) Pattern(io.trino.matching.Pattern) IfExpression(io.trino.sql.tree.IfExpression) BIGINT(io.trino.spi.type.BigintType.BIGINT) Optional(java.util.Optional) Expression(io.trino.sql.tree.Expression) WindowNode(io.trino.sql.planner.plan.WindowNode) Session(io.trino.Session) QueryCardinalityUtil.isScalar(io.trino.sql.planner.optimizations.QueryCardinalityUtil.isScalar) INNER(io.trino.sql.planner.plan.JoinNode.Type.INNER) LimitNode(io.trino.sql.planner.plan.LimitNode) BOOLEAN(io.trino.spi.type.BooleanType.BOOLEAN) Cast(io.trino.sql.tree.Cast) CorrelatedJoin.filter(io.trino.sql.planner.plan.Patterns.CorrelatedJoin.filter) VARCHAR(io.trino.spi.type.VarcharType.VARCHAR) ImmutableList(com.google.common.collect.ImmutableList) RowNumberNode(io.trino.sql.planner.plan.RowNumberNode) Objects.requireNonNull(java.util.Objects.requireNonNull) NullLiteral(io.trino.sql.tree.NullLiteral) Rule(io.trino.sql.planner.iterative.Rule) SymbolsExtractor(io.trino.sql.planner.SymbolsExtractor) Pattern.nonEmpty(io.trino.matching.Pattern.nonEmpty) ProjectNode(io.trino.sql.planner.plan.ProjectNode) Symbol(io.trino.sql.planner.Symbol) StringLiteral(io.trino.sql.tree.StringLiteral) Lookup(io.trino.sql.planner.iterative.Lookup) PlanVisitor(io.trino.sql.planner.plan.PlanVisitor) TopNNode(io.trino.sql.planner.plan.TopNNode) TRUE_LITERAL(io.trino.sql.tree.BooleanLiteral.TRUE_LITERAL) UnnestNode(io.trino.sql.planner.plan.UnnestNode) ImplementLimitWithTies.rewriteLimitWithTiesWithPartitioning(io.trino.sql.planner.iterative.rule.ImplementLimitWithTies.rewriteLimitWithTiesWithPartitioning) QualifiedName(io.trino.sql.tree.QualifiedName) Captures(io.trino.matching.Captures) Specification(io.trino.sql.planner.plan.WindowNode.Specification) Util.restrictOutputs(io.trino.sql.planner.iterative.rule.Util.restrictOutputs) Metadata(io.trino.metadata.Metadata) PlanNodeIdAllocator(io.trino.sql.planner.PlanNodeIdAllocator) Patterns.correlatedJoin(io.trino.sql.planner.plan.Patterns.correlatedJoin) Cast(io.trino.sql.tree.Cast) IfExpression(io.trino.sql.tree.IfExpression) Symbol(io.trino.sql.planner.Symbol) Assignments(io.trino.sql.planner.plan.Assignments) Type(io.trino.sql.planner.plan.JoinNode.Type) TypeSignatureTranslator.toSqlType(io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType) PlanNode(io.trino.sql.planner.plan.PlanNode) AssignUniqueId(io.trino.sql.planner.plan.AssignUniqueId) UnnestNode(io.trino.sql.planner.plan.UnnestNode) LimitNode(io.trino.sql.planner.plan.LimitNode) EnforceSingleRowNode(io.trino.sql.planner.plan.EnforceSingleRowNode) IsNullPredicate(io.trino.sql.tree.IsNullPredicate) ProjectNode(io.trino.sql.planner.plan.ProjectNode) NullLiteral(io.trino.sql.tree.NullLiteral) TopNNode(io.trino.sql.planner.plan.TopNNode)

Example 3 with TopNNode

use of io.trino.sql.planner.plan.TopNNode in project trino by trinodb.

the class TestUnion method testUnionUnderTopN.

@Test
public void testUnionUnderTopN() {
    Plan plan = plan("SELECT * FROM (" + "   SELECT regionkey FROM nation " + "   UNION ALL " + "   SELECT nationkey FROM nation" + ") t(a) " + "ORDER BY a LIMIT 1", OPTIMIZED_AND_VALIDATED, false);
    List<PlanNode> remotes = searchFrom(plan.getRoot()).where(TestUnion::isRemoteExchange).findAll();
    assertEquals(remotes.size(), 1, "There should be exactly one RemoteExchange");
    assertEquals(((ExchangeNode) Iterables.getOnlyElement(remotes)).getType(), GATHER);
    int numberOfpartialTopN = searchFrom(plan.getRoot()).where(planNode -> planNode instanceof TopNNode && ((TopNNode) planNode).getStep() == TopNNode.Step.PARTIAL).count();
    assertEquals(numberOfpartialTopN, 2, "There should be exactly two partial TopN nodes");
    assertPlanIsFullyDistributed(plan);
}
Also used : Iterables(com.google.common.collect.Iterables) PlanNodeSearcher.searchFrom(io.trino.sql.planner.optimizations.PlanNodeSearcher.searchFrom) TopNNode(io.trino.sql.planner.plan.TopNNode) OPTIMIZED_AND_VALIDATED(io.trino.sql.planner.LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED) Assert.assertEquals(org.testng.Assert.assertEquals) Test(org.testng.annotations.Test) PlanNode(io.trino.sql.planner.plan.PlanNode) List(java.util.List) Collectors.toList(java.util.stream.Collectors.toList) GATHER(io.trino.sql.planner.plan.ExchangeNode.Type.GATHER) REMOTE(io.trino.sql.planner.plan.ExchangeNode.Scope.REMOTE) Map(java.util.Map) AggregationNode(io.trino.sql.planner.plan.AggregationNode) REPARTITION(io.trino.sql.planner.plan.ExchangeNode.Type.REPARTITION) Assert.assertTrue(org.testng.Assert.assertTrue) ExchangeNode(io.trino.sql.planner.plan.ExchangeNode) Plan(io.trino.sql.planner.Plan) JoinNode(io.trino.sql.planner.plan.JoinNode) BasePlanTest(io.trino.sql.planner.assertions.BasePlanTest) Assert.assertFalse(org.testng.Assert.assertFalse) PlanNode(io.trino.sql.planner.plan.PlanNode) Plan(io.trino.sql.planner.Plan) TopNNode(io.trino.sql.planner.plan.TopNNode) Test(org.testng.annotations.Test) BasePlanTest(io.trino.sql.planner.assertions.BasePlanTest)

Example 4 with TopNNode

use of io.trino.sql.planner.plan.TopNNode in project trino by trinodb.

the class TopNStatsRule method doCalculate.

@Override
protected Optional<PlanNodeStatsEstimate> doCalculate(TopNNode node, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types) {
    PlanNodeStatsEstimate sourceStats = statsProvider.getStats(node.getSource());
    double rowCount = sourceStats.getOutputRowCount();
    if (node.getStep() != TopNNode.Step.SINGLE) {
        return Optional.empty();
    }
    if (rowCount <= node.getCount()) {
        return Optional.of(sourceStats);
    }
    long limitCount = node.getCount();
    PlanNodeStatsEstimate resultStats = PlanNodeStatsEstimate.buildFrom(sourceStats).setOutputRowCount(limitCount).build();
    if (limitCount == 0) {
        return Optional.of(resultStats);
    }
    // augment null fraction estimation for first ORDER BY symbol
    // Assuming not empty list
    Symbol firstOrderSymbol = node.getOrderingScheme().getOrderBy().get(0);
    SortOrder sortOrder = node.getOrderingScheme().getOrdering(firstOrderSymbol);
    resultStats = resultStats.mapSymbolColumnStatistics(firstOrderSymbol, symbolStats -> {
        SymbolStatsEstimate.Builder newStats = SymbolStatsEstimate.buildFrom(symbolStats);
        double nullCount = rowCount * symbolStats.getNullsFraction();
        if (sortOrder.isNullsFirst()) {
            if (nullCount > limitCount) {
                newStats.setNullsFraction(1.0);
            } else {
                newStats.setNullsFraction(nullCount / limitCount);
            }
        } else {
            double nonNullCount = (rowCount - nullCount);
            if (nonNullCount > limitCount) {
                newStats.setNullsFraction(0.0);
            } else {
                newStats.setNullsFraction((limitCount - nonNullCount) / limitCount);
            }
        }
        return newStats.build();
    });
    // TopN actually limits (or when there was no row count estimated for source)
    return Optional.of(resultStats);
}
Also used : Symbol(io.trino.sql.planner.Symbol) Pattern(io.trino.matching.Pattern) Lookup(io.trino.sql.planner.iterative.Lookup) TopNNode(io.trino.sql.planner.plan.TopNNode) TypeProvider(io.trino.sql.planner.TypeProvider) Optional(java.util.Optional) Patterns.topN(io.trino.sql.planner.plan.Patterns.topN) Session(io.trino.Session) SortOrder(io.trino.spi.connector.SortOrder) Symbol(io.trino.sql.planner.Symbol) SortOrder(io.trino.spi.connector.SortOrder)

Example 5 with TopNNode

use of io.trino.sql.planner.plan.TopNNode in project trino by trinodb.

the class TopNMatcher method detailMatches.

@Override
public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session session, Metadata metadata, SymbolAliases symbolAliases) {
    checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName());
    TopNNode topNNode = (TopNNode) node;
    if (topNNode.getCount() != count) {
        return NO_MATCH;
    }
    if (!orderingSchemeMatches(orderBy, topNNode.getOrderingScheme(), symbolAliases)) {
        return NO_MATCH;
    }
    if (topNNode.getStep() != step) {
        return NO_MATCH;
    }
    return match();
}
Also used : TopNNode(io.trino.sql.planner.plan.TopNNode)

Aggregations

TopNNode (io.trino.sql.planner.plan.TopNNode)9 PlanNode (io.trino.sql.planner.plan.PlanNode)5 Pattern (io.trino.matching.Pattern)4 ProjectNode (io.trino.sql.planner.plan.ProjectNode)4 Expression (io.trino.sql.tree.Expression)4 ImmutableList (com.google.common.collect.ImmutableList)3 Session (io.trino.Session)3 Captures (io.trino.matching.Captures)3 Rule (io.trino.sql.planner.iterative.Rule)3 Patterns.topN (io.trino.sql.planner.plan.Patterns.topN)3 List (java.util.List)3 Map (java.util.Map)3 Optional (java.util.Optional)3 ImmutableMap.toImmutableMap (com.google.common.collect.ImmutableMap.toImmutableMap)2 Capture (io.trino.matching.Capture)2 Capture.newCapture (io.trino.matching.Capture.newCapture)2 Metadata (io.trino.metadata.Metadata)2 Symbol (io.trino.sql.planner.Symbol)2 Lookup (io.trino.sql.planner.iterative.Lookup)2 Assignments (io.trino.sql.planner.plan.Assignments)2