Search in sources :

Example 1 with Captures

use of com.facebook.presto.matching.Captures in project presto by prestodb.

the class PushProjectionThroughExchange method apply.

@Override
public Result apply(ProjectNode project, Captures captures, Context context) {
    ExchangeNode exchange = captures.get(CHILD);
    Set<VariableReferenceExpression> partitioningColumns = exchange.getPartitioningScheme().getPartitioning().getVariableReferences();
    ImmutableList.Builder<PlanNode> newSourceBuilder = ImmutableList.builder();
    ImmutableList.Builder<List<VariableReferenceExpression>> inputsBuilder = ImmutableList.builder();
    for (int i = 0; i < exchange.getSources().size(); i++) {
        Map<VariableReferenceExpression, VariableReferenceExpression> outputToInputMap = extractExchangeOutputToInput(exchange, i);
        Assignments.Builder projections = Assignments.builder();
        ImmutableList.Builder<VariableReferenceExpression> inputs = ImmutableList.builder();
        // Need to retain the partition keys for the exchange
        partitioningColumns.stream().map(outputToInputMap::get).forEach(variable -> {
            projections.put(variable, variable);
            inputs.add(variable);
        });
        if (exchange.getPartitioningScheme().getHashColumn().isPresent()) {
            // Need to retain the hash symbol for the exchange
            VariableReferenceExpression hashVariable = exchange.getPartitioningScheme().getHashColumn().get();
            projections.put(hashVariable, hashVariable);
            inputs.add(hashVariable);
        }
        if (exchange.getOrderingScheme().isPresent()) {
            // need to retain ordering columns for the exchange
            exchange.getOrderingScheme().get().getOrderByVariables().stream().filter(variable -> !partitioningColumns.contains(variable)).map(outputToInputMap::get).forEach(variable -> {
                projections.put(variable, variable);
                inputs.add(variable);
            });
        }
        for (Map.Entry<VariableReferenceExpression, RowExpression> projection : project.getAssignments().entrySet()) {
            RowExpression translatedExpression = RowExpressionVariableInliner.inlineVariables(outputToInputMap, projection.getValue());
            VariableReferenceExpression variable = context.getVariableAllocator().newVariable(translatedExpression);
            projections.put(variable, translatedExpression);
            inputs.add(variable);
        }
        newSourceBuilder.add(new ProjectNode(project.getSourceLocation(), context.getIdAllocator().getNextId(), exchange.getSources().get(i), projections.build(), project.getLocality()));
        inputsBuilder.add(inputs.build());
    }
    // Construct the output symbols in the same order as the sources
    ImmutableList.Builder<VariableReferenceExpression> outputBuilder = ImmutableList.builder();
    partitioningColumns.forEach(outputBuilder::add);
    exchange.getPartitioningScheme().getHashColumn().ifPresent(outputBuilder::add);
    if (exchange.getOrderingScheme().isPresent()) {
        exchange.getOrderingScheme().get().getOrderByVariables().stream().filter(variable -> !partitioningColumns.contains(variable)).forEach(outputBuilder::add);
    }
    for (Map.Entry<VariableReferenceExpression, RowExpression> projection : project.getAssignments().entrySet()) {
        outputBuilder.add(projection.getKey());
    }
    // outputBuilder contains all partition and hash symbols so simply swap the output layout
    PartitioningScheme partitioningScheme = new PartitioningScheme(exchange.getPartitioningScheme().getPartitioning(), outputBuilder.build(), exchange.getPartitioningScheme().getHashColumn(), exchange.getPartitioningScheme().isReplicateNullsAndAny(), exchange.getPartitioningScheme().getBucketToPartition());
    PlanNode result = new ExchangeNode(exchange.getSourceLocation(), exchange.getId(), exchange.getType(), exchange.getScope(), partitioningScheme, newSourceBuilder.build(), inputsBuilder.build(), exchange.isEnsureSourceOrdering(), exchange.getOrderingScheme());
    // we need to strip unnecessary symbols (hash, partitioning columns).
    return Result.ofPlanNode(restrictOutputs(context.getIdAllocator(), result, ImmutableSet.copyOf(project.getOutputVariables()), true).orElse(result));
}
Also used : RowExpression(com.facebook.presto.spi.relation.RowExpression) Patterns.exchange(com.facebook.presto.sql.planner.plan.Patterns.exchange) ImmutableSet(com.google.common.collect.ImmutableSet) RowExpressionVariableInliner(com.facebook.presto.sql.planner.RowExpressionVariableInliner) Rule(com.facebook.presto.sql.planner.iterative.Rule) Captures(com.facebook.presto.matching.Captures) Patterns.project(com.facebook.presto.sql.planner.plan.Patterns.project) Assignments(com.facebook.presto.spi.plan.Assignments) Set(java.util.Set) VariableReferenceExpression(com.facebook.presto.spi.relation.VariableReferenceExpression) HashMap(java.util.HashMap) Util.restrictOutputs(com.facebook.presto.sql.planner.iterative.rule.Util.restrictOutputs) Pattern(com.facebook.presto.matching.Pattern) Patterns.source(com.facebook.presto.sql.planner.plan.Patterns.source) PlanNode(com.facebook.presto.spi.plan.PlanNode) List(java.util.List) Capture(com.facebook.presto.matching.Capture) ImmutableList(com.google.common.collect.ImmutableList) ProjectNode(com.facebook.presto.spi.plan.ProjectNode) Capture.newCapture(com.facebook.presto.matching.Capture.newCapture) Map(java.util.Map) PartitioningScheme(com.facebook.presto.sql.planner.PartitioningScheme) ExchangeNode(com.facebook.presto.sql.planner.plan.ExchangeNode) ExchangeNode(com.facebook.presto.sql.planner.plan.ExchangeNode) ImmutableList(com.google.common.collect.ImmutableList) PartitioningScheme(com.facebook.presto.sql.planner.PartitioningScheme) Assignments(com.facebook.presto.spi.plan.Assignments) RowExpression(com.facebook.presto.spi.relation.RowExpression) PlanNode(com.facebook.presto.spi.plan.PlanNode) VariableReferenceExpression(com.facebook.presto.spi.relation.VariableReferenceExpression) List(java.util.List) ImmutableList(com.google.common.collect.ImmutableList) ProjectNode(com.facebook.presto.spi.plan.ProjectNode) HashMap(java.util.HashMap) Map(java.util.Map)

