Search in sources :

Example 1 with AggregationNode

use of io.trino.sql.planner.plan.AggregationNode in project trino by trinodb.

the class ImplementFilteredAggregations method apply.

@Override
public Result apply(AggregationNode aggregationNode, Captures captures, Context context) {
    Assignments.Builder newAssignments = Assignments.builder();
    ImmutableMap.Builder<Symbol, Aggregation> aggregations = ImmutableMap.builder();
    ImmutableList.Builder<Expression> maskSymbols = ImmutableList.builder();
    boolean aggregateWithoutFilterOrMaskPresent = false;
    for (Map.Entry<Symbol, Aggregation> entry : aggregationNode.getAggregations().entrySet()) {
        Symbol output = entry.getKey();
        // strip the filters
        Aggregation aggregation = entry.getValue();
        Optional<Symbol> mask = aggregation.getMask();
        if (aggregation.getFilter().isPresent()) {
            Symbol filter = aggregation.getFilter().get();
            if (mask.isPresent()) {
                Symbol newMask = context.getSymbolAllocator().newSymbol("mask", BOOLEAN);
                Expression expression = and(mask.get().toSymbolReference(), filter.toSymbolReference());
                newAssignments.put(newMask, expression);
                mask = Optional.of(newMask);
                maskSymbols.add(newMask.toSymbolReference());
            } else {
                mask = Optional.of(filter);
                maskSymbols.add(filter.toSymbolReference());
            }
        } else if (mask.isPresent()) {
            maskSymbols.add(mask.get().toSymbolReference());
        } else {
            aggregateWithoutFilterOrMaskPresent = true;
        }
        aggregations.put(output, new Aggregation(aggregation.getResolvedFunction(), aggregation.getArguments(), aggregation.isDistinct(), Optional.empty(), aggregation.getOrderingScheme(), mask));
    }
    Expression predicate = TRUE_LITERAL;
    if (!aggregationNode.hasNonEmptyGroupingSet() && !aggregateWithoutFilterOrMaskPresent) {
        predicate = combineDisjunctsWithDefault(metadata, maskSymbols.build(), TRUE_LITERAL);
    }
    // identity projection for all existing inputs
    newAssignments.putIdentities(aggregationNode.getSource().getOutputSymbols());
    return Result.ofPlanNode(new AggregationNode(context.getIdAllocator().getNextId(), new FilterNode(context.getIdAllocator().getNextId(), new ProjectNode(context.getIdAllocator().getNextId(), aggregationNode.getSource(), newAssignments.build()), predicate), aggregations.buildOrThrow(), aggregationNode.getGroupingSets(), ImmutableList.of(), aggregationNode.getStep(), aggregationNode.getHashSymbol(), aggregationNode.getGroupIdSymbol()));
}
Also used : Symbol(io.trino.sql.planner.Symbol) ImmutableList(com.google.common.collect.ImmutableList) FilterNode(io.trino.sql.planner.plan.FilterNode) Assignments(io.trino.sql.planner.plan.Assignments) AggregationNode(io.trino.sql.planner.plan.AggregationNode) ImmutableMap(com.google.common.collect.ImmutableMap) Aggregation(io.trino.sql.planner.plan.AggregationNode.Aggregation) Expression(io.trino.sql.tree.Expression) ProjectNode(io.trino.sql.planner.plan.ProjectNode) ImmutableMap(com.google.common.collect.ImmutableMap) Map(java.util.Map)

Example 2 with AggregationNode

use of io.trino.sql.planner.plan.AggregationNode in project trino by trinodb.

the class DecorrelateInnerUnnestWithGlobalAggregation method apply.

