Search in sources :

Example 26 with AggregationNode

use of io.trino.sql.planner.plan.AggregationNode in project trino by trinodb.

the class AddIntermediateAggregations method recurseToPartial.

/**
 * Recurse through a series of preceding ExchangeNodes and ProjectNodes to find the preceding PARTIAL aggregation
 */
private Optional<PlanNode> recurseToPartial(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator) {
    if (node instanceof AggregationNode && ((AggregationNode) node).getStep() == AggregationNode.Step.PARTIAL) {
        return Optional.of(addGatheringIntermediate((AggregationNode) node, idAllocator));
    }
    if (!(node instanceof ExchangeNode) && !(node instanceof ProjectNode)) {
        return Optional.empty();
    }
    ImmutableList.Builder<PlanNode> builder = ImmutableList.builder();
    for (PlanNode source : node.getSources()) {
        Optional<PlanNode> planNode = recurseToPartial(lookup.resolve(source), lookup, idAllocator);
        if (planNode.isEmpty()) {
            return Optional.empty();
        }
        builder.add(planNode.get());
    }
    return Optional.of(node.replaceChildren(builder.build()));
}
Also used : PlanNode(io.trino.sql.planner.plan.PlanNode) ExchangeNode(io.trino.sql.planner.plan.ExchangeNode) ImmutableList(com.google.common.collect.ImmutableList) AggregationNode(io.trino.sql.planner.plan.AggregationNode) ProjectNode(io.trino.sql.planner.plan.ProjectNode)

Example 27 with AggregationNode

use of io.trino.sql.planner.plan.AggregationNode in project trino by trinodb.

the class AddIntermediateAggregations method addGatheringIntermediate.

private PlanNode addGatheringIntermediate(AggregationNode aggregation, PlanNodeIdAllocator idAllocator) {
    verify(aggregation.getGroupingKeys().isEmpty(), "Should be an un-grouped aggregation");
    ExchangeNode gatheringExchange = ExchangeNode.gatheringExchange(idAllocator.getNextId(), ExchangeNode.Scope.LOCAL, aggregation);
    return new AggregationNode(idAllocator.getNextId(), gatheringExchange, outputsAsInputs(aggregation.getAggregations()), aggregation.getGroupingSets(), aggregation.getPreGroupedSymbols(), AggregationNode.Step.INTERMEDIATE, aggregation.getHashSymbol(), aggregation.getGroupIdSymbol());
}
Also used : ExchangeNode(io.trino.sql.planner.plan.ExchangeNode) AggregationNode(io.trino.sql.planner.plan.AggregationNode)

Example 28 with AggregationNode

use of io.trino.sql.planner.plan.AggregationNode in project trino by trinodb.

the class PushAggregationIntoTableScan method pushAggregationIntoTableScan.

