use of io.trino.sql.planner.optimizations.PlanNodeDecorrelator in project trino by trinodb.
the class TransformCorrelatedDistinctAggregationWithProjection method apply.
@Override
public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Context context) {
// decorrelate nested plan
PlanNodeDecorrelator decorrelator = new PlanNodeDecorrelator(plannerContext, context.getSymbolAllocator(), context.getLookup());
Optional<PlanNodeDecorrelator.DecorrelatedNode> decorrelatedSource = decorrelator.decorrelateFilters(captures.get(AGGREGATION).getSource(), correlatedJoinNode.getCorrelation());
if (decorrelatedSource.isEmpty()) {
return Result.empty();
}
PlanNode source = decorrelatedSource.get().getNode();
// assign unique id on correlated join's input. It will be used to distinguish between original input rows after join
PlanNode inputWithUniqueId = new AssignUniqueId(context.getIdAllocator().getNextId(), correlatedJoinNode.getInput(), context.getSymbolAllocator().newSymbol("unique", BIGINT));
JoinNode join = new JoinNode(context.getIdAllocator().getNextId(), JoinNode.Type.LEFT, inputWithUniqueId, source, ImmutableList.of(), inputWithUniqueId.getOutputSymbols(), source.getOutputSymbols(), false, decorrelatedSource.get().getCorrelatedPredicates(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), ImmutableMap.of(), Optional.empty());
// restore aggregation
AggregationNode aggregation = captures.get(AGGREGATION);
aggregation = new AggregationNode(aggregation.getId(), join, aggregation.getAggregations(), singleGroupingSet(ImmutableList.<Symbol>builder().addAll(join.getLeftOutputSymbols()).addAll(aggregation.getGroupingKeys()).build()), ImmutableList.of(), aggregation.getStep(), Optional.empty(), Optional.empty());
// restrict outputs and apply projection
Set<Symbol> outputSymbols = new HashSet<>(correlatedJoinNode.getOutputSymbols());
List<Symbol> expectedAggregationOutputs = aggregation.getOutputSymbols().stream().filter(outputSymbols::contains).collect(toImmutableList());
Assignments assignments = Assignments.builder().putIdentities(expectedAggregationOutputs).putAll(captures.get(PROJECTION).getAssignments()).build();
return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), aggregation, assignments));
}
use of io.trino.sql.planner.optimizations.PlanNodeDecorrelator in project trino by trinodb.
the class TransformCorrelatedDistinctAggregationWithoutProjection method apply.
@Override
public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Context context) {
// decorrelate nested plan
PlanNodeDecorrelator decorrelator = new PlanNodeDecorrelator(plannerContext, context.getSymbolAllocator(), context.getLookup());
Optional<PlanNodeDecorrelator.DecorrelatedNode> decorrelatedSource = decorrelator.decorrelateFilters(captures.get(AGGREGATION).getSource(), correlatedJoinNode.getCorrelation());
if (decorrelatedSource.isEmpty()) {
return Result.empty();
}
PlanNode source = decorrelatedSource.get().getNode();
// assign unique id on correlated join's input. It will be used to distinguish between original input rows after join
PlanNode inputWithUniqueId = new AssignUniqueId(context.getIdAllocator().getNextId(), correlatedJoinNode.getInput(), context.getSymbolAllocator().newSymbol("unique", BIGINT));
JoinNode join = new JoinNode(context.getIdAllocator().getNextId(), JoinNode.Type.LEFT, inputWithUniqueId, source, ImmutableList.of(), inputWithUniqueId.getOutputSymbols(), source.getOutputSymbols(), false, decorrelatedSource.get().getCorrelatedPredicates(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), ImmutableMap.of(), Optional.empty());
// restore aggregation
AggregationNode aggregation = captures.get(AGGREGATION);
aggregation = new AggregationNode(aggregation.getId(), join, aggregation.getAggregations(), singleGroupingSet(ImmutableList.<Symbol>builder().addAll(join.getLeftOutputSymbols()).addAll(aggregation.getGroupingKeys()).build()), ImmutableList.of(), aggregation.getStep(), Optional.empty(), Optional.empty());
// restrict outputs
Optional<PlanNode> project = restrictOutputs(context.getIdAllocator(), aggregation, ImmutableSet.copyOf(correlatedJoinNode.getOutputSymbols()));
return Result.ofPlanNode(project.orElse(aggregation));
}
use of io.trino.sql.planner.optimizations.PlanNodeDecorrelator in project trino by trinodb.
the class TransformCorrelatedGlobalAggregationWithProjection method apply.
@Override
public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Context context) {
checkArgument(correlatedJoinNode.getType() == INNER || correlatedJoinNode.getType() == LEFT, "unexpected correlated join type: " + correlatedJoinNode.getType());
// if there is another aggregation below the AggregationNode, handle both
PlanNode source = captures.get(SOURCE);
AggregationNode distinct = null;
if (isDistinctOperator(source)) {
distinct = (AggregationNode) source;
source = distinct.getSource();
}
// decorrelate nested plan
PlanNodeDecorrelator decorrelator = new PlanNodeDecorrelator(plannerContext, context.getSymbolAllocator(), context.getLookup());
Optional<PlanNodeDecorrelator.DecorrelatedNode> decorrelatedSource = decorrelator.decorrelateFilters(source, correlatedJoinNode.getCorrelation());
if (decorrelatedSource.isEmpty()) {
return Result.empty();
}
source = decorrelatedSource.get().getNode();
// append non-null symbol on nested plan. It will be used to restore semantics of null-sensitive aggregations after LEFT join
Symbol nonNull = context.getSymbolAllocator().newSymbol("non_null", BOOLEAN);
source = new ProjectNode(context.getIdAllocator().getNextId(), source, Assignments.builder().putIdentities(source.getOutputSymbols()).put(nonNull, TRUE_LITERAL).build());
// assign unique id on correlated join's input. It will be used to distinguish between original input rows after join
PlanNode inputWithUniqueId = new AssignUniqueId(context.getIdAllocator().getNextId(), correlatedJoinNode.getInput(), context.getSymbolAllocator().newSymbol("unique", BIGINT));
JoinNode join = new JoinNode(context.getIdAllocator().getNextId(), JoinNode.Type.LEFT, inputWithUniqueId, source, ImmutableList.of(), inputWithUniqueId.getOutputSymbols(), source.getOutputSymbols(), false, decorrelatedSource.get().getCorrelatedPredicates(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), ImmutableMap.of(), Optional.empty());
PlanNode root = join;
// restore distinct aggregation
if (distinct != null) {
root = new AggregationNode(distinct.getId(), join, distinct.getAggregations(), singleGroupingSet(ImmutableList.<Symbol>builder().addAll(join.getLeftOutputSymbols()).add(nonNull).addAll(distinct.getGroupingKeys()).build()), ImmutableList.of(), distinct.getStep(), Optional.empty(), Optional.empty());
}
// prepare mask symbols for aggregations
// Every original aggregation agg() will be rewritten to agg() mask(non_null). If the aggregation
// already has a mask, it will be replaced with conjunction of the existing mask and non_null.
// This is necessary to restore the original aggregation result in case when:
// - the nested lateral subquery returned empty result for some input row,
// - aggregation is null-sensitive, which means that its result over a single null row is different
// than result for empty input (with global grouping)
// It applies to the following aggregate functions: count(*), checksum(), array_agg().
AggregationNode globalAggregation = captures.get(AGGREGATION);
ImmutableMap.Builder<Symbol, Symbol> masks = ImmutableMap.builder();
Assignments.Builder assignmentsBuilder = Assignments.builder();
for (Map.Entry<Symbol, Aggregation> entry : globalAggregation.getAggregations().entrySet()) {
Aggregation aggregation = entry.getValue();
if (aggregation.getMask().isPresent()) {
Symbol newMask = context.getSymbolAllocator().newSymbol("mask", BOOLEAN);
Expression expression = and(aggregation.getMask().get().toSymbolReference(), nonNull.toSymbolReference());
assignmentsBuilder.put(newMask, expression);
masks.put(entry.getKey(), newMask);
} else {
masks.put(entry.getKey(), nonNull);
}
}
Assignments maskAssignments = assignmentsBuilder.build();
if (!maskAssignments.isEmpty()) {
root = new ProjectNode(context.getIdAllocator().getNextId(), root, Assignments.builder().putIdentities(root.getOutputSymbols()).putAll(maskAssignments).build());
}
// restore global aggregation
globalAggregation = new AggregationNode(globalAggregation.getId(), root, rewriteWithMasks(globalAggregation.getAggregations(), masks.buildOrThrow()), singleGroupingSet(ImmutableList.<Symbol>builder().addAll(join.getLeftOutputSymbols()).addAll(globalAggregation.getGroupingKeys()).build()), ImmutableList.of(), globalAggregation.getStep(), Optional.empty(), Optional.empty());
// restrict outputs and apply projection
Set<Symbol> outputSymbols = new HashSet<>(correlatedJoinNode.getOutputSymbols());
List<Symbol> expectedAggregationOutputs = globalAggregation.getOutputSymbols().stream().filter(outputSymbols::contains).collect(toImmutableList());
Assignments assignments = Assignments.builder().putIdentities(expectedAggregationOutputs).putAll(captures.get(PROJECTION).getAssignments()).build();
return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), globalAggregation, assignments));
}
use of io.trino.sql.planner.optimizations.PlanNodeDecorrelator in project trino by trinodb.
the class TransformExistsApplyToCorrelatedJoin method rewriteToNonDefaultAggregation.
private Optional<PlanNode> rewriteToNonDefaultAggregation(ApplyNode applyNode, Context context) {
checkState(applyNode.getSubquery().getOutputSymbols().isEmpty(), "Expected subquery output symbols to be pruned");
Symbol subqueryTrue = context.getSymbolAllocator().newSymbol("subqueryTrue", BOOLEAN);
PlanNode subquery = new ProjectNode(context.getIdAllocator().getNextId(), new LimitNode(context.getIdAllocator().getNextId(), applyNode.getSubquery(), 1L, false), Assignments.of(subqueryTrue, TRUE_LITERAL));
PlanNodeDecorrelator decorrelator = new PlanNodeDecorrelator(plannerContext, context.getSymbolAllocator(), context.getLookup());
if (decorrelator.decorrelateFilters(subquery, applyNode.getCorrelation()).isEmpty()) {
return Optional.empty();
}
Symbol exists = getOnlyElement(applyNode.getSubqueryAssignments().getSymbols());
Assignments.Builder assignments = Assignments.builder().putIdentities(applyNode.getInput().getOutputSymbols()).put(exists, new CoalesceExpression(ImmutableList.of(subqueryTrue.toSymbolReference(), BooleanLiteral.FALSE_LITERAL)));
return Optional.of(new ProjectNode(context.getIdAllocator().getNextId(), new CorrelatedJoinNode(applyNode.getId(), applyNode.getInput(), subquery, applyNode.getCorrelation(), LEFT, TRUE_LITERAL, applyNode.getOriginSubquery()), assignments.build()));
}
use of io.trino.sql.planner.optimizations.PlanNodeDecorrelator in project trino by trinodb.
the class TransformCorrelatedGroupedAggregationWithProjection method apply.
@Override
public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Context context) {
// if there is another aggregation below the AggregationNode, handle both
PlanNode source = captures.get(SOURCE);
AggregationNode distinct = null;
if (isDistinctOperator(source)) {
distinct = (AggregationNode) source;
source = distinct.getSource();
}
// decorrelate nested plan
PlanNodeDecorrelator decorrelator = new PlanNodeDecorrelator(plannerContext, context.getSymbolAllocator(), context.getLookup());
Optional<PlanNodeDecorrelator.DecorrelatedNode> decorrelatedSource = decorrelator.decorrelateFilters(source, correlatedJoinNode.getCorrelation());
if (decorrelatedSource.isEmpty()) {
return Result.empty();
}
source = decorrelatedSource.get().getNode();
// assign unique id on correlated join's input. It will be used to distinguish between original input rows after join
PlanNode inputWithUniqueId = new AssignUniqueId(context.getIdAllocator().getNextId(), correlatedJoinNode.getInput(), context.getSymbolAllocator().newSymbol("unique", BIGINT));
JoinNode join = new JoinNode(context.getIdAllocator().getNextId(), JoinNode.Type.INNER, inputWithUniqueId, source, ImmutableList.of(), inputWithUniqueId.getOutputSymbols(), source.getOutputSymbols(), false, decorrelatedSource.get().getCorrelatedPredicates(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), ImmutableMap.of(), Optional.empty());
// restore distinct aggregation
if (distinct != null) {
distinct = new AggregationNode(distinct.getId(), join, distinct.getAggregations(), singleGroupingSet(ImmutableList.<Symbol>builder().addAll(join.getLeftOutputSymbols()).addAll(distinct.getGroupingKeys()).build()), ImmutableList.of(), distinct.getStep(), Optional.empty(), Optional.empty());
}
// restore grouped aggregation
AggregationNode groupedAggregation = captures.get(AGGREGATION);
groupedAggregation = new AggregationNode(groupedAggregation.getId(), distinct != null ? distinct : join, groupedAggregation.getAggregations(), singleGroupingSet(ImmutableList.<Symbol>builder().addAll(join.getLeftOutputSymbols()).addAll(groupedAggregation.getGroupingKeys()).build()), ImmutableList.of(), groupedAggregation.getStep(), Optional.empty(), Optional.empty());
// restrict outputs and apply projection
Set<Symbol> outputSymbols = new HashSet<>(correlatedJoinNode.getOutputSymbols());
List<Symbol> expectedAggregationOutputs = groupedAggregation.getOutputSymbols().stream().filter(outputSymbols::contains).collect(toImmutableList());
Assignments assignments = Assignments.builder().putIdentities(expectedAggregationOutputs).putAll(captures.get(PROJECTION).getAssignments()).build();
return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), groupedAggregation, assignments));
}
Aggregations