Search in sources :

Example 11 with AggregationNode

use of io.prestosql.spi.plan.AggregationNode in project hetu-core by openlookeng.

the class SimplifyCountOverConstant method apply.

@Override
public Result apply(AggregationNode parent, Captures captures, Context context) {
    ProjectNode child = captures.get(CHILD);
    boolean changed = false;
    Map<Symbol, AggregationNode.Aggregation> aggregations = new LinkedHashMap<>(parent.getAggregations());
    for (Entry<Symbol, AggregationNode.Aggregation> entry : parent.getAggregations().entrySet()) {
        Symbol symbol = entry.getKey();
        AggregationNode.Aggregation aggregation = entry.getValue();
        if (isCountOverConstant(aggregation, child.getAssignments())) {
            changed = true;
            aggregations.put(symbol, new AggregationNode.Aggregation(new CallExpression("count", functionResolution.countFunction(), BIGINT, ImmutableList.of(), Optional.empty()), ImmutableList.of(), false, Optional.empty(), Optional.empty(), aggregation.getMask()));
        }
    }
    if (!changed) {
        return Result.empty();
    }
    return Result.ofPlanNode(new AggregationNode(parent.getId(), child, aggregations, parent.getGroupingSets(), ImmutableList.of(), parent.getStep(), parent.getHashSymbol(), parent.getGroupIdSymbol(), parent.getAggregationType(), parent.getFinalizeSymbol()));
}
Also used : Symbol(io.prestosql.spi.plan.Symbol) ProjectNode(io.prestosql.spi.plan.ProjectNode) AggregationNode(io.prestosql.spi.plan.AggregationNode) CallExpression(io.prestosql.spi.relation.CallExpression) LinkedHashMap(java.util.LinkedHashMap)

Example 12 with AggregationNode

use of io.prestosql.spi.plan.AggregationNode in project hetu-core by openlookeng.

the class SingleDistinctAggregationToGroupBy method apply.

@Override
public Result apply(AggregationNode aggregation, Captures captures, Context context) {
    List<Set<Expression>> argumentSets = extractArgumentSets(aggregation).collect(Collectors.toList());
    Set<Symbol> symbols = Iterables.getOnlyElement(argumentSets).stream().map(SymbolUtils::from).collect(Collectors.toSet());
    return Result.ofPlanNode(new AggregationNode(aggregation.getId(), new AggregationNode(context.getIdAllocator().getNextId(), aggregation.getSource(), ImmutableMap.of(), singleGroupingSet(ImmutableList.<Symbol>builder().addAll(aggregation.getGroupingKeys()).addAll(symbols).build()), ImmutableList.of(), SINGLE, Optional.empty(), Optional.empty(), aggregation.getAggregationType(), aggregation.getFinalizeSymbol()), // remove DISTINCT flag from function calls
    aggregation.getAggregations().entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> removeDistinct(e.getValue()))), aggregation.getGroupingSets(), emptyList(), aggregation.getStep(), aggregation.getHashSymbol(), aggregation.getGroupIdSymbol(), aggregation.getAggregationType(), aggregation.getFinalizeSymbol()));
}
Also used : Iterables(com.google.common.collect.Iterables) Patterns.aggregation(io.prestosql.sql.planner.plan.Patterns.aggregation) Pattern(io.prestosql.matching.Pattern) AggregationNode(io.prestosql.spi.plan.AggregationNode) HashSet(java.util.HashSet) Preconditions.checkArgument(com.google.common.base.Preconditions.checkArgument) ImmutableList(com.google.common.collect.ImmutableList) SINGLE(io.prestosql.spi.plan.AggregationNode.Step.SINGLE) Map(java.util.Map) AggregationNode.singleGroupingSet(io.prestosql.spi.plan.AggregationNode.singleGroupingSet) Symbol(io.prestosql.spi.plan.Symbol) SymbolUtils(io.prestosql.sql.planner.SymbolUtils) ImmutableMap(com.google.common.collect.ImmutableMap) Rule(io.prestosql.sql.planner.iterative.Rule) Collections.emptyList(java.util.Collections.emptyList) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) Set(java.util.Set) Collectors(java.util.stream.Collectors) Captures(io.prestosql.matching.Captures) List(java.util.List) Stream(java.util.stream.Stream) Aggregation(io.prestosql.spi.plan.AggregationNode.Aggregation) Optional(java.util.Optional) Expression(io.prestosql.sql.tree.Expression) OriginalExpressionUtils(io.prestosql.sql.relational.OriginalExpressionUtils) HashSet(java.util.HashSet) AggregationNode.singleGroupingSet(io.prestosql.spi.plan.AggregationNode.singleGroupingSet) Set(java.util.Set) Symbol(io.prestosql.spi.plan.Symbol) AggregationNode(io.prestosql.spi.plan.AggregationNode) Map(java.util.Map) ImmutableMap(com.google.common.collect.ImmutableMap)

