Search in sources :

Example 26 with CallExpression

use of com.facebook.presto.spi.relation.CallExpression in project presto by prestodb.

the class ExtractSpatialJoins method tryCreateSpatialJoin.

private static Result tryCreateSpatialJoin(Context context, JoinNode joinNode, RowExpression filter, PlanNodeId nodeId, List<VariableReferenceExpression> outputVariables, CallExpression spatialFunction, Optional<RowExpression> radius, Metadata metadata, SplitManager splitManager, PageSourceManager pageSourceManager) {
    FunctionAndTypeManager functionAndTypeManager = metadata.getFunctionAndTypeManager();
    List<RowExpression> arguments = spatialFunction.getArguments();
    verify(arguments.size() == 2);
    RowExpression firstArgument = arguments.get(0);
    RowExpression secondArgument = arguments.get(1);
    // Currently, only inner joins are supported for spherical geometries.
    if (joinNode.getType() != INNER && isSphericalJoin(metadata, firstArgument, secondArgument)) {
        return Result.empty();
    }
    Set<VariableReferenceExpression> firstVariables = extractUnique(firstArgument);
    Set<VariableReferenceExpression> secondVariables = extractUnique(secondArgument);
    if (firstVariables.isEmpty() || secondVariables.isEmpty()) {
        return Result.empty();
    }
    // If either firstArgument or secondArgument is not a
    // VariableReferenceExpression, will replace the left/right join node
    // with a projection that adds the argument as a variable.
    Optional<VariableReferenceExpression> newFirstVariable = newGeometryVariable(context, firstArgument);
    Optional<VariableReferenceExpression> newSecondVariable = newGeometryVariable(context, secondArgument);
    PlanNode leftNode = joinNode.getLeft();
    PlanNode rightNode = joinNode.getRight();
    PlanNode newLeftNode;
    PlanNode newRightNode;
    // Check if the order of arguments of the spatial function matches the order of join sides
    int alignment = checkAlignment(joinNode, firstVariables, secondVariables);
    if (alignment > 0) {
        newLeftNode = newFirstVariable.map(variable -> addProjection(context, leftNode, variable, firstArgument)).orElse(leftNode);
        newRightNode = newSecondVariable.map(variable -> addProjection(context, rightNode, variable, secondArgument)).orElse(rightNode);
    } else if (alignment < 0) {
        newLeftNode = newSecondVariable.map(variable -> addProjection(context, leftNode, variable, secondArgument)).orElse(leftNode);
        newRightNode = newFirstVariable.map(variable -> addProjection(context, rightNode, variable, firstArgument)).orElse(rightNode);
    } else {
        return Result.empty();
    }
    RowExpression newFirstArgument = mapToExpression(newFirstVariable, firstArgument);
    RowExpression newSecondArgument = mapToExpression(newSecondVariable, secondArgument);
    // Implement partitioned spatial joins:
    // If the session parameter points to a valid spatial partitioning, use
    // that to assign to each probe and build rows the partitions that the
    // geometry intersects.  This is a projection that adds an array of ints
    // which is subsequently unnested.
    Optional<String> spatialPartitioningTableName = canPartitionSpatialJoin(joinNode) ? getSpatialPartitioningTableName(context.getSession()) : Optional.empty();
    Optional<KdbTree> kdbTree = spatialPartitioningTableName.map(tableName -> loadKdbTree(tableName, context.getSession(), metadata, splitManager, pageSourceManager));
    Optional<VariableReferenceExpression> leftPartitionVariable = Optional.empty();
    Optional<VariableReferenceExpression> rightPartitionVariable = Optional.empty();
    if (kdbTree.isPresent()) {
        leftPartitionVariable = Optional.of(context.getVariableAllocator().newVariable(newFirstArgument.getSourceLocation(), "pid", INTEGER));
        rightPartitionVariable = Optional.of(context.getVariableAllocator().newVariable(newSecondArgument.getSourceLocation(), "pid", INTEGER));
        if (alignment > 0) {
            newLeftNode = addPartitioningNodes(context, functionAndTypeManager, newLeftNode, leftPartitionVariable.get(), kdbTree.get(), newFirstArgument, Optional.empty());
            newRightNode = addPartitioningNodes(context, functionAndTypeManager, newRightNode, rightPartitionVariable.get(), kdbTree.get(), newSecondArgument, radius);
        } else {
            newLeftNode = addPartitioningNodes(context, functionAndTypeManager, newLeftNode, leftPartitionVariable.get(), kdbTree.get(), newSecondArgument, Optional.empty());
            newRightNode = addPartitioningNodes(context, functionAndTypeManager, newRightNode, rightPartitionVariable.get(), kdbTree.get(), newFirstArgument, radius);
        }
    }
    CallExpression newSpatialFunction = new CallExpression(spatialFunction.getSourceLocation(), spatialFunction.getDisplayName(), spatialFunction.getFunctionHandle(), spatialFunction.getType(), ImmutableList.of(newFirstArgument, newSecondArgument));
    RowExpression newFilter = RowExpressionNodeInliner.replaceExpression(filter, ImmutableMap.of(spatialFunction, newSpatialFunction));
    return Result.ofPlanNode(new SpatialJoinNode(joinNode.getSourceLocation(), nodeId, SpatialJoinNode.Type.fromJoinNodeType(joinNode.getType()), newLeftNode, newRightNode, outputVariables, newFilter, leftPartitionVariable, rightPartitionVariable, kdbTree.map(KdbTreeUtils::toJson)));
}
Also used : FunctionAndTypeManager(com.facebook.presto.metadata.FunctionAndTypeManager) WarningCollector(com.facebook.presto.spi.WarningCollector) Page(com.facebook.presto.common.Page) SpatialJoinNode(com.facebook.presto.sql.planner.plan.SpatialJoinNode) VariableReferenceExpression(com.facebook.presto.spi.relation.VariableReferenceExpression) TypeSignature(com.facebook.presto.common.type.TypeSignature) MoreFutures.getFutureValue(com.facebook.airlift.concurrent.MoreFutures.getFutureValue) NOT_PARTITIONED(com.facebook.presto.spi.connector.NotPartitionedPartitionHandle.NOT_PARTITIONED) SpatialJoinUtils.getFlippedFunctionHandle(com.facebook.presto.util.SpatialJoinUtils.getFlippedFunctionHandle) Pattern(com.facebook.presto.matching.Pattern) Preconditions.checkArgument(com.google.common.base.Preconditions.checkArgument) Capture(com.facebook.presto.matching.Capture) SpatialJoinUtils.flip(com.facebook.presto.util.SpatialJoinUtils.flip) Map(java.util.Map) UNGROUPED_SCHEDULING(com.facebook.presto.spi.connector.ConnectorSplitManager.SplitSchedulingStrategy.UNGROUPED_SCHEDULING) LOCAL(com.facebook.presto.spi.plan.ProjectNode.Locality.LOCAL) QualifiedObjectName(com.facebook.presto.common.QualifiedObjectName) SystemSessionProperties.isSpatialJoinEnabled(com.facebook.presto.SystemSessionProperties.isSpatialJoinEnabled) Slices.utf8Slice(io.airlift.slice.Slices.utf8Slice) CallExpression(com.facebook.presto.spi.relation.CallExpression) Splitter(com.google.common.base.Splitter) SplitSource(com.facebook.presto.split.SplitSource) Lifespan(com.facebook.presto.execution.Lifespan) ImmutableSet(com.google.common.collect.ImmutableSet) ImmutableMap(com.google.common.collect.ImmutableMap) Collection(java.util.Collection) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) Set(java.util.Set) SplitManager(com.facebook.presto.split.SplitManager) String.format(java.lang.String.format) KDB_TREE(com.facebook.presto.common.type.KdbTreeType.KDB_TREE) UncheckedIOException(java.io.UncheckedIOException) FunctionMetadata(com.facebook.presto.spi.function.FunctionMetadata) List(java.util.List) SpatialJoinUtils.extractSupportedSpatialFunctions(com.facebook.presto.util.SpatialJoinUtils.extractSupportedSpatialFunctions) KdbTreeUtils(com.facebook.presto.geospatial.KdbTreeUtils) ProjectNode(com.facebook.presto.spi.plan.ProjectNode) Capture.newCapture(com.facebook.presto.matching.Capture.newCapture) INTEGER(com.facebook.presto.common.type.IntegerType.INTEGER) Optional(java.util.Optional) CAST(com.facebook.presto.metadata.CastType.CAST) VariablesExtractor.extractUnique(com.facebook.presto.sql.planner.VariablesExtractor.extractUnique) INNER(com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER) PlanNodeId(com.facebook.presto.spi.plan.PlanNodeId) Iterables(com.google.common.collect.Iterables) TableLayoutResult(com.facebook.presto.metadata.TableLayoutResult) VARCHAR(com.facebook.presto.common.type.VarcharType.VARCHAR) Captures(com.facebook.presto.matching.Captures) Assignments(com.facebook.presto.spi.plan.Assignments) Result(com.facebook.presto.sql.planner.iterative.Rule.Result) PrestoException(com.facebook.presto.spi.PrestoException) Patterns.join(com.facebook.presto.sql.planner.plan.Patterns.join) TypeSignatureProvider.fromTypes(com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes) Patterns.filter(com.facebook.presto.sql.planner.plan.Patterns.filter) FilterNode(com.facebook.presto.spi.plan.FilterNode) SplitBatch(com.facebook.presto.split.SplitSource.SplitBatch) RowExpressionNodeInliner(com.facebook.presto.expressions.RowExpressionNodeInliner) ImmutableList(com.google.common.collect.ImmutableList) PageSourceManager(com.facebook.presto.split.PageSourceManager) Verify.verify(com.google.common.base.Verify.verify) Objects.requireNonNull(java.util.Objects.requireNonNull) SpatialJoinUtils.extractSupportedSpatialComparisons(com.facebook.presto.util.SpatialJoinUtils.extractSupportedSpatialComparisons) ArrayType(com.facebook.presto.common.type.ArrayType) TableHandle(com.facebook.presto.spi.TableHandle) Expressions(com.facebook.presto.sql.relational.Expressions) KdbTree(com.facebook.presto.geospatial.KdbTree) UnnestNode(com.facebook.presto.sql.planner.plan.UnnestNode) Type(com.facebook.presto.common.type.Type) RowExpression(com.facebook.presto.spi.relation.RowExpression) JoinNode(com.facebook.presto.sql.planner.plan.JoinNode) INVALID_SPATIAL_PARTITIONING(com.facebook.presto.spi.StandardErrorCode.INVALID_SPATIAL_PARTITIONING) SystemSessionProperties.getSpatialPartitioningTableName(com.facebook.presto.SystemSessionProperties.getSpatialPartitioningTableName) Session(com.facebook.presto.Session) Rule(com.facebook.presto.sql.planner.iterative.Rule) Constraint(com.facebook.presto.spi.Constraint) IOException(java.io.IOException) OperatorType(com.facebook.presto.common.function.OperatorType) Patterns.source(com.facebook.presto.sql.planner.plan.Patterns.source) PlanNode(com.facebook.presto.spi.plan.PlanNode) ConnectorPageSource(com.facebook.presto.spi.ConnectorPageSource) TypeSignature.parseTypeSignature(com.facebook.presto.common.type.TypeSignature.parseTypeSignature) LEFT(com.facebook.presto.sql.planner.plan.JoinNode.Type.LEFT) ColumnHandle(com.facebook.presto.spi.ColumnHandle) FunctionHandle(com.facebook.presto.spi.function.FunctionHandle) Split(com.facebook.presto.metadata.Split) Context(com.facebook.presto.sql.planner.iterative.Rule.Context) VisibleForTesting(com.google.common.annotations.VisibleForTesting) Metadata(com.facebook.presto.metadata.Metadata) RowExpressionNodeInliner.replaceExpression(com.facebook.presto.expressions.RowExpressionNodeInliner.replaceExpression) KdbTreeUtils(com.facebook.presto.geospatial.KdbTreeUtils) KdbTree(com.facebook.presto.geospatial.KdbTree) RowExpression(com.facebook.presto.spi.relation.RowExpression) SpatialJoinNode(com.facebook.presto.sql.planner.plan.SpatialJoinNode) Constraint(com.facebook.presto.spi.Constraint) PlanNode(com.facebook.presto.spi.plan.PlanNode) FunctionAndTypeManager(com.facebook.presto.metadata.FunctionAndTypeManager) VariableReferenceExpression(com.facebook.presto.spi.relation.VariableReferenceExpression) CallExpression(com.facebook.presto.spi.relation.CallExpression)

