Search in sources :

Example 1 with Type

use of io.trino.sql.planner.plan.JoinNode.Type in project trino by trinodb.

the class DecorrelateUnnest method apply.

@Override
public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Context context) {
    // determine shape of the subquery
    PlanNode searchRoot = correlatedJoinNode.getSubquery();
    // 1. find EnforceSingleRowNode in the subquery
    Optional<EnforceSingleRowNode> enforceSingleRow = PlanNodeSearcher.searchFrom(searchRoot, context.getLookup()).where(EnforceSingleRowNode.class::isInstance).recurseOnlyWhen(planNode -> false).findFirst();
    if (enforceSingleRow.isPresent()) {
        searchRoot = enforceSingleRow.get().getSource();
    }
    // 2. find correlated UnnestNode in the subquery
    Optional<UnnestNode> subqueryUnnest = PlanNodeSearcher.searchFrom(searchRoot, context.getLookup()).where(node -> isSupportedUnnest(node, correlatedJoinNode.getCorrelation(), context.getLookup())).recurseOnlyWhen(node -> node instanceof ProjectNode || (node instanceof LimitNode && ((LimitNode) node).getCount() > 0) || (node instanceof TopNNode && ((TopNNode) node).getCount() > 0)).findFirst();
    if (subqueryUnnest.isEmpty()) {
        return Result.empty();
    }
    UnnestNode unnestNode = subqueryUnnest.get();
    // assign unique id to input rows
    Symbol uniqueSymbol = context.getSymbolAllocator().newSymbol("unique", BIGINT);
    PlanNode input = new AssignUniqueId(context.getIdAllocator().getNextId(), correlatedJoinNode.getInput(), uniqueSymbol);
    // pre-project unnest symbols if they were pre-projected in subquery
    // The correlated UnnestNode either unnests correlation symbols directly, or unnests symbols produced by a projection that uses only correlation symbols.
    // Here, any underlying projection that was a source of the correlated UnnestNode, is appended as a source of the rewritten UnnestNode.
    // If the projection is not necessary for UnnestNode (i.e. it does not produce any unnest symbols), it should be pruned afterwards.
    PlanNode unnestSource = context.getLookup().resolve(unnestNode.getSource());
    if (unnestSource instanceof ProjectNode) {
        ProjectNode sourceProjection = (ProjectNode) unnestSource;
        input = new ProjectNode(sourceProjection.getId(), input, Assignments.builder().putIdentities(input.getOutputSymbols()).putAll(sourceProjection.getAssignments()).build());
    }
    // determine join type for rewritten UnnestNode
    Type unnestJoinType = LEFT;
    if (enforceSingleRow.isEmpty() && correlatedJoinNode.getType() == CorrelatedJoinNode.Type.INNER && unnestNode.getJoinType() == INNER) {
        unnestJoinType = INNER;
    }
    // make sure that the rewritten node is with ordinality, which might be necessary to restore inner unnest semantics after rewrite.
    Symbol ordinalitySymbol = unnestNode.getOrdinalitySymbol().orElseGet(() -> context.getSymbolAllocator().newSymbol("ordinality", BIGINT));
    // rewrite correlated join to UnnestNode.
    UnnestNode rewrittenUnnest = new UnnestNode(context.getIdAllocator().getNextId(), input, input.getOutputSymbols(), unnestNode.getMappings(), Optional.of(ordinalitySymbol), unnestJoinType, Optional.empty());
    // restore all nodes from the subquery
    PlanNode rewrittenPlan = Rewriter.rewriteNodeSequence(correlatedJoinNode.getSubquery(), input.getOutputSymbols(), ordinalitySymbol, uniqueSymbol, rewrittenUnnest, context.getSession(), metadata, context.getLookup(), context.getIdAllocator(), context.getSymbolAllocator());
    // between unnested rows and synthetic rows added by left unnest.
    if (unnestNode.getJoinType() == INNER && rewrittenUnnest.getJoinType() == LEFT) {
        Assignments.Builder assignments = Assignments.builder().putIdentities(correlatedJoinNode.getInput().getOutputSymbols());
        for (Symbol subquerySymbol : correlatedJoinNode.getSubquery().getOutputSymbols()) {
            assignments.put(subquerySymbol, new IfExpression(new IsNullPredicate(ordinalitySymbol.toSymbolReference()), new Cast(new NullLiteral(), toSqlType(context.getSymbolAllocator().getTypes().get(subquerySymbol))), subquerySymbol.toSymbolReference()));
        }
        rewrittenPlan = new ProjectNode(context.getIdAllocator().getNextId(), rewrittenPlan, assignments.build());
    }
    // restrict outputs
    return Result.ofPlanNode(restrictOutputs(context.getIdAllocator(), rewrittenPlan, ImmutableSet.copyOf(correlatedJoinNode.getOutputSymbols())).orElse(rewrittenPlan));
}
Also used : CorrelatedJoin.correlation(io.trino.sql.planner.plan.Patterns.CorrelatedJoin.correlation) IsNullPredicate(io.trino.sql.tree.IsNullPredicate) SymbolAllocator(io.trino.sql.planner.SymbolAllocator) TypeSignatureProvider.fromTypes(io.trino.sql.analyzer.TypeSignatureProvider.fromTypes) CorrelatedJoinNode(io.trino.sql.planner.plan.CorrelatedJoinNode) FilterNode(io.trino.sql.planner.plan.FilterNode) PlanNode(io.trino.sql.planner.plan.PlanNode) LEFT(io.trino.sql.planner.plan.JoinNode.Type.LEFT) Type(io.trino.sql.planner.plan.JoinNode.Type) PlanNodeSearcher(io.trino.sql.planner.optimizations.PlanNodeSearcher) GREATER_THAN(io.trino.sql.tree.ComparisonExpression.Operator.GREATER_THAN) AssignUniqueId(io.trino.sql.planner.plan.AssignUniqueId) FunctionCall(io.trino.sql.tree.FunctionCall) ImmutableSet(com.google.common.collect.ImmutableSet) ImmutableMap(com.google.common.collect.ImmutableMap) ResolvedFunction(io.trino.metadata.ResolvedFunction) EnforceSingleRowNode(io.trino.sql.planner.plan.EnforceSingleRowNode) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) TypeSignatureTranslator.toSqlType(io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType) Assignments(io.trino.sql.planner.plan.Assignments) ComparisonExpression(io.trino.sql.tree.ComparisonExpression) DEFAULT_FRAME(io.trino.sql.planner.plan.WindowNode.Frame.DEFAULT_FRAME) LESS_THAN_OR_EQUAL(io.trino.sql.tree.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL) GenericLiteral(io.trino.sql.tree.GenericLiteral) List(java.util.List) Pattern(io.trino.matching.Pattern) IfExpression(io.trino.sql.tree.IfExpression) BIGINT(io.trino.spi.type.BigintType.BIGINT) Optional(java.util.Optional) Expression(io.trino.sql.tree.Expression) WindowNode(io.trino.sql.planner.plan.WindowNode) Session(io.trino.Session) QueryCardinalityUtil.isScalar(io.trino.sql.planner.optimizations.QueryCardinalityUtil.isScalar) INNER(io.trino.sql.planner.plan.JoinNode.Type.INNER) LimitNode(io.trino.sql.planner.plan.LimitNode) BOOLEAN(io.trino.spi.type.BooleanType.BOOLEAN) Cast(io.trino.sql.tree.Cast) CorrelatedJoin.filter(io.trino.sql.planner.plan.Patterns.CorrelatedJoin.filter) VARCHAR(io.trino.spi.type.VarcharType.VARCHAR) ImmutableList(com.google.common.collect.ImmutableList) RowNumberNode(io.trino.sql.planner.plan.RowNumberNode) Objects.requireNonNull(java.util.Objects.requireNonNull) NullLiteral(io.trino.sql.tree.NullLiteral) Rule(io.trino.sql.planner.iterative.Rule) SymbolsExtractor(io.trino.sql.planner.SymbolsExtractor) Pattern.nonEmpty(io.trino.matching.Pattern.nonEmpty) ProjectNode(io.trino.sql.planner.plan.ProjectNode) Symbol(io.trino.sql.planner.Symbol) StringLiteral(io.trino.sql.tree.StringLiteral) Lookup(io.trino.sql.planner.iterative.Lookup) PlanVisitor(io.trino.sql.planner.plan.PlanVisitor) TopNNode(io.trino.sql.planner.plan.TopNNode) TRUE_LITERAL(io.trino.sql.tree.BooleanLiteral.TRUE_LITERAL) UnnestNode(io.trino.sql.planner.plan.UnnestNode) ImplementLimitWithTies.rewriteLimitWithTiesWithPartitioning(io.trino.sql.planner.iterative.rule.ImplementLimitWithTies.rewriteLimitWithTiesWithPartitioning) QualifiedName(io.trino.sql.tree.QualifiedName) Captures(io.trino.matching.Captures) Specification(io.trino.sql.planner.plan.WindowNode.Specification) Util.restrictOutputs(io.trino.sql.planner.iterative.rule.Util.restrictOutputs) Metadata(io.trino.metadata.Metadata) PlanNodeIdAllocator(io.trino.sql.planner.PlanNodeIdAllocator) Patterns.correlatedJoin(io.trino.sql.planner.plan.Patterns.correlatedJoin) Cast(io.trino.sql.tree.Cast) IfExpression(io.trino.sql.tree.IfExpression) Symbol(io.trino.sql.planner.Symbol) Assignments(io.trino.sql.planner.plan.Assignments) Type(io.trino.sql.planner.plan.JoinNode.Type) TypeSignatureTranslator.toSqlType(io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType) PlanNode(io.trino.sql.planner.plan.PlanNode) AssignUniqueId(io.trino.sql.planner.plan.AssignUniqueId) UnnestNode(io.trino.sql.planner.plan.UnnestNode) LimitNode(io.trino.sql.planner.plan.LimitNode) EnforceSingleRowNode(io.trino.sql.planner.plan.EnforceSingleRowNode) IsNullPredicate(io.trino.sql.tree.IsNullPredicate) ProjectNode(io.trino.sql.planner.plan.ProjectNode) NullLiteral(io.trino.sql.tree.NullLiteral) TopNNode(io.trino.sql.planner.plan.TopNNode)