Example 13 with AggregationNode

use of io.prestosql.spi.plan.AggregationNode in project hetu-core by openlookeng.

the class TablePushdown method getNewIntermediateTreeAfterInnerTableUpdate.

/**
 * @param newInnerJoinNode is the recently created join node after the outer table has been pushed into subquery
 * @param stack is the stack of nodes from the original captured JoinNode and subquery TableScanNode
 * @return the PlanNode after pulling the subquery table's Aggregation and Group By above the join
 */
private PlanNode getNewIntermediateTreeAfterInnerTableUpdate(JoinNode newInnerJoinNode, Stack<NodeWithTreeDirection> stack) {
    TableScanNode subqueryTableNode = (TableScanNode) stack.peek().getNode();
    /*
         * create assignment builder for a new intermediate ProjectNode between the newInnerJoinNode and
         * TableScanNode for subquery table
         * */
    Assignments.Builder assignmentsBuilder = Assignments.builder();
    /*
         * All symbols from TableScanNode are directly copied over
         * */
    for (Map.Entry<Symbol, ColumnHandle> tableEntry : subqueryTableNode.getAssignments().entrySet()) {
        Symbol s = tableEntry.getKey();
        assignmentsBuilder.put(s, castToRowExpression(new SymbolReference(s.getName())));
    }
    ProjectNode parentOfSubqueryTableNode = new ProjectNode(ruleContext.getIdAllocator().getNextId(), subqueryTableNode, assignmentsBuilder.build());
    List<Symbol> parentOfSubqueryTableNodeOutputSymbols = parentOfSubqueryTableNode.getOutputSymbols();
    /*
         * Recreate the inner joinNode using the new ProjectNode as one of its sources.
         * */
    PlanNodeStatsEstimate leftSourceStats = ruleContext.getStatsProvider().getStats(newInnerJoinNode.getLeft());
    PlanNodeStatsEstimate rightSourceStats = ruleContext.getStatsProvider().getStats(parentOfSubqueryTableNode);
    JoinNode newJoinNode;
    if (leftSourceStats.isOutputRowCountUnknown()) {
        /*
            * CAUTION: the stats are not available, so source reordering is not allowed. Query may fail
            * */
        newJoinNode = new JoinNode(newInnerJoinNode.getId(), newInnerJoinNode.getType(), newInnerJoinNode.getLeft(), parentOfSubqueryTableNode, newInnerJoinNode.getCriteria(), ImmutableList.<Symbol>builder().addAll(parentOfSubqueryTableNodeOutputSymbols).build(), newInnerJoinNode.getFilter(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), newInnerJoinNode.getDynamicFilters());
    } else {
        double leftSourceRowsCount = leftSourceStats.getOutputRowCount();
        double rightSourceRowsCount = rightSourceStats.getOutputRowCount();
        if (leftSourceRowsCount <= rightSourceRowsCount) {
            // We reorder the children of this new join node such that the table with more rows is on the left
            List<JoinNode.EquiJoinClause> newInnerJoinCriteria = newInnerJoinNode.getCriteria().stream().map(JoinNode.EquiJoinClause::flip).collect(toImmutableList());
            newJoinNode = new JoinNode(newInnerJoinNode.getId(), newInnerJoinNode.getType(), parentOfSubqueryTableNode, newInnerJoinNode.getLeft(), newInnerJoinCriteria, ImmutableList.<Symbol>builder().addAll(parentOfSubqueryTableNodeOutputSymbols).build(), newInnerJoinNode.getFilter(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), newInnerJoinNode.getDynamicFilters());
        } else {
            newJoinNode = new JoinNode(newInnerJoinNode.getId(), newInnerJoinNode.getType(), newInnerJoinNode.getLeft(), parentOfSubqueryTableNode, newInnerJoinNode.getCriteria(), ImmutableList.<Symbol>builder().addAll(parentOfSubqueryTableNodeOutputSymbols).build(), newInnerJoinNode.getFilter(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), newInnerJoinNode.getDynamicFilters());
        }
    }
    // Remove the TableScanNode from the stack
    stack.pop();
    AggregationNode newAggNode = (AggregationNode) stack.peek().getNode().replaceChildren(ImmutableList.of(newJoinNode));
    return stack.firstElement().getNode().replaceChildren(ImmutableList.of(newAggNode));
}
Also used : ColumnHandle(io.prestosql.spi.connector.ColumnHandle) PlanNodeStatsEstimate(io.prestosql.cost.PlanNodeStatsEstimate) Symbol(io.prestosql.spi.plan.Symbol) SymbolReference(io.prestosql.sql.tree.SymbolReference) JoinNode(io.prestosql.spi.plan.JoinNode) Assignments(io.prestosql.spi.plan.Assignments) AggregationNode(io.prestosql.spi.plan.AggregationNode) TableScanNode(io.prestosql.spi.plan.TableScanNode) ProjectNode(io.prestosql.spi.plan.ProjectNode) Map(java.util.Map) SpecialCommentFormatter.getUniqueColumnTableMap(io.prestosql.sql.util.SpecialCommentFormatter.getUniqueColumnTableMap)