Example 27 with CallExpression

use of com.facebook.presto.spi.relation.CallExpression in project presto by prestodb.

the class PushAggregationThroughOuterJoin method createAggregationOverNull.

private Optional<MappedAggregationInfo> createAggregationOverNull(AggregationNode referenceAggregation, PlanVariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, Lookup lookup) {
    // Create a values node that consists of a single row of nulls.
    // Map the output symbols from the referenceAggregation's source
    // to symbol references for the new values node.
    ImmutableList.Builder<VariableReferenceExpression> nullVariables = ImmutableList.builder();
    ImmutableList.Builder<RowExpression> nullLiterals = ImmutableList.builder();
    ImmutableMap.Builder<VariableReferenceExpression, VariableReferenceExpression> sourcesVariableMappingBuilder = ImmutableMap.builder();
    for (VariableReferenceExpression sourceVariable : referenceAggregation.getSource().getOutputVariables()) {
        RowExpression nullLiteral = constantNull(sourceVariable.getSourceLocation(), sourceVariable.getType());
        nullLiterals.add(nullLiteral);
        VariableReferenceExpression nullVariable = variableAllocator.newVariable(nullLiteral);
        nullVariables.add(nullVariable);
        // TODO The type should be from sourceVariable.getType
        sourcesVariableMappingBuilder.put(sourceVariable, nullVariable);
    }
    ValuesNode nullRow = new ValuesNode(referenceAggregation.getSourceLocation(), idAllocator.getNextId(), nullVariables.build(), ImmutableList.of(nullLiterals.build()));
    Map<VariableReferenceExpression, VariableReferenceExpression> sourcesVariableMapping = sourcesVariableMappingBuilder.build();
    // For each aggregation function in the reference node, create a corresponding aggregation function
    // that points to the nullRow. Map the symbols from the aggregations in referenceAggregation to the
    // symbols in these new aggregations.
    ImmutableMap.Builder<VariableReferenceExpression, VariableReferenceExpression> aggregationsVariableMappingBuilder = ImmutableMap.builder();
    ImmutableMap.Builder<VariableReferenceExpression, AggregationNode.Aggregation> aggregationsOverNullBuilder = ImmutableMap.builder();
    for (Map.Entry<VariableReferenceExpression, AggregationNode.Aggregation> entry : referenceAggregation.getAggregations().entrySet()) {
        VariableReferenceExpression aggregationVariable = entry.getKey();
        AggregationNode.Aggregation aggregation = entry.getValue();
        if (!isUsingVariables(aggregation, sourcesVariableMapping.keySet())) {
            return Optional.empty();
        }
        AggregationNode.Aggregation overNullAggregation = new AggregationNode.Aggregation(new CallExpression(aggregation.getCall().getSourceLocation(), aggregation.getCall().getDisplayName(), aggregation.getCall().getFunctionHandle(), aggregation.getCall().getType(), aggregation.getArguments().stream().map(argument -> inlineVariables(sourcesVariableMapping, argument)).collect(toImmutableList())), aggregation.getFilter().map(filter -> inlineVariables(sourcesVariableMapping, filter)), aggregation.getOrderBy().map(orderBy -> inlineOrderByVariables(sourcesVariableMapping, orderBy)), aggregation.isDistinct(), aggregation.getMask().map(x -> new VariableReferenceExpression(sourcesVariableMapping.get(x).getSourceLocation(), sourcesVariableMapping.get(x).getName(), x.getType())));
        QualifiedObjectName functionName = functionAndTypeManager.getFunctionMetadata(overNullAggregation.getFunctionHandle()).getName();
        VariableReferenceExpression overNull = variableAllocator.newVariable(aggregation.getCall().getSourceLocation(), functionName.getObjectName(), aggregationVariable.getType());
        aggregationsOverNullBuilder.put(overNull, overNullAggregation);
        aggregationsVariableMappingBuilder.put(aggregationVariable, overNull);
    }
    Map<VariableReferenceExpression, VariableReferenceExpression> aggregationsSymbolMapping = aggregationsVariableMappingBuilder.build();
    // create an aggregation node whose source is the null row.
    AggregationNode aggregationOverNullRow = new AggregationNode(referenceAggregation.getSourceLocation(), idAllocator.getNextId(), nullRow, aggregationsOverNullBuilder.build(), globalAggregation(), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty());
    return Optional.of(new MappedAggregationInfo(aggregationOverNullRow, aggregationsSymbolMapping));
}
Also used : FunctionAndTypeManager(com.facebook.presto.metadata.FunctionAndTypeManager) AggregationNode(com.facebook.presto.spi.plan.AggregationNode) Captures(com.facebook.presto.matching.Captures) Assignments(com.facebook.presto.spi.plan.Assignments) VariableReferenceExpression(com.facebook.presto.spi.relation.VariableReferenceExpression) Patterns.join(com.facebook.presto.sql.planner.plan.Patterns.join) Pattern(com.facebook.presto.matching.Pattern) ValuesNode(com.facebook.presto.spi.plan.ValuesNode) HashSet(java.util.HashSet) Capture(com.facebook.presto.matching.Capture) ImmutableList(com.google.common.collect.ImmutableList) DistinctOutputQueryUtil.isDistinct(com.facebook.presto.sql.planner.optimizations.DistinctOutputQueryUtil.isDistinct) Map(java.util.Map) Objects.requireNonNull(java.util.Objects.requireNonNull) QualifiedObjectName(com.facebook.presto.common.QualifiedObjectName) AggregationNode.singleGroupingSet(com.facebook.presto.spi.plan.AggregationNode.singleGroupingSet) CallExpression(com.facebook.presto.spi.relation.CallExpression) OrderingScheme(com.facebook.presto.spi.plan.OrderingScheme) SpecialFormExpression(com.facebook.presto.spi.relation.SpecialFormExpression) RowExpression(com.facebook.presto.spi.relation.RowExpression) JoinNode(com.facebook.presto.sql.planner.plan.JoinNode) Patterns.aggregation(com.facebook.presto.sql.planner.plan.Patterns.aggregation) PlanNodeIdAllocator(com.facebook.presto.spi.plan.PlanNodeIdAllocator) AggregationNode.globalAggregation(com.facebook.presto.spi.plan.AggregationNode.globalAggregation) SortOrder(com.facebook.presto.common.block.SortOrder) ImmutableMap(com.google.common.collect.ImmutableMap) Session(com.facebook.presto.Session) Ordering(com.facebook.presto.spi.plan.Ordering) Rule(com.facebook.presto.sql.planner.iterative.Rule) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) Set(java.util.Set) RowExpressionVariableInliner.inlineVariables(com.facebook.presto.sql.planner.RowExpressionVariableInliner.inlineVariables) Expressions.constantNull(com.facebook.presto.sql.relational.Expressions.constantNull) Lookup(com.facebook.presto.sql.planner.iterative.Lookup) Preconditions.checkState(com.google.common.base.Preconditions.checkState) Patterns.source(com.facebook.presto.sql.planner.plan.Patterns.source) PlanNode(com.facebook.presto.spi.plan.PlanNode) List(java.util.List) SystemSessionProperties.shouldPushAggregationThroughJoin(com.facebook.presto.SystemSessionProperties.shouldPushAggregationThroughJoin) ProjectNode(com.facebook.presto.spi.plan.ProjectNode) Capture.newCapture(com.facebook.presto.matching.Capture.newCapture) PlanVariableAllocator(com.facebook.presto.sql.planner.PlanVariableAllocator) Optional(java.util.Optional) ValuesNode(com.facebook.presto.spi.plan.ValuesNode) ImmutableList(com.google.common.collect.ImmutableList) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) RowExpression(com.facebook.presto.spi.relation.RowExpression) AggregationNode(com.facebook.presto.spi.plan.AggregationNode) ImmutableMap(com.google.common.collect.ImmutableMap) QualifiedObjectName(com.facebook.presto.common.QualifiedObjectName) AggregationNode.globalAggregation(com.facebook.presto.spi.plan.AggregationNode.globalAggregation) VariableReferenceExpression(com.facebook.presto.spi.relation.VariableReferenceExpression) Map(java.util.Map) ImmutableMap(com.google.common.collect.ImmutableMap) CallExpression(com.facebook.presto.spi.relation.CallExpression)