Example 2 with Captures

use of com.facebook.presto.matching.Captures in project presto by prestodb.

the class SingleDistinctAggregationToGroupBy method apply.

@Override
public Result apply(AggregationNode aggregation, Captures captures, Context context) {
    List<Set<RowExpression>> argumentSets = extractArgumentSets(aggregation).collect(Collectors.toList());
    Set<VariableReferenceExpression> variables = Iterables.getOnlyElement(argumentSets).stream().map(OriginalExpressionUtils::castToExpression).map(context.getVariableAllocator()::toVariableReference).collect(Collectors.toSet());
    return Result.ofPlanNode(new AggregationNode(aggregation.getSourceLocation(), aggregation.getId(), new AggregationNode(aggregation.getSourceLocation(), context.getIdAllocator().getNextId(), aggregation.getSource(), ImmutableMap.of(), singleGroupingSet(ImmutableList.<VariableReferenceExpression>builder().addAll(aggregation.getGroupingKeys()).addAll(variables).build()), ImmutableList.of(), SINGLE, Optional.empty(), Optional.empty()), // remove DISTINCT flag from function calls
    aggregation.getAggregations().entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> removeDistinct(e.getValue()))), aggregation.getGroupingSets(), emptyList(), aggregation.getStep(), aggregation.getHashVariable(), aggregation.getGroupIdVariable()));
}
Also used : RowExpression(com.facebook.presto.spi.relation.RowExpression) Iterables(com.google.common.collect.Iterables) AggregationNode(com.facebook.presto.spi.plan.AggregationNode) Patterns.aggregation(com.facebook.presto.sql.planner.plan.Patterns.aggregation) ImmutableMap(com.google.common.collect.ImmutableMap) OriginalExpressionUtils(com.facebook.presto.sql.relational.OriginalExpressionUtils) Collections.emptyList(java.util.Collections.emptyList) Rule(com.facebook.presto.sql.planner.iterative.Rule) Captures(com.facebook.presto.matching.Captures) Set(java.util.Set) VariableReferenceExpression(com.facebook.presto.spi.relation.VariableReferenceExpression) SINGLE(com.facebook.presto.spi.plan.AggregationNode.Step.SINGLE) Collectors(java.util.stream.Collectors) Pattern(com.facebook.presto.matching.Pattern) HashSet(java.util.HashSet) List(java.util.List) Preconditions.checkArgument(com.google.common.base.Preconditions.checkArgument) Stream(java.util.stream.Stream) ImmutableList(com.google.common.collect.ImmutableList) Map(java.util.Map) Aggregation(com.facebook.presto.spi.plan.AggregationNode.Aggregation) Optional(java.util.Optional) AggregationNode.singleGroupingSet(com.facebook.presto.spi.plan.AggregationNode.singleGroupingSet) Set(java.util.Set) HashSet(java.util.HashSet) AggregationNode.singleGroupingSet(com.facebook.presto.spi.plan.AggregationNode.singleGroupingSet) VariableReferenceExpression(com.facebook.presto.spi.relation.VariableReferenceExpression) AggregationNode(com.facebook.presto.spi.plan.AggregationNode) OriginalExpressionUtils(com.facebook.presto.sql.relational.OriginalExpressionUtils) ImmutableMap(com.google.common.collect.ImmutableMap) Map(java.util.Map)

