use of com.facebook.presto.sql.planner.plan.GroupIdNode in project presto by prestodb.
the class GroupIdMatcher method detailMatches.
@Override
public MatchResult detailMatches(PlanNode node, Session session, Metadata metadata, SymbolAliases symbolAliases) {
checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName());
GroupIdNode groudIdNode = (GroupIdNode) node;
List<List<Symbol>> actualGroups = groudIdNode.getGroupingSets();
Map<Symbol, Symbol> actualArgumentMappings = groudIdNode.getArgumentMappings();
if (actualGroups.size() != groups.size()) {
return NO_MATCH;
}
for (int i = 0; i < actualGroups.size(); i++) {
if (!AggregationMatcher.matches(groups.get(i), actualGroups.get(i), symbolAliases)) {
return NO_MATCH;
}
}
if (!AggregationMatcher.matches(identityMappings.keySet(), actualArgumentMappings.keySet(), symbolAliases)) {
return NO_MATCH;
}
return match(groupIdAlias, groudIdNode.getGroupIdSymbol().toSymbolReference());
}
use of com.facebook.presto.sql.planner.plan.GroupIdNode in project presto by prestodb.
the class QueryPlanner method aggregate.
private PlanBuilder aggregate(PlanBuilder subPlan, QuerySpecification node) {
List<List<Expression>> groupingSets = analysis.getGroupingSets(node);
if (groupingSets.isEmpty()) {
return subPlan;
}
// 1. Pre-project all scalar inputs (arguments and non-trivial group by expressions)
Set<Expression> distinctGroupingColumns = groupingSets.stream().flatMap(Collection::stream).collect(toImmutableSet());
ImmutableList.Builder<Expression> arguments = ImmutableList.builder();
analysis.getAggregates(node).stream().map(FunctionCall::getArguments).flatMap(List::stream).forEach(arguments::add);
// filter expressions need to be projected first
analysis.getAggregates(node).stream().map(FunctionCall::getFilter).filter(Optional::isPresent).map(Optional::get).forEach(arguments::add);
Iterable<Expression> inputs = Iterables.concat(distinctGroupingColumns, arguments.build());
subPlan = handleSubqueries(subPlan, node, inputs);
if (!Iterables.isEmpty(inputs)) {
// avoid an empty projection if the only aggregation is COUNT (which has no arguments)
subPlan = project(subPlan, inputs);
}
// 2. Aggregate
// 2.a. Rewrite aggregate arguments
TranslationMap argumentTranslations = new TranslationMap(subPlan.getRelationPlan(), analysis, lambdaDeclarationToSymbolMap);
ImmutableMap.Builder<Symbol, Symbol> argumentMappingBuilder = ImmutableMap.builder();
for (Expression argument : arguments.build()) {
Expression parametersReplaced = ExpressionTreeRewriter.rewriteWith(new ParameterRewriter(analysis.getParameters(), analysis), argument);
argumentTranslations.addIntermediateMapping(argument, parametersReplaced);
Symbol input = subPlan.translate(parametersReplaced);
if (!argumentTranslations.containsSymbol(parametersReplaced)) {
Symbol output = symbolAllocator.newSymbol(parametersReplaced, analysis.getTypeWithCoercions(parametersReplaced), "arg");
argumentMappingBuilder.put(output, input);
argumentTranslations.put(parametersReplaced, output);
}
}
Map<Symbol, Symbol> argumentMappings = argumentMappingBuilder.build();
// 2.b. Rewrite grouping columns
TranslationMap groupingTranslations = new TranslationMap(subPlan.getRelationPlan(), analysis, lambdaDeclarationToSymbolMap);
Map<Symbol, Symbol> groupingSetMappings = new HashMap<>();
List<List<Symbol>> groupingSymbols = new ArrayList<>();
for (List<Expression> groupingSet : groupingSets) {
ImmutableList.Builder<Symbol> symbols = ImmutableList.builder();
for (Expression expression : groupingSet) {
Expression parametersReplaced = ExpressionTreeRewriter.rewriteWith(new ParameterRewriter(analysis.getParameters(), analysis), expression);
groupingTranslations.addIntermediateMapping(expression, parametersReplaced);
Symbol input = subPlan.translate(expression);
Symbol output;
if (!groupingTranslations.containsSymbol(parametersReplaced)) {
output = symbolAllocator.newSymbol(parametersReplaced, analysis.getTypeWithCoercions(expression), "gid");
groupingTranslations.put(parametersReplaced, output);
} else {
output = groupingTranslations.get(parametersReplaced);
}
groupingSetMappings.put(output, input);
symbols.add(output);
}
groupingSymbols.add(symbols.build());
}
// 2.c. Generate GroupIdNode (multiple grouping sets) or ProjectNode (single grouping set)
Optional<Symbol> groupIdSymbol = Optional.empty();
if (groupingSets.size() > 1) {
groupIdSymbol = Optional.of(symbolAllocator.newSymbol("groupId", BIGINT));
GroupIdNode groupId = new GroupIdNode(idAllocator.getNextId(), subPlan.getRoot(), groupingSymbols, groupingSetMappings, argumentMappings, groupIdSymbol.get());
subPlan = new PlanBuilder(groupingTranslations, groupId, analysis.getParameters());
} else {
Assignments.Builder assignments = Assignments.builder();
for (Symbol output : argumentMappings.keySet()) {
assignments.put(output, argumentMappings.get(output).toSymbolReference());
}
for (Symbol output : groupingSetMappings.keySet()) {
assignments.put(output, groupingSetMappings.get(output).toSymbolReference());
}
ProjectNode project = new ProjectNode(idAllocator.getNextId(), subPlan.getRoot(), assignments.build());
subPlan = new PlanBuilder(groupingTranslations, project, analysis.getParameters());
}
TranslationMap aggregationTranslations = new TranslationMap(subPlan.getRelationPlan(), analysis, lambdaDeclarationToSymbolMap);
aggregationTranslations.copyMappingsFrom(groupingTranslations);
// 2.d. Rewrite aggregates
ImmutableMap.Builder<Symbol, FunctionCall> aggregationAssignments = ImmutableMap.builder();
ImmutableMap.Builder<Symbol, Signature> functions = ImmutableMap.builder();
boolean needPostProjectionCoercion = false;
for (FunctionCall aggregate : analysis.getAggregates(node)) {
Expression parametersReplaced = ExpressionTreeRewriter.rewriteWith(new ParameterRewriter(analysis.getParameters(), analysis), aggregate);
aggregationTranslations.addIntermediateMapping(aggregate, parametersReplaced);
Expression rewritten = argumentTranslations.rewrite(parametersReplaced);
Symbol newSymbol = symbolAllocator.newSymbol(rewritten, analysis.getType(aggregate));
// Therefore we can end up with this implicit cast, and have to move it into a post-projection
if (rewritten instanceof Cast) {
rewritten = ((Cast) rewritten).getExpression();
needPostProjectionCoercion = true;
}
aggregationAssignments.put(newSymbol, (FunctionCall) rewritten);
aggregationTranslations.put(parametersReplaced, newSymbol);
functions.put(newSymbol, analysis.getFunctionSignature(aggregate));
}
// 2.e. Mark distinct rows for each aggregate that has DISTINCT
// Map from aggregate function arguments to marker symbols, so that we can reuse the markers, if two aggregates have the same argument
Map<Set<Expression>, Symbol> argumentMarkers = new HashMap<>();
// Map from aggregate functions to marker symbols
Map<Symbol, Symbol> masks = new HashMap<>();
for (FunctionCall aggregate : Iterables.filter(analysis.getAggregates(node), FunctionCall::isDistinct)) {
Set<Expression> args = ImmutableSet.copyOf(aggregate.getArguments());
Symbol marker = argumentMarkers.get(args);
Symbol aggregateSymbol = aggregationTranslations.get(aggregate);
if (marker == null) {
if (args.size() == 1) {
marker = symbolAllocator.newSymbol(getOnlyElement(args), BOOLEAN, "distinct");
} else {
marker = symbolAllocator.newSymbol(aggregateSymbol.getName(), BOOLEAN, "distinct");
}
argumentMarkers.put(args, marker);
}
masks.put(aggregateSymbol, marker);
}
for (Map.Entry<Set<Expression>, Symbol> entry : argumentMarkers.entrySet()) {
ImmutableList.Builder<Symbol> builder = ImmutableList.builder();
builder.addAll(groupingSymbols.stream().flatMap(Collection::stream).distinct().collect(Collectors.toList()));
groupIdSymbol.ifPresent(builder::add);
for (Expression expression : entry.getKey()) {
builder.add(argumentTranslations.get(expression));
}
subPlan = subPlan.withNewRoot(new MarkDistinctNode(idAllocator.getNextId(), subPlan.getRoot(), entry.getValue(), builder.build(), Optional.empty()));
}
AggregationNode aggregationNode = new AggregationNode(idAllocator.getNextId(), subPlan.getRoot(), aggregationAssignments.build(), functions.build(), masks, groupingSymbols, AggregationNode.Step.SINGLE, Optional.empty(), groupIdSymbol);
subPlan = new PlanBuilder(aggregationTranslations, aggregationNode, analysis.getParameters());
// TODO: this is a hack, we should change type coercions to coerce the inputs to functions/operators instead of coercing the output
if (needPostProjectionCoercion) {
return explicitCoercionFields(subPlan, distinctGroupingColumns, analysis.getAggregates(node));
}
return subPlan;
}
Aggregations