Example 28 with CallExpression

use of com.facebook.presto.spi.relation.CallExpression in project presto by prestodb.

the class RewriteSpatialPartitioningAggregation method apply.

@Override
public Result apply(AggregationNode node, Captures captures, Context context) {
    ImmutableMap.Builder<VariableReferenceExpression, Aggregation> aggregations = ImmutableMap.builder();
    VariableReferenceExpression partitionCountVariable = context.getVariableAllocator().newVariable("partition_count", INTEGER);
    ImmutableMap.Builder<VariableReferenceExpression, RowExpression> envelopeAssignments = ImmutableMap.builder();
    for (Map.Entry<VariableReferenceExpression, Aggregation> entry : node.getAggregations().entrySet()) {
        Aggregation aggregation = entry.getValue();
        QualifiedObjectName name = metadata.getFunctionAndTypeManager().getFunctionMetadata(aggregation.getFunctionHandle()).getName();
        Type geometryType = metadata.getType(GEOMETRY_TYPE_SIGNATURE);
        if (name.equals(NAME) && aggregation.getArguments().size() == 1) {
            RowExpression geometry = getOnlyElement(aggregation.getArguments());
            VariableReferenceExpression envelopeVariable = context.getVariableAllocator().newVariable(aggregation.getCall().getSourceLocation(), "envelope", geometryType);
            if (isFunctionNameMatch(geometry, "ST_Envelope")) {
                envelopeAssignments.put(envelopeVariable, geometry);
            } else {
                envelopeAssignments.put(envelopeVariable, castToRowExpression(new FunctionCall(QualifiedName.of("ST_Envelope"), ImmutableList.of(castToExpression(geometry)))));
            }
            aggregations.put(entry.getKey(), new Aggregation(new CallExpression(envelopeVariable.getSourceLocation(), name.getObjectName(), metadata.getFunctionAndTypeManager().lookupFunction(NAME.getObjectName(), fromTypes(geometryType, INTEGER)), entry.getKey().getType(), ImmutableList.of(castToRowExpression(asSymbolReference(envelopeVariable)), castToRowExpression(asSymbolReference(partitionCountVariable)))), Optional.empty(), Optional.empty(), false, aggregation.getMask()));
        } else {
            aggregations.put(entry);
        }
    }
    return Result.ofPlanNode(new AggregationNode(node.getSourceLocation(), node.getId(), new ProjectNode(context.getIdAllocator().getNextId(), node.getSource(), Assignments.builder().putAll(identitiesAsSymbolReferences(node.getSource().getOutputVariables())).put(partitionCountVariable, castToRowExpression(new LongLiteral(Integer.toString(getHashPartitionCount(context.getSession()))))).putAll(envelopeAssignments.build()).build()), aggregations.build(), node.getGroupingSets(), node.getPreGroupedVariables(), node.getStep(), node.getHashVariable(), node.getGroupIdVariable()));
}
Also used : LongLiteral(com.facebook.presto.sql.tree.LongLiteral) RowExpression(com.facebook.presto.spi.relation.RowExpression) OriginalExpressionUtils.castToRowExpression(com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression) AggregationNode(com.facebook.presto.spi.plan.AggregationNode) ImmutableMap(com.google.common.collect.ImmutableMap) QualifiedObjectName(com.facebook.presto.common.QualifiedObjectName) Aggregation(com.facebook.presto.spi.plan.AggregationNode.Aggregation) Type(com.facebook.presto.common.type.Type) VariableReferenceExpression(com.facebook.presto.spi.relation.VariableReferenceExpression) ProjectNode(com.facebook.presto.spi.plan.ProjectNode) FunctionCall(com.facebook.presto.sql.tree.FunctionCall) Map(java.util.Map) ImmutableMap(com.google.common.collect.ImmutableMap) CallExpression(com.facebook.presto.spi.relation.CallExpression)

