use of com.facebook.presto.spi.relation.CallExpression in project presto by prestodb.
the class ExtractSpatialJoins method tryCreateSpatialJoin.
private static Result tryCreateSpatialJoin(Context context, JoinNode joinNode, RowExpression filter, PlanNodeId nodeId, List<VariableReferenceExpression> outputVariables, CallExpression spatialFunction, Optional<RowExpression> radius, Metadata metadata, SplitManager splitManager, PageSourceManager pageSourceManager) {
FunctionAndTypeManager functionAndTypeManager = metadata.getFunctionAndTypeManager();
List<RowExpression> arguments = spatialFunction.getArguments();
verify(arguments.size() == 2);
RowExpression firstArgument = arguments.get(0);
RowExpression secondArgument = arguments.get(1);
// Currently, only inner joins are supported for spherical geometries.
if (joinNode.getType() != INNER && isSphericalJoin(metadata, firstArgument, secondArgument)) {
return Result.empty();
}
Set<VariableReferenceExpression> firstVariables = extractUnique(firstArgument);
Set<VariableReferenceExpression> secondVariables = extractUnique(secondArgument);
if (firstVariables.isEmpty() || secondVariables.isEmpty()) {
return Result.empty();
}
// If either firstArgument or secondArgument is not a
// VariableReferenceExpression, will replace the left/right join node
// with a projection that adds the argument as a variable.
Optional<VariableReferenceExpression> newFirstVariable = newGeometryVariable(context, firstArgument);
Optional<VariableReferenceExpression> newSecondVariable = newGeometryVariable(context, secondArgument);
PlanNode leftNode = joinNode.getLeft();
PlanNode rightNode = joinNode.getRight();
PlanNode newLeftNode;
PlanNode newRightNode;
// Check if the order of arguments of the spatial function matches the order of join sides
int alignment = checkAlignment(joinNode, firstVariables, secondVariables);
if (alignment > 0) {
newLeftNode = newFirstVariable.map(variable -> addProjection(context, leftNode, variable, firstArgument)).orElse(leftNode);
newRightNode = newSecondVariable.map(variable -> addProjection(context, rightNode, variable, secondArgument)).orElse(rightNode);
} else if (alignment < 0) {
newLeftNode = newSecondVariable.map(variable -> addProjection(context, leftNode, variable, secondArgument)).orElse(leftNode);
newRightNode = newFirstVariable.map(variable -> addProjection(context, rightNode, variable, firstArgument)).orElse(rightNode);
} else {
return Result.empty();
}
RowExpression newFirstArgument = mapToExpression(newFirstVariable, firstArgument);
RowExpression newSecondArgument = mapToExpression(newSecondVariable, secondArgument);
// Implement partitioned spatial joins:
// If the session parameter points to a valid spatial partitioning, use
// that to assign to each probe and build rows the partitions that the
// geometry intersects. This is a projection that adds an array of ints
// which is subsequently unnested.
Optional<String> spatialPartitioningTableName = canPartitionSpatialJoin(joinNode) ? getSpatialPartitioningTableName(context.getSession()) : Optional.empty();
Optional<KdbTree> kdbTree = spatialPartitioningTableName.map(tableName -> loadKdbTree(tableName, context.getSession(), metadata, splitManager, pageSourceManager));
Optional<VariableReferenceExpression> leftPartitionVariable = Optional.empty();
Optional<VariableReferenceExpression> rightPartitionVariable = Optional.empty();
if (kdbTree.isPresent()) {
leftPartitionVariable = Optional.of(context.getVariableAllocator().newVariable(newFirstArgument.getSourceLocation(), "pid", INTEGER));
rightPartitionVariable = Optional.of(context.getVariableAllocator().newVariable(newSecondArgument.getSourceLocation(), "pid", INTEGER));
if (alignment > 0) {
newLeftNode = addPartitioningNodes(context, functionAndTypeManager, newLeftNode, leftPartitionVariable.get(), kdbTree.get(), newFirstArgument, Optional.empty());
newRightNode = addPartitioningNodes(context, functionAndTypeManager, newRightNode, rightPartitionVariable.get(), kdbTree.get(), newSecondArgument, radius);
} else {
newLeftNode = addPartitioningNodes(context, functionAndTypeManager, newLeftNode, leftPartitionVariable.get(), kdbTree.get(), newSecondArgument, Optional.empty());
newRightNode = addPartitioningNodes(context, functionAndTypeManager, newRightNode, rightPartitionVariable.get(), kdbTree.get(), newFirstArgument, radius);
}
}
CallExpression newSpatialFunction = new CallExpression(spatialFunction.getSourceLocation(), spatialFunction.getDisplayName(), spatialFunction.getFunctionHandle(), spatialFunction.getType(), ImmutableList.of(newFirstArgument, newSecondArgument));
RowExpression newFilter = RowExpressionNodeInliner.replaceExpression(filter, ImmutableMap.of(spatialFunction, newSpatialFunction));
return Result.ofPlanNode(new SpatialJoinNode(joinNode.getSourceLocation(), nodeId, SpatialJoinNode.Type.fromJoinNodeType(joinNode.getType()), newLeftNode, newRightNode, outputVariables, newFilter, leftPartitionVariable, rightPartitionVariable, kdbTree.map(KdbTreeUtils::toJson)));
}
use of com.facebook.presto.spi.relation.CallExpression in project presto by prestodb.
the class PushAggregationThroughOuterJoin method createAggregationOverNull.
private Optional<MappedAggregationInfo> createAggregationOverNull(AggregationNode referenceAggregation, PlanVariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, Lookup lookup) {
// Create a values node that consists of a single row of nulls.
// Map the output symbols from the referenceAggregation's source
// to symbol references for the new values node.
ImmutableList.Builder<VariableReferenceExpression> nullVariables = ImmutableList.builder();
ImmutableList.Builder<RowExpression> nullLiterals = ImmutableList.builder();
ImmutableMap.Builder<VariableReferenceExpression, VariableReferenceExpression> sourcesVariableMappingBuilder = ImmutableMap.builder();
for (VariableReferenceExpression sourceVariable : referenceAggregation.getSource().getOutputVariables()) {
RowExpression nullLiteral = constantNull(sourceVariable.getSourceLocation(), sourceVariable.getType());
nullLiterals.add(nullLiteral);
VariableReferenceExpression nullVariable = variableAllocator.newVariable(nullLiteral);
nullVariables.add(nullVariable);
// TODO The type should be from sourceVariable.getType
sourcesVariableMappingBuilder.put(sourceVariable, nullVariable);
}
ValuesNode nullRow = new ValuesNode(referenceAggregation.getSourceLocation(), idAllocator.getNextId(), nullVariables.build(), ImmutableList.of(nullLiterals.build()));
Map<VariableReferenceExpression, VariableReferenceExpression> sourcesVariableMapping = sourcesVariableMappingBuilder.build();
// For each aggregation function in the reference node, create a corresponding aggregation function
// that points to the nullRow. Map the symbols from the aggregations in referenceAggregation to the
// symbols in these new aggregations.
ImmutableMap.Builder<VariableReferenceExpression, VariableReferenceExpression> aggregationsVariableMappingBuilder = ImmutableMap.builder();
ImmutableMap.Builder<VariableReferenceExpression, AggregationNode.Aggregation> aggregationsOverNullBuilder = ImmutableMap.builder();
for (Map.Entry<VariableReferenceExpression, AggregationNode.Aggregation> entry : referenceAggregation.getAggregations().entrySet()) {
VariableReferenceExpression aggregationVariable = entry.getKey();
AggregationNode.Aggregation aggregation = entry.getValue();
if (!isUsingVariables(aggregation, sourcesVariableMapping.keySet())) {
return Optional.empty();
}
AggregationNode.Aggregation overNullAggregation = new AggregationNode.Aggregation(new CallExpression(aggregation.getCall().getSourceLocation(), aggregation.getCall().getDisplayName(), aggregation.getCall().getFunctionHandle(), aggregation.getCall().getType(), aggregation.getArguments().stream().map(argument -> inlineVariables(sourcesVariableMapping, argument)).collect(toImmutableList())), aggregation.getFilter().map(filter -> inlineVariables(sourcesVariableMapping, filter)), aggregation.getOrderBy().map(orderBy -> inlineOrderByVariables(sourcesVariableMapping, orderBy)), aggregation.isDistinct(), aggregation.getMask().map(x -> new VariableReferenceExpression(sourcesVariableMapping.get(x).getSourceLocation(), sourcesVariableMapping.get(x).getName(), x.getType())));
QualifiedObjectName functionName = functionAndTypeManager.getFunctionMetadata(overNullAggregation.getFunctionHandle()).getName();
VariableReferenceExpression overNull = variableAllocator.newVariable(aggregation.getCall().getSourceLocation(), functionName.getObjectName(), aggregationVariable.getType());
aggregationsOverNullBuilder.put(overNull, overNullAggregation);
aggregationsVariableMappingBuilder.put(aggregationVariable, overNull);
}
Map<VariableReferenceExpression, VariableReferenceExpression> aggregationsSymbolMapping = aggregationsVariableMappingBuilder.build();
// create an aggregation node whose source is the null row.
AggregationNode aggregationOverNullRow = new AggregationNode(referenceAggregation.getSourceLocation(), idAllocator.getNextId(), nullRow, aggregationsOverNullBuilder.build(), globalAggregation(), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty());
return Optional.of(new MappedAggregationInfo(aggregationOverNullRow, aggregationsSymbolMapping));
}
use of com.facebook.presto.spi.relation.CallExpression in project presto by prestodb.
the class RewriteSpatialPartitioningAggregation method apply.
@Override
public Result apply(AggregationNode node, Captures captures, Context context) {
ImmutableMap.Builder<VariableReferenceExpression, Aggregation> aggregations = ImmutableMap.builder();
VariableReferenceExpression partitionCountVariable = context.getVariableAllocator().newVariable("partition_count", INTEGER);
ImmutableMap.Builder<VariableReferenceExpression, RowExpression> envelopeAssignments = ImmutableMap.builder();
for (Map.Entry<VariableReferenceExpression, Aggregation> entry : node.getAggregations().entrySet()) {
Aggregation aggregation = entry.getValue();
QualifiedObjectName name = metadata.getFunctionAndTypeManager().getFunctionMetadata(aggregation.getFunctionHandle()).getName();
Type geometryType = metadata.getType(GEOMETRY_TYPE_SIGNATURE);
if (name.equals(NAME) && aggregation.getArguments().size() == 1) {
RowExpression geometry = getOnlyElement(aggregation.getArguments());
VariableReferenceExpression envelopeVariable = context.getVariableAllocator().newVariable(aggregation.getCall().getSourceLocation(), "envelope", geometryType);
if (isFunctionNameMatch(geometry, "ST_Envelope")) {
envelopeAssignments.put(envelopeVariable, geometry);
} else {
envelopeAssignments.put(envelopeVariable, castToRowExpression(new FunctionCall(QualifiedName.of("ST_Envelope"), ImmutableList.of(castToExpression(geometry)))));
}
aggregations.put(entry.getKey(), new Aggregation(new CallExpression(envelopeVariable.getSourceLocation(), name.getObjectName(), metadata.getFunctionAndTypeManager().lookupFunction(NAME.getObjectName(), fromTypes(geometryType, INTEGER)), entry.getKey().getType(), ImmutableList.of(castToRowExpression(asSymbolReference(envelopeVariable)), castToRowExpression(asSymbolReference(partitionCountVariable)))), Optional.empty(), Optional.empty(), false, aggregation.getMask()));
} else {
aggregations.put(entry);
}
}
return Result.ofPlanNode(new AggregationNode(node.getSourceLocation(), node.getId(), new ProjectNode(context.getIdAllocator().getNextId(), node.getSource(), Assignments.builder().putAll(identitiesAsSymbolReferences(node.getSource().getOutputVariables())).put(partitionCountVariable, castToRowExpression(new LongLiteral(Integer.toString(getHashPartitionCount(context.getSession()))))).putAll(envelopeAssignments.build()).build()), aggregations.build(), node.getGroupingSets(), node.getPreGroupedVariables(), node.getStep(), node.getHashVariable(), node.getGroupIdVariable()));
}
use of com.facebook.presto.spi.relation.CallExpression in project presto by prestodb.
the class PushPartialAggregationThroughExchange method split.
private PlanNode split(AggregationNode node, Context context) {
// otherwise, add a partial and final with an exchange in between
Map<VariableReferenceExpression, AggregationNode.Aggregation> intermediateAggregation = new HashMap<>();
Map<VariableReferenceExpression, AggregationNode.Aggregation> finalAggregation = new HashMap<>();
for (Map.Entry<VariableReferenceExpression, AggregationNode.Aggregation> entry : node.getAggregations().entrySet()) {
AggregationNode.Aggregation originalAggregation = entry.getValue();
String functionName = functionAndTypeManager.getFunctionMetadata(originalAggregation.getFunctionHandle()).getName().getObjectName();
FunctionHandle functionHandle = originalAggregation.getFunctionHandle();
InternalAggregationFunction function = functionAndTypeManager.getAggregateFunctionImplementation(functionHandle);
VariableReferenceExpression intermediateVariable = context.getVariableAllocator().newVariable(entry.getValue().getCall().getSourceLocation(), functionName, function.getIntermediateType());
checkState(!originalAggregation.getOrderBy().isPresent(), "Aggregate with ORDER BY does not support partial aggregation");
intermediateAggregation.put(intermediateVariable, new AggregationNode.Aggregation(new CallExpression(originalAggregation.getCall().getSourceLocation(), functionName, functionHandle, function.getIntermediateType(), originalAggregation.getArguments()), originalAggregation.getFilter(), originalAggregation.getOrderBy(), originalAggregation.isDistinct(), originalAggregation.getMask()));
// rewrite final aggregation in terms of intermediate function
finalAggregation.put(entry.getKey(), new AggregationNode.Aggregation(new CallExpression(originalAggregation.getCall().getSourceLocation(), functionName, functionHandle, function.getFinalType(), ImmutableList.<RowExpression>builder().add(intermediateVariable).addAll(originalAggregation.getArguments().stream().filter(PushPartialAggregationThroughExchange::isLambda).collect(toImmutableList())).build()), Optional.empty(), Optional.empty(), false, Optional.empty()));
}
// We can always enable streaming aggregation for partial aggregations. But if the table is not pre-group by the groupby columns, it may have regressions.
// This session property is just a solution to force enabling when we know the execution would benefit from partial streaming aggregation.
// We can work on determining it based on the input table properties later.
List<VariableReferenceExpression> preGroupedSymbols = ImmutableList.of();
if (isStreamingForPartialAggregationEnabled(context.getSession())) {
preGroupedSymbols = ImmutableList.copyOf(node.getGroupingSets().getGroupingKeys());
}
PlanNode partial = new AggregationNode(node.getSourceLocation(), context.getIdAllocator().getNextId(), node.getSource(), intermediateAggregation, node.getGroupingSets(), // through the exchange may or may not preserve these properties. Hence, it is safest to drop preGroupedSymbols here.
preGroupedSymbols, PARTIAL, node.getHashVariable(), node.getGroupIdVariable());
return new AggregationNode(node.getSourceLocation(), node.getId(), partial, finalAggregation, node.getGroupingSets(), // through the exchange may or may not preserve these properties. Hence, it is safest to drop preGroupedSymbols here.
ImmutableList.of(), FINAL, node.getHashVariable(), node.getGroupIdVariable());
}
use of com.facebook.presto.spi.relation.CallExpression in project presto by prestodb.
the class RewriteAggregationIfToFilter method apply.
@Override
public Result apply(AggregationNode aggregationNode, Captures captures, Context context) {
ProjectNode sourceProject = captures.get(CHILD);
Set<Aggregation> aggregationsToRewrite = aggregationNode.getAggregations().values().stream().filter(aggregation -> shouldRewriteAggregation(aggregation, sourceProject)).collect(toImmutableSet());
if (aggregationsToRewrite.isEmpty()) {
return Result.empty();
}
context.getSession().getRuntimeStats().addMetricValue(REWRITE_AGGREGATION_IF_TO_FILTER_APPLIED, 1);
// Get the corresponding assignments in the input project.
// The aggregationReferences only has the aggregations to rewrite, thus the sourceAssignments only has IF/CAST(IF) expressions with NULL false results.
// Multiple aggregations may reference the same input. We use a map to dedup them based on the VariableReferenceExpression, so that we only do the rewrite once per input
// IF expression.
// The order of sourceAssignments determines the order of generating the new variables for the IF conditions and results. We use a sorted map to get a deterministic
// order based on the name of the VariableReferenceExpressions.
Map<VariableReferenceExpression, RowExpression> sourceAssignments = aggregationsToRewrite.stream().map(aggregation -> (VariableReferenceExpression) aggregation.getArguments().get(0)).collect(toImmutableSortedMap(VariableReferenceExpression::compareTo, identity(), variable -> sourceProject.getAssignments().get(variable), (left, right) -> left));
Assignments.Builder newAssignments = Assignments.builder();
newAssignments.putAll(sourceProject.getAssignments());
// Map from the aggregation reference to the IF condition reference which will be put in the mask.
Map<VariableReferenceExpression, VariableReferenceExpression> aggregationReferenceToConditionReference = new HashMap<>();
// Map from the aggregation reference to the IF result reference. This only contains the aggregates where the IF can be safely unwrapped.
// E.g., SUM(IF(CARDINALITY(array) > 0, array[1])) will not be included in this map as array[1] can return errors if we unwrap the IF.
Map<VariableReferenceExpression, VariableReferenceExpression> aggregationReferenceToIfResultReference = new HashMap<>();
AggregationIfToFilterRewriteStrategy rewriteStrategy = getAggregationIfToFilterRewriteStrategy(context.getSession());
for (Map.Entry<VariableReferenceExpression, RowExpression> entry : sourceAssignments.entrySet()) {
VariableReferenceExpression outputVariable = entry.getKey();
RowExpression rowExpression = entry.getValue();
SpecialFormExpression ifExpression = (SpecialFormExpression) ((rowExpression instanceof CallExpression) ? ((CallExpression) rowExpression).getArguments().get(0) : rowExpression);
RowExpression condition = ifExpression.getArguments().get(0);
VariableReferenceExpression conditionReference = context.getVariableAllocator().newVariable(condition);
newAssignments.put(conditionReference, condition);
aggregationReferenceToConditionReference.put(outputVariable, conditionReference);
if (canUnwrapIf(ifExpression, rewriteStrategy)) {
RowExpression trueResult = ifExpression.getArguments().get(1);
if (rowExpression instanceof CallExpression) {
// Wrap the result with CAST().
trueResult = new CallExpression(((CallExpression) rowExpression).getDisplayName(), ((CallExpression) rowExpression).getFunctionHandle(), rowExpression.getType(), ImmutableList.of(trueResult));
}
VariableReferenceExpression ifResultReference = context.getVariableAllocator().newVariable(trueResult);
newAssignments.put(ifResultReference, trueResult);
aggregationReferenceToIfResultReference.put(outputVariable, ifResultReference);
}
}
// Build new aggregations.
ImmutableMap.Builder<VariableReferenceExpression, Aggregation> aggregations = ImmutableMap.builder();
// Stores the masks used to build the filter predicates. Use set to dedup the predicates.
ImmutableSortedSet.Builder<VariableReferenceExpression> masks = ImmutableSortedSet.naturalOrder();
for (Map.Entry<VariableReferenceExpression, Aggregation> entry : aggregationNode.getAggregations().entrySet()) {
VariableReferenceExpression output = entry.getKey();
Aggregation aggregation = entry.getValue();
if (!aggregationsToRewrite.contains(aggregation)) {
aggregations.put(output, aggregation);
continue;
}
VariableReferenceExpression aggregationReference = (VariableReferenceExpression) aggregation.getArguments().get(0);
CallExpression callExpression = aggregation.getCall();
VariableReferenceExpression ifResultReference = aggregationReferenceToIfResultReference.get(aggregationReference);
if (ifResultReference != null) {
callExpression = new CallExpression(callExpression.getSourceLocation(), callExpression.getDisplayName(), callExpression.getFunctionHandle(), callExpression.getType(), ImmutableList.of(ifResultReference));
}
VariableReferenceExpression mask = aggregationReferenceToConditionReference.get(aggregationReference);
aggregations.put(output, new Aggregation(callExpression, Optional.empty(), aggregation.getOrderBy(), aggregation.isDistinct(), Optional.of(aggregationReferenceToConditionReference.get(aggregationReference))));
masks.add(mask);
}
RowExpression predicate = TRUE_CONSTANT;
if (!aggregationNode.hasNonEmptyGroupingSet() && aggregationsToRewrite.size() == aggregationNode.getAggregations().size()) {
// All aggregations are rewritten by this rule. We can add a filter with all the masks to make the query more efficient.
predicate = or(masks.build());
}
return Result.ofPlanNode(new AggregationNode(aggregationNode.getSourceLocation(), context.getIdAllocator().getNextId(), new FilterNode(aggregationNode.getSourceLocation(), context.getIdAllocator().getNextId(), new ProjectNode(context.getIdAllocator().getNextId(), sourceProject.getSource(), newAssignments.build()), predicate), aggregations.build(), aggregationNode.getGroupingSets(), aggregationNode.getPreGroupedVariables(), aggregationNode.getStep(), aggregationNode.getHashVariable(), aggregationNode.getGroupIdVariable()));
}
Aggregations