Example 2 with Type

use of io.trino.sql.planner.plan.JoinNode.Type in project trino by trinodb.

the class TestDetermineJoinDistributionType method testReplicatesWhenSourceIsSmall.

@Test
public void testReplicatesWhenSourceIsSmall() {
    // variable width so that average row size is respected
    VarcharType symbolType = createUnboundedVarcharType();
    int aRows = 10_000;
    int bRows = 10;
    // output size exceeds AUTOMATIC_RESTRICTED limit
    PlanNodeStatsEstimate aStatsEstimate = PlanNodeStatsEstimate.builder().setOutputRowCount(aRows).addSymbolStatistics(ImmutableMap.of(new Symbol("A1"), new SymbolStatsEstimate(0, 100, 0, 640000d * 10000, 10))).build();
    // output size exceeds AUTOMATIC_RESTRICTED limit
    PlanNodeStatsEstimate bStatsEstimate = PlanNodeStatsEstimate.builder().setOutputRowCount(bRows).addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 640000d * 10000, 10))).build();
    // output size does not  exceed AUTOMATIC_RESTRICTED limit
    PlanNodeStatsEstimate bSourceStatsEstimate = PlanNodeStatsEstimate.builder().setOutputRowCount(bRows).addSymbolStatistics(ImmutableMap.of(new Symbol("B1"), new SymbolStatsEstimate(0, 100, 0, 64, 10))).build();
    // immediate join sources exceeds AUTOMATIC_RESTRICTED limit but build tables are small
    // therefore replicated distribution type is chosen
    assertDetermineJoinDistributionType().setSystemProperty(JOIN_DISTRIBUTION_TYPE, JoinDistributionType.AUTOMATIC.name()).setSystemProperty(JOIN_MAX_BROADCAST_TABLE_SIZE, "100MB").overrideStats("valuesA", aStatsEstimate).overrideStats("filterB", bStatsEstimate).overrideStats("valuesB", bSourceStatsEstimate).on(p -> {
        Symbol a1 = p.symbol("A1", symbolType);
        Symbol b1 = p.symbol("B1", symbolType);
        return p.join(INNER, p.values(new PlanNodeId("valuesA"), aRows, a1), p.filter(new PlanNodeId("filterB"), TRUE_LITERAL, p.values(new PlanNodeId("valuesB"), bRows, b1)), ImmutableList.of(new JoinNode.EquiJoinClause(a1, b1)), ImmutableList.of(a1), ImmutableList.of(b1), Optional.empty());
    }).matches(join(INNER, ImmutableList.of(equiJoinClause("A1", "B1")), Optional.empty(), Optional.of(REPLICATED), values(ImmutableMap.of("A1", 0)), filter("true", values(ImmutableMap.of("B1", 0)))));
    // same but with join sides reversed
    assertDetermineJoinDistributionType().setSystemProperty(JOIN_DISTRIBUTION_TYPE, JoinDistributionType.AUTOMATIC.name()).setSystemProperty(JOIN_MAX_BROADCAST_TABLE_SIZE, "100MB").overrideStats("valuesA", aStatsEstimate).overrideStats("filterB", bStatsEstimate).overrideStats("valuesB", bSourceStatsEstimate).on(p -> {
        Symbol a1 = p.symbol("A1", symbolType);
        Symbol b1 = p.symbol("B1", symbolType);
        return p.join(INNER, p.filter(new PlanNodeId("filterB"), TRUE_LITERAL, p.values(new PlanNodeId("valuesB"), bRows, b1)), p.values(new PlanNodeId("valuesA"), aRows, a1), ImmutableList.of(new JoinNode.EquiJoinClause(b1, a1)), ImmutableList.of(b1), ImmutableList.of(a1), Optional.empty());
    }).matches(join(INNER, ImmutableList.of(equiJoinClause("A1", "B1")), Optional.empty(), Optional.of(REPLICATED), values(ImmutableMap.of("A1", 0)), filter("true", values(ImmutableMap.of("B1", 0)))));
    // only probe side (with small tables) source stats are available, join sides should be flipped
    assertDetermineJoinDistributionType().setSystemProperty(JOIN_DISTRIBUTION_TYPE, JoinDistributionType.AUTOMATIC.name()).setSystemProperty(JOIN_MAX_BROADCAST_TABLE_SIZE, "100MB").overrideStats("valuesA", PlanNodeStatsEstimate.unknown()).overrideStats("filterB", PlanNodeStatsEstimate.unknown()).overrideStats("valuesB", bSourceStatsEstimate).on(p -> {
        Symbol a1 = p.symbol("A1", symbolType);
        Symbol b1 = p.symbol("B1", symbolType);
        return p.join(LEFT, p.filter(new PlanNodeId("filterB"), TRUE_LITERAL, p.values(new PlanNodeId("valuesB"), bRows, b1)), p.values(new PlanNodeId("valuesA"), aRows, a1), ImmutableList.of(new JoinNode.EquiJoinClause(b1, a1)), ImmutableList.of(b1), ImmutableList.of(a1), Optional.empty());
    }).matches(join(RIGHT, ImmutableList.of(equiJoinClause("A1", "B1")), Optional.empty(), Optional.of(PARTITIONED), values(ImmutableMap.of("A1", 0)), filter("true", values(ImmutableMap.of("B1", 0)))));
}
Also used : PARTITIONED(io.trino.sql.planner.plan.JoinNode.DistributionType.PARTITIONED) INNER(io.trino.sql.planner.plan.JoinNode.Type.INNER) VarcharType.createUnboundedVarcharType(io.trino.spi.type.VarcharType.createUnboundedVarcharType) Assert.assertEquals(org.testng.Assert.assertEquals) Test(org.testng.annotations.Test) PlanMatchPattern.filter(io.trino.sql.planner.assertions.PlanMatchPattern.filter) REPLICATED(io.trino.sql.planner.plan.JoinNode.DistributionType.REPLICATED) Lookup.noLookup(io.trino.sql.planner.iterative.Lookup.noLookup) VarcharType(io.trino.spi.type.VarcharType) LEFT(io.trino.sql.planner.plan.JoinNode.Type.LEFT) RuleAssert(io.trino.sql.planner.iterative.rule.test.RuleAssert) Type(io.trino.sql.planner.plan.JoinNode.Type) PlanBuilder.expressions(io.trino.sql.planner.iterative.rule.test.PlanBuilder.expressions) ImmutableList(com.google.common.collect.ImmutableList) NaN(java.lang.Double.NaN) DistributionType(io.trino.sql.planner.plan.JoinNode.DistributionType) PlanNodeId(io.trino.sql.planner.plan.PlanNodeId) PlanBuilder(io.trino.sql.planner.iterative.rule.test.PlanBuilder) PlanMatchPattern.equiJoinClause(io.trino.sql.planner.assertions.PlanMatchPattern.equiJoinClause) JoinNode(io.trino.sql.planner.plan.JoinNode) TableScanNode(io.trino.sql.planner.plan.TableScanNode) PlanMatchPattern.join(io.trino.sql.planner.assertions.PlanMatchPattern.join) PlanNodeStatsEstimate(io.trino.cost.PlanNodeStatsEstimate) JOIN_MAX_BROADCAST_TABLE_SIZE(io.trino.SystemSessionProperties.JOIN_MAX_BROADCAST_TABLE_SIZE) TaskCountEstimator(io.trino.cost.TaskCountEstimator) Symbol(io.trino.sql.planner.Symbol) AfterClass(org.testng.annotations.AfterClass) SymbolStatsEstimate(io.trino.cost.SymbolStatsEstimate) RuleTester.defaultRuleTester(io.trino.sql.planner.iterative.rule.test.RuleTester.defaultRuleTester) ImmutableMap(com.google.common.collect.ImmutableMap) BeforeClass(org.testng.annotations.BeforeClass) FULL(io.trino.sql.planner.plan.JoinNode.Type.FULL) RuleTester(io.trino.sql.planner.iterative.rule.test.RuleTester) PlanMatchPattern.values(io.trino.sql.planner.assertions.PlanMatchPattern.values) TRUE_LITERAL(io.trino.sql.tree.BooleanLiteral.TRUE_LITERAL) JoinDistributionType(io.trino.sql.planner.OptimizerConfig.JoinDistributionType) DetermineJoinDistributionType.getSourceTablesSizeInBytes(io.trino.sql.planner.iterative.rule.DetermineJoinDistributionType.getSourceTablesSizeInBytes) CostComparator(io.trino.cost.CostComparator) PlanMatchPattern.enforceSingleRow(io.trino.sql.planner.assertions.PlanMatchPattern.enforceSingleRow) BIGINT(io.trino.spi.type.BigintType.BIGINT) RIGHT(io.trino.sql.planner.plan.JoinNode.Type.RIGHT) JOIN_DISTRIBUTION_TYPE(io.trino.SystemSessionProperties.JOIN_DISTRIBUTION_TYPE) ImmutableListMultimap(com.google.common.collect.ImmutableListMultimap) Optional(java.util.Optional) ValuesNode(io.trino.sql.planner.plan.ValuesNode) TestingColumnHandle(io.trino.testing.TestingMetadata.TestingColumnHandle) PlanBuilder.expression(io.trino.sql.planner.iterative.rule.test.PlanBuilder.expression) PlanNodeIdAllocator(io.trino.sql.planner.PlanNodeIdAllocator) PlanNodeId(io.trino.sql.planner.plan.PlanNodeId) VarcharType.createUnboundedVarcharType(io.trino.spi.type.VarcharType.createUnboundedVarcharType) VarcharType(io.trino.spi.type.VarcharType) PlanNodeStatsEstimate(io.trino.cost.PlanNodeStatsEstimate) Symbol(io.trino.sql.planner.Symbol) JoinNode(io.trino.sql.planner.plan.JoinNode) SymbolStatsEstimate(io.trino.cost.SymbolStatsEstimate) Test(org.testng.annotations.Test)

