use of com.facebook.presto.spi.plan.AggregationNode in project presto by prestodb.
the class PruneCountAggregationOverScalar method apply.
@Override
public Result apply(AggregationNode parent, Captures captures, Context context) {
if (!parent.hasDefaultOutput() || parent.getOutputVariables().size() != 1) {
return Result.empty();
}
Map<VariableReferenceExpression, AggregationNode.Aggregation> assignments = parent.getAggregations();
for (Map.Entry<VariableReferenceExpression, AggregationNode.Aggregation> entry : assignments.entrySet()) {
AggregationNode.Aggregation aggregation = entry.getValue();
requireNonNull(aggregation, "aggregation is null");
if (!functionResolution.isCountFunction(aggregation.getFunctionHandle()) || !aggregation.getArguments().isEmpty()) {
return Result.empty();
}
}
if (!assignments.isEmpty() && isScalar(parent.getSource(), context.getLookup())) {
return Result.ofPlanNode(new ValuesNode(parent.getSourceLocation(), parent.getId(), parent.getOutputVariables(), ImmutableList.of(ImmutableList.of(constant(1L, BIGINT)))));
}
return Result.empty();
}
use of com.facebook.presto.spi.plan.AggregationNode in project presto by prestodb.
the class PushAggregationThroughOuterJoin method createAggregationOverNull.
private Optional<MappedAggregationInfo> createAggregationOverNull(AggregationNode referenceAggregation, PlanVariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, Lookup lookup) {
// Create a values node that consists of a single row of nulls.
// Map the output symbols from the referenceAggregation's source
// to symbol references for the new values node.
ImmutableList.Builder<VariableReferenceExpression> nullVariables = ImmutableList.builder();
ImmutableList.Builder<RowExpression> nullLiterals = ImmutableList.builder();
ImmutableMap.Builder<VariableReferenceExpression, VariableReferenceExpression> sourcesVariableMappingBuilder = ImmutableMap.builder();
for (VariableReferenceExpression sourceVariable : referenceAggregation.getSource().getOutputVariables()) {
RowExpression nullLiteral = constantNull(sourceVariable.getSourceLocation(), sourceVariable.getType());
nullLiterals.add(nullLiteral);
VariableReferenceExpression nullVariable = variableAllocator.newVariable(nullLiteral);
nullVariables.add(nullVariable);
// TODO The type should be from sourceVariable.getType
sourcesVariableMappingBuilder.put(sourceVariable, nullVariable);
}
ValuesNode nullRow = new ValuesNode(referenceAggregation.getSourceLocation(), idAllocator.getNextId(), nullVariables.build(), ImmutableList.of(nullLiterals.build()));
Map<VariableReferenceExpression, VariableReferenceExpression> sourcesVariableMapping = sourcesVariableMappingBuilder.build();
// For each aggregation function in the reference node, create a corresponding aggregation function
// that points to the nullRow. Map the symbols from the aggregations in referenceAggregation to the
// symbols in these new aggregations.
ImmutableMap.Builder<VariableReferenceExpression, VariableReferenceExpression> aggregationsVariableMappingBuilder = ImmutableMap.builder();
ImmutableMap.Builder<VariableReferenceExpression, AggregationNode.Aggregation> aggregationsOverNullBuilder = ImmutableMap.builder();
for (Map.Entry<VariableReferenceExpression, AggregationNode.Aggregation> entry : referenceAggregation.getAggregations().entrySet()) {
VariableReferenceExpression aggregationVariable = entry.getKey();
AggregationNode.Aggregation aggregation = entry.getValue();
if (!isUsingVariables(aggregation, sourcesVariableMapping.keySet())) {
return Optional.empty();
}
AggregationNode.Aggregation overNullAggregation = new AggregationNode.Aggregation(new CallExpression(aggregation.getCall().getSourceLocation(), aggregation.getCall().getDisplayName(), aggregation.getCall().getFunctionHandle(), aggregation.getCall().getType(), aggregation.getArguments().stream().map(argument -> inlineVariables(sourcesVariableMapping, argument)).collect(toImmutableList())), aggregation.getFilter().map(filter -> inlineVariables(sourcesVariableMapping, filter)), aggregation.getOrderBy().map(orderBy -> inlineOrderByVariables(sourcesVariableMapping, orderBy)), aggregation.isDistinct(), aggregation.getMask().map(x -> new VariableReferenceExpression(sourcesVariableMapping.get(x).getSourceLocation(), sourcesVariableMapping.get(x).getName(), x.getType())));
QualifiedObjectName functionName = functionAndTypeManager.getFunctionMetadata(overNullAggregation.getFunctionHandle()).getName();
VariableReferenceExpression overNull = variableAllocator.newVariable(aggregation.getCall().getSourceLocation(), functionName.getObjectName(), aggregationVariable.getType());
aggregationsOverNullBuilder.put(overNull, overNullAggregation);
aggregationsVariableMappingBuilder.put(aggregationVariable, overNull);
}
Map<VariableReferenceExpression, VariableReferenceExpression> aggregationsSymbolMapping = aggregationsVariableMappingBuilder.build();
// create an aggregation node whose source is the null row.
AggregationNode aggregationOverNullRow = new AggregationNode(referenceAggregation.getSourceLocation(), idAllocator.getNextId(), nullRow, aggregationsOverNullBuilder.build(), globalAggregation(), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty());
return Optional.of(new MappedAggregationInfo(aggregationOverNullRow, aggregationsSymbolMapping));
}
use of com.facebook.presto.spi.plan.AggregationNode in project presto by prestodb.
the class PushAggregationThroughOuterJoin method apply.
@Override
public Result apply(AggregationNode aggregation, Captures captures, Context context) {
JoinNode join = captures.get(JOIN);
if (join.getFilter().isPresent() || !(join.getType() == JoinNode.Type.LEFT || join.getType() == JoinNode.Type.RIGHT) || !groupsOnAllColumns(aggregation, getOuterTable(join).getOutputVariables()) || !isDistinct(context.getLookup().resolve(getOuterTable(join)), context.getLookup()::resolve)) {
return Result.empty();
}
List<VariableReferenceExpression> groupingKeys = join.getCriteria().stream().map(join.getType() == JoinNode.Type.RIGHT ? JoinNode.EquiJoinClause::getLeft : JoinNode.EquiJoinClause::getRight).collect(toImmutableList());
AggregationNode rewrittenAggregation = new AggregationNode(aggregation.getSourceLocation(), aggregation.getId(), getInnerTable(join), aggregation.getAggregations(), singleGroupingSet(groupingKeys), ImmutableList.of(), aggregation.getStep(), aggregation.getHashVariable(), aggregation.getGroupIdVariable());
JoinNode rewrittenJoin;
if (join.getType() == JoinNode.Type.LEFT) {
rewrittenJoin = new JoinNode(join.getSourceLocation(), join.getId(), join.getType(), join.getLeft(), rewrittenAggregation, join.getCriteria(), ImmutableList.<VariableReferenceExpression>builder().addAll(join.getLeft().getOutputVariables()).addAll(rewrittenAggregation.getAggregations().keySet()).build(), join.getFilter(), join.getLeftHashVariable(), join.getRightHashVariable(), join.getDistributionType(), join.getDynamicFilters());
} else {
rewrittenJoin = new JoinNode(join.getSourceLocation(), join.getId(), join.getType(), rewrittenAggregation, join.getRight(), join.getCriteria(), ImmutableList.<VariableReferenceExpression>builder().addAll(rewrittenAggregation.getAggregations().keySet()).addAll(join.getRight().getOutputVariables()).build(), join.getFilter(), join.getLeftHashVariable(), join.getRightHashVariable(), join.getDistributionType(), join.getDynamicFilters());
}
Optional<PlanNode> resultNode = coalesceWithNullAggregation(rewrittenAggregation, rewrittenJoin, context.getVariableAllocator(), context.getIdAllocator(), context.getLookup());
if (!resultNode.isPresent()) {
return Result.empty();
}
return Result.ofPlanNode(resultNode.get());
}
use of com.facebook.presto.spi.plan.AggregationNode in project presto by prestodb.
the class PushAggregationThroughOuterJoin method coalesceWithNullAggregation.
// When the aggregation is done after the join, there will be a null value that gets aggregated over
// where rows did not exist in the inner table. For some aggregate functions, such as count, the result
// of an aggregation over a single null row is one or zero rather than null. In order to ensure correct results,
// we add a coalesce function with the output of the new outer join and the agggregation performed over a single
// null row.
private Optional<PlanNode> coalesceWithNullAggregation(AggregationNode aggregationNode, PlanNode outerJoin, PlanVariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, Lookup lookup) {
// Create an aggregation node over a row of nulls.
Optional<MappedAggregationInfo> aggregationOverNullInfoResultNode = createAggregationOverNull(aggregationNode, variableAllocator, idAllocator, lookup);
if (!aggregationOverNullInfoResultNode.isPresent()) {
return Optional.empty();
}
MappedAggregationInfo aggregationOverNullInfo = aggregationOverNullInfoResultNode.get();
AggregationNode aggregationOverNull = aggregationOverNullInfo.getAggregation();
Map<VariableReferenceExpression, VariableReferenceExpression> sourceAggregationToOverNullMapping = aggregationOverNullInfo.getVariableMapping();
// Do a cross join with the aggregation over null
JoinNode crossJoin = new JoinNode(outerJoin.getSourceLocation(), idAllocator.getNextId(), JoinNode.Type.INNER, outerJoin, aggregationOverNull, ImmutableList.of(), ImmutableList.<VariableReferenceExpression>builder().addAll(outerJoin.getOutputVariables()).addAll(aggregationOverNull.getOutputVariables()).build(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), ImmutableMap.of());
// Add coalesce expressions for all aggregation functions
Assignments.Builder assignmentsBuilder = Assignments.builder();
for (VariableReferenceExpression variable : outerJoin.getOutputVariables()) {
if (aggregationNode.getAggregations().keySet().contains(variable)) {
assignmentsBuilder.put(variable, coalesce(ImmutableList.of(variable, sourceAggregationToOverNullMapping.get(variable))));
} else {
assignmentsBuilder.put(variable, variable);
}
}
return Optional.of(new ProjectNode(idAllocator.getNextId(), crossJoin, assignmentsBuilder.build()));
}
use of com.facebook.presto.spi.plan.AggregationNode in project presto by prestodb.
the class RewriteSpatialPartitioningAggregation method apply.
@Override
public Result apply(AggregationNode node, Captures captures, Context context) {
ImmutableMap.Builder<VariableReferenceExpression, Aggregation> aggregations = ImmutableMap.builder();
VariableReferenceExpression partitionCountVariable = context.getVariableAllocator().newVariable("partition_count", INTEGER);
ImmutableMap.Builder<VariableReferenceExpression, RowExpression> envelopeAssignments = ImmutableMap.builder();
for (Map.Entry<VariableReferenceExpression, Aggregation> entry : node.getAggregations().entrySet()) {
Aggregation aggregation = entry.getValue();
QualifiedObjectName name = metadata.getFunctionAndTypeManager().getFunctionMetadata(aggregation.getFunctionHandle()).getName();
Type geometryType = metadata.getType(GEOMETRY_TYPE_SIGNATURE);
if (name.equals(NAME) && aggregation.getArguments().size() == 1) {
RowExpression geometry = getOnlyElement(aggregation.getArguments());
VariableReferenceExpression envelopeVariable = context.getVariableAllocator().newVariable(aggregation.getCall().getSourceLocation(), "envelope", geometryType);
if (isFunctionNameMatch(geometry, "ST_Envelope")) {
envelopeAssignments.put(envelopeVariable, geometry);
} else {
envelopeAssignments.put(envelopeVariable, castToRowExpression(new FunctionCall(QualifiedName.of("ST_Envelope"), ImmutableList.of(castToExpression(geometry)))));
}
aggregations.put(entry.getKey(), new Aggregation(new CallExpression(envelopeVariable.getSourceLocation(), name.getObjectName(), metadata.getFunctionAndTypeManager().lookupFunction(NAME.getObjectName(), fromTypes(geometryType, INTEGER)), entry.getKey().getType(), ImmutableList.of(castToRowExpression(asSymbolReference(envelopeVariable)), castToRowExpression(asSymbolReference(partitionCountVariable)))), Optional.empty(), Optional.empty(), false, aggregation.getMask()));
} else {
aggregations.put(entry);
}
}
return Result.ofPlanNode(new AggregationNode(node.getSourceLocation(), node.getId(), new ProjectNode(context.getIdAllocator().getNextId(), node.getSource(), Assignments.builder().putAll(identitiesAsSymbolReferences(node.getSource().getOutputVariables())).put(partitionCountVariable, castToRowExpression(new LongLiteral(Integer.toString(getHashPartitionCount(context.getSession()))))).putAll(envelopeAssignments.build()).build()), aggregations.build(), node.getGroupingSets(), node.getPreGroupedVariables(), node.getStep(), node.getHashVariable(), node.getGroupIdVariable()));
}
Aggregations