Example 3 with Captures

use of com.facebook.presto.matching.Captures in project presto by prestodb.

the class PruneRedundantProjectionAssignments method apply.

@Override
public Result apply(ProjectNode node, Captures captures, Context context) {
    Map<Boolean, List<Map.Entry<VariableReferenceExpression, RowExpression>>> projections = node.getAssignments().entrySet().stream().collect(Collectors.partitioningBy(entry -> entry.getValue() instanceof VariableReferenceExpression || entry.getValue() instanceof ConstantExpression));
    Map<RowExpression, ImmutableMap<VariableReferenceExpression, RowExpression>> uniqueProjections = projections.get(false).stream().collect(Collectors.groupingBy(Map.Entry::getValue, toImmutableMap(Map.Entry::getKey, Map.Entry::getValue)));
    if (uniqueProjections.size() == projections.get(false).size()) {
        return Result.empty();
    }
    Assignments.Builder childAssignments = Assignments.builder();
    Assignments.Builder parentAssignments = Assignments.builder();
    projections.get(true).forEach(entry -> childAssignments.put(entry.getKey(), entry.getValue()));
    projections.get(true).forEach(entry -> parentAssignments.put(entry.getKey(), entry.getKey()));
    for (Map.Entry<RowExpression, ImmutableMap<VariableReferenceExpression, RowExpression>> entry : uniqueProjections.entrySet()) {
        VariableReferenceExpression variable = getFirst(entry.getValue().keySet(), null);
        checkState(variable != null, "variable should not be null");
        childAssignments.put(variable, entry.getKey());
        entry.getValue().keySet().forEach(v -> parentAssignments.put(v, variable));
    }
    return Result.ofPlanNode(new ProjectNode(node.getSourceLocation(), context.getIdAllocator().getNextId(), new ProjectNode(node.getSourceLocation(), node.getId(), node.getSource(), childAssignments.build(), node.getLocality()), parentAssignments.build(), LOCAL));
}
Also used : RowExpression(com.facebook.presto.spi.relation.RowExpression) ImmutableMap(com.google.common.collect.ImmutableMap) Rule(com.facebook.presto.sql.planner.iterative.Rule) Captures(com.facebook.presto.matching.Captures) Patterns.project(com.facebook.presto.sql.planner.plan.Patterns.project) Assignments(com.facebook.presto.spi.plan.Assignments) VariableReferenceExpression(com.facebook.presto.spi.relation.VariableReferenceExpression) ConstantExpression(com.facebook.presto.spi.relation.ConstantExpression) Collectors(java.util.stream.Collectors) Preconditions.checkState(com.google.common.base.Preconditions.checkState) Pattern(com.facebook.presto.matching.Pattern) Iterables.getFirst(com.google.common.collect.Iterables.getFirst) List(java.util.List) ImmutableMap.toImmutableMap(com.google.common.collect.ImmutableMap.toImmutableMap) ProjectNode(com.facebook.presto.spi.plan.ProjectNode) Map(java.util.Map) LOCAL(com.facebook.presto.spi.plan.ProjectNode.Locality.LOCAL) ConstantExpression(com.facebook.presto.spi.relation.ConstantExpression) Assignments(com.facebook.presto.spi.plan.Assignments) RowExpression(com.facebook.presto.spi.relation.RowExpression) ImmutableMap(com.google.common.collect.ImmutableMap) ImmutableMap.toImmutableMap(com.google.common.collect.ImmutableMap.toImmutableMap) VariableReferenceExpression(com.facebook.presto.spi.relation.VariableReferenceExpression) List(java.util.List) ProjectNode(com.facebook.presto.spi.plan.ProjectNode) ImmutableMap(com.google.common.collect.ImmutableMap) ImmutableMap.toImmutableMap(com.google.common.collect.ImmutableMap.toImmutableMap) Map(java.util.Map)

