use of io.trino.sql.planner.plan.PlanNode in project trino by trinodb.
the class TransformCorrelatedGlobalAggregationWithProjection method apply.
@Override
public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Context context) {
checkArgument(correlatedJoinNode.getType() == INNER || correlatedJoinNode.getType() == LEFT, "unexpected correlated join type: " + correlatedJoinNode.getType());
// if there is another aggregation below the AggregationNode, handle both
PlanNode source = captures.get(SOURCE);
AggregationNode distinct = null;
if (isDistinctOperator(source)) {
distinct = (AggregationNode) source;
source = distinct.getSource();
}
// decorrelate nested plan
PlanNodeDecorrelator decorrelator = new PlanNodeDecorrelator(plannerContext, context.getSymbolAllocator(), context.getLookup());
Optional<PlanNodeDecorrelator.DecorrelatedNode> decorrelatedSource = decorrelator.decorrelateFilters(source, correlatedJoinNode.getCorrelation());
if (decorrelatedSource.isEmpty()) {
return Result.empty();
}
source = decorrelatedSource.get().getNode();
// append non-null symbol on nested plan. It will be used to restore semantics of null-sensitive aggregations after LEFT join
Symbol nonNull = context.getSymbolAllocator().newSymbol("non_null", BOOLEAN);
source = new ProjectNode(context.getIdAllocator().getNextId(), source, Assignments.builder().putIdentities(source.getOutputSymbols()).put(nonNull, TRUE_LITERAL).build());
// assign unique id on correlated join's input. It will be used to distinguish between original input rows after join
PlanNode inputWithUniqueId = new AssignUniqueId(context.getIdAllocator().getNextId(), correlatedJoinNode.getInput(), context.getSymbolAllocator().newSymbol("unique", BIGINT));
JoinNode join = new JoinNode(context.getIdAllocator().getNextId(), JoinNode.Type.LEFT, inputWithUniqueId, source, ImmutableList.of(), inputWithUniqueId.getOutputSymbols(), source.getOutputSymbols(), false, decorrelatedSource.get().getCorrelatedPredicates(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), ImmutableMap.of(), Optional.empty());
PlanNode root = join;
// restore distinct aggregation
if (distinct != null) {
root = new AggregationNode(distinct.getId(), join, distinct.getAggregations(), singleGroupingSet(ImmutableList.<Symbol>builder().addAll(join.getLeftOutputSymbols()).add(nonNull).addAll(distinct.getGroupingKeys()).build()), ImmutableList.of(), distinct.getStep(), Optional.empty(), Optional.empty());
}
// prepare mask symbols for aggregations
// Every original aggregation agg() will be rewritten to agg() mask(non_null). If the aggregation
// already has a mask, it will be replaced with conjunction of the existing mask and non_null.
// This is necessary to restore the original aggregation result in case when:
// - the nested lateral subquery returned empty result for some input row,
// - aggregation is null-sensitive, which means that its result over a single null row is different
// than result for empty input (with global grouping)
// It applies to the following aggregate functions: count(*), checksum(), array_agg().
AggregationNode globalAggregation = captures.get(AGGREGATION);
ImmutableMap.Builder<Symbol, Symbol> masks = ImmutableMap.builder();
Assignments.Builder assignmentsBuilder = Assignments.builder();
for (Map.Entry<Symbol, Aggregation> entry : globalAggregation.getAggregations().entrySet()) {
Aggregation aggregation = entry.getValue();
if (aggregation.getMask().isPresent()) {
Symbol newMask = context.getSymbolAllocator().newSymbol("mask", BOOLEAN);
Expression expression = and(aggregation.getMask().get().toSymbolReference(), nonNull.toSymbolReference());
assignmentsBuilder.put(newMask, expression);
masks.put(entry.getKey(), newMask);
} else {
masks.put(entry.getKey(), nonNull);
}
}
Assignments maskAssignments = assignmentsBuilder.build();
if (!maskAssignments.isEmpty()) {
root = new ProjectNode(context.getIdAllocator().getNextId(), root, Assignments.builder().putIdentities(root.getOutputSymbols()).putAll(maskAssignments).build());
}
// restore global aggregation
globalAggregation = new AggregationNode(globalAggregation.getId(), root, rewriteWithMasks(globalAggregation.getAggregations(), masks.buildOrThrow()), singleGroupingSet(ImmutableList.<Symbol>builder().addAll(join.getLeftOutputSymbols()).addAll(globalAggregation.getGroupingKeys()).build()), ImmutableList.of(), globalAggregation.getStep(), Optional.empty(), Optional.empty());
// restrict outputs and apply projection
Set<Symbol> outputSymbols = new HashSet<>(correlatedJoinNode.getOutputSymbols());
List<Symbol> expectedAggregationOutputs = globalAggregation.getOutputSymbols().stream().filter(outputSymbols::contains).collect(toImmutableList());
Assignments assignments = Assignments.builder().putIdentities(expectedAggregationOutputs).putAll(captures.get(PROJECTION).getAssignments()).build();
return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), globalAggregation, assignments));
}
use of io.trino.sql.planner.plan.PlanNode in project trino by trinodb.
the class PushProjectionThroughJoin method inlineProjections.
private static PlanNode inlineProjections(PlannerContext plannerContext, ProjectNode parentProjection, Lookup lookup, Session session, TypeAnalyzer typeAnalyzer, TypeProvider types) {
PlanNode child = lookup.resolve(parentProjection.getSource());
if (!(child instanceof ProjectNode)) {
return parentProjection;
}
ProjectNode childProjection = (ProjectNode) child;
return InlineProjections.inlineProjections(plannerContext, parentProjection, childProjection, session, typeAnalyzer, types).map(node -> inlineProjections(plannerContext, node, lookup, session, typeAnalyzer, types)).orElse(parentProjection);
}
use of io.trino.sql.planner.plan.PlanNode in project trino by trinodb.
the class PushProjectionThroughJoin method pushProjectionThroughJoin.
public static Optional<PlanNode> pushProjectionThroughJoin(PlannerContext plannerContext, ProjectNode projectNode, Lookup lookup, PlanNodeIdAllocator planNodeIdAllocator, Session session, TypeAnalyzer typeAnalyzer, TypeProvider types) {
if (!projectNode.getAssignments().getExpressions().stream().allMatch(expression -> isDeterministic(expression, plannerContext.getMetadata()))) {
return Optional.empty();
}
PlanNode child = lookup.resolve(projectNode.getSource());
if (!(child instanceof JoinNode)) {
return Optional.empty();
}
JoinNode joinNode = (JoinNode) child;
PlanNode leftChild = joinNode.getLeft();
PlanNode rightChild = joinNode.getRight();
if (joinNode.getType() != INNER) {
return Optional.empty();
}
Assignments.Builder leftAssignmentsBuilder = Assignments.builder();
Assignments.Builder rightAssignmentsBuilder = Assignments.builder();
for (Map.Entry<Symbol, Expression> assignment : projectNode.getAssignments().entrySet()) {
Expression expression = assignment.getValue();
Set<Symbol> symbols = extractUnique(expression);
if (leftChild.getOutputSymbols().containsAll(symbols)) {
// expression is satisfied with left child symbols
leftAssignmentsBuilder.put(assignment.getKey(), expression);
} else if (rightChild.getOutputSymbols().containsAll(symbols)) {
// expression is satisfied with right child symbols
rightAssignmentsBuilder.put(assignment.getKey(), expression);
} else {
// expression is using symbols from both join sides
return Optional.empty();
}
}
// add projections for symbols required by the join itself
Set<Symbol> joinRequiredSymbols = getJoinRequiredSymbols(joinNode);
for (Symbol requiredSymbol : joinRequiredSymbols) {
if (leftChild.getOutputSymbols().contains(requiredSymbol)) {
leftAssignmentsBuilder.putIdentity(requiredSymbol);
} else {
checkState(rightChild.getOutputSymbols().contains(requiredSymbol));
rightAssignmentsBuilder.putIdentity(requiredSymbol);
}
}
Assignments leftAssignments = leftAssignmentsBuilder.build();
Assignments rightAssignments = rightAssignmentsBuilder.build();
List<Symbol> leftOutputSymbols = leftAssignments.getOutputs().stream().filter(ImmutableSet.copyOf(projectNode.getOutputSymbols())::contains).collect(toImmutableList());
List<Symbol> rightOutputSymbols = rightAssignments.getOutputs().stream().filter(ImmutableSet.copyOf(projectNode.getOutputSymbols())::contains).collect(toImmutableList());
return Optional.of(new JoinNode(joinNode.getId(), joinNode.getType(), inlineProjections(plannerContext, new ProjectNode(planNodeIdAllocator.getNextId(), leftChild, leftAssignments), lookup, session, typeAnalyzer, types), inlineProjections(plannerContext, new ProjectNode(planNodeIdAllocator.getNextId(), rightChild, rightAssignments), lookup, session, typeAnalyzer, types), joinNode.getCriteria(), leftOutputSymbols, rightOutputSymbols, joinNode.isMaySkipOutputDuplicates(), joinNode.getFilter(), joinNode.getLeftHashSymbol(), joinNode.getRightHashSymbol(), joinNode.getDistributionType(), joinNode.isSpillable(), joinNode.getDynamicFilters(), joinNode.getReorderJoinStatsAndCost()));
}
use of io.trino.sql.planner.plan.PlanNode in project trino by trinodb.
the class PushProjectionThroughUnion method apply.
@Override
public Result apply(ProjectNode parent, Captures captures, Context context) {
UnionNode source = captures.get(CHILD);
// OutputLayout of the resultant Union, will be same as the layout of the Project
List<Symbol> outputLayout = parent.getOutputSymbols();
// Mapping from the output symbol to ordered list of symbols from each of the sources
ImmutableListMultimap.Builder<Symbol, Symbol> mappings = ImmutableListMultimap.builder();
// sources for the resultant UnionNode
ImmutableList.Builder<PlanNode> outputSources = ImmutableList.builder();
for (int i = 0; i < source.getSources().size(); i++) {
// Map: output of union -> input of this source to the union
Map<Symbol, SymbolReference> outputToInput = source.sourceSymbolMap(i);
// assignments for the new ProjectNode
Assignments.Builder assignments = Assignments.builder();
// mapping from current ProjectNode to new ProjectNode, used to identify the output layout
Map<Symbol, Symbol> projectSymbolMapping = new HashMap<>();
// Translate the assignments in the ProjectNode using symbols of the source of the UnionNode
for (Map.Entry<Symbol, Expression> entry : parent.getAssignments().entrySet()) {
Expression translatedExpression = inlineSymbols(outputToInput, entry.getValue());
Type type = context.getSymbolAllocator().getTypes().get(entry.getKey());
Symbol symbol = context.getSymbolAllocator().newSymbol(translatedExpression, type);
assignments.put(symbol, translatedExpression);
projectSymbolMapping.put(entry.getKey(), symbol);
}
outputSources.add(new ProjectNode(context.getIdAllocator().getNextId(), source.getSources().get(i), assignments.build()));
outputLayout.forEach(symbol -> mappings.put(symbol, projectSymbolMapping.get(symbol)));
}
return Result.ofPlanNode(new UnionNode(parent.getId(), outputSources.build(), mappings.build(), ImmutableList.copyOf(mappings.build().keySet())));
}
use of io.trino.sql.planner.plan.PlanNode in project trino by trinodb.
the class ReplaceRedundantJoinWithSource method apply.
@Override
public Result apply(JoinNode node, Captures captures, Context context) {
if (isAtMost(node.getLeft(), context.getLookup(), 0) || isAtMost(node.getRight(), context.getLookup(), 0)) {
return Result.empty();
}
boolean leftSourceScalarWithNoOutputs = node.getLeft().getOutputSymbols().isEmpty() && isScalar(node.getLeft(), context.getLookup());
boolean rightSourceScalarWithNoOutputs = node.getRight().getOutputSymbols().isEmpty() && isScalar(node.getRight(), context.getLookup());
switch(node.getType()) {
case INNER:
PlanNode source;
List<Symbol> sourceOutputs;
if (leftSourceScalarWithNoOutputs) {
source = node.getRight();
sourceOutputs = node.getRightOutputSymbols();
} else if (rightSourceScalarWithNoOutputs) {
source = node.getLeft();
sourceOutputs = node.getLeftOutputSymbols();
} else {
return Result.empty();
}
if (node.getFilter().isPresent()) {
source = new FilterNode(context.getIdAllocator().getNextId(), source, node.getFilter().get());
}
return Result.ofPlanNode(restrictOutputs(context.getIdAllocator(), source, ImmutableSet.copyOf(sourceOutputs)).orElse(source));
case LEFT:
if (rightSourceScalarWithNoOutputs) {
return Result.ofPlanNode(restrictOutputs(context.getIdAllocator(), node.getLeft(), ImmutableSet.copyOf(node.getLeftOutputSymbols())).orElse(node.getLeft()));
}
break;
case RIGHT:
if (leftSourceScalarWithNoOutputs) {
return Result.ofPlanNode(restrictOutputs(context.getIdAllocator(), node.getRight(), ImmutableSet.copyOf(node.getRightOutputSymbols())).orElse(node.getRight()));
}
break;
case FULL:
if (leftSourceScalarWithNoOutputs && isAtLeastScalar(node.getRight(), context.getLookup())) {
return Result.ofPlanNode(restrictOutputs(context.getIdAllocator(), node.getRight(), ImmutableSet.copyOf(node.getRightOutputSymbols())).orElse(node.getRight()));
}
if (rightSourceScalarWithNoOutputs && isAtLeastScalar(node.getLeft(), context.getLookup())) {
return Result.ofPlanNode(restrictOutputs(context.getIdAllocator(), node.getLeft(), ImmutableSet.copyOf(node.getLeftOutputSymbols())).orElse(node.getLeft()));
}
}
return Result.empty();
}
Aggregations