Example 29 with CallExpression

use of com.facebook.presto.spi.relation.CallExpression in project presto by prestodb.

the class PushPartialAggregationThroughExchange method split.

private PlanNode split(AggregationNode node, Context context) {
    // otherwise, add a partial and final with an exchange in between
    Map<VariableReferenceExpression, AggregationNode.Aggregation> intermediateAggregation = new HashMap<>();
    Map<VariableReferenceExpression, AggregationNode.Aggregation> finalAggregation = new HashMap<>();
    for (Map.Entry<VariableReferenceExpression, AggregationNode.Aggregation> entry : node.getAggregations().entrySet()) {
        AggregationNode.Aggregation originalAggregation = entry.getValue();
        String functionName = functionAndTypeManager.getFunctionMetadata(originalAggregation.getFunctionHandle()).getName().getObjectName();
        FunctionHandle functionHandle = originalAggregation.getFunctionHandle();
        InternalAggregationFunction function = functionAndTypeManager.getAggregateFunctionImplementation(functionHandle);
        VariableReferenceExpression intermediateVariable = context.getVariableAllocator().newVariable(entry.getValue().getCall().getSourceLocation(), functionName, function.getIntermediateType());
        checkState(!originalAggregation.getOrderBy().isPresent(), "Aggregate with ORDER BY does not support partial aggregation");
        intermediateAggregation.put(intermediateVariable, new AggregationNode.Aggregation(new CallExpression(originalAggregation.getCall().getSourceLocation(), functionName, functionHandle, function.getIntermediateType(), originalAggregation.getArguments()), originalAggregation.getFilter(), originalAggregation.getOrderBy(), originalAggregation.isDistinct(), originalAggregation.getMask()));
        // rewrite final aggregation in terms of intermediate function
        finalAggregation.put(entry.getKey(), new AggregationNode.Aggregation(new CallExpression(originalAggregation.getCall().getSourceLocation(), functionName, functionHandle, function.getFinalType(), ImmutableList.<RowExpression>builder().add(intermediateVariable).addAll(originalAggregation.getArguments().stream().filter(PushPartialAggregationThroughExchange::isLambda).collect(toImmutableList())).build()), Optional.empty(), Optional.empty(), false, Optional.empty()));
    }
    // We can always enable streaming aggregation for partial aggregations. But if the table is not pre-group by the groupby columns, it may have regressions.
    // This session property is just a solution to force enabling when we know the execution would benefit from partial streaming aggregation.
    // We can work on determining it based on the input table properties later.
    List<VariableReferenceExpression> preGroupedSymbols = ImmutableList.of();
    if (isStreamingForPartialAggregationEnabled(context.getSession())) {
        preGroupedSymbols = ImmutableList.copyOf(node.getGroupingSets().getGroupingKeys());
    }
    PlanNode partial = new AggregationNode(node.getSourceLocation(), context.getIdAllocator().getNextId(), node.getSource(), intermediateAggregation, node.getGroupingSets(), // through the exchange may or may not preserve these properties. Hence, it is safest to drop preGroupedSymbols here.
    preGroupedSymbols, PARTIAL, node.getHashVariable(), node.getGroupIdVariable());
    return new AggregationNode(node.getSourceLocation(), node.getId(), partial, finalAggregation, node.getGroupingSets(), // through the exchange may or may not preserve these properties. Hence, it is safest to drop preGroupedSymbols here.
    ImmutableList.of(), FINAL, node.getHashVariable(), node.getGroupIdVariable());
}
Also used : HashMap(java.util.HashMap) RowExpression(com.facebook.presto.spi.relation.RowExpression) AggregationNode(com.facebook.presto.spi.plan.AggregationNode) InternalAggregationFunction(com.facebook.presto.operator.aggregation.InternalAggregationFunction) PlanNode(com.facebook.presto.spi.plan.PlanNode) VariableReferenceExpression(com.facebook.presto.spi.relation.VariableReferenceExpression) HashMap(java.util.HashMap) Map(java.util.Map) FunctionHandle(com.facebook.presto.spi.function.FunctionHandle) CallExpression(com.facebook.presto.spi.relation.CallExpression)