@Override
public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Context context) {
    // find global aggregation in subquery
    List<PlanNode> globalAggregations = PlanNodeSearcher.searchFrom(correlatedJoinNode.getSubquery(), context.getLookup()).where(DecorrelateInnerUnnestWithGlobalAggregation::isGlobalAggregation).recurseOnlyWhen(node -> node instanceof ProjectNode || isGlobalAggregation(node)).findAll();
    if (globalAggregations.isEmpty()) {
        return Result.empty();
    }
    // if there are multiple global aggregations, the one that is closest to the source is the "reducing" aggregation, because it reduces multiple input rows to single output row
    AggregationNode reducingAggregation = (AggregationNode) globalAggregations.get(globalAggregations.size() - 1);
    // find unnest in subquery
    Optional<UnnestNode> subqueryUnnest = PlanNodeSearcher.searchFrom(reducingAggregation.getSource(), context.getLookup()).where(node -> isSupportedUnnest(node, correlatedJoinNode.getCorrelation(), context.getLookup())).recurseOnlyWhen(node -> node instanceof ProjectNode || isGroupedAggregation(node)).findFirst();
    if (subqueryUnnest.isEmpty()) {
        return Result.empty();
    }
    UnnestNode unnestNode = subqueryUnnest.get();
    // assign unique id to input rows to restore semantics of aggregations after rewrite
    PlanNode input = new AssignUniqueId(context.getIdAllocator().getNextId(), correlatedJoinNode.getInput(), context.getSymbolAllocator().newSymbol("unique", BIGINT));
    // pre-project unnest symbols if they were pre-projected in subquery
    // The correlated UnnestNode either unnests correlation symbols directly, or unnests symbols produced by a projection that uses only correlation symbols.
    // Here, any underlying projection that was a source of the correlated UnnestNode, is appended as a source of the rewritten UnnestNode.
    // If the projection is not necessary for UnnestNode (i.e. it does not produce any unnest symbols), it should be pruned afterwards.
    PlanNode unnestSource = context.getLookup().resolve(unnestNode.getSource());
    if (unnestSource instanceof ProjectNode) {
        ProjectNode sourceProjection = (ProjectNode) unnestSource;
        input = new ProjectNode(sourceProjection.getId(), input, Assignments.builder().putIdentities(input.getOutputSymbols()).putAll(sourceProjection.getAssignments()).build());
    }
    // rewrite correlated join to UnnestNode
    Symbol ordinalitySymbol = unnestNode.getOrdinalitySymbol().orElseGet(() -> context.getSymbolAllocator().newSymbol("ordinality", BIGINT));
    UnnestNode rewrittenUnnest = new UnnestNode(context.getIdAllocator().getNextId(), input, input.getOutputSymbols(), unnestNode.getMappings(), Optional.of(ordinalitySymbol), LEFT, Optional.empty());
    // append mask symbol based on ordinality to distinguish between the unnested rows and synthetic null rows
    Symbol mask = context.getSymbolAllocator().newSymbol("mask", BOOLEAN);
    ProjectNode sourceWithMask = new ProjectNode(context.getIdAllocator().getNextId(), rewrittenUnnest, Assignments.builder().putIdentities(rewrittenUnnest.getOutputSymbols()).put(mask, new IsNotNullPredicate(ordinalitySymbol.toSymbolReference())).build());
    // restore all projections, grouped aggregations and global aggregations from the subquery
    PlanNode result = rewriteNodeSequence(context.getLookup().resolve(correlatedJoinNode.getSubquery()), input.getOutputSymbols(), mask, sourceWithMask, reducingAggregation.getId(), unnestNode.getId(), context.getSymbolAllocator(), context.getIdAllocator(), context.getLookup());
    // restrict outputs
    return Result.ofPlanNode(restrictOutputs(context.getIdAllocator(), result, ImmutableSet.copyOf(correlatedJoinNode.getOutputSymbols())).orElse(result));
}
Also used : CorrelatedJoin.correlation(io.trino.sql.planner.plan.Patterns.CorrelatedJoin.correlation) QueryCardinalityUtil.isScalar(io.trino.sql.planner.optimizations.QueryCardinalityUtil.isScalar) INNER(io.trino.sql.planner.plan.JoinNode.Type.INNER) SymbolAllocator(io.trino.sql.planner.SymbolAllocator) BOOLEAN(io.trino.spi.type.BooleanType.BOOLEAN) SINGLE(io.trino.sql.planner.plan.AggregationNode.Step.SINGLE) CorrelatedJoinNode(io.trino.sql.planner.plan.CorrelatedJoinNode) PlanNode(io.trino.sql.planner.plan.PlanNode) AggregationDecorrelation.rewriteWithMasks(io.trino.sql.planner.iterative.rule.AggregationDecorrelation.rewriteWithMasks) CorrelatedJoin.filter(io.trino.sql.planner.plan.Patterns.CorrelatedJoin.filter) LEFT(io.trino.sql.planner.plan.JoinNode.Type.LEFT) ImmutableList(com.google.common.collect.ImmutableList) PlanNodeSearcher(io.trino.sql.planner.optimizations.PlanNodeSearcher) PlanNodeId(io.trino.sql.planner.plan.PlanNodeId) IsNotNullPredicate(io.trino.sql.tree.IsNotNullPredicate) Map(java.util.Map) AggregationNode(io.trino.sql.planner.plan.AggregationNode) Rule(io.trino.sql.planner.iterative.Rule) SymbolsExtractor(io.trino.sql.planner.SymbolsExtractor) Pattern.nonEmpty(io.trino.matching.Pattern.nonEmpty) AssignUniqueId(io.trino.sql.planner.plan.AssignUniqueId) ProjectNode(io.trino.sql.planner.plan.ProjectNode) Symbol(io.trino.sql.planner.Symbol) ImmutableSet(com.google.common.collect.ImmutableSet) ImmutableMap(com.google.common.collect.ImmutableMap) Lookup(io.trino.sql.planner.iterative.Lookup) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) Assignments(io.trino.sql.planner.plan.Assignments) Set(java.util.Set) Iterables.getOnlyElement(com.google.common.collect.Iterables.getOnlyElement) TRUE_LITERAL(io.trino.sql.tree.BooleanLiteral.TRUE_LITERAL) Streams(com.google.common.collect.Streams) UnnestNode(io.trino.sql.planner.plan.UnnestNode) AggregationNode.singleGroupingSet(io.trino.sql.planner.plan.AggregationNode.singleGroupingSet) List(java.util.List) Pattern(io.trino.matching.Pattern) BIGINT(io.trino.spi.type.BigintType.BIGINT) ExpressionUtils.and(io.trino.sql.ExpressionUtils.and) Captures(io.trino.matching.Captures) Util.restrictOutputs(io.trino.sql.planner.iterative.rule.Util.restrictOutputs) Mapping(io.trino.sql.planner.plan.UnnestNode.Mapping) Optional(java.util.Optional) Expression(io.trino.sql.tree.Expression) PlanNodeIdAllocator(io.trino.sql.planner.PlanNodeIdAllocator) Patterns.correlatedJoin(io.trino.sql.planner.plan.Patterns.correlatedJoin) PlanNode(io.trino.sql.planner.plan.PlanNode) AssignUniqueId(io.trino.sql.planner.plan.AssignUniqueId) UnnestNode(io.trino.sql.planner.plan.UnnestNode) Symbol(io.trino.sql.planner.Symbol) ProjectNode(io.trino.sql.planner.plan.ProjectNode) AggregationNode(io.trino.sql.planner.plan.AggregationNode) IsNotNullPredicate(io.trino.sql.tree.IsNotNullPredicate)

