Search in sources :

Example 21 with AggregationNode

use of io.trino.sql.planner.plan.AggregationNode in project trino by trinodb.

the class RemoveEmptyExceptBranches method apply.

@Override
public Result apply(ExceptNode node, Captures captures, Context context) {
    if (isEmpty(node.getSources().get(0), context.getLookup())) {
        return Result.ofPlanNode(new ValuesNode(node.getId(), node.getOutputSymbols(), ImmutableList.of()));
    }
    boolean hasEmptyBranches = node.getSources().stream().skip(// first source is the set we're excluding rows from, so ignore it
    1).anyMatch(source -> isEmpty(source, context.getLookup()));
    if (!hasEmptyBranches) {
        return Result.empty();
    }
    ImmutableList.Builder<PlanNode> newSourcesBuilder = ImmutableList.builder();
    ImmutableListMultimap.Builder<Symbol, Symbol> outputsToInputsBuilder = ImmutableListMultimap.builder();
    for (int i = 0; i < node.getSources().size(); i++) {
        PlanNode source = node.getSources().get(i);
        if (i == 0 || !isEmpty(source, context.getLookup())) {
            newSourcesBuilder.add(source);
            for (Symbol column : node.getOutputSymbols()) {
                outputsToInputsBuilder.put(column, node.getSymbolMapping().get(column).get(i));
            }
        }
    }
    List<PlanNode> newSources = newSourcesBuilder.build();
    ListMultimap<Symbol, Symbol> outputsToInputs = outputsToInputsBuilder.build();
    if (newSources.size() == 1) {
        Assignments.Builder assignments = Assignments.builder();
        outputsToInputs.entries().stream().forEach(entry -> assignments.put(entry.getKey(), entry.getValue().toSymbolReference()));
        if (node.isDistinct()) {
            return Result.ofPlanNode(new AggregationNode(node.getId(), new ProjectNode(context.getIdAllocator().getNextId(), newSources.get(0), assignments.build()), ImmutableMap.of(), singleGroupingSet(node.getOutputSymbols()), ImmutableList.of(), Step.SINGLE, Optional.empty(), Optional.empty()));
        }
        return Result.ofPlanNode(new ProjectNode(node.getId(), newSources.get(0), assignments.build()));
    }
    return Result.ofPlanNode(new ExceptNode(node.getId(), newSources, outputsToInputs, node.getOutputSymbols(), node.isDistinct()));
}
Also used : ValuesNode(io.trino.sql.planner.plan.ValuesNode) ImmutableList(com.google.common.collect.ImmutableList) Symbol(io.trino.sql.planner.Symbol) Assignments(io.trino.sql.planner.plan.Assignments) AggregationNode(io.trino.sql.planner.plan.AggregationNode) PlanNode(io.trino.sql.planner.plan.PlanNode) ExceptNode(io.trino.sql.planner.plan.ExceptNode) ImmutableListMultimap(com.google.common.collect.ImmutableListMultimap) ProjectNode(io.trino.sql.planner.plan.ProjectNode)

Example 22 with AggregationNode

use of io.trino.sql.planner.plan.AggregationNode in project trino by trinodb.

the class PlanUtils method createAggregationFragment.

static PlanFragment createAggregationFragment(String name, PlanFragment sourceFragment) {
    RemoteSourceNode source = new RemoteSourceNode(new PlanNodeId("source_id"), sourceFragment.getId(), ImmutableList.of(), Optional.empty(), REPARTITION, RetryPolicy.NONE);
    PlanNode planNode = new AggregationNode(new PlanNodeId(name + "_id"), source, ImmutableMap.of(), singleGroupingSet(ImmutableList.of()), ImmutableList.of(), FINAL, Optional.empty(), Optional.empty());
    return createFragment(planNode);
}
Also used : PlanNodeId(io.trino.sql.planner.plan.PlanNodeId) RemoteSourceNode(io.trino.sql.planner.plan.RemoteSourceNode) PlanNode(io.trino.sql.planner.plan.PlanNode) AggregationNode(io.trino.sql.planner.plan.AggregationNode)