Example 30 with CallExpression

use of com.facebook.presto.spi.relation.CallExpression in project presto by prestodb.

the class RewriteAggregationIfToFilter method apply.

@Override
public Result apply(AggregationNode aggregationNode, Captures captures, Context context) {
    ProjectNode sourceProject = captures.get(CHILD);
    Set<Aggregation> aggregationsToRewrite = aggregationNode.getAggregations().values().stream().filter(aggregation -> shouldRewriteAggregation(aggregation, sourceProject)).collect(toImmutableSet());
    if (aggregationsToRewrite.isEmpty()) {
        return Result.empty();
    }
    context.getSession().getRuntimeStats().addMetricValue(REWRITE_AGGREGATION_IF_TO_FILTER_APPLIED, 1);
    // Get the corresponding assignments in the input project.
    // The aggregationReferences only has the aggregations to rewrite, thus the sourceAssignments only has IF/CAST(IF) expressions with NULL false results.
    // Multiple aggregations may reference the same input. We use a map to dedup them based on the VariableReferenceExpression, so that we only do the rewrite once per input
    // IF expression.
    // The order of sourceAssignments determines the order of generating the new variables for the IF conditions and results. We use a sorted map to get a deterministic
    // order based on the name of the VariableReferenceExpressions.
    Map<VariableReferenceExpression, RowExpression> sourceAssignments = aggregationsToRewrite.stream().map(aggregation -> (VariableReferenceExpression) aggregation.getArguments().get(0)).collect(toImmutableSortedMap(VariableReferenceExpression::compareTo, identity(), variable -> sourceProject.getAssignments().get(variable), (left, right) -> left));
    Assignments.Builder newAssignments = Assignments.builder();
    newAssignments.putAll(sourceProject.getAssignments());
    // Map from the aggregation reference to the IF condition reference which will be put in the mask.
    Map<VariableReferenceExpression, VariableReferenceExpression> aggregationReferenceToConditionReference = new HashMap<>();
    // Map from the aggregation reference to the IF result reference. This only contains the aggregates where the IF can be safely unwrapped.
    // E.g., SUM(IF(CARDINALITY(array) > 0, array[1])) will not be included in this map as array[1] can return errors if we unwrap the IF.
    Map<VariableReferenceExpression, VariableReferenceExpression> aggregationReferenceToIfResultReference = new HashMap<>();
    AggregationIfToFilterRewriteStrategy rewriteStrategy = getAggregationIfToFilterRewriteStrategy(context.getSession());
    for (Map.Entry<VariableReferenceExpression, RowExpression> entry : sourceAssignments.entrySet()) {
        VariableReferenceExpression outputVariable = entry.getKey();
        RowExpression rowExpression = entry.getValue();
        SpecialFormExpression ifExpression = (SpecialFormExpression) ((rowExpression instanceof CallExpression) ? ((CallExpression) rowExpression).getArguments().get(0) : rowExpression);
        RowExpression condition = ifExpression.getArguments().get(0);
        VariableReferenceExpression conditionReference = context.getVariableAllocator().newVariable(condition);
        newAssignments.put(conditionReference, condition);
        aggregationReferenceToConditionReference.put(outputVariable, conditionReference);
        if (canUnwrapIf(ifExpression, rewriteStrategy)) {
            RowExpression trueResult = ifExpression.getArguments().get(1);
            if (rowExpression instanceof CallExpression) {
                // Wrap the result with CAST().
                trueResult = new CallExpression(((CallExpression) rowExpression).getDisplayName(), ((CallExpression) rowExpression).getFunctionHandle(), rowExpression.getType(), ImmutableList.of(trueResult));
            }
            VariableReferenceExpression ifResultReference = context.getVariableAllocator().newVariable(trueResult);
            newAssignments.put(ifResultReference, trueResult);
            aggregationReferenceToIfResultReference.put(outputVariable, ifResultReference);
        }
    }
    // Build new aggregations.
    ImmutableMap.Builder<VariableReferenceExpression, Aggregation> aggregations = ImmutableMap.builder();
    // Stores the masks used to build the filter predicates. Use set to dedup the predicates.
    ImmutableSortedSet.Builder<VariableReferenceExpression> masks = ImmutableSortedSet.naturalOrder();
    for (Map.Entry<VariableReferenceExpression, Aggregation> entry : aggregationNode.getAggregations().entrySet()) {
        VariableReferenceExpression output = entry.getKey();
        Aggregation aggregation = entry.getValue();
        if (!aggregationsToRewrite.contains(aggregation)) {
            aggregations.put(output, aggregation);
            continue;
        }
        VariableReferenceExpression aggregationReference = (VariableReferenceExpression) aggregation.getArguments().get(0);
        CallExpression callExpression = aggregation.getCall();
        VariableReferenceExpression ifResultReference = aggregationReferenceToIfResultReference.get(aggregationReference);
        if (ifResultReference != null) {
            callExpression = new CallExpression(callExpression.getSourceLocation(), callExpression.getDisplayName(), callExpression.getFunctionHandle(), callExpression.getType(), ImmutableList.of(ifResultReference));
        }
        VariableReferenceExpression mask = aggregationReferenceToConditionReference.get(aggregationReference);
        aggregations.put(output, new Aggregation(callExpression, Optional.empty(), aggregation.getOrderBy(), aggregation.isDistinct(), Optional.of(aggregationReferenceToConditionReference.get(aggregationReference))));
        masks.add(mask);
    }
    RowExpression predicate = TRUE_CONSTANT;
    if (!aggregationNode.hasNonEmptyGroupingSet() && aggregationsToRewrite.size() == aggregationNode.getAggregations().size()) {
        // All aggregations are rewritten by this rule. We can add a filter with all the masks to make the query more efficient.
        predicate = or(masks.build());
    }
    return Result.ofPlanNode(new AggregationNode(aggregationNode.getSourceLocation(), context.getIdAllocator().getNextId(), new FilterNode(aggregationNode.getSourceLocation(), context.getIdAllocator().getNextId(), new ProjectNode(context.getIdAllocator().getNextId(), sourceProject.getSource(), newAssignments.build()), predicate), aggregations.build(), aggregationNode.getGroupingSets(), aggregationNode.getPreGroupedVariables(), aggregationNode.getStep(), aggregationNode.getHashVariable(), aggregationNode.getGroupIdVariable()));
}
Also used : FunctionAndTypeManager(com.facebook.presto.metadata.FunctionAndTypeManager) AggregationNode(com.facebook.presto.spi.plan.AggregationNode) Captures(com.facebook.presto.matching.Captures) Assignments(com.facebook.presto.spi.plan.Assignments) AtomicBoolean(java.util.concurrent.atomic.AtomicBoolean) ImmutableSortedMap.toImmutableSortedMap(com.google.common.collect.ImmutableSortedMap.toImmutableSortedMap) VariableReferenceExpression(com.facebook.presto.spi.relation.VariableReferenceExpression) HashMap(java.util.HashMap) RowExpressionDeterminismEvaluator(com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator) SystemSessionProperties.getAggregationIfToFilterRewriteStrategy(com.facebook.presto.SystemSessionProperties.getAggregationIfToFilterRewriteStrategy) StandardFunctionResolution(com.facebook.presto.spi.function.StandardFunctionResolution) Pattern(com.facebook.presto.matching.Pattern) FilterNode(com.facebook.presto.spi.plan.FilterNode) Capture(com.facebook.presto.matching.Capture) ImmutableList(com.google.common.collect.ImmutableList) UNWRAP_IF(com.facebook.presto.sql.analyzer.FeaturesConfig.AggregationIfToFilterRewriteStrategy.UNWRAP_IF) IF(com.facebook.presto.spi.relation.SpecialFormExpression.Form.IF) Map(java.util.Map) Objects.requireNonNull(java.util.Objects.requireNonNull) ImmutableSet.toImmutableSet(com.google.common.collect.ImmutableSet.toImmutableSet) Expressions(com.facebook.presto.sql.relational.Expressions) FunctionResolution(com.facebook.presto.sql.relational.FunctionResolution) CallExpression(com.facebook.presto.spi.relation.CallExpression) SpecialFormExpression(com.facebook.presto.spi.relation.SpecialFormExpression) VariablesExtractor(com.facebook.presto.sql.planner.VariablesExtractor) RowExpression(com.facebook.presto.spi.relation.RowExpression) ImmutableSortedSet(com.google.common.collect.ImmutableSortedSet) Patterns.aggregation(com.facebook.presto.sql.planner.plan.Patterns.aggregation) ImmutableMap(com.google.common.collect.ImmutableMap) Session(com.facebook.presto.Session) Rule(com.facebook.presto.sql.planner.iterative.Rule) Patterns.project(com.facebook.presto.sql.planner.plan.Patterns.project) Set(java.util.Set) LambdaDefinitionExpression(com.facebook.presto.spi.relation.LambdaDefinitionExpression) OperatorType(com.facebook.presto.common.function.OperatorType) AggregationIfToFilterRewriteStrategy(com.facebook.presto.sql.analyzer.FeaturesConfig.AggregationIfToFilterRewriteStrategy) TRUE_CONSTANT(com.facebook.presto.expressions.LogicalRowExpressions.TRUE_CONSTANT) Patterns.source(com.facebook.presto.sql.planner.plan.Patterns.source) FILTER_WITH_IF(com.facebook.presto.sql.analyzer.FeaturesConfig.AggregationIfToFilterRewriteStrategy.FILTER_WITH_IF) DefaultRowExpressionTraversalVisitor(com.facebook.presto.expressions.DefaultRowExpressionTraversalVisitor) ProjectNode(com.facebook.presto.spi.plan.ProjectNode) LogicalRowExpressions.or(com.facebook.presto.expressions.LogicalRowExpressions.or) Capture.newCapture(com.facebook.presto.matching.Capture.newCapture) Aggregation(com.facebook.presto.spi.plan.AggregationNode.Aggregation) Function.identity(java.util.function.Function.identity) Optional(java.util.Optional) REWRITE_AGGREGATION_IF_TO_FILTER_APPLIED(com.facebook.presto.common.RuntimeMetricName.REWRITE_AGGREGATION_IF_TO_FILTER_APPLIED) SystemSessionProperties.getAggregationIfToFilterRewriteStrategy(com.facebook.presto.SystemSessionProperties.getAggregationIfToFilterRewriteStrategy) AggregationIfToFilterRewriteStrategy(com.facebook.presto.sql.analyzer.FeaturesConfig.AggregationIfToFilterRewriteStrategy) HashMap(java.util.HashMap) FilterNode(com.facebook.presto.spi.plan.FilterNode) Assignments(com.facebook.presto.spi.plan.Assignments) RowExpression(com.facebook.presto.spi.relation.RowExpression) AggregationNode(com.facebook.presto.spi.plan.AggregationNode) ImmutableMap(com.google.common.collect.ImmutableMap) Aggregation(com.facebook.presto.spi.plan.AggregationNode.Aggregation) VariableReferenceExpression(com.facebook.presto.spi.relation.VariableReferenceExpression) ImmutableSortedSet(com.google.common.collect.ImmutableSortedSet) ProjectNode(com.facebook.presto.spi.plan.ProjectNode) ImmutableSortedMap.toImmutableSortedMap(com.google.common.collect.ImmutableSortedMap.toImmutableSortedMap) HashMap(java.util.HashMap) Map(java.util.Map) ImmutableMap(com.google.common.collect.ImmutableMap) SpecialFormExpression(com.facebook.presto.spi.relation.SpecialFormExpression) CallExpression(com.facebook.presto.spi.relation.CallExpression)