Example 3 with AggregationNode

use of io.trino.sql.planner.plan.AggregationNode in project trino by trinodb.

the class AddIntermediateAggregations method apply.

@Override
public Result apply(AggregationNode aggregation, Captures captures, Context context) {
    Lookup lookup = context.getLookup();
    PlanNodeIdAllocator idAllocator = context.getIdAllocator();
    Session session = context.getSession();
    Optional<PlanNode> rewrittenSource = recurseToPartial(lookup.resolve(aggregation.getSource()), lookup, idAllocator);
    if (rewrittenSource.isEmpty()) {
        return Result.empty();
    }
    PlanNode source = rewrittenSource.get();
    if (getTaskConcurrency(session) > 1) {
        source = ExchangeNode.partitionedExchange(idAllocator.getNextId(), ExchangeNode.Scope.LOCAL, source, new PartitioningScheme(Partitioning.create(FIXED_ARBITRARY_DISTRIBUTION, ImmutableList.of()), source.getOutputSymbols()));
        source = new AggregationNode(idAllocator.getNextId(), source, inputsAsOutputs(aggregation.getAggregations()), aggregation.getGroupingSets(), aggregation.getPreGroupedSymbols(), AggregationNode.Step.INTERMEDIATE, aggregation.getHashSymbol(), aggregation.getGroupIdSymbol());
        source = ExchangeNode.gatheringExchange(idAllocator.getNextId(), ExchangeNode.Scope.LOCAL, source);
    }
    return Result.ofPlanNode(aggregation.replaceChildren(ImmutableList.of(source)));
}
Also used : PlanNode(io.trino.sql.planner.plan.PlanNode) PlanNodeIdAllocator(io.trino.sql.planner.PlanNodeIdAllocator) PartitioningScheme(io.trino.sql.planner.PartitioningScheme) Lookup(io.trino.sql.planner.iterative.Lookup) AggregationNode(io.trino.sql.planner.plan.AggregationNode) Session(io.trino.Session)

