Search in sources :

Example 6 with AggregationNode

use of io.trino.sql.planner.plan.AggregationNode in project trino by trinodb.

the class PruneCountAggregationOverScalar method apply.

@Override
public Result apply(AggregationNode parent, Captures captures, Context context) {
    if (!parent.hasDefaultOutput() || parent.getOutputSymbols().size() != 1) {
        return Result.empty();
    }
    FunctionId countFunctionId = metadata.resolveFunction(context.getSession(), QualifiedName.of("count"), ImmutableList.of()).getFunctionId();
    Map<Symbol, AggregationNode.Aggregation> assignments = parent.getAggregations();
    for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : assignments.entrySet()) {
        AggregationNode.Aggregation aggregation = entry.getValue();
        requireNonNull(aggregation, "aggregation is null");
        ResolvedFunction resolvedFunction = aggregation.getResolvedFunction();
        if (!countFunctionId.equals(resolvedFunction.getFunctionId())) {
            return Result.empty();
        }
    }
    if (!assignments.isEmpty() && isScalar(parent.getSource(), context.getLookup())) {
        return Result.ofPlanNode(new ValuesNode(parent.getId(), parent.getOutputSymbols(), ImmutableList.of(new Row(ImmutableList.of(new GenericLiteral("BIGINT", "1"))))));
    }
    return Result.empty();
}
Also used : FunctionId(io.trino.metadata.FunctionId) ValuesNode(io.trino.sql.planner.plan.ValuesNode) Symbol(io.trino.sql.planner.Symbol) ResolvedFunction(io.trino.metadata.ResolvedFunction) AggregationNode(io.trino.sql.planner.plan.AggregationNode) Row(io.trino.sql.tree.Row) Map(java.util.Map) GenericLiteral(io.trino.sql.tree.GenericLiteral)

Example 7 with AggregationNode

use of io.trino.sql.planner.plan.AggregationNode in project trino by trinodb.

the class SingleDistinctAggregationToGroupBy method apply.

@Override
public Result apply(AggregationNode aggregation, Captures captures, Context context) {
    List<Set<Expression>> argumentSets = extractArgumentSets(aggregation).collect(Collectors.toList());
    Set<Symbol> symbols = Iterables.getOnlyElement(argumentSets).stream().map(Symbol::from).collect(Collectors.toSet());
    return Result.ofPlanNode(new AggregationNode(aggregation.getId(), new AggregationNode(context.getIdAllocator().getNextId(), aggregation.getSource(), ImmutableMap.of(), singleGroupingSet(ImmutableList.<Symbol>builder().addAll(aggregation.getGroupingKeys()).addAll(symbols).build()), ImmutableList.of(), SINGLE, Optional.empty(), Optional.empty()), // remove DISTINCT flag from function calls
    aggregation.getAggregations().entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> removeDistinct(e.getValue()))), aggregation.getGroupingSets(), emptyList(), aggregation.getStep(), aggregation.getHashSymbol(), aggregation.getGroupIdSymbol()));
}
Also used : Symbol(io.trino.sql.planner.Symbol) Iterables(com.google.common.collect.Iterables) ImmutableMap(com.google.common.collect.ImmutableMap) Aggregation(io.trino.sql.planner.plan.AggregationNode.Aggregation) Collections.emptyList(java.util.Collections.emptyList) Set(java.util.Set) SINGLE(io.trino.sql.planner.plan.AggregationNode.Step.SINGLE) Collectors(java.util.stream.Collectors) HashSet(java.util.HashSet) AggregationNode.singleGroupingSet(io.trino.sql.planner.plan.AggregationNode.singleGroupingSet) List(java.util.List) Preconditions.checkArgument(com.google.common.base.Preconditions.checkArgument) Pattern(io.trino.matching.Pattern) Stream(java.util.stream.Stream) Patterns.aggregation(io.trino.sql.planner.plan.Patterns.aggregation) ImmutableList(com.google.common.collect.ImmutableList) Captures(io.trino.matching.Captures) Map(java.util.Map) AggregationNode(io.trino.sql.planner.plan.AggregationNode) Optional(java.util.Optional) Rule(io.trino.sql.planner.iterative.Rule) Expression(io.trino.sql.tree.Expression) Set(java.util.Set) HashSet(java.util.HashSet) AggregationNode.singleGroupingSet(io.trino.sql.planner.plan.AggregationNode.singleGroupingSet) Symbol(io.trino.sql.planner.Symbol) AggregationNode(io.trino.sql.planner.plan.AggregationNode) ImmutableMap(com.google.common.collect.ImmutableMap) Map(java.util.Map)