Aggregations

CallExpression (com.facebook.presto.spi.relation.CallExpression)64 RowExpression (com.facebook.presto.spi.relation.RowExpression)33 VariableReferenceExpression (com.facebook.presto.spi.relation.VariableReferenceExpression)33 Test (org.testng.annotations.Test)22 AggregationNode (com.facebook.presto.spi.plan.AggregationNode)20 FunctionHandle (com.facebook.presto.spi.function.FunctionHandle)19 ImmutableList (com.google.common.collect.ImmutableList)18 FunctionAndTypeManager (com.facebook.presto.metadata.FunctionAndTypeManager)16 Type (com.facebook.presto.common.type.Type)14 Map (java.util.Map)14 ConstantExpression (com.facebook.presto.spi.relation.ConstantExpression)13 ImmutableMap (com.google.common.collect.ImmutableMap)13 Optional (java.util.Optional)12 ImmutableList.toImmutableList (com.google.common.collect.ImmutableList.toImmutableList)11 OperatorType (com.facebook.presto.common.function.OperatorType)10 Aggregation (com.facebook.presto.spi.plan.AggregationNode.Aggregation)10 PlanNode (com.facebook.presto.spi.plan.PlanNode)10 ProjectNode (com.facebook.presto.spi.plan.ProjectNode)10 SpecialFormExpression (com.facebook.presto.spi.relation.SpecialFormExpression)10 Page (com.facebook.presto.common.Page)8