Example 4 with AggregationNode

use of io.trino.sql.planner.plan.AggregationNode in project trino by trinodb.

the class MultipleDistinctAggregationToMarkDistinct method apply.

@Override
public Result apply(AggregationNode parent, Captures captures, Context context) {
    if (!SystemSessionProperties.useMarkDistinct(context.getSession())) {
        return Result.empty();
    }
    // the distinct marker for the given set of input columns
    Map<Set<Symbol>, Symbol> markers = new HashMap<>();
    Map<Symbol, Aggregation> newAggregations = new HashMap<>();
    PlanNode subPlan = parent.getSource();
    for (Map.Entry<Symbol, Aggregation> entry : parent.getAggregations().entrySet()) {
        Aggregation aggregation = entry.getValue();
        if (aggregation.isDistinct() && aggregation.getFilter().isEmpty() && aggregation.getMask().isEmpty()) {
            Set<Symbol> inputs = aggregation.getArguments().stream().map(Symbol::from).collect(toSet());
            Symbol marker = markers.get(inputs);
            if (marker == null) {
                marker = context.getSymbolAllocator().newSymbol(Iterables.getLast(inputs).getName(), BOOLEAN, "distinct");
                markers.put(inputs, marker);
                ImmutableSet.Builder<Symbol> distinctSymbols = ImmutableSet.<Symbol>builder().addAll(parent.getGroupingKeys()).addAll(inputs);
                parent.getGroupIdSymbol().ifPresent(distinctSymbols::add);
                subPlan = new MarkDistinctNode(context.getIdAllocator().getNextId(), subPlan, marker, ImmutableList.copyOf(distinctSymbols.build()), Optional.empty());
            }
            // remove the distinct flag and set the distinct marker
            newAggregations.put(entry.getKey(), new Aggregation(aggregation.getResolvedFunction(), aggregation.getArguments(), false, aggregation.getFilter(), aggregation.getOrderingScheme(), Optional.of(marker)));
        } else {
            newAggregations.put(entry.getKey(), aggregation);
        }
    }
    return Result.ofPlanNode(new AggregationNode(parent.getId(), subPlan, newAggregations, parent.getGroupingSets(), ImmutableList.of(), parent.getStep(), parent.getHashSymbol(), parent.getGroupIdSymbol()));
}
Also used : ImmutableSet(com.google.common.collect.ImmutableSet) Set(java.util.Set) HashSet(java.util.HashSet) Collectors.toSet(java.util.stream.Collectors.toSet) MarkDistinctNode(io.trino.sql.planner.plan.MarkDistinctNode) HashMap(java.util.HashMap) Symbol(io.trino.sql.planner.Symbol) AggregationNode(io.trino.sql.planner.plan.AggregationNode) Aggregation(io.trino.sql.planner.plan.AggregationNode.Aggregation) PlanNode(io.trino.sql.planner.plan.PlanNode) ImmutableSet(com.google.common.collect.ImmutableSet) HashMap(java.util.HashMap) Map(java.util.Map)

Example 5 with AggregationNode

use of io.trino.sql.planner.plan.AggregationNode in project trino by trinodb.

the class PushPartialAggregationThroughExchange method split.

