use of com.facebook.presto.spi.plan.AggregationNode in project presto by prestodb.
the class PushPartialAggregationThroughExchange method split.
private PlanNode split(AggregationNode node, Context context) {
// otherwise, add a partial and final with an exchange in between
Map<VariableReferenceExpression, AggregationNode.Aggregation> intermediateAggregation = new HashMap<>();
Map<VariableReferenceExpression, AggregationNode.Aggregation> finalAggregation = new HashMap<>();
for (Map.Entry<VariableReferenceExpression, AggregationNode.Aggregation> entry : node.getAggregations().entrySet()) {
AggregationNode.Aggregation originalAggregation = entry.getValue();
String functionName = functionAndTypeManager.getFunctionMetadata(originalAggregation.getFunctionHandle()).getName().getObjectName();
FunctionHandle functionHandle = originalAggregation.getFunctionHandle();
InternalAggregationFunction function = functionAndTypeManager.getAggregateFunctionImplementation(functionHandle);
VariableReferenceExpression intermediateVariable = context.getVariableAllocator().newVariable(entry.getValue().getCall().getSourceLocation(), functionName, function.getIntermediateType());
checkState(!originalAggregation.getOrderBy().isPresent(), "Aggregate with ORDER BY does not support partial aggregation");
intermediateAggregation.put(intermediateVariable, new AggregationNode.Aggregation(new CallExpression(originalAggregation.getCall().getSourceLocation(), functionName, functionHandle, function.getIntermediateType(), originalAggregation.getArguments()), originalAggregation.getFilter(), originalAggregation.getOrderBy(), originalAggregation.isDistinct(), originalAggregation.getMask()));
// rewrite final aggregation in terms of intermediate function
finalAggregation.put(entry.getKey(), new AggregationNode.Aggregation(new CallExpression(originalAggregation.getCall().getSourceLocation(), functionName, functionHandle, function.getFinalType(), ImmutableList.<RowExpression>builder().add(intermediateVariable).addAll(originalAggregation.getArguments().stream().filter(PushPartialAggregationThroughExchange::isLambda).collect(toImmutableList())).build()), Optional.empty(), Optional.empty(), false, Optional.empty()));
}
// We can always enable streaming aggregation for partial aggregations. But if the table is not pre-group by the groupby columns, it may have regressions.
// This session property is just a solution to force enabling when we know the execution would benefit from partial streaming aggregation.
// We can work on determining it based on the input table properties later.
List<VariableReferenceExpression> preGroupedSymbols = ImmutableList.of();
if (isStreamingForPartialAggregationEnabled(context.getSession())) {
preGroupedSymbols = ImmutableList.copyOf(node.getGroupingSets().getGroupingKeys());
}
PlanNode partial = new AggregationNode(node.getSourceLocation(), context.getIdAllocator().getNextId(), node.getSource(), intermediateAggregation, node.getGroupingSets(), // through the exchange may or may not preserve these properties. Hence, it is safest to drop preGroupedSymbols here.
preGroupedSymbols, PARTIAL, node.getHashVariable(), node.getGroupIdVariable());
return new AggregationNode(node.getSourceLocation(), node.getId(), partial, finalAggregation, node.getGroupingSets(), // through the exchange may or may not preserve these properties. Hence, it is safest to drop preGroupedSymbols here.
ImmutableList.of(), FINAL, node.getHashVariable(), node.getGroupIdVariable());
}
use of com.facebook.presto.spi.plan.AggregationNode in project presto by prestodb.
the class RewriteAggregationIfToFilter method apply.
@Override
public Result apply(AggregationNode aggregationNode, Captures captures, Context context) {
ProjectNode sourceProject = captures.get(CHILD);
Set<Aggregation> aggregationsToRewrite = aggregationNode.getAggregations().values().stream().filter(aggregation -> shouldRewriteAggregation(aggregation, sourceProject)).collect(toImmutableSet());
if (aggregationsToRewrite.isEmpty()) {
return Result.empty();
}
context.getSession().getRuntimeStats().addMetricValue(REWRITE_AGGREGATION_IF_TO_FILTER_APPLIED, 1);
// Get the corresponding assignments in the input project.
// The aggregationReferences only has the aggregations to rewrite, thus the sourceAssignments only has IF/CAST(IF) expressions with NULL false results.
// Multiple aggregations may reference the same input. We use a map to dedup them based on the VariableReferenceExpression, so that we only do the rewrite once per input
// IF expression.
// The order of sourceAssignments determines the order of generating the new variables for the IF conditions and results. We use a sorted map to get a deterministic
// order based on the name of the VariableReferenceExpressions.
Map<VariableReferenceExpression, RowExpression> sourceAssignments = aggregationsToRewrite.stream().map(aggregation -> (VariableReferenceExpression) aggregation.getArguments().get(0)).collect(toImmutableSortedMap(VariableReferenceExpression::compareTo, identity(), variable -> sourceProject.getAssignments().get(variable), (left, right) -> left));
Assignments.Builder newAssignments = Assignments.builder();
newAssignments.putAll(sourceProject.getAssignments());
// Map from the aggregation reference to the IF condition reference which will be put in the mask.
Map<VariableReferenceExpression, VariableReferenceExpression> aggregationReferenceToConditionReference = new HashMap<>();
// Map from the aggregation reference to the IF result reference. This only contains the aggregates where the IF can be safely unwrapped.
// E.g., SUM(IF(CARDINALITY(array) > 0, array[1])) will not be included in this map as array[1] can return errors if we unwrap the IF.
Map<VariableReferenceExpression, VariableReferenceExpression> aggregationReferenceToIfResultReference = new HashMap<>();
AggregationIfToFilterRewriteStrategy rewriteStrategy = getAggregationIfToFilterRewriteStrategy(context.getSession());
for (Map.Entry<VariableReferenceExpression, RowExpression> entry : sourceAssignments.entrySet()) {
VariableReferenceExpression outputVariable = entry.getKey();
RowExpression rowExpression = entry.getValue();
SpecialFormExpression ifExpression = (SpecialFormExpression) ((rowExpression instanceof CallExpression) ? ((CallExpression) rowExpression).getArguments().get(0) : rowExpression);
RowExpression condition = ifExpression.getArguments().get(0);
VariableReferenceExpression conditionReference = context.getVariableAllocator().newVariable(condition);
newAssignments.put(conditionReference, condition);
aggregationReferenceToConditionReference.put(outputVariable, conditionReference);
if (canUnwrapIf(ifExpression, rewriteStrategy)) {
RowExpression trueResult = ifExpression.getArguments().get(1);
if (rowExpression instanceof CallExpression) {
// Wrap the result with CAST().
trueResult = new CallExpression(((CallExpression) rowExpression).getDisplayName(), ((CallExpression) rowExpression).getFunctionHandle(), rowExpression.getType(), ImmutableList.of(trueResult));
}
VariableReferenceExpression ifResultReference = context.getVariableAllocator().newVariable(trueResult);
newAssignments.put(ifResultReference, trueResult);
aggregationReferenceToIfResultReference.put(outputVariable, ifResultReference);
}
}
// Build new aggregations.
ImmutableMap.Builder<VariableReferenceExpression, Aggregation> aggregations = ImmutableMap.builder();
// Stores the masks used to build the filter predicates. Use set to dedup the predicates.
ImmutableSortedSet.Builder<VariableReferenceExpression> masks = ImmutableSortedSet.naturalOrder();
for (Map.Entry<VariableReferenceExpression, Aggregation> entry : aggregationNode.getAggregations().entrySet()) {
VariableReferenceExpression output = entry.getKey();
Aggregation aggregation = entry.getValue();
if (!aggregationsToRewrite.contains(aggregation)) {
aggregations.put(output, aggregation);
continue;
}
VariableReferenceExpression aggregationReference = (VariableReferenceExpression) aggregation.getArguments().get(0);
CallExpression callExpression = aggregation.getCall();
VariableReferenceExpression ifResultReference = aggregationReferenceToIfResultReference.get(aggregationReference);
if (ifResultReference != null) {
callExpression = new CallExpression(callExpression.getSourceLocation(), callExpression.getDisplayName(), callExpression.getFunctionHandle(), callExpression.getType(), ImmutableList.of(ifResultReference));
}
VariableReferenceExpression mask = aggregationReferenceToConditionReference.get(aggregationReference);
aggregations.put(output, new Aggregation(callExpression, Optional.empty(), aggregation.getOrderBy(), aggregation.isDistinct(), Optional.of(aggregationReferenceToConditionReference.get(aggregationReference))));
masks.add(mask);
}
RowExpression predicate = TRUE_CONSTANT;
if (!aggregationNode.hasNonEmptyGroupingSet() && aggregationsToRewrite.size() == aggregationNode.getAggregations().size()) {
// All aggregations are rewritten by this rule. We can add a filter with all the masks to make the query more efficient.
predicate = or(masks.build());
}
return Result.ofPlanNode(new AggregationNode(aggregationNode.getSourceLocation(), context.getIdAllocator().getNextId(), new FilterNode(aggregationNode.getSourceLocation(), context.getIdAllocator().getNextId(), new ProjectNode(context.getIdAllocator().getNextId(), sourceProject.getSource(), newAssignments.build()), predicate), aggregations.build(), aggregationNode.getGroupingSets(), aggregationNode.getPreGroupedVariables(), aggregationNode.getStep(), aggregationNode.getHashVariable(), aggregationNode.getGroupIdVariable()));
}
use of com.facebook.presto.spi.plan.AggregationNode in project presto by prestodb.
the class ImplementFilteredAggregations method apply.
@Override
public Result apply(AggregationNode aggregation, Captures captures, Context context) {
Assignments.Builder newAssignments = Assignments.builder();
ImmutableMap.Builder<VariableReferenceExpression, Aggregation> aggregations = ImmutableMap.builder();
ImmutableList.Builder<Expression> maskSymbols = ImmutableList.builder();
boolean aggregateWithoutFilterPresent = false;
for (Map.Entry<VariableReferenceExpression, Aggregation> entry : aggregation.getAggregations().entrySet()) {
VariableReferenceExpression output = entry.getKey();
// strip the filters
Optional<VariableReferenceExpression> mask = entry.getValue().getMask();
if (entry.getValue().getFilter().isPresent()) {
// TODO remove cast once assignment can be RowExpression
Expression filter = OriginalExpressionUtils.castToExpression(entry.getValue().getFilter().get());
VariableReferenceExpression variable = context.getVariableAllocator().newVariable(filter, BOOLEAN);
verify(!mask.isPresent(), "Expected aggregation without mask symbols, see Rule pattern");
newAssignments.put(variable, castToRowExpression(filter));
mask = Optional.of(variable);
maskSymbols.add(createSymbolReference(variable));
} else {
aggregateWithoutFilterPresent = true;
}
aggregations.put(output, new Aggregation(entry.getValue().getCall(), Optional.empty(), entry.getValue().getOrderBy(), entry.getValue().isDistinct(), mask));
}
Expression predicate = TRUE_LITERAL;
if (!aggregation.hasNonEmptyGroupingSet() && !aggregateWithoutFilterPresent) {
predicate = combineDisjunctsWithDefault(maskSymbols.build(), TRUE_LITERAL);
}
// identity projection for all existing inputs
newAssignments.putAll(identitiesAsSymbolReferences(aggregation.getSource().getOutputVariables()));
return Result.ofPlanNode(new AggregationNode(aggregation.getSourceLocation(), context.getIdAllocator().getNextId(), new FilterNode(aggregation.getSourceLocation(), context.getIdAllocator().getNextId(), new ProjectNode(context.getIdAllocator().getNextId(), aggregation.getSource(), newAssignments.build()), castToRowExpression(predicate)), aggregations.build(), aggregation.getGroupingSets(), ImmutableList.of(), aggregation.getStep(), aggregation.getHashVariable(), aggregation.getGroupIdVariable()));
}
use of com.facebook.presto.spi.plan.AggregationNode in project presto by prestodb.
the class PruneOrderByInAggregation method apply.
@Override
public Result apply(AggregationNode node, Captures captures, Context context) {
if (!node.hasOrderings()) {
return Result.empty();
}
boolean anyRewritten = false;
ImmutableMap.Builder<VariableReferenceExpression, Aggregation> aggregations = ImmutableMap.builder();
for (Map.Entry<VariableReferenceExpression, Aggregation> entry : node.getAggregations().entrySet()) {
Aggregation aggregation = entry.getValue();
if (!aggregation.getOrderBy().isPresent()) {
aggregations.put(entry);
} else // getAggregateFunctionImplementation can be expensive, so check it last.
if (functionAndTypeManager.getAggregateFunctionImplementation(aggregation.getFunctionHandle()).isOrderSensitive()) {
aggregations.put(entry);
} else {
anyRewritten = true;
aggregations.put(entry.getKey(), new Aggregation(aggregation.getCall(), aggregation.getFilter(), Optional.empty(), aggregation.isDistinct(), aggregation.getMask()));
}
}
if (!anyRewritten) {
return Result.empty();
}
return Result.ofPlanNode(new AggregationNode(node.getSourceLocation(), node.getId(), node.getSource(), aggregations.build(), node.getGroupingSets(), node.getPreGroupedVariables(), node.getStep(), node.getHashVariable(), node.getGroupIdVariable()));
}
use of com.facebook.presto.spi.plan.AggregationNode in project presto by prestodb.
the class ScalarAggregationToJoinRewriter method rewriteScalarAggregation.
private PlanNode rewriteScalarAggregation(LateralJoinNode lateralJoinNode, AggregationNode scalarAggregation, PlanNode scalarAggregationSource, Optional<Expression> joinExpression, VariableReferenceExpression nonNull) {
AssignUniqueId inputWithUniqueColumns = new AssignUniqueId(lateralJoinNode.getSourceLocation(), idAllocator.getNextId(), lateralJoinNode.getInput(), variableAllocator.newVariable(nonNull.getSourceLocation(), "unique", BIGINT));
JoinNode leftOuterJoin = new JoinNode(scalarAggregation.getSourceLocation(), idAllocator.getNextId(), JoinNode.Type.LEFT, inputWithUniqueColumns, scalarAggregationSource, ImmutableList.of(), ImmutableList.<VariableReferenceExpression>builder().addAll(inputWithUniqueColumns.getOutputVariables()).addAll(scalarAggregationSource.getOutputVariables()).build(), joinExpression.map(OriginalExpressionUtils::castToRowExpression), Optional.empty(), Optional.empty(), Optional.empty(), ImmutableMap.of());
Optional<AggregationNode> aggregationNode = createAggregationNode(scalarAggregation, leftOuterJoin, nonNull);
if (!aggregationNode.isPresent()) {
return lateralJoinNode;
}
Optional<ProjectNode> subqueryProjection = searchFrom(lateralJoinNode.getSubquery(), lookup).where(ProjectNode.class::isInstance).recurseOnlyWhen(EnforceSingleRowNode.class::isInstance).findFirst();
List<VariableReferenceExpression> aggregationOutputVariables = getTruncatedAggregationVariables(lateralJoinNode, aggregationNode.get());
if (subqueryProjection.isPresent()) {
Assignments assignments = Assignments.builder().putAll(identitiesAsSymbolReferences(aggregationOutputVariables)).putAll(subqueryProjection.get().getAssignments()).build();
return new ProjectNode(idAllocator.getNextId(), aggregationNode.get(), assignments);
} else {
return new ProjectNode(idAllocator.getNextId(), aggregationNode.get(), identityAssignmentsAsSymbolReferences(aggregationOutputVariables));
}
}
Aggregations