Example 14 with AggregationNode

use of io.prestosql.spi.plan.AggregationNode in project hetu-core by openlookeng.

the class TablePushdown method verifyIfGroupByInPath.

/**
 * Check if the path to the TableScanNode has a GROUP BY clause
 * @param stack relevant to the side of the join node which is being checked
 * @return boolean denoting if the condition is satisfied
 */
private boolean verifyIfGroupByInPath(Stack<NodeWithTreeDirection> stack) {
    boolean hasGroupByClause = false;
    PlanNode node;
    for (NodeWithTreeDirection childNode : stack) {
        node = childNode.getNode();
        if (node instanceof AggregationNode) {
            if (!((AggregationNode) node).getGroupingKeys().isEmpty()) {
                hasGroupByClause = true;
            }
        }
    }
    return hasGroupByClause;
}
Also used : PlanNode(io.prestosql.spi.plan.PlanNode) AggregationNode(io.prestosql.spi.plan.AggregationNode)

Example 15 with AggregationNode

use of io.prestosql.spi.plan.AggregationNode in project hetu-core by openlookeng.

the class PushPartialAggregationThroughExchange method split.

private PlanNode split(AggregationNode node, Context context) {
    // otherwise, add a partial and final with an exchange in between
    Map<Symbol, AggregationNode.Aggregation> intermediateAggregation = new HashMap<>();
    Map<Symbol, AggregationNode.Aggregation> finalAggregation = new HashMap<>();
    for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : node.getAggregations().entrySet()) {
        AggregationNode.Aggregation originalAggregation = entry.getValue();
        String functionName = metadata.getFunctionAndTypeManager().getFunctionMetadata(originalAggregation.getFunctionHandle()).getName().getObjectName();
        FunctionHandle functionHandle = originalAggregation.getFunctionHandle();
        InternalAggregationFunction function = metadata.getFunctionAndTypeManager().getAggregateFunctionImplementation(functionHandle);
        Symbol intermediateSymbol = context.getSymbolAllocator().newSymbol(functionName, function.getIntermediateType());
        checkState(!originalAggregation.getOrderingScheme().isPresent(), "Aggregate with ORDER BY does not support partial aggregation");
        intermediateAggregation.put(intermediateSymbol, new AggregationNode.Aggregation(new CallExpression(functionName, functionHandle, function.getIntermediateType(), originalAggregation.getArguments(), Optional.empty()), originalAggregation.getArguments(), originalAggregation.isDistinct(), originalAggregation.getFilter(), originalAggregation.getOrderingScheme(), originalAggregation.getMask()));
        // rewrite final aggregation in terms of intermediate function
        finalAggregation.put(entry.getKey(), new AggregationNode.Aggregation(new CallExpression(functionName, functionHandle, function.getFinalType(), ImmutableList.<RowExpression>builder().add(new VariableReferenceExpression(intermediateSymbol.getName(), function.getIntermediateType())).addAll(originalAggregation.getArguments().stream().filter(PushPartialAggregationThroughExchange::isLambda).collect(toImmutableList())).build(), Optional.empty()), ImmutableList.<RowExpression>builder().add(new VariableReferenceExpression(intermediateSymbol.getName(), function.getIntermediateType())).addAll(originalAggregation.getArguments().stream().filter(PushPartialAggregationThroughExchange::isLambda).collect(toImmutableList())).build(), false, Optional.empty(), Optional.empty(), Optional.empty()));
    }
    PlanNode partial = new AggregationNode(context.getIdAllocator().getNextId(), node.getSource(), intermediateAggregation, node.getGroupingSets(), // through the exchange may or may not preserve these properties. Hence, it is safest to drop preGroupedSymbols here.
    ImmutableList.of(), PARTIAL, node.getHashSymbol(), node.getGroupIdSymbol(), node.getAggregationType(), node.getFinalizeSymbol());
    return new AggregationNode(node.getId(), partial, finalAggregation, node.getGroupingSets(), // through the exchange may or may not preserve these properties. Hence, it is safest to drop preGroupedSymbols here.
    ImmutableList.of(), FINAL, node.getHashSymbol(), node.getGroupIdSymbol(), node.getAggregationType(), node.getFinalizeSymbol());
}
Also used : HashMap(java.util.HashMap) Symbol(io.prestosql.spi.plan.Symbol) RowExpression(io.prestosql.spi.relation.RowExpression) AggregationNode(io.prestosql.spi.plan.AggregationNode) InternalAggregationFunction(io.prestosql.operator.aggregation.InternalAggregationFunction) SystemSessionProperties.preferPartialAggregation(io.prestosql.SystemSessionProperties.preferPartialAggregation) PlanNode(io.prestosql.spi.plan.PlanNode) VariableReferenceExpression(io.prestosql.spi.relation.VariableReferenceExpression) HashMap(java.util.HashMap) Map(java.util.Map) FunctionHandle(io.prestosql.spi.function.FunctionHandle) CallExpression(io.prestosql.spi.relation.CallExpression)

Aggregations

AggregationNode (io.prestosql.spi.plan.AggregationNode)52 Symbol (io.prestosql.spi.plan.Symbol)40 PlanNode (io.prestosql.spi.plan.PlanNode)30 ProjectNode (io.prestosql.spi.plan.ProjectNode)24 Map (java.util.Map)23 CallExpression (io.prestosql.spi.relation.CallExpression)20 ImmutableMap (com.google.common.collect.ImmutableMap)17 Assignments (io.prestosql.spi.plan.Assignments)16 Aggregation (io.prestosql.spi.plan.AggregationNode.Aggregation)14 RowExpression (io.prestosql.spi.relation.RowExpression)14 Expression (io.prestosql.sql.tree.Expression)14 List (java.util.List)14 ImmutableList (com.google.common.collect.ImmutableList)13 TableScanNode (io.prestosql.spi.plan.TableScanNode)13 Optional (java.util.Optional)13 OriginalExpressionUtils.castToRowExpression (io.prestosql.sql.relational.OriginalExpressionUtils.castToRowExpression)12 Test (org.testng.annotations.Test)12 Session (io.prestosql.Session)11 FilterNode (io.prestosql.spi.plan.FilterNode)11 Metadata (io.prestosql.metadata.Metadata)9