public static Optional<PlanNode> pushAggregationIntoTableScan(PlannerContext plannerContext, Context context, PlanNode aggregationNode, TableScanNode tableScan, Map<Symbol, AggregationNode.Aggregation> aggregations, List<Symbol> groupingKeys) {
    Map<String, ColumnHandle> assignments = tableScan.getAssignments().entrySet().stream().collect(toImmutableMap(entry -> entry.getKey().getName(), Entry::getValue));
    List<Entry<Symbol, AggregationNode.Aggregation>> aggregationsList = aggregations.entrySet().stream().collect(toImmutableList());
    List<AggregateFunction> aggregateFunctions = aggregationsList.stream().map(Entry::getValue).map(aggregation -> toAggregateFunction(plannerContext.getMetadata(), context, aggregation)).collect(toImmutableList());
    List<Symbol> aggregationOutputSymbols = aggregationsList.stream().map(Entry::getKey).collect(toImmutableList());
    List<ColumnHandle> groupByColumns = groupingKeys.stream().map(groupByColumn -> assignments.get(groupByColumn.getName())).collect(toImmutableList());
    Optional<AggregationApplicationResult<TableHandle>> aggregationPushdownResult = plannerContext.getMetadata().applyAggregation(context.getSession(), tableScan.getTable(), aggregateFunctions, assignments, ImmutableList.of(groupByColumns));
    if (aggregationPushdownResult.isEmpty()) {
        return Optional.empty();
    }
    AggregationApplicationResult<TableHandle> result = aggregationPushdownResult.get();
    // The new scan outputs should be the symbols associated with grouping columns plus the symbols associated with aggregations.
    ImmutableList.Builder<Symbol> newScanOutputs = ImmutableList.builder();
    newScanOutputs.addAll(tableScan.getOutputSymbols());
    ImmutableBiMap.Builder<Symbol, ColumnHandle> newScanAssignments = ImmutableBiMap.builder();
    newScanAssignments.putAll(tableScan.getAssignments());
    Map<String, Symbol> variableMappings = new HashMap<>();
    for (Assignment assignment : result.getAssignments()) {
        Symbol symbol = context.getSymbolAllocator().newSymbol(assignment.getVariable(), assignment.getType());
        newScanOutputs.add(symbol);
        newScanAssignments.put(symbol, assignment.getColumn());
        variableMappings.put(assignment.getVariable(), symbol);
    }
    List<Expression> newProjections = result.getProjections().stream().map(expression -> ConnectorExpressionTranslator.translate(context.getSession(), expression, plannerContext, variableMappings, new LiteralEncoder(plannerContext))).collect(toImmutableList());
    verify(aggregationOutputSymbols.size() == newProjections.size());
    Assignments.Builder assignmentBuilder = Assignments.builder();
    IntStream.range(0, aggregationOutputSymbols.size()).forEach(index -> assignmentBuilder.put(aggregationOutputSymbols.get(index), newProjections.get(index)));
    ImmutableBiMap<Symbol, ColumnHandle> scanAssignments = newScanAssignments.build();
    ImmutableBiMap<ColumnHandle, Symbol> columnHandleToSymbol = scanAssignments.inverse();
    // projections assignmentBuilder should have both agg and group by so we add all the group bys as symbol references
    groupingKeys.forEach(groupBySymbol -> {
        // if the connector returned a new mapping from oldColumnHandle to newColumnHandle, groupBy needs to point to
        // new columnHandle's symbol reference, otherwise it will continue pointing at oldColumnHandle.
        ColumnHandle originalColumnHandle = assignments.get(groupBySymbol.getName());
        ColumnHandle groupByColumnHandle = result.getGroupingColumnMapping().getOrDefault(originalColumnHandle, originalColumnHandle);
        assignmentBuilder.put(groupBySymbol, columnHandleToSymbol.get(groupByColumnHandle).toSymbolReference());
    });
    return Optional.of(new ProjectNode(context.getIdAllocator().getNextId(), new TableScanNode(context.getIdAllocator().getNextId(), result.getHandle(), newScanOutputs.build(), scanAssignments, TupleDomain.all(), deriveTableStatisticsForPushdown(context.getStatsProvider(), context.getSession(), result.isPrecalculateStatistics(), aggregationNode), tableScan.isUpdateTarget(), // table scan partitioning might have changed with new table handle
    Optional.empty()), assignmentBuilder.build()));
}
Also used : IntStream(java.util.stream.IntStream) SortItem(io.trino.spi.connector.SortItem) LiteralEncoder(io.trino.sql.planner.LiteralEncoder) Aggregation.step(io.trino.sql.planner.plan.Patterns.Aggregation.step) AggregateFunction(io.trino.spi.connector.AggregateFunction) HashMap(java.util.HashMap) Variable(io.trino.spi.expression.Variable) Capture.newCapture(io.trino.matching.Capture.newCapture) PlanNode(io.trino.sql.planner.plan.PlanNode) ImmutableBiMap(com.google.common.collect.ImmutableBiMap) Patterns.aggregation(io.trino.sql.planner.plan.Patterns.aggregation) ImmutableList(com.google.common.collect.ImmutableList) Verify.verify(com.google.common.base.Verify.verify) Map(java.util.Map) Objects.requireNonNull(java.util.Objects.requireNonNull) ColumnHandle(io.trino.spi.connector.ColumnHandle) AggregationNode(io.trino.sql.planner.plan.AggregationNode) Rule(io.trino.sql.planner.iterative.Rule) ProjectNode(io.trino.sql.planner.plan.ProjectNode) TableScanNode(io.trino.sql.planner.plan.TableScanNode) Symbol(io.trino.sql.planner.Symbol) Rules.deriveTableStatisticsForPushdown(io.trino.sql.planner.iterative.rule.Rules.deriveTableStatisticsForPushdown) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) ConnectorExpressionTranslator(io.trino.sql.planner.ConnectorExpressionTranslator) Assignments(io.trino.sql.planner.plan.Assignments) TupleDomain(io.trino.spi.predicate.TupleDomain) OrderingScheme(io.trino.sql.planner.OrderingScheme) Patterns.tableScan(io.trino.sql.planner.plan.Patterns.tableScan) Capture(io.trino.matching.Capture) AggregationApplicationResult(io.trino.spi.connector.AggregationApplicationResult) List(java.util.List) ImmutableMap.toImmutableMap(com.google.common.collect.ImmutableMap.toImmutableMap) Pattern(io.trino.matching.Pattern) TableHandle(io.trino.metadata.TableHandle) ConnectorExpression(io.trino.spi.expression.ConnectorExpression) SystemSessionProperties.isAllowPushdownIntoConnectors(io.trino.SystemSessionProperties.isAllowPushdownIntoConnectors) Patterns.source(io.trino.sql.planner.plan.Patterns.source) SymbolReference(io.trino.sql.tree.SymbolReference) Captures(io.trino.matching.Captures) BoundSignature(io.trino.metadata.BoundSignature) Assignment(io.trino.spi.connector.Assignment) Entry(java.util.Map.Entry) Metadata(io.trino.metadata.Metadata) Optional(java.util.Optional) Expression(io.trino.sql.tree.Expression) Session(io.trino.Session) PlannerContext(io.trino.sql.PlannerContext) HashMap(java.util.HashMap) Symbol(io.trino.sql.planner.Symbol) ImmutableList(com.google.common.collect.ImmutableList) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) Assignments(io.trino.sql.planner.plan.Assignments) Assignment(io.trino.spi.connector.Assignment) Entry(java.util.Map.Entry) LiteralEncoder(io.trino.sql.planner.LiteralEncoder) ColumnHandle(io.trino.spi.connector.ColumnHandle) AggregationNode(io.trino.sql.planner.plan.AggregationNode) ImmutableBiMap(com.google.common.collect.ImmutableBiMap) TableScanNode(io.trino.sql.planner.plan.TableScanNode) ConnectorExpression(io.trino.spi.expression.ConnectorExpression) Expression(io.trino.sql.tree.Expression) AggregateFunction(io.trino.spi.connector.AggregateFunction) TableHandle(io.trino.metadata.TableHandle) ProjectNode(io.trino.sql.planner.plan.ProjectNode) AggregationApplicationResult(io.trino.spi.connector.AggregationApplicationResult)