Example 23 with AggregationNode

use of io.trino.sql.planner.plan.AggregationNode in project trino by trinodb.

the class TestUnion method assertAtMostOneAggregationBetweenRemoteExchanges.

private static void assertAtMostOneAggregationBetweenRemoteExchanges(Plan plan) {
    List<PlanNode> fragments = searchFrom(plan.getRoot()).where(TestUnion::isRemoteExchange).findAll().stream().flatMap(exchangeNode -> exchangeNode.getSources().stream()).collect(toList());
    for (PlanNode fragment : fragments) {
        List<PlanNode> aggregations = searchFrom(fragment).where(AggregationNode.class::isInstance).recurseOnlyWhen(TestUnion::isNotRemoteExchange).findAll();
        assertFalse(aggregations.size() > 1, "More than a single AggregationNode between remote exchanges");
    }
}
Also used : Iterables(com.google.common.collect.Iterables) PlanNodeSearcher.searchFrom(io.trino.sql.planner.optimizations.PlanNodeSearcher.searchFrom) TopNNode(io.trino.sql.planner.plan.TopNNode) OPTIMIZED_AND_VALIDATED(io.trino.sql.planner.LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED) Assert.assertEquals(org.testng.Assert.assertEquals) Test(org.testng.annotations.Test) PlanNode(io.trino.sql.planner.plan.PlanNode) List(java.util.List) Collectors.toList(java.util.stream.Collectors.toList) GATHER(io.trino.sql.planner.plan.ExchangeNode.Type.GATHER) REMOTE(io.trino.sql.planner.plan.ExchangeNode.Scope.REMOTE) Map(java.util.Map) AggregationNode(io.trino.sql.planner.plan.AggregationNode) REPARTITION(io.trino.sql.planner.plan.ExchangeNode.Type.REPARTITION) Assert.assertTrue(org.testng.Assert.assertTrue) ExchangeNode(io.trino.sql.planner.plan.ExchangeNode) Plan(io.trino.sql.planner.Plan) JoinNode(io.trino.sql.planner.plan.JoinNode) BasePlanTest(io.trino.sql.planner.assertions.BasePlanTest) Assert.assertFalse(org.testng.Assert.assertFalse) PlanNode(io.trino.sql.planner.plan.PlanNode) AggregationNode(io.trino.sql.planner.plan.AggregationNode)

Example 24 with AggregationNode

use of io.trino.sql.planner.plan.AggregationNode in project trino by trinodb.

the class DecorrelateInnerUnnestWithGlobalAggregation method withGroupingAndMask.

private static AggregationNode withGroupingAndMask(AggregationNode aggregationNode, List<Symbol> groupingSymbols, Symbol mask, PlanNode source, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator) {
    // Every original aggregation agg() will be rewritten to agg() mask(mask_symbol). If the aggregation
    // already has a mask, it will be replaced with conjunction of the existing mask and mask_symbol.
    ImmutableMap.Builder<Symbol, Symbol> masks = ImmutableMap.builder();
    Assignments.Builder assignmentsBuilder = Assignments.builder();
    for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : aggregationNode.getAggregations().entrySet()) {
        AggregationNode.Aggregation aggregation = entry.getValue();
        if (aggregation.getMask().isPresent()) {
            Symbol newMask = symbolAllocator.newSymbol("mask", BOOLEAN);
            Expression expression = and(aggregation.getMask().get().toSymbolReference(), mask.toSymbolReference());
            assignmentsBuilder.put(newMask, expression);
            masks.put(entry.getKey(), newMask);
        } else {
            masks.put(entry.getKey(), mask);
        }
    }
    Assignments maskAssignments = assignmentsBuilder.build();
    if (!maskAssignments.isEmpty()) {
        source = new ProjectNode(idAllocator.getNextId(), source, Assignments.builder().putIdentities(source.getOutputSymbols()).putAll(maskAssignments).build());
    }
    return new AggregationNode(aggregationNode.getId(), source, rewriteWithMasks(aggregationNode.getAggregations(), masks.buildOrThrow()), singleGroupingSet(groupingSymbols), ImmutableList.of(), SINGLE, Optional.empty(), Optional.empty());
}
Also used : Expression(io.trino.sql.tree.Expression) Symbol(io.trino.sql.planner.Symbol) Assignments(io.trino.sql.planner.plan.Assignments) AggregationNode(io.trino.sql.planner.plan.AggregationNode) ProjectNode(io.trino.sql.planner.plan.ProjectNode) Map(java.util.Map) ImmutableMap(com.google.common.collect.ImmutableMap) ImmutableMap(com.google.common.collect.ImmutableMap)

