use of io.trino.sql.planner.plan.AggregationNode.Aggregation in project trino by trinodb.
the class PruneOrderByInAggregation method apply.
@Override
public Result apply(AggregationNode node, Captures captures, Context context) {
if (!node.hasOrderings()) {
return Result.empty();
}
boolean anyRewritten = false;
ImmutableMap.Builder<Symbol, Aggregation> aggregations = ImmutableMap.builder();
for (Map.Entry<Symbol, Aggregation> entry : node.getAggregations().entrySet()) {
Aggregation aggregation = entry.getValue();
if (aggregation.getOrderingScheme().isEmpty()) {
aggregations.put(entry);
} else // getAggregateFunctionImplementation can be expensive, so check it last.
if (metadata.getAggregationFunctionMetadata(aggregation.getResolvedFunction()).isOrderSensitive()) {
aggregations.put(entry);
} else {
anyRewritten = true;
aggregations.put(entry.getKey(), new Aggregation(aggregation.getResolvedFunction(), aggregation.getArguments(), aggregation.isDistinct(), aggregation.getFilter(), Optional.empty(), aggregation.getMask()));
}
}
if (!anyRewritten) {
return Result.empty();
}
return Result.ofPlanNode(new AggregationNode(node.getId(), node.getSource(), aggregations.buildOrThrow(), node.getGroupingSets(), node.getPreGroupedSymbols(), node.getStep(), node.getHashSymbol(), node.getGroupIdSymbol()));
}
use of io.trino.sql.planner.plan.AggregationNode.Aggregation in project trino by trinodb.
the class TransformExistsApplyToCorrelatedJoin method rewriteToDefaultAggregation.
private PlanNode rewriteToDefaultAggregation(ApplyNode applyNode, Context context) {
ResolvedFunction countFunction = plannerContext.getMetadata().resolveFunction(context.getSession(), COUNT, ImmutableList.of());
Symbol count = context.getSymbolAllocator().newSymbol(COUNT.toString(), BIGINT);
Symbol exists = getOnlyElement(applyNode.getSubqueryAssignments().getSymbols());
return new CorrelatedJoinNode(applyNode.getId(), applyNode.getInput(), new ProjectNode(context.getIdAllocator().getNextId(), new AggregationNode(context.getIdAllocator().getNextId(), applyNode.getSubquery(), ImmutableMap.of(count, new Aggregation(countFunction, ImmutableList.of(), false, Optional.empty(), Optional.empty(), Optional.empty())), globalAggregation(), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty()), Assignments.of(exists, new ComparisonExpression(GREATER_THAN, count.toSymbolReference(), new Cast(new LongLiteral("0"), toSqlType(BIGINT))))), applyNode.getCorrelation(), INNER, TRUE_LITERAL, applyNode.getOriginSubquery());
}
use of io.trino.sql.planner.plan.AggregationNode.Aggregation in project trino by trinodb.
the class TransformCorrelatedGlobalAggregationWithoutProjection method apply.
@Override
public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Context context) {
checkArgument(correlatedJoinNode.getType() == INNER || correlatedJoinNode.getType() == LEFT, "unexpected correlated join type: " + correlatedJoinNode.getType());
// if there is another aggregation below the AggregationNode, handle both
PlanNode source = captures.get(SOURCE);
AggregationNode distinct = null;
if (isDistinctOperator(source)) {
distinct = (AggregationNode) source;
source = distinct.getSource();
}
// decorrelate nested plan
PlanNodeDecorrelator decorrelator = new PlanNodeDecorrelator(plannerContext, context.getSymbolAllocator(), context.getLookup());
Optional<PlanNodeDecorrelator.DecorrelatedNode> decorrelatedSource = decorrelator.decorrelateFilters(source, correlatedJoinNode.getCorrelation());
if (decorrelatedSource.isEmpty()) {
return Result.empty();
}
source = decorrelatedSource.get().getNode();
// append non-null symbol on nested plan. It will be used to restore semantics of null-sensitive aggregations after LEFT join
Symbol nonNull = context.getSymbolAllocator().newSymbol("non_null", BOOLEAN);
source = new ProjectNode(context.getIdAllocator().getNextId(), source, Assignments.builder().putIdentities(source.getOutputSymbols()).put(nonNull, TRUE_LITERAL).build());
// assign unique id on correlated join's input. It will be used to distinguish between original input rows after join
PlanNode inputWithUniqueId = new AssignUniqueId(context.getIdAllocator().getNextId(), correlatedJoinNode.getInput(), context.getSymbolAllocator().newSymbol("unique", BIGINT));
JoinNode join = new JoinNode(context.getIdAllocator().getNextId(), JoinNode.Type.LEFT, inputWithUniqueId, source, ImmutableList.of(), inputWithUniqueId.getOutputSymbols(), source.getOutputSymbols(), false, decorrelatedSource.get().getCorrelatedPredicates(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), ImmutableMap.of(), Optional.empty());
PlanNode root = join;
// restore distinct aggregation
if (distinct != null) {
root = new AggregationNode(distinct.getId(), join, distinct.getAggregations(), singleGroupingSet(ImmutableList.<Symbol>builder().addAll(join.getLeftOutputSymbols()).add(nonNull).addAll(distinct.getGroupingKeys()).build()), ImmutableList.of(), distinct.getStep(), Optional.empty(), Optional.empty());
}
// prepare mask symbols for aggregations
// Every original aggregation agg() will be rewritten to agg() mask(non_null). If the aggregation
// already has a mask, it will be replaced with conjunction of the existing mask and non_null.
// This is necessary to restore the original aggregation result in case when:
// - the nested lateral subquery returned empty result for some input row,
// - aggregation is null-sensitive, which means that its result over a single null row is different
// than result for empty input (with global grouping)
// It applies to the following aggregate functions: count(*), checksum(), array_agg().
AggregationNode globalAggregation = captures.get(AGGREGATION);
ImmutableMap.Builder<Symbol, Symbol> masks = ImmutableMap.builder();
Assignments.Builder assignmentsBuilder = Assignments.builder();
for (Map.Entry<Symbol, Aggregation> entry : globalAggregation.getAggregations().entrySet()) {
Aggregation aggregation = entry.getValue();
if (aggregation.getMask().isPresent()) {
Symbol newMask = context.getSymbolAllocator().newSymbol("mask", BOOLEAN);
Expression expression = and(aggregation.getMask().get().toSymbolReference(), nonNull.toSymbolReference());
assignmentsBuilder.put(newMask, expression);
masks.put(entry.getKey(), newMask);
} else {
masks.put(entry.getKey(), nonNull);
}
}
Assignments maskAssignments = assignmentsBuilder.build();
if (!maskAssignments.isEmpty()) {
root = new ProjectNode(context.getIdAllocator().getNextId(), root, Assignments.builder().putIdentities(root.getOutputSymbols()).putAll(maskAssignments).build());
}
// restore global aggregation
globalAggregation = new AggregationNode(globalAggregation.getId(), root, rewriteWithMasks(globalAggregation.getAggregations(), masks.buildOrThrow()), singleGroupingSet(ImmutableList.<Symbol>builder().addAll(join.getLeftOutputSymbols()).addAll(globalAggregation.getGroupingKeys()).build()), ImmutableList.of(), globalAggregation.getStep(), Optional.empty(), Optional.empty());
// restrict outputs
Optional<PlanNode> project = restrictOutputs(context.getIdAllocator(), globalAggregation, ImmutableSet.copyOf(correlatedJoinNode.getOutputSymbols()));
return Result.ofPlanNode(project.orElse(globalAggregation));
}
use of io.trino.sql.planner.plan.AggregationNode.Aggregation in project trino by trinodb.
the class AggregationFunctionMatcher method getAssignedSymbol.
@Override
public Optional<Symbol> getAssignedSymbol(PlanNode node, Session session, Metadata metadata, SymbolAliases symbolAliases) {
Optional<Symbol> result = Optional.empty();
if (!(node instanceof AggregationNode)) {
return result;
}
AggregationNode aggregationNode = (AggregationNode) node;
FunctionCall expectedCall = callMaker.getExpectedValue(symbolAliases);
for (Map.Entry<Symbol, Aggregation> assignment : aggregationNode.getAggregations().entrySet()) {
Aggregation aggregation = assignment.getValue();
if (aggregationMatches(aggregation, expectedCall)) {
checkState(result.isEmpty(), "Ambiguous function calls in %s", aggregationNode);
result = Optional.of(assignment.getKey());
}
}
return result;
}
use of io.trino.sql.planner.plan.AggregationNode.Aggregation in project trino by trinodb.
the class QueryPlanner method planAggregation.
private PlanBuilder planAggregation(PlanBuilder subPlan, List<List<Symbol>> groupingSets, Optional<Symbol> groupIdSymbol, List<FunctionCall> aggregates, Function<Expression, Symbol> coercions) {
ImmutableList.Builder<AggregationAssignment> aggregateMappingBuilder = ImmutableList.builder();
// deduplicate based on scope-aware equality
for (FunctionCall function : scopeAwareDistinct(subPlan, aggregates)) {
Symbol symbol = symbolAllocator.newSymbol(function, analysis.getType(function));
// TODO: for ORDER BY arguments, rewrite them such that they match the actual arguments to the function. This is necessary to maintain the semantics of DISTINCT + ORDER BY,
// which requires that ORDER BY be a subset of arguments
// What can happen currently is that if the argument requires a coercion, the argument will take a different input that the ORDER BY clause, which is undefined behavior
Aggregation aggregation = new Aggregation(analysis.getResolvedFunction(function), function.getArguments().stream().map(argument -> {
if (argument instanceof LambdaExpression) {
return subPlan.rewrite(argument);
}
return coercions.apply(argument).toSymbolReference();
}).collect(toImmutableList()), function.isDistinct(), function.getFilter().map(coercions), function.getOrderBy().map(orderBy -> translateOrderingScheme(orderBy.getSortItems(), coercions)), Optional.empty());
aggregateMappingBuilder.add(new AggregationAssignment(symbol, function, aggregation));
}
List<AggregationAssignment> aggregateMappings = aggregateMappingBuilder.build();
ImmutableSet.Builder<Integer> globalGroupingSets = ImmutableSet.builder();
for (int i = 0; i < groupingSets.size(); i++) {
if (groupingSets.get(i).isEmpty()) {
globalGroupingSets.add(i);
}
}
ImmutableList.Builder<Symbol> groupingKeys = ImmutableList.builder();
groupingSets.stream().flatMap(List::stream).distinct().forEach(groupingKeys::add);
groupIdSymbol.ifPresent(groupingKeys::add);
AggregationNode aggregationNode = new AggregationNode(idAllocator.getNextId(), subPlan.getRoot(), aggregateMappings.stream().collect(toImmutableMap(AggregationAssignment::getSymbol, AggregationAssignment::getRewritten)), groupingSets(groupingKeys.build(), groupingSets.size(), globalGroupingSets.build()), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), groupIdSymbol);
return new PlanBuilder(subPlan.getTranslations().withAdditionalMappings(aggregateMappings.stream().collect(toImmutableMap(assignment -> scopeAwareKey(assignment.getAstExpression(), analysis, subPlan.getScope()), AggregationAssignment::getSymbol))), aggregationNode);
}
Aggregations