use of io.trino.sql.planner.plan.CorrelatedJoinNode in project trino by trinodb.
the class PruneCorrelatedJoinColumns method pushDownProjectOff.
@Override
protected Optional<PlanNode> pushDownProjectOff(Context context, CorrelatedJoinNode correlatedJoinNode, Set<Symbol> referencedOutputs) {
PlanNode input = correlatedJoinNode.getInput();
PlanNode subquery = correlatedJoinNode.getSubquery();
// remove unused correlated join node, retain input
if (intersection(ImmutableSet.copyOf(subquery.getOutputSymbols()), referencedOutputs).isEmpty()) {
// remove unused subquery of inner join
if (correlatedJoinNode.getType() == INNER && isScalar(subquery, context.getLookup()) && correlatedJoinNode.getFilter().equals(TRUE_LITERAL)) {
return Optional.of(input);
}
// remove unused subquery of left join
if (correlatedJoinNode.getType() == LEFT && isAtMostScalar(subquery, context.getLookup())) {
return Optional.of(input);
}
}
Set<Symbol> referencedAndCorrelationSymbols = ImmutableSet.<Symbol>builder().addAll(referencedOutputs).addAll(correlatedJoinNode.getCorrelation()).build();
// remove unused input node, retain subquery
if (intersection(ImmutableSet.copyOf(input.getOutputSymbols()), referencedAndCorrelationSymbols).isEmpty()) {
// remove unused input of inner join
if (correlatedJoinNode.getType() == INNER && isScalar(input, context.getLookup()) && correlatedJoinNode.getFilter().equals(TRUE_LITERAL)) {
return Optional.of(subquery);
}
// remove unused input of right join
if (correlatedJoinNode.getType() == RIGHT && isAtMostScalar(input, context.getLookup())) {
return Optional.of(subquery);
}
}
Set<Symbol> filterSymbols = extractUnique(correlatedJoinNode.getFilter());
Set<Symbol> referencedAndFilterSymbols = ImmutableSet.<Symbol>builder().addAll(referencedOutputs).addAll(filterSymbols).build();
Optional<PlanNode> newSubquery = restrictOutputs(context.getIdAllocator(), subquery, referencedAndFilterSymbols);
Set<Symbol> referencedAndFilterAndCorrelationSymbols = ImmutableSet.<Symbol>builder().addAll(referencedAndFilterSymbols).addAll(correlatedJoinNode.getCorrelation()).build();
Optional<PlanNode> newInput = restrictOutputs(context.getIdAllocator(), input, referencedAndFilterAndCorrelationSymbols);
boolean pruned = newSubquery.isPresent() || newInput.isPresent();
if (pruned) {
return Optional.of(new CorrelatedJoinNode(correlatedJoinNode.getId(), newInput.orElse(input), newSubquery.orElse(subquery), correlatedJoinNode.getCorrelation(), correlatedJoinNode.getType(), correlatedJoinNode.getFilter(), correlatedJoinNode.getOriginSubquery()));
}
return Optional.empty();
}
use of io.trino.sql.planner.plan.CorrelatedJoinNode in project trino by trinodb.
the class PruneCorrelatedJoinCorrelation method apply.
@Override
public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Context context) {
Set<Symbol> subquerySymbols = extractUnique(correlatedJoinNode.getSubquery(), context.getLookup());
List<Symbol> newCorrelation = correlatedJoinNode.getCorrelation().stream().filter(subquerySymbols::contains).collect(toImmutableList());
if (newCorrelation.size() < correlatedJoinNode.getCorrelation().size()) {
return Result.ofPlanNode(new CorrelatedJoinNode(correlatedJoinNode.getId(), correlatedJoinNode.getInput(), correlatedJoinNode.getSubquery(), newCorrelation, correlatedJoinNode.getType(), correlatedJoinNode.getFilter(), correlatedJoinNode.getOriginSubquery()));
}
return Result.empty();
}
use of io.trino.sql.planner.plan.CorrelatedJoinNode in project trino by trinodb.
the class TransformCorrelatedGroupedAggregationWithProjection method apply.
@Override
public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Context context) {
// if there is another aggregation below the AggregationNode, handle both
PlanNode source = captures.get(SOURCE);
AggregationNode distinct = null;
if (isDistinctOperator(source)) {
distinct = (AggregationNode) source;
source = distinct.getSource();
}
// decorrelate nested plan
PlanNodeDecorrelator decorrelator = new PlanNodeDecorrelator(plannerContext, context.getSymbolAllocator(), context.getLookup());
Optional<PlanNodeDecorrelator.DecorrelatedNode> decorrelatedSource = decorrelator.decorrelateFilters(source, correlatedJoinNode.getCorrelation());
if (decorrelatedSource.isEmpty()) {
return Result.empty();
}
source = decorrelatedSource.get().getNode();
// assign unique id on correlated join's input. It will be used to distinguish between original input rows after join
PlanNode inputWithUniqueId = new AssignUniqueId(context.getIdAllocator().getNextId(), correlatedJoinNode.getInput(), context.getSymbolAllocator().newSymbol("unique", BIGINT));
JoinNode join = new JoinNode(context.getIdAllocator().getNextId(), JoinNode.Type.INNER, inputWithUniqueId, source, ImmutableList.of(), inputWithUniqueId.getOutputSymbols(), source.getOutputSymbols(), false, decorrelatedSource.get().getCorrelatedPredicates(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), ImmutableMap.of(), Optional.empty());
// restore distinct aggregation
if (distinct != null) {
distinct = new AggregationNode(distinct.getId(), join, distinct.getAggregations(), singleGroupingSet(ImmutableList.<Symbol>builder().addAll(join.getLeftOutputSymbols()).addAll(distinct.getGroupingKeys()).build()), ImmutableList.of(), distinct.getStep(), Optional.empty(), Optional.empty());
}
// restore grouped aggregation
AggregationNode groupedAggregation = captures.get(AGGREGATION);
groupedAggregation = new AggregationNode(groupedAggregation.getId(), distinct != null ? distinct : join, groupedAggregation.getAggregations(), singleGroupingSet(ImmutableList.<Symbol>builder().addAll(join.getLeftOutputSymbols()).addAll(groupedAggregation.getGroupingKeys()).build()), ImmutableList.of(), groupedAggregation.getStep(), Optional.empty(), Optional.empty());
// restrict outputs and apply projection
Set<Symbol> outputSymbols = new HashSet<>(correlatedJoinNode.getOutputSymbols());
List<Symbol> expectedAggregationOutputs = groupedAggregation.getOutputSymbols().stream().filter(outputSymbols::contains).collect(toImmutableList());
Assignments assignments = Assignments.builder().putIdentities(expectedAggregationOutputs).putAll(captures.get(PROJECTION).getAssignments()).build();
return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), groupedAggregation, assignments));
}
use of io.trino.sql.planner.plan.CorrelatedJoinNode in project trino by trinodb.
the class TransformCorrelatedGroupedAggregationWithoutProjection method apply.
@Override
public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Context context) {
// if there is another aggregation below the AggregationNode, handle both
PlanNode source = captures.get(SOURCE);
AggregationNode distinct = null;
if (isDistinctOperator(source)) {
distinct = (AggregationNode) source;
source = distinct.getSource();
}
// decorrelate nested plan
PlanNodeDecorrelator decorrelator = new PlanNodeDecorrelator(plannerContext, context.getSymbolAllocator(), context.getLookup());
Optional<PlanNodeDecorrelator.DecorrelatedNode> decorrelatedSource = decorrelator.decorrelateFilters(source, correlatedJoinNode.getCorrelation());
if (decorrelatedSource.isEmpty()) {
return Result.empty();
}
source = decorrelatedSource.get().getNode();
// assign unique id on correlated join's input. It will be used to distinguish between original input rows after join
PlanNode inputWithUniqueId = new AssignUniqueId(context.getIdAllocator().getNextId(), correlatedJoinNode.getInput(), context.getSymbolAllocator().newSymbol("unique", BIGINT));
JoinNode join = new JoinNode(context.getIdAllocator().getNextId(), JoinNode.Type.INNER, inputWithUniqueId, source, ImmutableList.of(), inputWithUniqueId.getOutputSymbols(), source.getOutputSymbols(), false, decorrelatedSource.get().getCorrelatedPredicates(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), ImmutableMap.of(), Optional.empty());
// restore distinct aggregation
if (distinct != null) {
distinct = new AggregationNode(distinct.getId(), join, distinct.getAggregations(), singleGroupingSet(ImmutableList.<Symbol>builder().addAll(join.getLeftOutputSymbols()).addAll(distinct.getGroupingKeys()).build()), ImmutableList.of(), distinct.getStep(), Optional.empty(), Optional.empty());
}
// restore grouped aggregation
AggregationNode groupedAggregation = captures.get(AGGREGATION);
groupedAggregation = new AggregationNode(groupedAggregation.getId(), distinct != null ? distinct : join, groupedAggregation.getAggregations(), singleGroupingSet(ImmutableList.<Symbol>builder().addAll(join.getLeftOutputSymbols()).addAll(groupedAggregation.getGroupingKeys()).build()), ImmutableList.of(), groupedAggregation.getStep(), Optional.empty(), Optional.empty());
// restrict outputs
Optional<PlanNode> project = restrictOutputs(context.getIdAllocator(), groupedAggregation, ImmutableSet.copyOf(correlatedJoinNode.getOutputSymbols()));
return Result.ofPlanNode(project.orElse(groupedAggregation));
}
use of io.trino.sql.planner.plan.CorrelatedJoinNode in project trino by trinodb.
the class TransformCorrelatedJoinToJoin method apply.
@Override
public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Context context) {
checkArgument(correlatedJoinNode.getType() == INNER || correlatedJoinNode.getType() == LEFT, "correlation in %s JOIN", correlatedJoinNode.getType().name());
PlanNode subquery = correlatedJoinNode.getSubquery();
PlanNodeDecorrelator planNodeDecorrelator = new PlanNodeDecorrelator(plannerContext, context.getSymbolAllocator(), context.getLookup());
Optional<DecorrelatedNode> decorrelatedNodeOptional = planNodeDecorrelator.decorrelateFilters(subquery, correlatedJoinNode.getCorrelation());
if (decorrelatedNodeOptional.isEmpty()) {
return Result.empty();
}
DecorrelatedNode decorrelatedSubquery = decorrelatedNodeOptional.get();
Expression filter = combineConjuncts(plannerContext.getMetadata(), decorrelatedSubquery.getCorrelatedPredicates().orElse(TRUE_LITERAL), correlatedJoinNode.getFilter());
return Result.ofPlanNode(new JoinNode(correlatedJoinNode.getId(), correlatedJoinNode.getType().toJoinNodeType(), correlatedJoinNode.getInput(), decorrelatedSubquery.getNode(), ImmutableList.of(), correlatedJoinNode.getInput().getOutputSymbols(), correlatedJoinNode.getSubquery().getOutputSymbols(), false, filter.equals(TRUE_LITERAL) ? Optional.empty() : Optional.of(filter), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), ImmutableMap.of(), Optional.empty()));
}
Aggregations