Search in sources :

Example 1 with Aggregation

use of io.trino.sql.planner.plan.AggregationNode.Aggregation in project trino by trinodb.

the class AggregationStatsRule method groupBy.

public static PlanNodeStatsEstimate groupBy(PlanNodeStatsEstimate sourceStats, Collection<Symbol> groupBySymbols, Map<Symbol, Aggregation> aggregations) {
    PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.builder();
    for (Symbol groupBySymbol : groupBySymbols) {
        SymbolStatsEstimate symbolStatistics = sourceStats.getSymbolStatistics(groupBySymbol);
        result.addSymbolStatistics(groupBySymbol, symbolStatistics.mapNullsFraction(nullsFraction -> {
            if (nullsFraction == 0.0) {
                return 0.0;
            }
            return 1.0 / (symbolStatistics.getDistinctValuesCount() + 1);
        }));
    }
    double rowsCount = getRowsCount(sourceStats, groupBySymbols);
    result.setOutputRowCount(min(rowsCount, sourceStats.getOutputRowCount()));
    for (Map.Entry<Symbol, Aggregation> aggregationEntry : aggregations.entrySet()) {
        result.addSymbolStatistics(aggregationEntry.getKey(), estimateAggregationStats(aggregationEntry.getValue(), sourceStats));
    }
    return result.build();
}
Also used : Symbol(io.trino.sql.planner.Symbol) Lookup(io.trino.sql.planner.iterative.Lookup) Aggregation(io.trino.sql.planner.plan.AggregationNode.Aggregation) Collection(java.util.Collection) SINGLE(io.trino.sql.planner.plan.AggregationNode.Step.SINGLE) Math.min(java.lang.Math.min) Pattern(io.trino.matching.Pattern) Patterns.aggregation(io.trino.sql.planner.plan.Patterns.aggregation) Map(java.util.Map) Objects.requireNonNull(java.util.Objects.requireNonNull) TypeProvider(io.trino.sql.planner.TypeProvider) AggregationNode(io.trino.sql.planner.plan.AggregationNode) Optional(java.util.Optional) Session(io.trino.Session) Aggregation(io.trino.sql.planner.plan.AggregationNode.Aggregation) Symbol(io.trino.sql.planner.Symbol) Map(java.util.Map)

Example 2 with Aggregation

use of io.trino.sql.planner.plan.AggregationNode.Aggregation in project trino by trinodb.

the class ImplementFilteredAggregations method apply.

@Override
public Result apply(AggregationNode aggregationNode, Captures captures, Context context) {
    Assignments.Builder newAssignments = Assignments.builder();
    ImmutableMap.Builder<Symbol, Aggregation> aggregations = ImmutableMap.builder();
    ImmutableList.Builder<Expression> maskSymbols = ImmutableList.builder();
    boolean aggregateWithoutFilterOrMaskPresent = false;
    for (Map.Entry<Symbol, Aggregation> entry : aggregationNode.getAggregations().entrySet()) {
        Symbol output = entry.getKey();
        // strip the filters
        Aggregation aggregation = entry.getValue();
        Optional<Symbol> mask = aggregation.getMask();
        if (aggregation.getFilter().isPresent()) {
            Symbol filter = aggregation.getFilter().get();
            if (mask.isPresent()) {
                Symbol newMask = context.getSymbolAllocator().newSymbol("mask", BOOLEAN);
                Expression expression = and(mask.get().toSymbolReference(), filter.toSymbolReference());
                newAssignments.put(newMask, expression);
                mask = Optional.of(newMask);
                maskSymbols.add(newMask.toSymbolReference());
            } else {
                mask = Optional.of(filter);
                maskSymbols.add(filter.toSymbolReference());
            }
        } else if (mask.isPresent()) {
            maskSymbols.add(mask.get().toSymbolReference());
        } else {
            aggregateWithoutFilterOrMaskPresent = true;
        }
        aggregations.put(output, new Aggregation(aggregation.getResolvedFunction(), aggregation.getArguments(), aggregation.isDistinct(), Optional.empty(), aggregation.getOrderingScheme(), mask));
    }
    Expression predicate = TRUE_LITERAL;
    if (!aggregationNode.hasNonEmptyGroupingSet() && !aggregateWithoutFilterOrMaskPresent) {
        predicate = combineDisjunctsWithDefault(metadata, maskSymbols.build(), TRUE_LITERAL);
    }
    // identity projection for all existing inputs
    newAssignments.putIdentities(aggregationNode.getSource().getOutputSymbols());
    return Result.ofPlanNode(new AggregationNode(context.getIdAllocator().getNextId(), new FilterNode(context.getIdAllocator().getNextId(), new ProjectNode(context.getIdAllocator().getNextId(), aggregationNode.getSource(), newAssignments.build()), predicate), aggregations.buildOrThrow(), aggregationNode.getGroupingSets(), ImmutableList.of(), aggregationNode.getStep(), aggregationNode.getHashSymbol(), aggregationNode.getGroupIdSymbol()));
}
Also used : Symbol(io.trino.sql.planner.Symbol) ImmutableList(com.google.common.collect.ImmutableList) FilterNode(io.trino.sql.planner.plan.FilterNode) Assignments(io.trino.sql.planner.plan.Assignments) AggregationNode(io.trino.sql.planner.plan.AggregationNode) ImmutableMap(com.google.common.collect.ImmutableMap) Aggregation(io.trino.sql.planner.plan.AggregationNode.Aggregation) Expression(io.trino.sql.tree.Expression) ProjectNode(io.trino.sql.planner.plan.ProjectNode) ImmutableMap(com.google.common.collect.ImmutableMap) Map(java.util.Map)