private PlanNode split(AggregationNode node, Context context) {
    // otherwise, add a partial and final with an exchange in between
    Map<Symbol, AggregationNode.Aggregation> intermediateAggregation = new HashMap<>();
    Map<Symbol, AggregationNode.Aggregation> finalAggregation = new HashMap<>();
    for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : node.getAggregations().entrySet()) {
        AggregationNode.Aggregation originalAggregation = entry.getValue();
        ResolvedFunction resolvedFunction = originalAggregation.getResolvedFunction();
        AggregationFunctionMetadata functionMetadata = plannerContext.getMetadata().getAggregationFunctionMetadata(resolvedFunction);
        List<Type> intermediateTypes = functionMetadata.getIntermediateTypes().stream().map(plannerContext.getTypeManager()::getType).collect(toImmutableList());
        Type intermediateType = intermediateTypes.size() == 1 ? intermediateTypes.get(0) : RowType.anonymous(intermediateTypes);
        Symbol intermediateSymbol = context.getSymbolAllocator().newSymbol(resolvedFunction.getSignature().getName(), intermediateType);
        checkState(originalAggregation.getOrderingScheme().isEmpty(), "Aggregate with ORDER BY does not support partial aggregation");
        intermediateAggregation.put(intermediateSymbol, new AggregationNode.Aggregation(resolvedFunction, originalAggregation.getArguments(), originalAggregation.isDistinct(), originalAggregation.getFilter(), originalAggregation.getOrderingScheme(), originalAggregation.getMask()));
        // rewrite final aggregation in terms of intermediate function
        finalAggregation.put(entry.getKey(), new AggregationNode.Aggregation(resolvedFunction, ImmutableList.<Expression>builder().add(intermediateSymbol.toSymbolReference()).addAll(originalAggregation.getArguments().stream().filter(LambdaExpression.class::isInstance).collect(toImmutableList())).build(), false, Optional.empty(), Optional.empty(), Optional.empty()));
    }
    PlanNode partial = new AggregationNode(context.getIdAllocator().getNextId(), node.getSource(), intermediateAggregation, node.getGroupingSets(), // through the exchange may or may not preserve these properties. Hence, it is safest to drop preGroupedSymbols here.
    ImmutableList.of(), PARTIAL, node.getHashSymbol(), node.getGroupIdSymbol());
    return new AggregationNode(node.getId(), partial, finalAggregation, node.getGroupingSets(), // through the exchange may or may not preserve these properties. Hence, it is safest to drop preGroupedSymbols here.
    ImmutableList.of(), FINAL, node.getHashSymbol(), node.getGroupIdSymbol());
}
Also used : HashMap(java.util.HashMap) Symbol(io.trino.sql.planner.Symbol) ResolvedFunction(io.trino.metadata.ResolvedFunction) AggregationNode(io.trino.sql.planner.plan.AggregationNode) AggregationFunctionMetadata(io.trino.metadata.AggregationFunctionMetadata) SystemSessionProperties.preferPartialAggregation(io.trino.SystemSessionProperties.preferPartialAggregation) Type(io.trino.spi.type.Type) RowType(io.trino.spi.type.RowType) PlanNode(io.trino.sql.planner.plan.PlanNode) HashMap(java.util.HashMap) Map(java.util.Map)

Aggregations

AggregationNode (io.trino.sql.planner.plan.AggregationNode)49 Symbol (io.trino.sql.planner.Symbol)32 PlanNode (io.trino.sql.planner.plan.PlanNode)32 ProjectNode (io.trino.sql.planner.plan.ProjectNode)21 Map (java.util.Map)18 Aggregation (io.trino.sql.planner.plan.AggregationNode.Aggregation)16 ImmutableMap (com.google.common.collect.ImmutableMap)15 Assignments (io.trino.sql.planner.plan.Assignments)15 Expression (io.trino.sql.tree.Expression)15 ImmutableList (com.google.common.collect.ImmutableList)14 JoinNode (io.trino.sql.planner.plan.JoinNode)12 List (java.util.List)11 ImmutableList.toImmutableList (com.google.common.collect.ImmutableList.toImmutableList)10 CorrelatedJoinNode (io.trino.sql.planner.plan.CorrelatedJoinNode)10 AssignUniqueId (io.trino.sql.planner.plan.AssignUniqueId)9 Optional (java.util.Optional)9 Session (io.trino.Session)7 ResolvedFunction (io.trino.metadata.ResolvedFunction)7 Captures (io.trino.matching.Captures)6 Pattern (io.trino.matching.Pattern)6