use of io.prestosql.spi.plan.AggregationNode in project hetu-core by openlookeng.
the class SimplifyCountOverConstant method apply.
@Override
public Result apply(AggregationNode parent, Captures captures, Context context) {
ProjectNode child = captures.get(CHILD);
boolean changed = false;
Map<Symbol, AggregationNode.Aggregation> aggregations = new LinkedHashMap<>(parent.getAggregations());
for (Entry<Symbol, AggregationNode.Aggregation> entry : parent.getAggregations().entrySet()) {
Symbol symbol = entry.getKey();
AggregationNode.Aggregation aggregation = entry.getValue();
if (isCountOverConstant(aggregation, child.getAssignments())) {
changed = true;
aggregations.put(symbol, new AggregationNode.Aggregation(new CallExpression("count", functionResolution.countFunction(), BIGINT, ImmutableList.of(), Optional.empty()), ImmutableList.of(), false, Optional.empty(), Optional.empty(), aggregation.getMask()));
}
}
if (!changed) {
return Result.empty();
}
return Result.ofPlanNode(new AggregationNode(parent.getId(), child, aggregations, parent.getGroupingSets(), ImmutableList.of(), parent.getStep(), parent.getHashSymbol(), parent.getGroupIdSymbol(), parent.getAggregationType(), parent.getFinalizeSymbol()));
}
use of io.prestosql.spi.plan.AggregationNode in project hetu-core by openlookeng.
the class SingleDistinctAggregationToGroupBy method apply.
@Override
public Result apply(AggregationNode aggregation, Captures captures, Context context) {
List<Set<Expression>> argumentSets = extractArgumentSets(aggregation).collect(Collectors.toList());
Set<Symbol> symbols = Iterables.getOnlyElement(argumentSets).stream().map(SymbolUtils::from).collect(Collectors.toSet());
return Result.ofPlanNode(new AggregationNode(aggregation.getId(), new AggregationNode(context.getIdAllocator().getNextId(), aggregation.getSource(), ImmutableMap.of(), singleGroupingSet(ImmutableList.<Symbol>builder().addAll(aggregation.getGroupingKeys()).addAll(symbols).build()), ImmutableList.of(), SINGLE, Optional.empty(), Optional.empty(), aggregation.getAggregationType(), aggregation.getFinalizeSymbol()), // remove DISTINCT flag from function calls
aggregation.getAggregations().entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> removeDistinct(e.getValue()))), aggregation.getGroupingSets(), emptyList(), aggregation.getStep(), aggregation.getHashSymbol(), aggregation.getGroupIdSymbol(), aggregation.getAggregationType(), aggregation.getFinalizeSymbol()));
}
use of io.prestosql.spi.plan.AggregationNode in project hetu-core by openlookeng.
the class TablePushdown method getNewIntermediateTreeAfterInnerTableUpdate.
/**
* @param newInnerJoinNode is the recently created join node after the outer table has been pushed into subquery
* @param stack is the stack of nodes from the original captured JoinNode and subquery TableScanNode
* @return the PlanNode after pulling the subquery table's Aggregation and Group By above the join
*/
private PlanNode getNewIntermediateTreeAfterInnerTableUpdate(JoinNode newInnerJoinNode, Stack<NodeWithTreeDirection> stack) {
TableScanNode subqueryTableNode = (TableScanNode) stack.peek().getNode();
/*
* create assignment builder for a new intermediate ProjectNode between the newInnerJoinNode and
* TableScanNode for subquery table
* */
Assignments.Builder assignmentsBuilder = Assignments.builder();
/*
* All symbols from TableScanNode are directly copied over
* */
for (Map.Entry<Symbol, ColumnHandle> tableEntry : subqueryTableNode.getAssignments().entrySet()) {
Symbol s = tableEntry.getKey();
assignmentsBuilder.put(s, castToRowExpression(new SymbolReference(s.getName())));
}
ProjectNode parentOfSubqueryTableNode = new ProjectNode(ruleContext.getIdAllocator().getNextId(), subqueryTableNode, assignmentsBuilder.build());
List<Symbol> parentOfSubqueryTableNodeOutputSymbols = parentOfSubqueryTableNode.getOutputSymbols();
/*
* Recreate the inner joinNode using the new ProjectNode as one of its sources.
* */
PlanNodeStatsEstimate leftSourceStats = ruleContext.getStatsProvider().getStats(newInnerJoinNode.getLeft());
PlanNodeStatsEstimate rightSourceStats = ruleContext.getStatsProvider().getStats(parentOfSubqueryTableNode);
JoinNode newJoinNode;
if (leftSourceStats.isOutputRowCountUnknown()) {
/*
* CAUTION: the stats are not available, so source reordering is not allowed. Query may fail
* */
newJoinNode = new JoinNode(newInnerJoinNode.getId(), newInnerJoinNode.getType(), newInnerJoinNode.getLeft(), parentOfSubqueryTableNode, newInnerJoinNode.getCriteria(), ImmutableList.<Symbol>builder().addAll(parentOfSubqueryTableNodeOutputSymbols).build(), newInnerJoinNode.getFilter(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), newInnerJoinNode.getDynamicFilters());
} else {
double leftSourceRowsCount = leftSourceStats.getOutputRowCount();
double rightSourceRowsCount = rightSourceStats.getOutputRowCount();
if (leftSourceRowsCount <= rightSourceRowsCount) {
// We reorder the children of this new join node such that the table with more rows is on the left
List<JoinNode.EquiJoinClause> newInnerJoinCriteria = newInnerJoinNode.getCriteria().stream().map(JoinNode.EquiJoinClause::flip).collect(toImmutableList());
newJoinNode = new JoinNode(newInnerJoinNode.getId(), newInnerJoinNode.getType(), parentOfSubqueryTableNode, newInnerJoinNode.getLeft(), newInnerJoinCriteria, ImmutableList.<Symbol>builder().addAll(parentOfSubqueryTableNodeOutputSymbols).build(), newInnerJoinNode.getFilter(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), newInnerJoinNode.getDynamicFilters());
} else {
newJoinNode = new JoinNode(newInnerJoinNode.getId(), newInnerJoinNode.getType(), newInnerJoinNode.getLeft(), parentOfSubqueryTableNode, newInnerJoinNode.getCriteria(), ImmutableList.<Symbol>builder().addAll(parentOfSubqueryTableNodeOutputSymbols).build(), newInnerJoinNode.getFilter(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), newInnerJoinNode.getDynamicFilters());
}
}
// Remove the TableScanNode from the stack
stack.pop();
AggregationNode newAggNode = (AggregationNode) stack.peek().getNode().replaceChildren(ImmutableList.of(newJoinNode));
return stack.firstElement().getNode().replaceChildren(ImmutableList.of(newAggNode));
}
use of io.prestosql.spi.plan.AggregationNode in project hetu-core by openlookeng.
the class TablePushdown method verifyIfGroupByInPath.
/**
* Check if the path to the TableScanNode has a GROUP BY clause
* @param stack relevant to the side of the join node which is being checked
* @return boolean denoting if the condition is satisfied
*/
private boolean verifyIfGroupByInPath(Stack<NodeWithTreeDirection> stack) {
boolean hasGroupByClause = false;
PlanNode node;
for (NodeWithTreeDirection childNode : stack) {
node = childNode.getNode();
if (node instanceof AggregationNode) {
if (!((AggregationNode) node).getGroupingKeys().isEmpty()) {
hasGroupByClause = true;
}
}
}
return hasGroupByClause;
}
use of io.prestosql.spi.plan.AggregationNode in project hetu-core by openlookeng.
the class PushPartialAggregationThroughExchange method split.
private PlanNode split(AggregationNode node, Context context) {
// otherwise, add a partial and final with an exchange in between
Map<Symbol, AggregationNode.Aggregation> intermediateAggregation = new HashMap<>();
Map<Symbol, AggregationNode.Aggregation> finalAggregation = new HashMap<>();
for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : node.getAggregations().entrySet()) {
AggregationNode.Aggregation originalAggregation = entry.getValue();
String functionName = metadata.getFunctionAndTypeManager().getFunctionMetadata(originalAggregation.getFunctionHandle()).getName().getObjectName();
FunctionHandle functionHandle = originalAggregation.getFunctionHandle();
InternalAggregationFunction function = metadata.getFunctionAndTypeManager().getAggregateFunctionImplementation(functionHandle);
Symbol intermediateSymbol = context.getSymbolAllocator().newSymbol(functionName, function.getIntermediateType());
checkState(!originalAggregation.getOrderingScheme().isPresent(), "Aggregate with ORDER BY does not support partial aggregation");
intermediateAggregation.put(intermediateSymbol, new AggregationNode.Aggregation(new CallExpression(functionName, functionHandle, function.getIntermediateType(), originalAggregation.getArguments(), Optional.empty()), originalAggregation.getArguments(), originalAggregation.isDistinct(), originalAggregation.getFilter(), originalAggregation.getOrderingScheme(), originalAggregation.getMask()));
// rewrite final aggregation in terms of intermediate function
finalAggregation.put(entry.getKey(), new AggregationNode.Aggregation(new CallExpression(functionName, functionHandle, function.getFinalType(), ImmutableList.<RowExpression>builder().add(new VariableReferenceExpression(intermediateSymbol.getName(), function.getIntermediateType())).addAll(originalAggregation.getArguments().stream().filter(PushPartialAggregationThroughExchange::isLambda).collect(toImmutableList())).build(), Optional.empty()), ImmutableList.<RowExpression>builder().add(new VariableReferenceExpression(intermediateSymbol.getName(), function.getIntermediateType())).addAll(originalAggregation.getArguments().stream().filter(PushPartialAggregationThroughExchange::isLambda).collect(toImmutableList())).build(), false, Optional.empty(), Optional.empty(), Optional.empty()));
}
PlanNode partial = new AggregationNode(context.getIdAllocator().getNextId(), node.getSource(), intermediateAggregation, node.getGroupingSets(), // through the exchange may or may not preserve these properties. Hence, it is safest to drop preGroupedSymbols here.
ImmutableList.of(), PARTIAL, node.getHashSymbol(), node.getGroupIdSymbol(), node.getAggregationType(), node.getFinalizeSymbol());
return new AggregationNode(node.getId(), partial, finalAggregation, node.getGroupingSets(), // through the exchange may or may not preserve these properties. Hence, it is safest to drop preGroupedSymbols here.
ImmutableList.of(), FINAL, node.getHashSymbol(), node.getGroupIdSymbol(), node.getAggregationType(), node.getFinalizeSymbol());
}
Aggregations