Example 8 with AggregationNode

use of io.trino.sql.planner.plan.AggregationNode in project trino by trinodb.

the class TransformCorrelatedDistinctAggregationWithProjection method apply.

@Override
public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Context context) {
    // decorrelate nested plan
    PlanNodeDecorrelator decorrelator = new PlanNodeDecorrelator(plannerContext, context.getSymbolAllocator(), context.getLookup());
    Optional<PlanNodeDecorrelator.DecorrelatedNode> decorrelatedSource = decorrelator.decorrelateFilters(captures.get(AGGREGATION).getSource(), correlatedJoinNode.getCorrelation());
    if (decorrelatedSource.isEmpty()) {
        return Result.empty();
    }
    PlanNode source = decorrelatedSource.get().getNode();
    // assign unique id on correlated join's input. It will be used to distinguish between original input rows after join
    PlanNode inputWithUniqueId = new AssignUniqueId(context.getIdAllocator().getNextId(), correlatedJoinNode.getInput(), context.getSymbolAllocator().newSymbol("unique", BIGINT));
    JoinNode join = new JoinNode(context.getIdAllocator().getNextId(), JoinNode.Type.LEFT, inputWithUniqueId, source, ImmutableList.of(), inputWithUniqueId.getOutputSymbols(), source.getOutputSymbols(), false, decorrelatedSource.get().getCorrelatedPredicates(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), ImmutableMap.of(), Optional.empty());
    // restore aggregation
    AggregationNode aggregation = captures.get(AGGREGATION);
    aggregation = new AggregationNode(aggregation.getId(), join, aggregation.getAggregations(), singleGroupingSet(ImmutableList.<Symbol>builder().addAll(join.getLeftOutputSymbols()).addAll(aggregation.getGroupingKeys()).build()), ImmutableList.of(), aggregation.getStep(), Optional.empty(), Optional.empty());
    // restrict outputs and apply projection
    Set<Symbol> outputSymbols = new HashSet<>(correlatedJoinNode.getOutputSymbols());
    List<Symbol> expectedAggregationOutputs = aggregation.getOutputSymbols().stream().filter(outputSymbols::contains).collect(toImmutableList());
    Assignments assignments = Assignments.builder().putIdentities(expectedAggregationOutputs).putAll(captures.get(PROJECTION).getAssignments()).build();
    return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), aggregation, assignments));
}
Also used : CorrelatedJoinNode(io.trino.sql.planner.plan.CorrelatedJoinNode) JoinNode(io.trino.sql.planner.plan.JoinNode) Symbol(io.trino.sql.planner.Symbol) Assignments(io.trino.sql.planner.plan.Assignments) AggregationNode(io.trino.sql.planner.plan.AggregationNode) PlanNodeDecorrelator(io.trino.sql.planner.optimizations.PlanNodeDecorrelator) PlanNode(io.trino.sql.planner.plan.PlanNode) AssignUniqueId(io.trino.sql.planner.plan.AssignUniqueId) ProjectNode(io.trino.sql.planner.plan.ProjectNode) HashSet(java.util.HashSet)

Example 9 with AggregationNode

use of io.trino.sql.planner.plan.AggregationNode in project trino by trinodb.

the class TransformCorrelatedDistinctAggregationWithoutProjection method apply.