Example 29 with AggregationNode

use of io.trino.sql.planner.plan.AggregationNode in project trino by trinodb.

the class PushAggregationThroughOuterJoin method createAggregationOverNull.

private MappedAggregationInfo createAggregationOverNull(AggregationNode referenceAggregation, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator) {
    // Create a values node that consists of a single row of nulls.
    // Map the output symbols from the referenceAggregation's source
    // to symbol references for the new values node.
    ImmutableList.Builder<Symbol> nullSymbols = ImmutableList.builder();
    ImmutableList.Builder<Expression> nullLiterals = ImmutableList.builder();
    ImmutableMap.Builder<Symbol, Symbol> sourcesSymbolMappingBuilder = ImmutableMap.builder();
    for (Symbol sourceSymbol : referenceAggregation.getSource().getOutputSymbols()) {
        Type type = symbolAllocator.getTypes().get(sourceSymbol);
        nullLiterals.add(new Cast(new NullLiteral(), toSqlType(type)));
        Symbol nullSymbol = symbolAllocator.newSymbol("null", type);
        nullSymbols.add(nullSymbol);
        sourcesSymbolMappingBuilder.put(sourceSymbol, nullSymbol);
    }
    ValuesNode nullRow = new ValuesNode(idAllocator.getNextId(), nullSymbols.build(), ImmutableList.of(new Row(nullLiterals.build())));
    // For each aggregation function in the reference node, create a corresponding aggregation function
    // that points to the nullRow. Map the symbols from the aggregations in referenceAggregation to the
    // symbols in these new aggregations.
    ImmutableMap.Builder<Symbol, Symbol> aggregationsSymbolMappingBuilder = ImmutableMap.builder();
    ImmutableMap.Builder<Symbol, AggregationNode.Aggregation> aggregationsOverNullBuilder = ImmutableMap.builder();
    SymbolMapper mapper = symbolMapper(sourcesSymbolMappingBuilder.buildOrThrow());
    for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : referenceAggregation.getAggregations().entrySet()) {
        Symbol aggregationSymbol = entry.getKey();
        Aggregation overNullAggregation = mapper.map(entry.getValue());
        Symbol overNullSymbol = symbolAllocator.newSymbol(overNullAggregation.getResolvedFunction().getSignature().getName(), symbolAllocator.getTypes().get(aggregationSymbol));
        aggregationsOverNullBuilder.put(overNullSymbol, overNullAggregation);
        aggregationsSymbolMappingBuilder.put(aggregationSymbol, overNullSymbol);
    }
    Map<Symbol, Symbol> aggregationsSymbolMapping = aggregationsSymbolMappingBuilder.buildOrThrow();
    // create an aggregation node whose source is the null row.
    AggregationNode aggregationOverNullRow = new AggregationNode(idAllocator.getNextId(), nullRow, aggregationsOverNullBuilder.buildOrThrow(), globalAggregation(), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty());
    return new MappedAggregationInfo(aggregationOverNullRow, aggregationsSymbolMapping);
}
Also used : Cast(io.trino.sql.tree.Cast) ValuesNode(io.trino.sql.planner.plan.ValuesNode) SymbolMapper(io.trino.sql.planner.optimizations.SymbolMapper) ImmutableList(com.google.common.collect.ImmutableList) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) Symbol(io.trino.sql.planner.Symbol) AggregationNode(io.trino.sql.planner.plan.AggregationNode) ImmutableMap(com.google.common.collect.ImmutableMap) Aggregation(io.trino.sql.planner.plan.AggregationNode.Aggregation) AggregationNode.globalAggregation(io.trino.sql.planner.plan.AggregationNode.globalAggregation) Type(io.trino.spi.type.Type) TypeSignatureTranslator.toSqlType(io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType) CoalesceExpression(io.trino.sql.tree.CoalesceExpression) Expression(io.trino.sql.tree.Expression) Row(io.trino.sql.tree.Row) NullLiteral(io.trino.sql.tree.NullLiteral) Map(java.util.Map) ImmutableMap(com.google.common.collect.ImmutableMap)