Example 3 with Aggregation

use of io.trino.sql.planner.plan.AggregationNode.Aggregation in project trino by trinodb.

the class AggregationDecorrelation method rewriteWithMasks.

public static Map<Symbol, Aggregation> rewriteWithMasks(Map<Symbol, Aggregation> aggregations, Map<Symbol, Symbol> masks) {
    ImmutableMap.Builder<Symbol, Aggregation> rewritten = ImmutableMap.builder();
    for (Map.Entry<Symbol, Aggregation> entry : aggregations.entrySet()) {
        Symbol symbol = entry.getKey();
        Aggregation aggregation = entry.getValue();
        rewritten.put(symbol, new Aggregation(aggregation.getResolvedFunction(), aggregation.getArguments(), aggregation.isDistinct(), aggregation.getFilter(), aggregation.getOrderingScheme(), Optional.of(masks.get(symbol))));
    }
    return rewritten.buildOrThrow();
}
Also used : Aggregation(io.trino.sql.planner.plan.AggregationNode.Aggregation) Symbol(io.trino.sql.planner.Symbol) ImmutableMap(com.google.common.collect.ImmutableMap) Map(java.util.Map) ImmutableMap(com.google.common.collect.ImmutableMap)

Example 4 with Aggregation

use of io.trino.sql.planner.plan.AggregationNode.Aggregation in project trino by trinodb.

the class MultipleDistinctAggregationToMarkDistinct method apply.

@Override
public Result apply(AggregationNode parent, Captures captures, Context context) {
    if (!SystemSessionProperties.useMarkDistinct(context.getSession())) {
        return Result.empty();
    }
    // the distinct marker for the given set of input columns
    Map<Set<Symbol>, Symbol> markers = new HashMap<>();
    Map<Symbol, Aggregation> newAggregations = new HashMap<>();
    PlanNode subPlan = parent.getSource();
    for (Map.Entry<Symbol, Aggregation> entry : parent.getAggregations().entrySet()) {
        Aggregation aggregation = entry.getValue();
        if (aggregation.isDistinct() && aggregation.getFilter().isEmpty() && aggregation.getMask().isEmpty()) {
            Set<Symbol> inputs = aggregation.getArguments().stream().map(Symbol::from).collect(toSet());
            Symbol marker = markers.get(inputs);
            if (marker == null) {
                marker = context.getSymbolAllocator().newSymbol(Iterables.getLast(inputs).getName(), BOOLEAN, "distinct");
                markers.put(inputs, marker);
                ImmutableSet.Builder<Symbol> distinctSymbols = ImmutableSet.<Symbol>builder().addAll(parent.getGroupingKeys()).addAll(inputs);
                parent.getGroupIdSymbol().ifPresent(distinctSymbols::add);
                subPlan = new MarkDistinctNode(context.getIdAllocator().getNextId(), subPlan, marker, ImmutableList.copyOf(distinctSymbols.build()), Optional.empty());
            }
            // remove the distinct flag and set the distinct marker
            newAggregations.put(entry.getKey(), new Aggregation(aggregation.getResolvedFunction(), aggregation.getArguments(), false, aggregation.getFilter(), aggregation.getOrderingScheme(), Optional.of(marker)));
        } else {
            newAggregations.put(entry.getKey(), aggregation);
        }
    }
    return Result.ofPlanNode(new AggregationNode(parent.getId(), subPlan, newAggregations, parent.getGroupingSets(), ImmutableList.of(), parent.getStep(), parent.getHashSymbol(), parent.getGroupIdSymbol()));
}
Also used : ImmutableSet(com.google.common.collect.ImmutableSet) Set(java.util.Set) HashSet(java.util.HashSet) Collectors.toSet(java.util.stream.Collectors.toSet) MarkDistinctNode(io.trino.sql.planner.plan.MarkDistinctNode) HashMap(java.util.HashMap) Symbol(io.trino.sql.planner.Symbol) AggregationNode(io.trino.sql.planner.plan.AggregationNode) Aggregation(io.trino.sql.planner.plan.AggregationNode.Aggregation) PlanNode(io.trino.sql.planner.plan.PlanNode) ImmutableSet(com.google.common.collect.ImmutableSet) HashMap(java.util.HashMap) Map(java.util.Map)

