use of io.trino.sql.planner.plan.JoinNode in project trino by trinodb.
the class TestJoinNodeFlattener method testPushesProjectionsThroughJoin.
@Test
public void testPushesProjectionsThroughJoin() {
PlanNodeIdAllocator planNodeIdAllocator = new PlanNodeIdAllocator();
PlanBuilder p = planBuilder(planNodeIdAllocator);
Symbol a = p.symbol("A");
Symbol b = p.symbol("B");
Symbol c = p.symbol("C");
Symbol d = p.symbol("D");
ValuesNode valuesA = p.values(a);
ValuesNode valuesB = p.values(b);
ValuesNode valuesC = p.values(c);
JoinNode joinNode = p.join(INNER, p.project(Assignments.of(d, new ArithmeticUnaryExpression(MINUS, a.toSymbolReference())), p.join(INNER, valuesA, valuesB, equiJoinClause(a, b))), valuesC, equiJoinClause(d, c));
MultiJoinNode actual = toMultiJoinNode(queryRunner.getPlannerContext(), joinNode, noLookup(), planNodeIdAllocator, DEFAULT_JOIN_LIMIT, true, testSessionBuilder().build(), createTestingTypeAnalyzer(queryRunner.getPlannerContext()), p.getTypes());
assertEquals(actual.getOutputSymbols(), ImmutableList.of(d, c));
assertEquals(actual.getFilter(), and(createEqualsExpression(a, b), createEqualsExpression(d, c)));
assertTrue(actual.isPushedProjectionThroughJoin());
List<PlanNode> actualSources = ImmutableList.copyOf(actual.getSources());
assertPlan(p.getTypes(), actualSources.get(0), node(ProjectNode.class, values("a")).withNumberOfOutputColumns(2));
assertPlan(p.getTypes(), actualSources.get(1), node(ProjectNode.class, values("b")).withNumberOfOutputColumns(1));
assertPlan(p.getTypes(), actualSources.get(2), values("c"));
}
use of io.trino.sql.planner.plan.JoinNode in project trino by trinodb.
the class TestPlanNodeSearcher method joinNodePreorder.
/**
* This method adds PlanNodeIds of JoinNodes to the builder in pre-order.
* The plan tree must contain only JoinNodes and ValuesNodes.
*/
private static void joinNodePreorder(PlanNode root, ImmutableList.Builder<PlanNodeId> builder) {
if (root instanceof ValuesNode) {
return;
}
if (root instanceof JoinNode) {
builder.add(root.getId());
JoinNode join = (JoinNode) root;
joinNodePreorder(join.getLeft(), builder);
joinNodePreorder(join.getRight(), builder);
return;
}
throw new IllegalArgumentException("unsupported node type: " + root.getClass().getSimpleName());
}
use of io.trino.sql.planner.plan.JoinNode in project trino by trinodb.
the class ExtractSpatialJoins method tryCreateSpatialJoin.
private static Result tryCreateSpatialJoin(Context context, JoinNode joinNode, Expression filter, PlanNodeId nodeId, List<Symbol> outputSymbols, ComparisonExpression spatialComparison, PlannerContext plannerContext, SplitManager splitManager, PageSourceManager pageSourceManager, TypeAnalyzer typeAnalyzer) {
PlanNode leftNode = joinNode.getLeft();
PlanNode rightNode = joinNode.getRight();
List<Symbol> leftSymbols = leftNode.getOutputSymbols();
List<Symbol> rightSymbols = rightNode.getOutputSymbols();
Expression radius;
Optional<Symbol> newRadiusSymbol;
ComparisonExpression newComparison;
if (spatialComparison.getOperator() == LESS_THAN || spatialComparison.getOperator() == LESS_THAN_OR_EQUAL) {
// ST_Distance(a, b) <= r
radius = spatialComparison.getRight();
Set<Symbol> radiusSymbols = extractUnique(radius);
if (radiusSymbols.isEmpty() || (rightSymbols.containsAll(radiusSymbols) && containsNone(leftSymbols, radiusSymbols))) {
newRadiusSymbol = newRadiusSymbol(context, radius);
newComparison = new ComparisonExpression(spatialComparison.getOperator(), spatialComparison.getLeft(), toExpression(newRadiusSymbol, radius));
} else {
return Result.empty();
}
} else {
// r >= ST_Distance(a, b)
radius = spatialComparison.getLeft();
Set<Symbol> radiusSymbols = extractUnique(radius);
if (radiusSymbols.isEmpty() || (rightSymbols.containsAll(radiusSymbols) && containsNone(leftSymbols, radiusSymbols))) {
newRadiusSymbol = newRadiusSymbol(context, radius);
newComparison = new ComparisonExpression(spatialComparison.getOperator().flip(), spatialComparison.getRight(), toExpression(newRadiusSymbol, radius));
} else {
return Result.empty();
}
}
Expression newFilter = replaceExpression(filter, ImmutableMap.of(spatialComparison, newComparison));
PlanNode newRightNode = newRadiusSymbol.map(symbol -> addProjection(context, rightNode, symbol, radius)).orElse(rightNode);
JoinNode newJoinNode = new JoinNode(joinNode.getId(), joinNode.getType(), leftNode, newRightNode, joinNode.getCriteria(), joinNode.getLeftOutputSymbols(), joinNode.getRightOutputSymbols(), joinNode.isMaySkipOutputDuplicates(), Optional.of(newFilter), joinNode.getLeftHashSymbol(), joinNode.getRightHashSymbol(), joinNode.getDistributionType(), joinNode.isSpillable(), joinNode.getDynamicFilters(), joinNode.getReorderJoinStatsAndCost());
return tryCreateSpatialJoin(context, newJoinNode, newFilter, nodeId, outputSymbols, (FunctionCall) newComparison.getLeft(), Optional.of(newComparison.getRight()), plannerContext, splitManager, pageSourceManager, typeAnalyzer);
}
use of io.trino.sql.planner.plan.JoinNode in project trino by trinodb.
the class EliminateCrossJoins method buildJoinTree.
public static PlanNode buildJoinTree(List<Symbol> expectedOutputSymbols, JoinGraph graph, List<Integer> joinOrder, PlanNodeIdAllocator idAllocator) {
requireNonNull(expectedOutputSymbols, "expectedOutputSymbols is null");
requireNonNull(idAllocator, "idAllocator is null");
requireNonNull(graph, "graph is null");
joinOrder = ImmutableList.copyOf(requireNonNull(joinOrder, "joinOrder is null"));
checkArgument(joinOrder.size() >= 2);
PlanNode result = graph.getNode(joinOrder.get(0));
Set<PlanNodeId> alreadyJoinedNodes = new HashSet<>();
alreadyJoinedNodes.add(result.getId());
for (int i = 1; i < joinOrder.size(); i++) {
PlanNode rightNode = graph.getNode(joinOrder.get(i));
alreadyJoinedNodes.add(rightNode.getId());
ImmutableList.Builder<JoinNode.EquiJoinClause> criteria = ImmutableList.builder();
for (JoinGraph.Edge edge : graph.getEdges(rightNode)) {
PlanNode targetNode = edge.getTargetNode();
if (alreadyJoinedNodes.contains(targetNode.getId())) {
criteria.add(new JoinNode.EquiJoinClause(edge.getTargetSymbol(), edge.getSourceSymbol()));
}
}
result = new JoinNode(idAllocator.getNextId(), JoinNode.Type.INNER, result, rightNode, criteria.build(), result.getOutputSymbols(), rightNode.getOutputSymbols(), false, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), ImmutableMap.of(), Optional.empty());
}
List<Expression> filters = graph.getFilters();
for (Expression filter : filters) {
result = new FilterNode(idAllocator.getNextId(), result, filter);
}
// Some nodes are sensitive to what's produced (e.g., DistinctLimit node)
return restrictOutputs(idAllocator, result, ImmutableSet.copyOf(expectedOutputSymbols)).orElse(result);
}
use of io.trino.sql.planner.plan.JoinNode in project trino by trinodb.
the class TestPlanNodeSearcher method testFindAllMultipleSources.
@Test
public void testFindAllMultipleSources() {
List<JoinNode> joins = new ArrayList<>();
for (int i = 0; i < 4; i++) {
joins.add(BUILDER.join(INNER, BUILDER.values(), BUILDER.values()));
}
JoinNode leftSource = BUILDER.join(INNER, joins.get(0), joins.get(1));
JoinNode rightSource = BUILDER.join(INNER, joins.get(2), joins.get(3));
JoinNode root = BUILDER.join(INNER, leftSource, rightSource);
ImmutableList.Builder<PlanNodeId> idsInPreOrder = ImmutableList.builder();
joinNodePreorder(root, idsInPreOrder);
List<PlanNodeId> findAllResult = PlanNodeSearcher.searchFrom(root).where(JoinNode.class::isInstance).findAll().stream().map(PlanNode::getId).collect(toImmutableList());
assertEquals(idsInPreOrder.build(), findAllResult);
}
Aggregations