Example 30 with AggregationNode

use of io.trino.sql.planner.plan.AggregationNode in project trino by trinodb.

the class PushAggregationThroughOuterJoin method apply.

@Override
public Result apply(AggregationNode aggregation, Captures captures, Context context) {
    // This rule doesn't deal with AggregationNode's hash symbol. Hash symbols are not yet present at this stage of optimization.
    checkArgument(aggregation.getHashSymbol().isEmpty(), "unexpected hash symbol");
    JoinNode join = captures.get(JOIN);
    if (join.getFilter().isPresent() || !(join.getType() == JoinNode.Type.LEFT || join.getType() == JoinNode.Type.RIGHT) || !groupsOnAllColumns(aggregation, getOuterTable(join).getOutputSymbols()) || !isDistinct(context.getLookup().resolve(getOuterTable(join)), context.getLookup()::resolve) || !isAggregationOnSymbols(aggregation, getInnerTable(join))) {
        return Result.empty();
    }
    List<Symbol> groupingKeys = join.getCriteria().stream().map(join.getType() == JoinNode.Type.RIGHT ? JoinNode.EquiJoinClause::getLeft : JoinNode.EquiJoinClause::getRight).collect(toImmutableList());
    AggregationNode rewrittenAggregation = new AggregationNode(aggregation.getId(), getInnerTable(join), aggregation.getAggregations(), singleGroupingSet(groupingKeys), ImmutableList.of(), aggregation.getStep(), aggregation.getHashSymbol(), aggregation.getGroupIdSymbol());
    JoinNode rewrittenJoin;
    if (join.getType() == JoinNode.Type.LEFT) {
        rewrittenJoin = new JoinNode(join.getId(), join.getType(), join.getLeft(), rewrittenAggregation, join.getCriteria(), join.getLeft().getOutputSymbols(), ImmutableList.copyOf(rewrittenAggregation.getAggregations().keySet()), // there are no duplicate rows possible since outer rows were guaranteed to be distinct
        false, join.getFilter(), join.getLeftHashSymbol(), join.getRightHashSymbol(), join.getDistributionType(), join.isSpillable(), join.getDynamicFilters(), join.getReorderJoinStatsAndCost());
    } else {
        rewrittenJoin = new JoinNode(join.getId(), join.getType(), rewrittenAggregation, join.getRight(), join.getCriteria(), ImmutableList.copyOf(rewrittenAggregation.getAggregations().keySet()), join.getRight().getOutputSymbols(), // there are no duplicate rows possible since outer rows were guaranteed to be distinct
        false, join.getFilter(), join.getLeftHashSymbol(), join.getRightHashSymbol(), join.getDistributionType(), join.isSpillable(), join.getDynamicFilters(), join.getReorderJoinStatsAndCost());
    }
    Optional<PlanNode> resultNode = coalesceWithNullAggregation(rewrittenAggregation, rewrittenJoin, context.getSymbolAllocator(), context.getIdAllocator());
    if (resultNode.isEmpty()) {
        return Result.empty();
    }
    return Result.ofPlanNode(resultNode.get());
}
Also used : PlanNode(io.trino.sql.planner.plan.PlanNode) JoinNode(io.trino.sql.planner.plan.JoinNode) Symbol(io.trino.sql.planner.Symbol) AggregationNode(io.trino.sql.planner.plan.AggregationNode)

Aggregations

AggregationNode (io.trino.sql.planner.plan.AggregationNode)49 Symbol (io.trino.sql.planner.Symbol)32 PlanNode (io.trino.sql.planner.plan.PlanNode)32 ProjectNode (io.trino.sql.planner.plan.ProjectNode)21 Map (java.util.Map)18 Aggregation (io.trino.sql.planner.plan.AggregationNode.Aggregation)16 ImmutableMap (com.google.common.collect.ImmutableMap)15 Assignments (io.trino.sql.planner.plan.Assignments)15 Expression (io.trino.sql.tree.Expression)15 ImmutableList (com.google.common.collect.ImmutableList)14 JoinNode (io.trino.sql.planner.plan.JoinNode)12 List (java.util.List)11 ImmutableList.toImmutableList (com.google.common.collect.ImmutableList.toImmutableList)10 CorrelatedJoinNode (io.trino.sql.planner.plan.CorrelatedJoinNode)10 AssignUniqueId (io.trino.sql.planner.plan.AssignUniqueId)9 Optional (java.util.Optional)9 Session (io.trino.Session)7 ResolvedFunction (io.trino.metadata.ResolvedFunction)7 Captures (io.trino.matching.Captures)6 Pattern (io.trino.matching.Pattern)6