use of io.trino.sql.planner.optimizations.joins.JoinGraph in project trino by trinodb.
the class EliminateCrossJoins method getJoinOrder.
/**
* Given JoinGraph determine the order of joins between graph nodes
* by traversing JoinGraph. Any graph traversal algorithm could be used
* here (like BFS or DFS), but we use PriorityQueue to preserve
* original JoinOrder as mush as it is possible. PriorityQueue returns
* next nodes to join in order of their occurrence in original Plan.
*/
public static List<Integer> getJoinOrder(JoinGraph graph) {
ImmutableList.Builder<PlanNode> joinOrder = ImmutableList.builder();
Map<PlanNodeId, Integer> priorities = new HashMap<>();
for (int i = 0; i < graph.size(); i++) {
priorities.put(graph.getNode(i).getId(), i);
}
PriorityQueue<PlanNode> nodesToVisit = new PriorityQueue<>(graph.size(), comparing(node -> priorities.get(node.getId())));
Set<PlanNode> visited = new HashSet<>();
nodesToVisit.add(graph.getNode(0));
while (!nodesToVisit.isEmpty()) {
PlanNode node = nodesToVisit.poll();
if (!visited.contains(node)) {
visited.add(node);
joinOrder.add(node);
for (JoinGraph.Edge edge : graph.getEdges(node)) {
nodesToVisit.add(edge.getTargetNode());
}
}
if (nodesToVisit.isEmpty() && visited.size() < graph.size()) {
// disconnected graph, find new starting point
Optional<PlanNode> firstNotVisitedNode = graph.getNodes().stream().filter(graphNode -> !visited.contains(graphNode)).findFirst();
firstNotVisitedNode.ifPresent(nodesToVisit::add);
}
}
checkState(visited.size() == graph.size());
return joinOrder.build().stream().map(node -> priorities.get(node.getId())).collect(toImmutableList());
}
use of io.trino.sql.planner.optimizations.joins.JoinGraph in project trino by trinodb.
the class TestEliminateCrossJoins method testDoNotReorderCrossJoins.
@Test
public void testDoNotReorderCrossJoins() {
Session session = testSessionBuilder().build();
PlanNode plan = joinNode(joinNode(values("a"), values("b")), values("c"), "b", "c");
JoinGraph joinGraph = JoinGraph.buildFrom(tester().getPlannerContext(), plan, noLookup(), new PlanNodeIdAllocator(), session, createTestingTypeAnalyzer(tester().getPlannerContext()), TypeProvider.empty());
assertEquals(getJoinOrder(joinGraph), ImmutableList.of(0, 1, 2));
}
use of io.trino.sql.planner.optimizations.joins.JoinGraph in project trino by trinodb.
the class TestEliminateCrossJoins method testJoinOrder.
@Test
public void testJoinOrder() {
Session session = testSessionBuilder().build();
PlanNode plan = joinNode(joinNode(values("a"), values("b")), values("c"), "a", "c", "b", "c");
JoinGraph joinGraph = JoinGraph.buildFrom(tester().getPlannerContext(), plan, noLookup(), new PlanNodeIdAllocator(), session, createTestingTypeAnalyzer(tester().getPlannerContext()), TypeProvider.empty());
assertEquals(getJoinOrder(joinGraph), ImmutableList.of(0, 2, 1));
}
use of io.trino.sql.planner.optimizations.joins.JoinGraph in project trino by trinodb.
the class TestEliminateCrossJoins method testJoinOrderWithRealCrossJoin.
@Test
public void testJoinOrderWithRealCrossJoin() {
Session session = testSessionBuilder().build();
PlanNode leftPlan = joinNode(joinNode(values("a"), values("b")), values("c"), "a", "c", "b", "c");
PlanNode rightPlan = joinNode(joinNode(values("x"), values("y")), values("z"), "x", "z", "y", "z");
PlanNode plan = joinNode(leftPlan, rightPlan);
JoinGraph joinGraph = JoinGraph.buildFrom(tester().getPlannerContext(), plan, noLookup(), new PlanNodeIdAllocator(), session, createTestingTypeAnalyzer(tester().getPlannerContext()), TypeProvider.empty());
assertEquals(getJoinOrder(joinGraph), ImmutableList.of(0, 2, 1, 3, 5, 4));
}
use of io.trino.sql.planner.optimizations.joins.JoinGraph in project trino by trinodb.
the class EliminateCrossJoins method apply.
@Override
public Result apply(JoinNode node, Captures captures, Context context) {
JoinGraph joinGraph = JoinGraph.buildFrom(plannerContext, node, context.getLookup(), context.getIdAllocator(), context.getSession(), typeAnalyzer, context.getSymbolAllocator().getTypes());
if (joinGraph.size() < 3 || !joinGraph.isContainsCrossJoin()) {
return Result.empty();
}
List<Integer> joinOrder = getJoinOrder(joinGraph);
if (isOriginalOrder(joinOrder)) {
return Result.empty();
}
PlanNode replacement = buildJoinTree(node.getOutputSymbols(), joinGraph, joinOrder, context.getIdAllocator());
return Result.ofPlanNode(replacement);
}
Aggregations