use of io.trino.sql.planner.plan.PlanNode in project trino by trinodb.
the class ExchangeStatsRule method doCalculate.
@Override
protected Optional<PlanNodeStatsEstimate> doCalculate(ExchangeNode node, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types) {
Optional<PlanNodeStatsEstimate> estimate = Optional.empty();
for (int i = 0; i < node.getSources().size(); i++) {
PlanNode source = node.getSources().get(i);
PlanNodeStatsEstimate sourceStats = statsProvider.getStats(source);
PlanNodeStatsEstimate sourceStatsWithMappedSymbols = mapToOutputSymbols(sourceStats, node.getInputs().get(i), node.getOutputSymbols());
if (estimate.isPresent()) {
estimate = Optional.of(addStatsAndMaxDistinctValues(estimate.get(), sourceStatsWithMappedSymbols));
} else {
estimate = Optional.of(sourceStatsWithMappedSymbols);
}
}
verify(estimate.isPresent());
return estimate;
}
use of io.trino.sql.planner.plan.PlanNode in project trino by trinodb.
the class ExtractSpatialJoins method tryCreateSpatialJoin.
private static Result tryCreateSpatialJoin(Context context, JoinNode joinNode, Expression filter, PlanNodeId nodeId, List<Symbol> outputSymbols, FunctionCall spatialFunction, Optional<Expression> radius, PlannerContext plannerContext, SplitManager splitManager, PageSourceManager pageSourceManager, TypeAnalyzer typeAnalyzer) {
// TODO Add support for distributed left spatial joins
Optional<String> spatialPartitioningTableName = joinNode.getType() == INNER ? getSpatialPartitioningTableName(context.getSession()) : Optional.empty();
Optional<KdbTree> kdbTree = spatialPartitioningTableName.map(tableName -> loadKdbTree(tableName, context.getSession(), plannerContext.getMetadata(), splitManager, pageSourceManager));
List<Expression> arguments = spatialFunction.getArguments();
verify(arguments.size() == 2);
Expression firstArgument = arguments.get(0);
Expression secondArgument = arguments.get(1);
Type sphericalGeographyType = plannerContext.getTypeManager().getType(SPHERICAL_GEOGRAPHY_TYPE_SIGNATURE);
if (typeAnalyzer.getType(context.getSession(), context.getSymbolAllocator().getTypes(), firstArgument).equals(sphericalGeographyType) || typeAnalyzer.getType(context.getSession(), context.getSymbolAllocator().getTypes(), secondArgument).equals(sphericalGeographyType)) {
return Result.empty();
}
Set<Symbol> firstSymbols = extractUnique(firstArgument);
Set<Symbol> secondSymbols = extractUnique(secondArgument);
if (firstSymbols.isEmpty() || secondSymbols.isEmpty()) {
return Result.empty();
}
Optional<Symbol> newFirstSymbol = newGeometrySymbol(context, firstArgument, plannerContext.getTypeManager());
Optional<Symbol> newSecondSymbol = newGeometrySymbol(context, secondArgument, plannerContext.getTypeManager());
PlanNode leftNode = joinNode.getLeft();
PlanNode rightNode = joinNode.getRight();
PlanNode newLeftNode;
PlanNode newRightNode;
// Check if the order of arguments of the spatial function matches the order of join sides
int alignment = checkAlignment(joinNode, firstSymbols, secondSymbols);
if (alignment > 0) {
newLeftNode = newFirstSymbol.map(symbol -> addProjection(context, leftNode, symbol, firstArgument)).orElse(leftNode);
newRightNode = newSecondSymbol.map(symbol -> addProjection(context, rightNode, symbol, secondArgument)).orElse(rightNode);
} else if (alignment < 0) {
newLeftNode = newSecondSymbol.map(symbol -> addProjection(context, leftNode, symbol, secondArgument)).orElse(leftNode);
newRightNode = newFirstSymbol.map(symbol -> addProjection(context, rightNode, symbol, firstArgument)).orElse(rightNode);
} else {
return Result.empty();
}
Expression newFirstArgument = toExpression(newFirstSymbol, firstArgument);
Expression newSecondArgument = toExpression(newSecondSymbol, secondArgument);
Optional<Symbol> leftPartitionSymbol = Optional.empty();
Optional<Symbol> rightPartitionSymbol = Optional.empty();
if (kdbTree.isPresent()) {
leftPartitionSymbol = Optional.of(context.getSymbolAllocator().newSymbol("pid", INTEGER));
rightPartitionSymbol = Optional.of(context.getSymbolAllocator().newSymbol("pid", INTEGER));
if (alignment > 0) {
newLeftNode = addPartitioningNodes(plannerContext, context, newLeftNode, leftPartitionSymbol.get(), kdbTree.get(), newFirstArgument, Optional.empty());
newRightNode = addPartitioningNodes(plannerContext, context, newRightNode, rightPartitionSymbol.get(), kdbTree.get(), newSecondArgument, radius);
} else {
newLeftNode = addPartitioningNodes(plannerContext, context, newLeftNode, leftPartitionSymbol.get(), kdbTree.get(), newSecondArgument, Optional.empty());
newRightNode = addPartitioningNodes(plannerContext, context, newRightNode, rightPartitionSymbol.get(), kdbTree.get(), newFirstArgument, radius);
}
}
Expression newSpatialFunction = FunctionCallBuilder.resolve(context.getSession(), plannerContext.getMetadata()).setName(spatialFunction.getName()).addArgument(GEOMETRY_TYPE_SIGNATURE, newFirstArgument).addArgument(GEOMETRY_TYPE_SIGNATURE, newSecondArgument).build();
Expression newFilter = replaceExpression(filter, ImmutableMap.of(spatialFunction, newSpatialFunction));
return Result.ofPlanNode(new SpatialJoinNode(nodeId, SpatialJoinNode.Type.fromJoinNodeType(joinNode.getType()), newLeftNode, newRightNode, outputSymbols, newFilter, leftPartitionSymbol, rightPartitionSymbol, kdbTree.map(KdbTreeUtils::toJson)));
}
use of io.trino.sql.planner.plan.PlanNode in project trino by trinodb.
the class DetermineSemiJoinDistributionType method canReplicate.
private boolean canReplicate(SemiJoinNode node, Context context) {
DataSize joinMaxBroadcastTableSize = getJoinMaxBroadcastTableSize(context.getSession());
PlanNode buildSide = node.getFilteringSource();
PlanNodeStatsEstimate buildSideStatsEstimate = context.getStatsProvider().getStats(buildSide);
double buildSideSizeInBytes = buildSideStatsEstimate.getOutputSizeInBytes(buildSide.getOutputSymbols(), context.getSymbolAllocator().getTypes());
return buildSideSizeInBytes <= joinMaxBroadcastTableSize.toBytes() || getSourceTablesSizeInBytes(buildSide, context) <= joinMaxBroadcastTableSize.toBytes();
}
use of io.trino.sql.planner.plan.PlanNode in project trino by trinodb.
the class EliminateCrossJoins method getJoinOrder.
/**
* Given JoinGraph determine the order of joins between graph nodes
* by traversing JoinGraph. Any graph traversal algorithm could be used
* here (like BFS or DFS), but we use PriorityQueue to preserve
* original JoinOrder as mush as it is possible. PriorityQueue returns
* next nodes to join in order of their occurrence in original Plan.
*/
public static List<Integer> getJoinOrder(JoinGraph graph) {
ImmutableList.Builder<PlanNode> joinOrder = ImmutableList.builder();
Map<PlanNodeId, Integer> priorities = new HashMap<>();
for (int i = 0; i < graph.size(); i++) {
priorities.put(graph.getNode(i).getId(), i);
}
PriorityQueue<PlanNode> nodesToVisit = new PriorityQueue<>(graph.size(), comparing(node -> priorities.get(node.getId())));
Set<PlanNode> visited = new HashSet<>();
nodesToVisit.add(graph.getNode(0));
while (!nodesToVisit.isEmpty()) {
PlanNode node = nodesToVisit.poll();
if (!visited.contains(node)) {
visited.add(node);
joinOrder.add(node);
for (JoinGraph.Edge edge : graph.getEdges(node)) {
nodesToVisit.add(edge.getTargetNode());
}
}
if (nodesToVisit.isEmpty() && visited.size() < graph.size()) {
// disconnected graph, find new starting point
Optional<PlanNode> firstNotVisitedNode = graph.getNodes().stream().filter(graphNode -> !visited.contains(graphNode)).findFirst();
firstNotVisitedNode.ifPresent(nodesToVisit::add);
}
}
checkState(visited.size() == graph.size());
return joinOrder.build().stream().map(node -> priorities.get(node.getId())).collect(toImmutableList());
}
use of io.trino.sql.planner.plan.PlanNode in project trino by trinodb.
the class DecorrelateInnerUnnestWithGlobalAggregation method apply.
@Override
public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Context context) {
// find global aggregation in subquery
List<PlanNode> globalAggregations = PlanNodeSearcher.searchFrom(correlatedJoinNode.getSubquery(), context.getLookup()).where(DecorrelateInnerUnnestWithGlobalAggregation::isGlobalAggregation).recurseOnlyWhen(node -> node instanceof ProjectNode || isGlobalAggregation(node)).findAll();
if (globalAggregations.isEmpty()) {
return Result.empty();
}
// if there are multiple global aggregations, the one that is closest to the source is the "reducing" aggregation, because it reduces multiple input rows to single output row
AggregationNode reducingAggregation = (AggregationNode) globalAggregations.get(globalAggregations.size() - 1);
// find unnest in subquery
Optional<UnnestNode> subqueryUnnest = PlanNodeSearcher.searchFrom(reducingAggregation.getSource(), context.getLookup()).where(node -> isSupportedUnnest(node, correlatedJoinNode.getCorrelation(), context.getLookup())).recurseOnlyWhen(node -> node instanceof ProjectNode || isGroupedAggregation(node)).findFirst();
if (subqueryUnnest.isEmpty()) {
return Result.empty();
}
UnnestNode unnestNode = subqueryUnnest.get();
// assign unique id to input rows to restore semantics of aggregations after rewrite
PlanNode input = new AssignUniqueId(context.getIdAllocator().getNextId(), correlatedJoinNode.getInput(), context.getSymbolAllocator().newSymbol("unique", BIGINT));
// pre-project unnest symbols if they were pre-projected in subquery
// The correlated UnnestNode either unnests correlation symbols directly, or unnests symbols produced by a projection that uses only correlation symbols.
// Here, any underlying projection that was a source of the correlated UnnestNode, is appended as a source of the rewritten UnnestNode.
// If the projection is not necessary for UnnestNode (i.e. it does not produce any unnest symbols), it should be pruned afterwards.
PlanNode unnestSource = context.getLookup().resolve(unnestNode.getSource());
if (unnestSource instanceof ProjectNode) {
ProjectNode sourceProjection = (ProjectNode) unnestSource;
input = new ProjectNode(sourceProjection.getId(), input, Assignments.builder().putIdentities(input.getOutputSymbols()).putAll(sourceProjection.getAssignments()).build());
}
// rewrite correlated join to UnnestNode
Symbol ordinalitySymbol = unnestNode.getOrdinalitySymbol().orElseGet(() -> context.getSymbolAllocator().newSymbol("ordinality", BIGINT));
UnnestNode rewrittenUnnest = new UnnestNode(context.getIdAllocator().getNextId(), input, input.getOutputSymbols(), unnestNode.getMappings(), Optional.of(ordinalitySymbol), LEFT, Optional.empty());
// append mask symbol based on ordinality to distinguish between the unnested rows and synthetic null rows
Symbol mask = context.getSymbolAllocator().newSymbol("mask", BOOLEAN);
ProjectNode sourceWithMask = new ProjectNode(context.getIdAllocator().getNextId(), rewrittenUnnest, Assignments.builder().putIdentities(rewrittenUnnest.getOutputSymbols()).put(mask, new IsNotNullPredicate(ordinalitySymbol.toSymbolReference())).build());
// restore all projections, grouped aggregations and global aggregations from the subquery
PlanNode result = rewriteNodeSequence(context.getLookup().resolve(correlatedJoinNode.getSubquery()), input.getOutputSymbols(), mask, sourceWithMask, reducingAggregation.getId(), unnestNode.getId(), context.getSymbolAllocator(), context.getIdAllocator(), context.getLookup());
// restrict outputs
return Result.ofPlanNode(restrictOutputs(context.getIdAllocator(), result, ImmutableSet.copyOf(correlatedJoinNode.getOutputSymbols())).orElse(result));
}
Aggregations