Aggregations

ImmutableList (com.google.common.collect.ImmutableList)2 ImmutableMap (com.google.common.collect.ImmutableMap)2 BIGINT (io.trino.spi.type.BigintType.BIGINT)2 PlanNodeIdAllocator (io.trino.sql.planner.PlanNodeIdAllocator)2 Symbol (io.trino.sql.planner.Symbol)2 Type (io.trino.sql.planner.plan.JoinNode.Type)2 INNER (io.trino.sql.planner.plan.JoinNode.Type.INNER)2 LEFT (io.trino.sql.planner.plan.JoinNode.Type.LEFT)2 TRUE_LITERAL (io.trino.sql.tree.BooleanLiteral.TRUE_LITERAL)2 Optional (java.util.Optional)2 ImmutableList.toImmutableList (com.google.common.collect.ImmutableList.toImmutableList)1 ImmutableListMultimap (com.google.common.collect.ImmutableListMultimap)1 ImmutableSet (com.google.common.collect.ImmutableSet)1 Session (io.trino.Session)1 JOIN_DISTRIBUTION_TYPE (io.trino.SystemSessionProperties.JOIN_DISTRIBUTION_TYPE)1 JOIN_MAX_BROADCAST_TABLE_SIZE (io.trino.SystemSessionProperties.JOIN_MAX_BROADCAST_TABLE_SIZE)1 CostComparator (io.trino.cost.CostComparator)1 PlanNodeStatsEstimate (io.trino.cost.PlanNodeStatsEstimate)1 SymbolStatsEstimate (io.trino.cost.SymbolStatsEstimate)1 TaskCountEstimator (io.trino.cost.TaskCountEstimator)1