use of io.trino.sql.planner.plan.AggregationNode in project trino by trinodb.
the class AddIntermediateAggregations method recurseToPartial.
/**
* Recurse through a series of preceding ExchangeNodes and ProjectNodes to find the preceding PARTIAL aggregation
*/
private Optional<PlanNode> recurseToPartial(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator) {
if (node instanceof AggregationNode && ((AggregationNode) node).getStep() == AggregationNode.Step.PARTIAL) {
return Optional.of(addGatheringIntermediate((AggregationNode) node, idAllocator));
}
if (!(node instanceof ExchangeNode) && !(node instanceof ProjectNode)) {
return Optional.empty();
}
ImmutableList.Builder<PlanNode> builder = ImmutableList.builder();
for (PlanNode source : node.getSources()) {
Optional<PlanNode> planNode = recurseToPartial(lookup.resolve(source), lookup, idAllocator);
if (planNode.isEmpty()) {
return Optional.empty();
}
builder.add(planNode.get());
}
return Optional.of(node.replaceChildren(builder.build()));
}
use of io.trino.sql.planner.plan.AggregationNode in project trino by trinodb.
the class AddIntermediateAggregations method addGatheringIntermediate.
private PlanNode addGatheringIntermediate(AggregationNode aggregation, PlanNodeIdAllocator idAllocator) {
verify(aggregation.getGroupingKeys().isEmpty(), "Should be an un-grouped aggregation");
ExchangeNode gatheringExchange = ExchangeNode.gatheringExchange(idAllocator.getNextId(), ExchangeNode.Scope.LOCAL, aggregation);
return new AggregationNode(idAllocator.getNextId(), gatheringExchange, outputsAsInputs(aggregation.getAggregations()), aggregation.getGroupingSets(), aggregation.getPreGroupedSymbols(), AggregationNode.Step.INTERMEDIATE, aggregation.getHashSymbol(), aggregation.getGroupIdSymbol());
}
use of io.trino.sql.planner.plan.AggregationNode in project trino by trinodb.
the class PushAggregationIntoTableScan method pushAggregationIntoTableScan.
public static Optional<PlanNode> pushAggregationIntoTableScan(PlannerContext plannerContext, Context context, PlanNode aggregationNode, TableScanNode tableScan, Map<Symbol, AggregationNode.Aggregation> aggregations, List<Symbol> groupingKeys) {
Map<String, ColumnHandle> assignments = tableScan.getAssignments().entrySet().stream().collect(toImmutableMap(entry -> entry.getKey().getName(), Entry::getValue));
List<Entry<Symbol, AggregationNode.Aggregation>> aggregationsList = aggregations.entrySet().stream().collect(toImmutableList());
List<AggregateFunction> aggregateFunctions = aggregationsList.stream().map(Entry::getValue).map(aggregation -> toAggregateFunction(plannerContext.getMetadata(), context, aggregation)).collect(toImmutableList());
List<Symbol> aggregationOutputSymbols = aggregationsList.stream().map(Entry::getKey).collect(toImmutableList());
List<ColumnHandle> groupByColumns = groupingKeys.stream().map(groupByColumn -> assignments.get(groupByColumn.getName())).collect(toImmutableList());
Optional<AggregationApplicationResult<TableHandle>> aggregationPushdownResult = plannerContext.getMetadata().applyAggregation(context.getSession(), tableScan.getTable(), aggregateFunctions, assignments, ImmutableList.of(groupByColumns));
if (aggregationPushdownResult.isEmpty()) {
return Optional.empty();
}
AggregationApplicationResult<TableHandle> result = aggregationPushdownResult.get();
// The new scan outputs should be the symbols associated with grouping columns plus the symbols associated with aggregations.
ImmutableList.Builder<Symbol> newScanOutputs = ImmutableList.builder();
newScanOutputs.addAll(tableScan.getOutputSymbols());
ImmutableBiMap.Builder<Symbol, ColumnHandle> newScanAssignments = ImmutableBiMap.builder();
newScanAssignments.putAll(tableScan.getAssignments());
Map<String, Symbol> variableMappings = new HashMap<>();
for (Assignment assignment : result.getAssignments()) {
Symbol symbol = context.getSymbolAllocator().newSymbol(assignment.getVariable(), assignment.getType());
newScanOutputs.add(symbol);
newScanAssignments.put(symbol, assignment.getColumn());
variableMappings.put(assignment.getVariable(), symbol);
}
List<Expression> newProjections = result.getProjections().stream().map(expression -> ConnectorExpressionTranslator.translate(context.getSession(), expression, plannerContext, variableMappings, new LiteralEncoder(plannerContext))).collect(toImmutableList());
verify(aggregationOutputSymbols.size() == newProjections.size());
Assignments.Builder assignmentBuilder = Assignments.builder();
IntStream.range(0, aggregationOutputSymbols.size()).forEach(index -> assignmentBuilder.put(aggregationOutputSymbols.get(index), newProjections.get(index)));
ImmutableBiMap<Symbol, ColumnHandle> scanAssignments = newScanAssignments.build();
ImmutableBiMap<ColumnHandle, Symbol> columnHandleToSymbol = scanAssignments.inverse();
// projections assignmentBuilder should have both agg and group by so we add all the group bys as symbol references
groupingKeys.forEach(groupBySymbol -> {
// if the connector returned a new mapping from oldColumnHandle to newColumnHandle, groupBy needs to point to
// new columnHandle's symbol reference, otherwise it will continue pointing at oldColumnHandle.
ColumnHandle originalColumnHandle = assignments.get(groupBySymbol.getName());
ColumnHandle groupByColumnHandle = result.getGroupingColumnMapping().getOrDefault(originalColumnHandle, originalColumnHandle);
assignmentBuilder.put(groupBySymbol, columnHandleToSymbol.get(groupByColumnHandle).toSymbolReference());
});
return Optional.of(new ProjectNode(context.getIdAllocator().getNextId(), new TableScanNode(context.getIdAllocator().getNextId(), result.getHandle(), newScanOutputs.build(), scanAssignments, TupleDomain.all(), deriveTableStatisticsForPushdown(context.getStatsProvider(), context.getSession(), result.isPrecalculateStatistics(), aggregationNode), tableScan.isUpdateTarget(), // table scan partitioning might have changed with new table handle
Optional.empty()), assignmentBuilder.build()));
}
use of io.trino.sql.planner.plan.AggregationNode in project trino by trinodb.
the class PushAggregationThroughOuterJoin method createAggregationOverNull.
private MappedAggregationInfo createAggregationOverNull(AggregationNode referenceAggregation, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator) {
// Create a values node that consists of a single row of nulls.
// Map the output symbols from the referenceAggregation's source
// to symbol references for the new values node.
ImmutableList.Builder<Symbol> nullSymbols = ImmutableList.builder();
ImmutableList.Builder<Expression> nullLiterals = ImmutableList.builder();
ImmutableMap.Builder<Symbol, Symbol> sourcesSymbolMappingBuilder = ImmutableMap.builder();
for (Symbol sourceSymbol : referenceAggregation.getSource().getOutputSymbols()) {
Type type = symbolAllocator.getTypes().get(sourceSymbol);
nullLiterals.add(new Cast(new NullLiteral(), toSqlType(type)));
Symbol nullSymbol = symbolAllocator.newSymbol("null", type);
nullSymbols.add(nullSymbol);
sourcesSymbolMappingBuilder.put(sourceSymbol, nullSymbol);
}
ValuesNode nullRow = new ValuesNode(idAllocator.getNextId(), nullSymbols.build(), ImmutableList.of(new Row(nullLiterals.build())));
// For each aggregation function in the reference node, create a corresponding aggregation function
// that points to the nullRow. Map the symbols from the aggregations in referenceAggregation to the
// symbols in these new aggregations.
ImmutableMap.Builder<Symbol, Symbol> aggregationsSymbolMappingBuilder = ImmutableMap.builder();
ImmutableMap.Builder<Symbol, AggregationNode.Aggregation> aggregationsOverNullBuilder = ImmutableMap.builder();
SymbolMapper mapper = symbolMapper(sourcesSymbolMappingBuilder.buildOrThrow());
for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : referenceAggregation.getAggregations().entrySet()) {
Symbol aggregationSymbol = entry.getKey();
Aggregation overNullAggregation = mapper.map(entry.getValue());
Symbol overNullSymbol = symbolAllocator.newSymbol(overNullAggregation.getResolvedFunction().getSignature().getName(), symbolAllocator.getTypes().get(aggregationSymbol));
aggregationsOverNullBuilder.put(overNullSymbol, overNullAggregation);
aggregationsSymbolMappingBuilder.put(aggregationSymbol, overNullSymbol);
}
Map<Symbol, Symbol> aggregationsSymbolMapping = aggregationsSymbolMappingBuilder.buildOrThrow();
// create an aggregation node whose source is the null row.
AggregationNode aggregationOverNullRow = new AggregationNode(idAllocator.getNextId(), nullRow, aggregationsOverNullBuilder.buildOrThrow(), globalAggregation(), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty());
return new MappedAggregationInfo(aggregationOverNullRow, aggregationsSymbolMapping);
}
use of io.trino.sql.planner.plan.AggregationNode in project trino by trinodb.
the class PushAggregationThroughOuterJoin method apply.
@Override
public Result apply(AggregationNode aggregation, Captures captures, Context context) {
// This rule doesn't deal with AggregationNode's hash symbol. Hash symbols are not yet present at this stage of optimization.
checkArgument(aggregation.getHashSymbol().isEmpty(), "unexpected hash symbol");
JoinNode join = captures.get(JOIN);
if (join.getFilter().isPresent() || !(join.getType() == JoinNode.Type.LEFT || join.getType() == JoinNode.Type.RIGHT) || !groupsOnAllColumns(aggregation, getOuterTable(join).getOutputSymbols()) || !isDistinct(context.getLookup().resolve(getOuterTable(join)), context.getLookup()::resolve) || !isAggregationOnSymbols(aggregation, getInnerTable(join))) {
return Result.empty();
}
List<Symbol> groupingKeys = join.getCriteria().stream().map(join.getType() == JoinNode.Type.RIGHT ? JoinNode.EquiJoinClause::getLeft : JoinNode.EquiJoinClause::getRight).collect(toImmutableList());
AggregationNode rewrittenAggregation = new AggregationNode(aggregation.getId(), getInnerTable(join), aggregation.getAggregations(), singleGroupingSet(groupingKeys), ImmutableList.of(), aggregation.getStep(), aggregation.getHashSymbol(), aggregation.getGroupIdSymbol());
JoinNode rewrittenJoin;
if (join.getType() == JoinNode.Type.LEFT) {
rewrittenJoin = new JoinNode(join.getId(), join.getType(), join.getLeft(), rewrittenAggregation, join.getCriteria(), join.getLeft().getOutputSymbols(), ImmutableList.copyOf(rewrittenAggregation.getAggregations().keySet()), // there are no duplicate rows possible since outer rows were guaranteed to be distinct
false, join.getFilter(), join.getLeftHashSymbol(), join.getRightHashSymbol(), join.getDistributionType(), join.isSpillable(), join.getDynamicFilters(), join.getReorderJoinStatsAndCost());
} else {
rewrittenJoin = new JoinNode(join.getId(), join.getType(), rewrittenAggregation, join.getRight(), join.getCriteria(), ImmutableList.copyOf(rewrittenAggregation.getAggregations().keySet()), join.getRight().getOutputSymbols(), // there are no duplicate rows possible since outer rows were guaranteed to be distinct
false, join.getFilter(), join.getLeftHashSymbol(), join.getRightHashSymbol(), join.getDistributionType(), join.isSpillable(), join.getDynamicFilters(), join.getReorderJoinStatsAndCost());
}
Optional<PlanNode> resultNode = coalesceWithNullAggregation(rewrittenAggregation, rewrittenJoin, context.getSymbolAllocator(), context.getIdAllocator());
if (resultNode.isEmpty()) {
return Result.empty();
}
return Result.ofPlanNode(resultNode.get());
}
Aggregations