@Override
public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Context context) {
    // decorrelate nested plan
    PlanNodeDecorrelator decorrelator = new PlanNodeDecorrelator(plannerContext, context.getSymbolAllocator(), context.getLookup());
    Optional<PlanNodeDecorrelator.DecorrelatedNode> decorrelatedSource = decorrelator.decorrelateFilters(captures.get(AGGREGATION).getSource(), correlatedJoinNode.getCorrelation());
    if (decorrelatedSource.isEmpty()) {
        return Result.empty();
    }
    PlanNode source = decorrelatedSource.get().getNode();
    // assign unique id on correlated join's input. It will be used to distinguish between original input rows after join
    PlanNode inputWithUniqueId = new AssignUniqueId(context.getIdAllocator().getNextId(), correlatedJoinNode.getInput(), context.getSymbolAllocator().newSymbol("unique", BIGINT));
    JoinNode join = new JoinNode(context.getIdAllocator().getNextId(), JoinNode.Type.LEFT, inputWithUniqueId, source, ImmutableList.of(), inputWithUniqueId.getOutputSymbols(), source.getOutputSymbols(), false, decorrelatedSource.get().getCorrelatedPredicates(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), ImmutableMap.of(), Optional.empty());
    // restore aggregation
    AggregationNode aggregation = captures.get(AGGREGATION);
    aggregation = new AggregationNode(aggregation.getId(), join, aggregation.getAggregations(), singleGroupingSet(ImmutableList.<Symbol>builder().addAll(join.getLeftOutputSymbols()).addAll(aggregation.getGroupingKeys()).build()), ImmutableList.of(), aggregation.getStep(), Optional.empty(), Optional.empty());
    // restrict outputs
    Optional<PlanNode> project = restrictOutputs(context.getIdAllocator(), aggregation, ImmutableSet.copyOf(correlatedJoinNode.getOutputSymbols()));
    return Result.ofPlanNode(project.orElse(aggregation));
}
Also used : PlanNodeDecorrelator(io.trino.sql.planner.optimizations.PlanNodeDecorrelator) PlanNode(io.trino.sql.planner.plan.PlanNode) AssignUniqueId(io.trino.sql.planner.plan.AssignUniqueId) CorrelatedJoinNode(io.trino.sql.planner.plan.CorrelatedJoinNode) JoinNode(io.trino.sql.planner.plan.JoinNode) Symbol(io.trino.sql.planner.Symbol) AggregationNode(io.trino.sql.planner.plan.AggregationNode)

Example 10 with AggregationNode

use of io.trino.sql.planner.plan.AggregationNode in project trino by trinodb.

the class TransformCorrelatedGlobalAggregationWithProjection method apply.

