use of io.prestosql.spi.plan.JoinNode in project hetu-core by openlookeng.
the class TestCostCalculator method testReplicatedJoinWithExchange.
@Test
public void testReplicatedJoinWithExchange() {
TableScanNode ts1 = tableScan("ts1", "orderkey");
TableScanNode ts2 = tableScan("ts2", "orderkey_0");
ExchangeNode remoteExchange2 = replicatedExchange(new PlanNodeId("re2"), REMOTE, ts2);
ExchangeNode localExchange = partitionedExchange(new PlanNodeId("le"), LOCAL, remoteExchange2, ImmutableList.of(new Symbol("orderkey_0")), Optional.empty());
JoinNode join = join("join", ts1, localExchange, JoinNode.DistributionType.REPLICATED, "orderkey", "orderkey_0");
Map<String, PlanNodeStatsEstimate> stats = ImmutableMap.<String, PlanNodeStatsEstimate>builder().put("join", statsEstimate(join, 12000)).put("re2", statsEstimate(remoteExchange2, 10000)).put("le", statsEstimate(localExchange, 6000)).put("ts1", statsEstimate(ts1, 6000)).put("ts2", statsEstimate(ts2, 1000)).build();
Map<String, Type> types = ImmutableMap.of("orderkey", BIGINT, "orderkey_0", BIGINT);
assertFragmentedEqualsUnfragmented(join, stats, types);
}
use of io.prestosql.spi.plan.JoinNode in project hetu-core by openlookeng.
the class TestCostCalculator method testRepartitionedJoinWithExchange.
@Test
public void testRepartitionedJoinWithExchange() {
TableScanNode ts1 = tableScan("ts1", "orderkey");
TableScanNode ts2 = tableScan("ts2", "orderkey_0");
ExchangeNode remoteExchange1 = partitionedExchange(new PlanNodeId("re1"), REMOTE, ts1, ImmutableList.of(new Symbol("orderkey")), Optional.empty());
ExchangeNode remoteExchange2 = partitionedExchange(new PlanNodeId("re2"), REMOTE, ts2, ImmutableList.of(new Symbol("orderkey_0")), Optional.empty());
ExchangeNode localExchange = partitionedExchange(new PlanNodeId("le"), LOCAL, remoteExchange2, ImmutableList.of(new Symbol("orderkey_0")), Optional.empty());
JoinNode join = join("join", remoteExchange1, localExchange, JoinNode.DistributionType.PARTITIONED, "orderkey", "orderkey_0");
Map<String, PlanNodeStatsEstimate> stats = ImmutableMap.<String, PlanNodeStatsEstimate>builder().put("join", statsEstimate(join, 12000)).put("re1", statsEstimate(remoteExchange1, 10000)).put("re2", statsEstimate(remoteExchange2, 10000)).put("le", statsEstimate(localExchange, 6000)).put("ts1", statsEstimate(ts1, 6000)).put("ts2", statsEstimate(ts2, 1000)).build();
Map<String, Type> types = ImmutableMap.of("orderkey", BIGINT, "orderkey_0", BIGINT);
assertFragmentedEqualsUnfragmented(join, stats, types);
}
use of io.prestosql.spi.plan.JoinNode in project hetu-core by openlookeng.
the class TransformCorrelatedLateralJoinToJoin method apply.
@Override
public Result apply(LateralJoinNode lateralJoinNode, Captures captures, Context context) {
PlanNode subquery = lateralJoinNode.getSubquery();
PlanNodeDecorrelator planNodeDecorrelator = new PlanNodeDecorrelator(context.getIdAllocator(), context.getLookup());
Optional<DecorrelatedNode> decorrelatedNodeOptional = planNodeDecorrelator.decorrelateFilters(subquery, lateralJoinNode.getCorrelation());
return decorrelatedNodeOptional.map(decorrelatedNode -> {
Expression joinFilter = combineConjuncts(decorrelatedNode.getCorrelatedPredicates().orElse(TRUE_LITERAL), lateralJoinNode.getFilter());
return Result.ofPlanNode(new JoinNode(context.getIdAllocator().getNextId(), lateralJoinNode.getType().toJoinNodeType(), lateralJoinNode.getInput(), decorrelatedNode.getNode(), ImmutableList.of(), lateralJoinNode.getOutputSymbols(), joinFilter.equals(TRUE_LITERAL) ? Optional.empty() : Optional.of(castToRowExpression(joinFilter)), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), ImmutableMap.of()));
}).orElseGet(Result::empty);
}
use of io.prestosql.spi.plan.JoinNode in project hetu-core by openlookeng.
the class TransformUnCorrelatedInPredicateSubQuerySelfJoinToAggregate method transformProjectNode.
private Optional<ProjectNode> transformProjectNode(Context context, ProjectNode projectNode) {
// IN predicate requires only one projection
if (projectNode.getOutputSymbols().size() > 1) {
return Optional.empty();
}
PlanNode source = context.getLookup().resolve(projectNode.getSource());
if (source instanceof CTEScanNode) {
source = getChildFilterNode(context, context.getLookup().resolve(((CTEScanNode) source).getSource()));
}
if (!(source instanceof FilterNode && context.getLookup().resolve(((FilterNode) source).getSource()) instanceof JoinNode)) {
return Optional.empty();
}
FilterNode filter = (FilterNode) source;
Expression predicate = OriginalExpressionUtils.castToExpression(filter.getPredicate());
List<SymbolReference> allPredicateSymbols = new ArrayList<>();
getAllSymbols(predicate, allPredicateSymbols);
JoinNode joinNode = (JoinNode) context.getLookup().resolve(((FilterNode) source).getSource());
if (!isSelfJoin(projectNode, predicate, joinNode, context.getLookup())) {
// Check next level for Self Join
PlanNode left = context.getLookup().resolve(joinNode.getLeft());
boolean changed = false;
if (left instanceof ProjectNode) {
Optional<ProjectNode> transformResult = transformProjectNode(context, (ProjectNode) left);
if (transformResult.isPresent()) {
joinNode = new JoinNode(joinNode.getId(), joinNode.getType(), transformResult.get(), joinNode.getRight(), joinNode.getCriteria(), joinNode.getOutputSymbols(), joinNode.getFilter(), joinNode.getLeftHashSymbol(), joinNode.getRightHashSymbol(), joinNode.getDistributionType(), joinNode.isSpillable(), joinNode.getDynamicFilters());
changed = true;
}
}
PlanNode right = context.getLookup().resolve(joinNode.getRight());
if (right instanceof ProjectNode) {
Optional<ProjectNode> transformResult = transformProjectNode(context, (ProjectNode) right);
if (transformResult.isPresent()) {
joinNode = new JoinNode(joinNode.getId(), joinNode.getType(), joinNode.getLeft(), transformResult.get(), joinNode.getCriteria(), joinNode.getOutputSymbols(), joinNode.getFilter(), joinNode.getLeftHashSymbol(), joinNode.getRightHashSymbol(), joinNode.getDistributionType(), joinNode.isSpillable(), joinNode.getDynamicFilters());
changed = true;
}
}
if (changed) {
FilterNode transformedFilter = new FilterNode(filter.getId(), joinNode, filter.getPredicate());
ProjectNode transformedProject = new ProjectNode(projectNode.getId(), transformedFilter, projectNode.getAssignments());
return Optional.of(transformedProject);
}
return Optional.empty();
}
// Choose the table to use based on projected output.
TableScanNode leftTable = (TableScanNode) context.getLookup().resolve(joinNode.getLeft());
TableScanNode rightTable = (TableScanNode) context.getLookup().resolve(joinNode.getRight());
TableScanNode tableToUse;
List<RowExpression> aggregationSymbols;
AggregationNode.GroupingSetDescriptor groupingSetDescriptor;
Assignments projectionsForCTE = null;
Assignments projectionsFromFilter = null;
// Use non-projected column for aggregation
if (context.getLookup().resolve(projectNode.getSource()) instanceof CTEScanNode) {
CTEScanNode cteScanNode = (CTEScanNode) context.getLookup().resolve(projectNode.getSource());
ProjectNode childProjectOfCte = (ProjectNode) context.getLookup().resolve(cteScanNode.getSource());
List<Symbol> completeOutputSymbols = new ArrayList<>();
rightTable.getOutputSymbols().forEach((s) -> completeOutputSymbols.add(s));
leftTable.getOutputSymbols().forEach((s) -> completeOutputSymbols.add(s));
List<Symbol> outputSymbols = new ArrayList<>();
for (int i = 0; i < completeOutputSymbols.size(); i++) {
Symbol outputSymbol = completeOutputSymbols.get(i);
for (Symbol symbol : projectNode.getOutputSymbols()) {
if (childProjectOfCte.getAssignments().getMap().containsKey(symbol)) {
if (((SymbolReference) OriginalExpressionUtils.castToExpression(childProjectOfCte.getAssignments().getMap().get(symbol))).getName().equals(outputSymbol.getName())) {
outputSymbols.add(outputSymbol);
}
}
}
}
Map<Symbol, RowExpression> projectionsForCTEMap = new HashMap<>();
Map<Symbol, RowExpression> projectionsFromFilterMap = new HashMap<>();
for (Map.Entry entry : childProjectOfCte.getAssignments().getMap().entrySet()) {
if (entry.getKey().equals(getOnlyElement(projectNode.getOutputSymbols()))) {
projectionsForCTEMap.put((Symbol) entry.getKey(), (RowExpression) entry.getValue());
}
if (entry.getKey().equals(getOnlyElement(projectNode.getOutputSymbols()))) {
projectionsFromFilterMap.put(getOnlyElement(outputSymbols), (RowExpression) entry.getValue());
}
}
projectionsForCTE = new Assignments(projectionsForCTEMap);
projectionsFromFilter = new Assignments(projectionsFromFilterMap);
tableToUse = leftTable.getOutputSymbols().contains(getOnlyElement(outputSymbols)) ? leftTable : rightTable;
aggregationSymbols = allPredicateSymbols.stream().filter(s -> tableToUse.getOutputSymbols().contains(SymbolUtils.from(s))).filter(s -> !outputSymbols.contains(SymbolUtils.from(s))).map(OriginalExpressionUtils::castToRowExpression).collect(Collectors.toList());
// Create aggregation
groupingSetDescriptor = new AggregationNode.GroupingSetDescriptor(ImmutableList.copyOf(outputSymbols), 1, ImmutableSet.of());
} else {
tableToUse = leftTable.getOutputSymbols().contains(getOnlyElement(projectNode.getOutputSymbols())) ? leftTable : rightTable;
aggregationSymbols = allPredicateSymbols.stream().filter(s -> tableToUse.getOutputSymbols().contains(SymbolUtils.from(s))).filter(s -> !projectNode.getOutputSymbols().contains(SymbolUtils.from(s))).map(OriginalExpressionUtils::castToRowExpression).collect(Collectors.toList());
// Create aggregation
groupingSetDescriptor = new AggregationNode.GroupingSetDescriptor(ImmutableList.copyOf(projectNode.getOutputSymbols()), 1, ImmutableSet.of());
}
AggregationNode.Aggregation aggregation = new AggregationNode.Aggregation(new CallExpression("count", functionResolution.countFunction(), BIGINT, aggregationSymbols), aggregationSymbols, // mark DISTINCT since NOT_EQUALS predicate
true, Optional.empty(), Optional.empty(), Optional.empty());
ImmutableMap.Builder<Symbol, AggregationNode.Aggregation> aggregationsBuilder = ImmutableMap.builder();
Symbol countSymbol = context.getSymbolAllocator().newSymbol(aggregation.getFunctionCall().getDisplayName(), BIGINT);
aggregationsBuilder.put(countSymbol, aggregation);
AggregationNode aggregationNode = new AggregationNode(context.getIdAllocator().getNextId(), tableToUse, aggregationsBuilder.build(), groupingSetDescriptor, ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty(), AggregationNode.AggregationType.HASH, Optional.empty());
// Filter rows with count < 1 from aggregation results to match the NOT_EQUALS clause in original query.
FilterNode filterNode = new FilterNode(context.getIdAllocator().getNextId(), aggregationNode, OriginalExpressionUtils.castToRowExpression(new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN, SymbolUtils.toSymbolReference(countSymbol), new GenericLiteral("BIGINT", "1"))));
// Project the aggregated+filtered rows.
ProjectNode transformedSubquery = new ProjectNode(projectNode.getId(), filterNode, projectNode.getAssignments());
if (context.getLookup().resolve(projectNode.getSource()) instanceof CTEScanNode) {
CTEScanNode cteScanNode = (CTEScanNode) context.getLookup().resolve(projectNode.getSource());
PlanNode projectNodeForCTE = context.getLookup().resolve(cteScanNode.getSource());
PlanNode projectNodeFromFilter = context.getLookup().resolve(projectNodeForCTE);
projectNodeFromFilter = new ProjectNode(projectNodeFromFilter.getId(), filterNode, projectionsFromFilter);
projectNodeForCTE = new ProjectNode(projectNodeForCTE.getId(), projectNodeFromFilter, projectionsForCTE);
cteScanNode = (CTEScanNode) cteScanNode.replaceChildren(ImmutableList.of(projectNodeForCTE));
cteScanNode.setOutputSymbols(projectNode.getOutputSymbols());
transformedSubquery = new ProjectNode(projectNode.getId(), cteScanNode, projectNode.getAssignments());
}
return Optional.of(transformedSubquery);
}
use of io.prestosql.spi.plan.JoinNode in project hetu-core by openlookeng.
the class TablePushdown method getNewIntermediateTreeAfterInnerTableUpdate.
/**
* @param newInnerJoinNode is the recently created join node after the outer table has been pushed into subquery
* @param stack is the stack of nodes from the original captured JoinNode and subquery TableScanNode
* @return the PlanNode after pulling the subquery table's Aggregation and Group By above the join
*/
private PlanNode getNewIntermediateTreeAfterInnerTableUpdate(JoinNode newInnerJoinNode, Stack<NodeWithTreeDirection> stack) {
TableScanNode subqueryTableNode = (TableScanNode) stack.peek().getNode();
/*
* create assignment builder for a new intermediate ProjectNode between the newInnerJoinNode and
* TableScanNode for subquery table
* */
Assignments.Builder assignmentsBuilder = Assignments.builder();
/*
* All symbols from TableScanNode are directly copied over
* */
for (Map.Entry<Symbol, ColumnHandle> tableEntry : subqueryTableNode.getAssignments().entrySet()) {
Symbol s = tableEntry.getKey();
assignmentsBuilder.put(s, castToRowExpression(new SymbolReference(s.getName())));
}
ProjectNode parentOfSubqueryTableNode = new ProjectNode(ruleContext.getIdAllocator().getNextId(), subqueryTableNode, assignmentsBuilder.build());
List<Symbol> parentOfSubqueryTableNodeOutputSymbols = parentOfSubqueryTableNode.getOutputSymbols();
/*
* Recreate the inner joinNode using the new ProjectNode as one of its sources.
* */
PlanNodeStatsEstimate leftSourceStats = ruleContext.getStatsProvider().getStats(newInnerJoinNode.getLeft());
PlanNodeStatsEstimate rightSourceStats = ruleContext.getStatsProvider().getStats(parentOfSubqueryTableNode);
JoinNode newJoinNode;
if (leftSourceStats.isOutputRowCountUnknown()) {
/*
* CAUTION: the stats are not available, so source reordering is not allowed. Query may fail
* */
newJoinNode = new JoinNode(newInnerJoinNode.getId(), newInnerJoinNode.getType(), newInnerJoinNode.getLeft(), parentOfSubqueryTableNode, newInnerJoinNode.getCriteria(), ImmutableList.<Symbol>builder().addAll(parentOfSubqueryTableNodeOutputSymbols).build(), newInnerJoinNode.getFilter(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), newInnerJoinNode.getDynamicFilters());
} else {
double leftSourceRowsCount = leftSourceStats.getOutputRowCount();
double rightSourceRowsCount = rightSourceStats.getOutputRowCount();
if (leftSourceRowsCount <= rightSourceRowsCount) {
// We reorder the children of this new join node such that the table with more rows is on the left
List<JoinNode.EquiJoinClause> newInnerJoinCriteria = newInnerJoinNode.getCriteria().stream().map(JoinNode.EquiJoinClause::flip).collect(toImmutableList());
newJoinNode = new JoinNode(newInnerJoinNode.getId(), newInnerJoinNode.getType(), parentOfSubqueryTableNode, newInnerJoinNode.getLeft(), newInnerJoinCriteria, ImmutableList.<Symbol>builder().addAll(parentOfSubqueryTableNodeOutputSymbols).build(), newInnerJoinNode.getFilter(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), newInnerJoinNode.getDynamicFilters());
} else {
newJoinNode = new JoinNode(newInnerJoinNode.getId(), newInnerJoinNode.getType(), newInnerJoinNode.getLeft(), parentOfSubqueryTableNode, newInnerJoinNode.getCriteria(), ImmutableList.<Symbol>builder().addAll(parentOfSubqueryTableNodeOutputSymbols).build(), newInnerJoinNode.getFilter(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), newInnerJoinNode.getDynamicFilters());
}
}
// Remove the TableScanNode from the stack
stack.pop();
AggregationNode newAggNode = (AggregationNode) stack.peek().getNode().replaceChildren(ImmutableList.of(newJoinNode));
return stack.firstElement().getNode().replaceChildren(ImmutableList.of(newAggNode));
}
Aggregations