Example 25 with AggregationNode

use of io.trino.sql.planner.plan.AggregationNode in project trino by trinodb.

the class DecorrelateLeftUnnestWithGlobalAggregation method apply.

@Override
public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Context context) {
    // find global aggregation in subquery
    Optional<AggregationNode> globalAggregation = PlanNodeSearcher.searchFrom(correlatedJoinNode.getSubquery(), context.getLookup()).where(DecorrelateLeftUnnestWithGlobalAggregation::isGlobalAggregation).recurseOnlyWhen(node -> node instanceof ProjectNode || isGroupedAggregation(node)).findFirst();
    if (globalAggregation.isEmpty()) {
        return Result.empty();
    }
    // find unnest in subquery
    Optional<UnnestNode> subqueryUnnest = PlanNodeSearcher.searchFrom(correlatedJoinNode.getSubquery(), context.getLookup()).where(node -> isSupportedUnnest(node, correlatedJoinNode.getCorrelation(), context.getLookup())).recurseOnlyWhen(node -> node instanceof ProjectNode || isGlobalAggregation(node) || isGroupedAggregation(node)).findFirst();
    if (subqueryUnnest.isEmpty()) {
        return Result.empty();
    }
    UnnestNode unnestNode = subqueryUnnest.get();
    // assign unique id to input rows to restore semantics of aggregations after rewrite
    PlanNode input = new AssignUniqueId(context.getIdAllocator().getNextId(), correlatedJoinNode.getInput(), context.getSymbolAllocator().newSymbol("unique", BIGINT));
    // pre-project unnest symbols if they were pre-projected in subquery
    // The correlated UnnestNode either unnests correlation symbols directly, or unnests symbols produced by a projection that uses only correlation symbols.
    // Here, any underlying projection that was a source of the correlated UnnestNode, is appended as a source of the rewritten UnnestNode.
    // If the projection is not necessary for UnnestNode (i.e. it does not produce any unnest symbols), it should be pruned afterwards.
    PlanNode unnestSource = context.getLookup().resolve(unnestNode.getSource());
    if (unnestSource instanceof ProjectNode) {
        ProjectNode sourceProjection = (ProjectNode) unnestSource;
        input = new ProjectNode(sourceProjection.getId(), input, Assignments.builder().putIdentities(input.getOutputSymbols()).putAll(sourceProjection.getAssignments()).build());
    }
    // rewrite correlated join to UnnestNode
    UnnestNode rewrittenUnnest = new UnnestNode(context.getIdAllocator().getNextId(), input, input.getOutputSymbols(), unnestNode.getMappings(), unnestNode.getOrdinalitySymbol(), LEFT, Optional.empty());
    // restore all projections, grouped aggregations and global aggregations from the subquery
    PlanNode result = rewriteNodeSequence(context.getLookup().resolve(correlatedJoinNode.getSubquery()), input.getOutputSymbols(), rewrittenUnnest, unnestNode.getId(), context.getLookup());
    // restrict outputs
    return Result.ofPlanNode(restrictOutputs(context.getIdAllocator(), result, ImmutableSet.copyOf(correlatedJoinNode.getOutputSymbols())).orElse(result));
}
Also used : CorrelatedJoin.correlation(io.trino.sql.planner.plan.Patterns.CorrelatedJoin.correlation) QueryCardinalityUtil.isScalar(io.trino.sql.planner.optimizations.QueryCardinalityUtil.isScalar) SINGLE(io.trino.sql.planner.plan.AggregationNode.Step.SINGLE) CorrelatedJoinNode(io.trino.sql.planner.plan.CorrelatedJoinNode) PlanNode(io.trino.sql.planner.plan.PlanNode) CorrelatedJoin.filter(io.trino.sql.planner.plan.Patterns.CorrelatedJoin.filter) LEFT(io.trino.sql.planner.plan.JoinNode.Type.LEFT) ImmutableList(com.google.common.collect.ImmutableList) PlanNodeSearcher(io.trino.sql.planner.optimizations.PlanNodeSearcher) PlanNodeId(io.trino.sql.planner.plan.PlanNodeId) AggregationNode(io.trino.sql.planner.plan.AggregationNode) Rule(io.trino.sql.planner.iterative.Rule) SymbolsExtractor(io.trino.sql.planner.SymbolsExtractor) Pattern.nonEmpty(io.trino.matching.Pattern.nonEmpty) AssignUniqueId(io.trino.sql.planner.plan.AssignUniqueId) ProjectNode(io.trino.sql.planner.plan.ProjectNode) Symbol(io.trino.sql.planner.Symbol) ImmutableSet(com.google.common.collect.ImmutableSet) Lookup(io.trino.sql.planner.iterative.Lookup) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) Assignments(io.trino.sql.planner.plan.Assignments) Iterables.getOnlyElement(com.google.common.collect.Iterables.getOnlyElement) TRUE_LITERAL(io.trino.sql.tree.BooleanLiteral.TRUE_LITERAL) Streams(com.google.common.collect.Streams) UnnestNode(io.trino.sql.planner.plan.UnnestNode) AggregationNode.singleGroupingSet(io.trino.sql.planner.plan.AggregationNode.singleGroupingSet) List(java.util.List) Pattern(io.trino.matching.Pattern) BIGINT(io.trino.spi.type.BigintType.BIGINT) Captures(io.trino.matching.Captures) Util.restrictOutputs(io.trino.sql.planner.iterative.rule.Util.restrictOutputs) Optional(java.util.Optional) Patterns.correlatedJoin(io.trino.sql.planner.plan.Patterns.correlatedJoin) PlanNode(io.trino.sql.planner.plan.PlanNode) AssignUniqueId(io.trino.sql.planner.plan.AssignUniqueId) UnnestNode(io.trino.sql.planner.plan.UnnestNode) AggregationNode(io.trino.sql.planner.plan.AggregationNode) ProjectNode(io.trino.sql.planner.plan.ProjectNode)

Aggregations

AggregationNode (io.trino.sql.planner.plan.AggregationNode)49 Symbol (io.trino.sql.planner.Symbol)32 PlanNode (io.trino.sql.planner.plan.PlanNode)32 ProjectNode (io.trino.sql.planner.plan.ProjectNode)21 Map (java.util.Map)18 Aggregation (io.trino.sql.planner.plan.AggregationNode.Aggregation)16 ImmutableMap (com.google.common.collect.ImmutableMap)15 Assignments (io.trino.sql.planner.plan.Assignments)15 Expression (io.trino.sql.tree.Expression)15 ImmutableList (com.google.common.collect.ImmutableList)14 JoinNode (io.trino.sql.planner.plan.JoinNode)12 List (java.util.List)11 ImmutableList.toImmutableList (com.google.common.collect.ImmutableList.toImmutableList)10 CorrelatedJoinNode (io.trino.sql.planner.plan.CorrelatedJoinNode)10 AssignUniqueId (io.trino.sql.planner.plan.AssignUniqueId)9 Optional (java.util.Optional)9 Session (io.trino.Session)7 ResolvedFunction (io.trino.metadata.ResolvedFunction)7 Captures (io.trino.matching.Captures)6 Pattern (io.trino.matching.Pattern)6