Example 5 with Aggregation

use of io.trino.sql.planner.plan.AggregationNode.Aggregation in project trino by trinodb.

the class SingleDistinctAggregationToGroupBy method apply.

@Override
public Result apply(AggregationNode aggregation, Captures captures, Context context) {
    List<Set<Expression>> argumentSets = extractArgumentSets(aggregation).collect(Collectors.toList());
    Set<Symbol> symbols = Iterables.getOnlyElement(argumentSets).stream().map(Symbol::from).collect(Collectors.toSet());
    return Result.ofPlanNode(new AggregationNode(aggregation.getId(), new AggregationNode(context.getIdAllocator().getNextId(), aggregation.getSource(), ImmutableMap.of(), singleGroupingSet(ImmutableList.<Symbol>builder().addAll(aggregation.getGroupingKeys()).addAll(symbols).build()), ImmutableList.of(), SINGLE, Optional.empty(), Optional.empty()), // remove DISTINCT flag from function calls
    aggregation.getAggregations().entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> removeDistinct(e.getValue()))), aggregation.getGroupingSets(), emptyList(), aggregation.getStep(), aggregation.getHashSymbol(), aggregation.getGroupIdSymbol()));
}
Also used : Symbol(io.trino.sql.planner.Symbol) Iterables(com.google.common.collect.Iterables) ImmutableMap(com.google.common.collect.ImmutableMap) Aggregation(io.trino.sql.planner.plan.AggregationNode.Aggregation) Collections.emptyList(java.util.Collections.emptyList) Set(java.util.Set) SINGLE(io.trino.sql.planner.plan.AggregationNode.Step.SINGLE) Collectors(java.util.stream.Collectors) HashSet(java.util.HashSet) AggregationNode.singleGroupingSet(io.trino.sql.planner.plan.AggregationNode.singleGroupingSet) List(java.util.List) Preconditions.checkArgument(com.google.common.base.Preconditions.checkArgument) Pattern(io.trino.matching.Pattern) Stream(java.util.stream.Stream) Patterns.aggregation(io.trino.sql.planner.plan.Patterns.aggregation) ImmutableList(com.google.common.collect.ImmutableList) Captures(io.trino.matching.Captures) Map(java.util.Map) AggregationNode(io.trino.sql.planner.plan.AggregationNode) Optional(java.util.Optional) Rule(io.trino.sql.planner.iterative.Rule) Expression(io.trino.sql.tree.Expression) Set(java.util.Set) HashSet(java.util.HashSet) AggregationNode.singleGroupingSet(io.trino.sql.planner.plan.AggregationNode.singleGroupingSet) Symbol(io.trino.sql.planner.Symbol) AggregationNode(io.trino.sql.planner.plan.AggregationNode) ImmutableMap(com.google.common.collect.ImmutableMap) Map(java.util.Map)

Aggregations

Aggregation (io.trino.sql.planner.plan.AggregationNode.Aggregation)20 AggregationNode (io.trino.sql.planner.plan.AggregationNode)17 Symbol (io.trino.sql.planner.Symbol)14 Map (java.util.Map)14 ImmutableMap (com.google.common.collect.ImmutableMap)11 PlanNode (io.trino.sql.planner.plan.PlanNode)9 Expression (io.trino.sql.tree.Expression)9 ProjectNode (io.trino.sql.planner.plan.ProjectNode)7 ImmutableList (com.google.common.collect.ImmutableList)5 ResolvedFunction (io.trino.metadata.ResolvedFunction)5 Assignments (io.trino.sql.planner.plan.Assignments)5 Type (io.trino.spi.type.Type)4 CorrelatedJoinNode (io.trino.sql.planner.plan.CorrelatedJoinNode)4 ComparisonExpression (io.trino.sql.tree.ComparisonExpression)4 Test (org.testng.annotations.Test)4 Preconditions.checkArgument (com.google.common.base.Preconditions.checkArgument)3 ImmutableList.toImmutableList (com.google.common.collect.ImmutableList.toImmutableList)3 ImmutableSet (com.google.common.collect.ImmutableSet)3 Session (io.trino.Session)3 Cast (io.trino.sql.tree.Cast)3