@Override
public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Context context) {
    checkArgument(correlatedJoinNode.getType() == INNER || correlatedJoinNode.getType() == LEFT, "unexpected correlated join type: " + correlatedJoinNode.getType());
    // if there is another aggregation below the AggregationNode, handle both
    PlanNode source = captures.get(SOURCE);
    AggregationNode distinct = null;
    if (isDistinctOperator(source)) {
        distinct = (AggregationNode) source;
        source = distinct.getSource();
    }
    // decorrelate nested plan
    PlanNodeDecorrelator decorrelator = new PlanNodeDecorrelator(plannerContext, context.getSymbolAllocator(), context.getLookup());
    Optional<PlanNodeDecorrelator.DecorrelatedNode> decorrelatedSource = decorrelator.decorrelateFilters(source, correlatedJoinNode.getCorrelation());
    if (decorrelatedSource.isEmpty()) {
        return Result.empty();
    }
    source = decorrelatedSource.get().getNode();
    // append non-null symbol on nested plan. It will be used to restore semantics of null-sensitive aggregations after LEFT join
    Symbol nonNull = context.getSymbolAllocator().newSymbol("non_null", BOOLEAN);
    source = new ProjectNode(context.getIdAllocator().getNextId(), source, Assignments.builder().putIdentities(source.getOutputSymbols()).put(nonNull, TRUE_LITERAL).build());
    // assign unique id on correlated join's input. It will be used to distinguish between original input rows after join
    PlanNode inputWithUniqueId = new AssignUniqueId(context.getIdAllocator().getNextId(), correlatedJoinNode.getInput(), context.getSymbolAllocator().newSymbol("unique", BIGINT));
    JoinNode join = new JoinNode(context.getIdAllocator().getNextId(), JoinNode.Type.LEFT, inputWithUniqueId, source, ImmutableList.of(), inputWithUniqueId.getOutputSymbols(), source.getOutputSymbols(), false, decorrelatedSource.get().getCorrelatedPredicates(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), ImmutableMap.of(), Optional.empty());
    PlanNode root = join;
    // restore distinct aggregation
    if (distinct != null) {
        root = new AggregationNode(distinct.getId(), join, distinct.getAggregations(), singleGroupingSet(ImmutableList.<Symbol>builder().addAll(join.getLeftOutputSymbols()).add(nonNull).addAll(distinct.getGroupingKeys()).build()), ImmutableList.of(), distinct.getStep(), Optional.empty(), Optional.empty());
    }
    // prepare mask symbols for aggregations
    // Every original aggregation agg() will be rewritten to agg() mask(non_null). If the aggregation
    // already has a mask, it will be replaced with conjunction of the existing mask and non_null.
    // This is necessary to restore the original aggregation result in case when:
    // - the nested lateral subquery returned empty result for some input row,
    // - aggregation is null-sensitive, which means that its result over a single null row is different
    // than result for empty input (with global grouping)
    // It applies to the following aggregate functions: count(*), checksum(), array_agg().
    AggregationNode globalAggregation = captures.get(AGGREGATION);
    ImmutableMap.Builder<Symbol, Symbol> masks = ImmutableMap.builder();
    Assignments.Builder assignmentsBuilder = Assignments.builder();
    for (Map.Entry<Symbol, Aggregation> entry : globalAggregation.getAggregations().entrySet()) {
        Aggregation aggregation = entry.getValue();
        if (aggregation.getMask().isPresent()) {
            Symbol newMask = context.getSymbolAllocator().newSymbol("mask", BOOLEAN);
            Expression expression = and(aggregation.getMask().get().toSymbolReference(), nonNull.toSymbolReference());
            assignmentsBuilder.put(newMask, expression);
            masks.put(entry.getKey(), newMask);
        } else {
            masks.put(entry.getKey(), nonNull);
        }
    }
    Assignments maskAssignments = assignmentsBuilder.build();
    if (!maskAssignments.isEmpty()) {
        root = new ProjectNode(context.getIdAllocator().getNextId(), root, Assignments.builder().putIdentities(root.getOutputSymbols()).putAll(maskAssignments).build());
    }
    // restore global aggregation
    globalAggregation = new AggregationNode(globalAggregation.getId(), root, rewriteWithMasks(globalAggregation.getAggregations(), masks.buildOrThrow()), singleGroupingSet(ImmutableList.<Symbol>builder().addAll(join.getLeftOutputSymbols()).addAll(globalAggregation.getGroupingKeys()).build()), ImmutableList.of(), globalAggregation.getStep(), Optional.empty(), Optional.empty());
    // restrict outputs and apply projection
    Set<Symbol> outputSymbols = new HashSet<>(correlatedJoinNode.getOutputSymbols());
    List<Symbol> expectedAggregationOutputs = globalAggregation.getOutputSymbols().stream().filter(outputSymbols::contains).collect(toImmutableList());
    Assignments assignments = Assignments.builder().putIdentities(expectedAggregationOutputs).putAll(captures.get(PROJECTION).getAssignments()).build();
    return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), globalAggregation, assignments));
}
Also used : Symbol(io.trino.sql.planner.Symbol) CorrelatedJoinNode(io.trino.sql.planner.plan.CorrelatedJoinNode) JoinNode(io.trino.sql.planner.plan.JoinNode) Assignments(io.trino.sql.planner.plan.Assignments) AggregationNode(io.trino.sql.planner.plan.AggregationNode) ImmutableMap(com.google.common.collect.ImmutableMap) Aggregation(io.trino.sql.planner.plan.AggregationNode.Aggregation) PlanNodeDecorrelator(io.trino.sql.planner.optimizations.PlanNodeDecorrelator) PlanNode(io.trino.sql.planner.plan.PlanNode) AssignUniqueId(io.trino.sql.planner.plan.AssignUniqueId) Expression(io.trino.sql.tree.Expression) ProjectNode(io.trino.sql.planner.plan.ProjectNode) Map(java.util.Map) ImmutableMap(com.google.common.collect.ImmutableMap) HashSet(java.util.HashSet)

Aggregations

AggregationNode (io.trino.sql.planner.plan.AggregationNode)49 Symbol (io.trino.sql.planner.Symbol)32 PlanNode (io.trino.sql.planner.plan.PlanNode)32 ProjectNode (io.trino.sql.planner.plan.ProjectNode)21 Map (java.util.Map)18 Aggregation (io.trino.sql.planner.plan.AggregationNode.Aggregation)16 ImmutableMap (com.google.common.collect.ImmutableMap)15 Assignments (io.trino.sql.planner.plan.Assignments)15 Expression (io.trino.sql.tree.Expression)15 ImmutableList (com.google.common.collect.ImmutableList)14 JoinNode (io.trino.sql.planner.plan.JoinNode)12 List (java.util.List)11 ImmutableList.toImmutableList (com.google.common.collect.ImmutableList.toImmutableList)10 CorrelatedJoinNode (io.trino.sql.planner.plan.CorrelatedJoinNode)10 AssignUniqueId (io.trino.sql.planner.plan.AssignUniqueId)9 Optional (java.util.Optional)9 Session (io.trino.Session)7 ResolvedFunction (io.trino.metadata.ResolvedFunction)7 Captures (io.trino.matching.Captures)6 Pattern (io.trino.matching.Pattern)6