Example 4 with Captures

use of com.facebook.presto.matching.Captures in project presto by prestodb.

the class RuntimeReorderJoinSides method apply.

@Override
public Result apply(JoinNode joinNode, Captures captures, Context context) {
    // Early exit if the leaves of the joinNode subtree include non tableScan nodes.
    if (searchFrom(joinNode, context.getLookup()).where(node -> node.getSources().isEmpty() && !(node instanceof TableScanNode)).matches()) {
        return Result.empty();
    }
    double leftOutputSizeInBytes = Double.NaN;
    double rightOutputSizeInBytes = Double.NaN;
    StatsProvider statsProvider = context.getStatsProvider();
    if (searchFrom(joinNode, context.getLookup()).where(node -> !(node instanceof TableScanNode) && !(node instanceof ExchangeNode)).findAll().size() == 1) {
        // Simple plan is characterized as Join directly on tableScanNodes only with exchangeNode in between.
        // For simple plans, directly fetch the overall table sizes as the size of the join sides to have
        // accurate input bytes statistics and meanwhile avoid non-negligible cost of collecting and processing
        // per-column statistics.
        leftOutputSizeInBytes = statsProvider.getStats(joinNode.getLeft()).getOutputSizeInBytes();
        rightOutputSizeInBytes = statsProvider.getStats(joinNode.getRight()).getOutputSizeInBytes();
    }
    if (Double.isNaN(leftOutputSizeInBytes) || Double.isNaN(rightOutputSizeInBytes)) {
        // Per-column estimate left and right output size for complex plans or when size statistics is unavailable.
        leftOutputSizeInBytes = statsProvider.getStats(joinNode.getLeft()).getOutputSizeInBytes(joinNode.getLeft().getOutputVariables());
        rightOutputSizeInBytes = statsProvider.getStats(joinNode.getRight()).getOutputSizeInBytes(joinNode.getRight().getOutputVariables());
    }
    if (Double.isNaN(leftOutputSizeInBytes) || Double.isNaN(rightOutputSizeInBytes)) {
        return Result.empty();
    }
    if (rightOutputSizeInBytes <= leftOutputSizeInBytes) {
        return Result.empty();
    }
    // Check if the swapped join is valid.
    if (!isSwappedJoinValid(joinNode)) {
        return Result.empty();
    }
    JoinNode swapped = joinNode.flipChildren();
    PlanNode newLeft = swapped.getLeft();
    Optional<VariableReferenceExpression> leftHashVariable = swapped.getLeftHashVariable();
    // Remove unnecessary LocalExchange in the current probe side. If the immediate left child (new probe side) of the join node
    // is a localExchange, there are two cases: an Exchange introduced by the current probe side (previous build side); or it is a UnionNode.
    // If the exchangeNode has more than 1 sources, it corresponds to the second case, otherwise it corresponds to the first case and could be safe to remove
    PlanNode resolvedSwappedLeft = context.getLookup().resolve(newLeft);
    if (resolvedSwappedLeft instanceof ExchangeNode && resolvedSwappedLeft.getSources().size() == 1) {
        // Ensure the new probe after skipping the local exchange will satisfy the required probe side property
        if (checkProbeSidePropertySatisfied(resolvedSwappedLeft.getSources().get(0), context)) {
            newLeft = resolvedSwappedLeft.getSources().get(0);
            // it as the leftHashVariable of the swapped join node.
            if (swapped.getLeftHashVariable().isPresent()) {
                int hashVariableIndex = resolvedSwappedLeft.getOutputVariables().indexOf(swapped.getLeftHashVariable().get());
                leftHashVariable = Optional.of(resolvedSwappedLeft.getSources().get(0).getOutputVariables().get(hashVariableIndex));
                // This is against typical iterativeOptimizer behavior and given this case is rare, just abort the swapping for this scenario.
                if (swapped.getOutputVariables().contains(swapped.getLeftHashVariable().get())) {
                    return Result.empty();
                }
            }
        }
    }
    // Add additional localExchange if the new build side does not satisfy the partitioning conditions.
    List<VariableReferenceExpression> buildJoinVariables = swapped.getCriteria().stream().map(JoinNode.EquiJoinClause::getRight).collect(toImmutableList());
    PlanNode newRight = swapped.getRight();
    if (!checkBuildSidePropertySatisfied(swapped.getRight(), buildJoinVariables, context)) {
        if (getTaskConcurrency(context.getSession()) > 1) {
            newRight = systemPartitionedExchange(context.getIdAllocator().getNextId(), LOCAL, swapped.getRight(), buildJoinVariables, swapped.getRightHashVariable());
        } else {
            newRight = gatheringExchange(context.getIdAllocator().getNextId(), LOCAL, swapped.getRight());
        }
    }
    JoinNode newJoinNode = new JoinNode(swapped.getSourceLocation(), swapped.getId(), swapped.getType(), newLeft, newRight, swapped.getCriteria(), swapped.getOutputVariables(), swapped.getFilter(), leftHashVariable, swapped.getRightHashVariable(), swapped.getDistributionType(), swapped.getDynamicFilters());
    log.debug(format("Probe size: %.2f is smaller than Build size: %.2f => invoke runtime join swapping on JoinNode ID: %s.", leftOutputSizeInBytes, rightOutputSizeInBytes, newJoinNode.getId()));
    return Result.ofPlanNode(newJoinNode);
}
Also used : ExchangeNode.systemPartitionedExchange(com.facebook.presto.sql.planner.plan.ExchangeNode.systemPartitionedExchange) Logger(com.facebook.airlift.log.Logger) StreamPropertyDerivations(com.facebook.presto.sql.planner.optimizations.StreamPropertyDerivations) Captures(com.facebook.presto.matching.Captures) VariableReferenceExpression(com.facebook.presto.spi.relation.VariableReferenceExpression) StreamPreferredProperties.exactlyPartitionedOn(com.facebook.presto.sql.planner.optimizations.StreamPreferredProperties.exactlyPartitionedOn) Patterns.join(com.facebook.presto.sql.planner.plan.Patterns.join) Pattern(com.facebook.presto.matching.Pattern) StreamPreferredProperties.fixedParallelism(com.facebook.presto.sql.planner.optimizations.StreamPreferredProperties.fixedParallelism) LOCAL(com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.LOCAL) Objects.requireNonNull(java.util.Objects.requireNonNull) SystemSessionProperties.isJoinSpillingEnabled(com.facebook.presto.SystemSessionProperties.isJoinSpillingEnabled) StreamPreferredProperties.singleStream(com.facebook.presto.sql.planner.optimizations.StreamPreferredProperties.singleStream) SystemSessionProperties.getTaskConcurrency(com.facebook.presto.SystemSessionProperties.getTaskConcurrency) PlanNodeSearcher.searchFrom(com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom) ExchangeNode.gatheringExchange(com.facebook.presto.sql.planner.plan.ExchangeNode.gatheringExchange) JoinNode(com.facebook.presto.sql.planner.plan.JoinNode) Rule(com.facebook.presto.sql.planner.iterative.Rule) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) StatsProvider(com.facebook.presto.cost.StatsProvider) String.format(java.lang.String.format) SqlParser(com.facebook.presto.sql.parser.SqlParser) StreamPreferredProperties.defaultParallelism(com.facebook.presto.sql.planner.optimizations.StreamPreferredProperties.defaultParallelism) RIGHT(com.facebook.presto.sql.planner.plan.JoinNode.Type.RIGHT) SystemSessionProperties.isSpillEnabled(com.facebook.presto.SystemSessionProperties.isSpillEnabled) PlanNode(com.facebook.presto.spi.plan.PlanNode) List(java.util.List) REPLICATED(com.facebook.presto.sql.planner.plan.JoinNode.DistributionType.REPLICATED) LEFT(com.facebook.presto.sql.planner.plan.JoinNode.Type.LEFT) TableScanNode(com.facebook.presto.spi.plan.TableScanNode) StreamProperties(com.facebook.presto.sql.planner.optimizations.StreamPropertyDerivations.StreamProperties) StreamPreferredProperties(com.facebook.presto.sql.planner.optimizations.StreamPreferredProperties) Optional(java.util.Optional) PARTITIONED(com.facebook.presto.sql.planner.plan.JoinNode.DistributionType.PARTITIONED) Metadata(com.facebook.presto.metadata.Metadata) ExchangeNode(com.facebook.presto.sql.planner.plan.ExchangeNode) PlanNode(com.facebook.presto.spi.plan.PlanNode) TableScanNode(com.facebook.presto.spi.plan.TableScanNode) StatsProvider(com.facebook.presto.cost.StatsProvider) ExchangeNode(com.facebook.presto.sql.planner.plan.ExchangeNode) JoinNode(com.facebook.presto.sql.planner.plan.JoinNode) VariableReferenceExpression(com.facebook.presto.spi.relation.VariableReferenceExpression)

Example 5 with Captures

use of com.facebook.presto.matching.Captures in project presto by prestodb.

the class RewriteAggregationIfToFilter method apply.

@Override
public Result apply(AggregationNode aggregationNode, Captures captures, Context context) {
    ProjectNode sourceProject = captures.get(CHILD);
    Set<Aggregation> aggregationsToRewrite = aggregationNode.getAggregations().values().stream().filter(aggregation -> shouldRewriteAggregation(aggregation, sourceProject)).collect(toImmutableSet());
    if (aggregationsToRewrite.isEmpty()) {
        return Result.empty();
    }
    context.getSession().getRuntimeStats().addMetricValue(REWRITE_AGGREGATION_IF_TO_FILTER_APPLIED, 1);
    // Get the corresponding assignments in the input project.
    // The aggregationReferences only has the aggregations to rewrite, thus the sourceAssignments only has IF/CAST(IF) expressions with NULL false results.
    // Multiple aggregations may reference the same input. We use a map to dedup them based on the VariableReferenceExpression, so that we only do the rewrite once per input
    // IF expression.
    // The order of sourceAssignments determines the order of generating the new variables for the IF conditions and results. We use a sorted map to get a deterministic
    // order based on the name of the VariableReferenceExpressions.
    Map<VariableReferenceExpression, RowExpression> sourceAssignments = aggregationsToRewrite.stream().map(aggregation -> (VariableReferenceExpression) aggregation.getArguments().get(0)).collect(toImmutableSortedMap(VariableReferenceExpression::compareTo, identity(), variable -> sourceProject.getAssignments().get(variable), (left, right) -> left));
    Assignments.Builder newAssignments = Assignments.builder();
    newAssignments.putAll(sourceProject.getAssignments());
    // Map from the aggregation reference to the IF condition reference which will be put in the mask.
    Map<VariableReferenceExpression, VariableReferenceExpression> aggregationReferenceToConditionReference = new HashMap<>();
    // Map from the aggregation reference to the IF result reference. This only contains the aggregates where the IF can be safely unwrapped.
    // E.g., SUM(IF(CARDINALITY(array) > 0, array[1])) will not be included in this map as array[1] can return errors if we unwrap the IF.
    Map<VariableReferenceExpression, VariableReferenceExpression> aggregationReferenceToIfResultReference = new HashMap<>();
    AggregationIfToFilterRewriteStrategy rewriteStrategy = getAggregationIfToFilterRewriteStrategy(context.getSession());
    for (Map.Entry<VariableReferenceExpression, RowExpression> entry : sourceAssignments.entrySet()) {
        VariableReferenceExpression outputVariable = entry.getKey();
        RowExpression rowExpression = entry.getValue();
        SpecialFormExpression ifExpression = (SpecialFormExpression) ((rowExpression instanceof CallExpression) ? ((CallExpression) rowExpression).getArguments().get(0) : rowExpression);
        RowExpression condition = ifExpression.getArguments().get(0);
        VariableReferenceExpression conditionReference = context.getVariableAllocator().newVariable(condition);
        newAssignments.put(conditionReference, condition);
        aggregationReferenceToConditionReference.put(outputVariable, conditionReference);
        if (canUnwrapIf(ifExpression, rewriteStrategy)) {
            RowExpression trueResult = ifExpression.getArguments().get(1);
            if (rowExpression instanceof CallExpression) {
                // Wrap the result with CAST().
                trueResult = new CallExpression(((CallExpression) rowExpression).getDisplayName(), ((CallExpression) rowExpression).getFunctionHandle(), rowExpression.getType(), ImmutableList.of(trueResult));
            }
            VariableReferenceExpression ifResultReference = context.getVariableAllocator().newVariable(trueResult);
            newAssignments.put(ifResultReference, trueResult);
            aggregationReferenceToIfResultReference.put(outputVariable, ifResultReference);
        }
    }
    // Build new aggregations.
    ImmutableMap.Builder<VariableReferenceExpression, Aggregation> aggregations = ImmutableMap.builder();
    // Stores the masks used to build the filter predicates. Use set to dedup the predicates.
    ImmutableSortedSet.Builder<VariableReferenceExpression> masks = ImmutableSortedSet.naturalOrder();
    for (Map.Entry<VariableReferenceExpression, Aggregation> entry : aggregationNode.getAggregations().entrySet()) {
        VariableReferenceExpression output = entry.getKey();
        Aggregation aggregation = entry.getValue();
        if (!aggregationsToRewrite.contains(aggregation)) {
            aggregations.put(output, aggregation);
            continue;
        }
        VariableReferenceExpression aggregationReference = (VariableReferenceExpression) aggregation.getArguments().get(0);
        CallExpression callExpression = aggregation.getCall();
        VariableReferenceExpression ifResultReference = aggregationReferenceToIfResultReference.get(aggregationReference);
        if (ifResultReference != null) {
            callExpression = new CallExpression(callExpression.getSourceLocation(), callExpression.getDisplayName(), callExpression.getFunctionHandle(), callExpression.getType(), ImmutableList.of(ifResultReference));
        }
        VariableReferenceExpression mask = aggregationReferenceToConditionReference.get(aggregationReference);
        aggregations.put(output, new Aggregation(callExpression, Optional.empty(), aggregation.getOrderBy(), aggregation.isDistinct(), Optional.of(aggregationReferenceToConditionReference.get(aggregationReference))));
        masks.add(mask);
    }
    RowExpression predicate = TRUE_CONSTANT;
    if (!aggregationNode.hasNonEmptyGroupingSet() && aggregationsToRewrite.size() == aggregationNode.getAggregations().size()) {
        // All aggregations are rewritten by this rule. We can add a filter with all the masks to make the query more efficient.
        predicate = or(masks.build());
    }
    return Result.ofPlanNode(new AggregationNode(aggregationNode.getSourceLocation(), context.getIdAllocator().getNextId(), new FilterNode(aggregationNode.getSourceLocation(), context.getIdAllocator().getNextId(), new ProjectNode(context.getIdAllocator().getNextId(), sourceProject.getSource(), newAssignments.build()), predicate), aggregations.build(), aggregationNode.getGroupingSets(), aggregationNode.getPreGroupedVariables(), aggregationNode.getStep(), aggregationNode.getHashVariable(), aggregationNode.getGroupIdVariable()));
}
Also used : FunctionAndTypeManager(com.facebook.presto.metadata.FunctionAndTypeManager) AggregationNode(com.facebook.presto.spi.plan.AggregationNode) Captures(com.facebook.presto.matching.Captures) Assignments(com.facebook.presto.spi.plan.Assignments) AtomicBoolean(java.util.concurrent.atomic.AtomicBoolean) ImmutableSortedMap.toImmutableSortedMap(com.google.common.collect.ImmutableSortedMap.toImmutableSortedMap) VariableReferenceExpression(com.facebook.presto.spi.relation.VariableReferenceExpression) HashMap(java.util.HashMap) RowExpressionDeterminismEvaluator(com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator) SystemSessionProperties.getAggregationIfToFilterRewriteStrategy(com.facebook.presto.SystemSessionProperties.getAggregationIfToFilterRewriteStrategy) StandardFunctionResolution(com.facebook.presto.spi.function.StandardFunctionResolution) Pattern(com.facebook.presto.matching.Pattern) FilterNode(com.facebook.presto.spi.plan.FilterNode) Capture(com.facebook.presto.matching.Capture) ImmutableList(com.google.common.collect.ImmutableList) UNWRAP_IF(com.facebook.presto.sql.analyzer.FeaturesConfig.AggregationIfToFilterRewriteStrategy.UNWRAP_IF) IF(com.facebook.presto.spi.relation.SpecialFormExpression.Form.IF) Map(java.util.Map) Objects.requireNonNull(java.util.Objects.requireNonNull) ImmutableSet.toImmutableSet(com.google.common.collect.ImmutableSet.toImmutableSet) Expressions(com.facebook.presto.sql.relational.Expressions) FunctionResolution(com.facebook.presto.sql.relational.FunctionResolution) CallExpression(com.facebook.presto.spi.relation.CallExpression) SpecialFormExpression(com.facebook.presto.spi.relation.SpecialFormExpression) VariablesExtractor(com.facebook.presto.sql.planner.VariablesExtractor) RowExpression(com.facebook.presto.spi.relation.RowExpression) ImmutableSortedSet(com.google.common.collect.ImmutableSortedSet) Patterns.aggregation(com.facebook.presto.sql.planner.plan.Patterns.aggregation) ImmutableMap(com.google.common.collect.ImmutableMap) Session(com.facebook.presto.Session) Rule(com.facebook.presto.sql.planner.iterative.Rule) Patterns.project(com.facebook.presto.sql.planner.plan.Patterns.project) Set(java.util.Set) LambdaDefinitionExpression(com.facebook.presto.spi.relation.LambdaDefinitionExpression) OperatorType(com.facebook.presto.common.function.OperatorType) AggregationIfToFilterRewriteStrategy(com.facebook.presto.sql.analyzer.FeaturesConfig.AggregationIfToFilterRewriteStrategy) TRUE_CONSTANT(com.facebook.presto.expressions.LogicalRowExpressions.TRUE_CONSTANT) Patterns.source(com.facebook.presto.sql.planner.plan.Patterns.source) FILTER_WITH_IF(com.facebook.presto.sql.analyzer.FeaturesConfig.AggregationIfToFilterRewriteStrategy.FILTER_WITH_IF) DefaultRowExpressionTraversalVisitor(com.facebook.presto.expressions.DefaultRowExpressionTraversalVisitor) ProjectNode(com.facebook.presto.spi.plan.ProjectNode) LogicalRowExpressions.or(com.facebook.presto.expressions.LogicalRowExpressions.or) Capture.newCapture(com.facebook.presto.matching.Capture.newCapture) Aggregation(com.facebook.presto.spi.plan.AggregationNode.Aggregation) Function.identity(java.util.function.Function.identity) Optional(java.util.Optional) REWRITE_AGGREGATION_IF_TO_FILTER_APPLIED(com.facebook.presto.common.RuntimeMetricName.REWRITE_AGGREGATION_IF_TO_FILTER_APPLIED) SystemSessionProperties.getAggregationIfToFilterRewriteStrategy(com.facebook.presto.SystemSessionProperties.getAggregationIfToFilterRewriteStrategy) AggregationIfToFilterRewriteStrategy(com.facebook.presto.sql.analyzer.FeaturesConfig.AggregationIfToFilterRewriteStrategy) HashMap(java.util.HashMap) FilterNode(com.facebook.presto.spi.plan.FilterNode) Assignments(com.facebook.presto.spi.plan.Assignments) RowExpression(com.facebook.presto.spi.relation.RowExpression) AggregationNode(com.facebook.presto.spi.plan.AggregationNode) ImmutableMap(com.google.common.collect.ImmutableMap) Aggregation(com.facebook.presto.spi.plan.AggregationNode.Aggregation) VariableReferenceExpression(com.facebook.presto.spi.relation.VariableReferenceExpression) ImmutableSortedSet(com.google.common.collect.ImmutableSortedSet) ProjectNode(com.facebook.presto.spi.plan.ProjectNode) ImmutableSortedMap.toImmutableSortedMap(com.google.common.collect.ImmutableSortedMap.toImmutableSortedMap) HashMap(java.util.HashMap) Map(java.util.Map) ImmutableMap(com.google.common.collect.ImmutableMap) SpecialFormExpression(com.facebook.presto.spi.relation.SpecialFormExpression) CallExpression(com.facebook.presto.spi.relation.CallExpression)

Aggregations

Captures (com.facebook.presto.matching.Captures)7 Pattern (com.facebook.presto.matching.Pattern)7 Rule (com.facebook.presto.sql.planner.iterative.Rule)7 VariableReferenceExpression (com.facebook.presto.spi.relation.VariableReferenceExpression)6 RowExpression (com.facebook.presto.spi.relation.RowExpression)5 List (java.util.List)5 Map (java.util.Map)5 Assignments (com.facebook.presto.spi.plan.Assignments)4 ProjectNode (com.facebook.presto.spi.plan.ProjectNode)4 Patterns.project (com.facebook.presto.sql.planner.plan.Patterns.project)4 ImmutableList (com.google.common.collect.ImmutableList)4 ImmutableMap (com.google.common.collect.ImmutableMap)4 Optional (java.util.Optional)4 Set (java.util.Set)4 Capture (com.facebook.presto.matching.Capture)3 Capture.newCapture (com.facebook.presto.matching.Capture.newCapture)3 PlanNode (com.facebook.presto.spi.plan.PlanNode)3 Patterns.source (com.facebook.presto.sql.planner.plan.Patterns.source)3 OriginalExpressionUtils (com.facebook.presto.sql.relational.OriginalExpressionUtils)3 Collectors (java.util.stream.Collectors)3