use of io.trino.sql.planner.plan.Assignments in project trino by trinodb.
the class DecorrelateUnnest method apply.
@Override
public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Context context) {
// determine shape of the subquery
PlanNode searchRoot = correlatedJoinNode.getSubquery();
// 1. find EnforceSingleRowNode in the subquery
Optional<EnforceSingleRowNode> enforceSingleRow = PlanNodeSearcher.searchFrom(searchRoot, context.getLookup()).where(EnforceSingleRowNode.class::isInstance).recurseOnlyWhen(planNode -> false).findFirst();
if (enforceSingleRow.isPresent()) {
searchRoot = enforceSingleRow.get().getSource();
}
// 2. find correlated UnnestNode in the subquery
Optional<UnnestNode> subqueryUnnest = PlanNodeSearcher.searchFrom(searchRoot, context.getLookup()).where(node -> isSupportedUnnest(node, correlatedJoinNode.getCorrelation(), context.getLookup())).recurseOnlyWhen(node -> node instanceof ProjectNode || (node instanceof LimitNode && ((LimitNode) node).getCount() > 0) || (node instanceof TopNNode && ((TopNNode) node).getCount() > 0)).findFirst();
if (subqueryUnnest.isEmpty()) {
return Result.empty();
}
UnnestNode unnestNode = subqueryUnnest.get();
// assign unique id to input rows
Symbol uniqueSymbol = context.getSymbolAllocator().newSymbol("unique", BIGINT);
PlanNode input = new AssignUniqueId(context.getIdAllocator().getNextId(), correlatedJoinNode.getInput(), uniqueSymbol);
// pre-project unnest symbols if they were pre-projected in subquery
// The correlated UnnestNode either unnests correlation symbols directly, or unnests symbols produced by a projection that uses only correlation symbols.
// Here, any underlying projection that was a source of the correlated UnnestNode, is appended as a source of the rewritten UnnestNode.
// If the projection is not necessary for UnnestNode (i.e. it does not produce any unnest symbols), it should be pruned afterwards.
PlanNode unnestSource = context.getLookup().resolve(unnestNode.getSource());
if (unnestSource instanceof ProjectNode) {
ProjectNode sourceProjection = (ProjectNode) unnestSource;
input = new ProjectNode(sourceProjection.getId(), input, Assignments.builder().putIdentities(input.getOutputSymbols()).putAll(sourceProjection.getAssignments()).build());
}
// determine join type for rewritten UnnestNode
Type unnestJoinType = LEFT;
if (enforceSingleRow.isEmpty() && correlatedJoinNode.getType() == CorrelatedJoinNode.Type.INNER && unnestNode.getJoinType() == INNER) {
unnestJoinType = INNER;
}
// make sure that the rewritten node is with ordinality, which might be necessary to restore inner unnest semantics after rewrite.
Symbol ordinalitySymbol = unnestNode.getOrdinalitySymbol().orElseGet(() -> context.getSymbolAllocator().newSymbol("ordinality", BIGINT));
// rewrite correlated join to UnnestNode.
UnnestNode rewrittenUnnest = new UnnestNode(context.getIdAllocator().getNextId(), input, input.getOutputSymbols(), unnestNode.getMappings(), Optional.of(ordinalitySymbol), unnestJoinType, Optional.empty());
// restore all nodes from the subquery
PlanNode rewrittenPlan = Rewriter.rewriteNodeSequence(correlatedJoinNode.getSubquery(), input.getOutputSymbols(), ordinalitySymbol, uniqueSymbol, rewrittenUnnest, context.getSession(), metadata, context.getLookup(), context.getIdAllocator(), context.getSymbolAllocator());
// between unnested rows and synthetic rows added by left unnest.
if (unnestNode.getJoinType() == INNER && rewrittenUnnest.getJoinType() == LEFT) {
Assignments.Builder assignments = Assignments.builder().putIdentities(correlatedJoinNode.getInput().getOutputSymbols());
for (Symbol subquerySymbol : correlatedJoinNode.getSubquery().getOutputSymbols()) {
assignments.put(subquerySymbol, new IfExpression(new IsNullPredicate(ordinalitySymbol.toSymbolReference()), new Cast(new NullLiteral(), toSqlType(context.getSymbolAllocator().getTypes().get(subquerySymbol))), subquerySymbol.toSymbolReference()));
}
rewrittenPlan = new ProjectNode(context.getIdAllocator().getNextId(), rewrittenPlan, assignments.build());
}
// restrict outputs
return Result.ofPlanNode(restrictOutputs(context.getIdAllocator(), rewrittenPlan, ImmutableSet.copyOf(correlatedJoinNode.getOutputSymbols())).orElse(rewrittenPlan));
}
use of io.trino.sql.planner.plan.Assignments in project trino by trinodb.
the class SubqueryPlanner method coerceIfNecessary.
private PlanAndMappings coerceIfNecessary(PlanBuilder subPlan, Symbol symbol, Expression value, Type type, Optional<? extends Type> coercion) {
Symbol coerced = symbol;
if (coercion.isPresent()) {
coerced = symbolAllocator.newSymbol(value, coercion.get());
Assignments assignments = Assignments.builder().putIdentities(subPlan.getRoot().getOutputSymbols()).put(coerced, new Cast(symbol.toSymbolReference(), toSqlType(coercion.get()), false, typeCoercion.isTypeOnlyCoercion(type, coercion.get()))).build();
subPlan = subPlan.withNewRoot(new ProjectNode(idAllocator.getNextId(), subPlan.getRoot(), assignments));
}
return new PlanAndMappings(subPlan, ImmutableMap.of(NodeRef.of(value), coerced));
}
use of io.trino.sql.planner.plan.Assignments in project trino by trinodb.
the class PushAggregationIntoTableScan method pushAggregationIntoTableScan.
public static Optional<PlanNode> pushAggregationIntoTableScan(PlannerContext plannerContext, Context context, PlanNode aggregationNode, TableScanNode tableScan, Map<Symbol, AggregationNode.Aggregation> aggregations, List<Symbol> groupingKeys) {
Map<String, ColumnHandle> assignments = tableScan.getAssignments().entrySet().stream().collect(toImmutableMap(entry -> entry.getKey().getName(), Entry::getValue));
List<Entry<Symbol, AggregationNode.Aggregation>> aggregationsList = aggregations.entrySet().stream().collect(toImmutableList());
List<AggregateFunction> aggregateFunctions = aggregationsList.stream().map(Entry::getValue).map(aggregation -> toAggregateFunction(plannerContext.getMetadata(), context, aggregation)).collect(toImmutableList());
List<Symbol> aggregationOutputSymbols = aggregationsList.stream().map(Entry::getKey).collect(toImmutableList());
List<ColumnHandle> groupByColumns = groupingKeys.stream().map(groupByColumn -> assignments.get(groupByColumn.getName())).collect(toImmutableList());
Optional<AggregationApplicationResult<TableHandle>> aggregationPushdownResult = plannerContext.getMetadata().applyAggregation(context.getSession(), tableScan.getTable(), aggregateFunctions, assignments, ImmutableList.of(groupByColumns));
if (aggregationPushdownResult.isEmpty()) {
return Optional.empty();
}
AggregationApplicationResult<TableHandle> result = aggregationPushdownResult.get();
// The new scan outputs should be the symbols associated with grouping columns plus the symbols associated with aggregations.
ImmutableList.Builder<Symbol> newScanOutputs = ImmutableList.builder();
newScanOutputs.addAll(tableScan.getOutputSymbols());
ImmutableBiMap.Builder<Symbol, ColumnHandle> newScanAssignments = ImmutableBiMap.builder();
newScanAssignments.putAll(tableScan.getAssignments());
Map<String, Symbol> variableMappings = new HashMap<>();
for (Assignment assignment : result.getAssignments()) {
Symbol symbol = context.getSymbolAllocator().newSymbol(assignment.getVariable(), assignment.getType());
newScanOutputs.add(symbol);
newScanAssignments.put(symbol, assignment.getColumn());
variableMappings.put(assignment.getVariable(), symbol);
}
List<Expression> newProjections = result.getProjections().stream().map(expression -> ConnectorExpressionTranslator.translate(context.getSession(), expression, plannerContext, variableMappings, new LiteralEncoder(plannerContext))).collect(toImmutableList());
verify(aggregationOutputSymbols.size() == newProjections.size());
Assignments.Builder assignmentBuilder = Assignments.builder();
IntStream.range(0, aggregationOutputSymbols.size()).forEach(index -> assignmentBuilder.put(aggregationOutputSymbols.get(index), newProjections.get(index)));
ImmutableBiMap<Symbol, ColumnHandle> scanAssignments = newScanAssignments.build();
ImmutableBiMap<ColumnHandle, Symbol> columnHandleToSymbol = scanAssignments.inverse();
// projections assignmentBuilder should have both agg and group by so we add all the group bys as symbol references
groupingKeys.forEach(groupBySymbol -> {
// if the connector returned a new mapping from oldColumnHandle to newColumnHandle, groupBy needs to point to
// new columnHandle's symbol reference, otherwise it will continue pointing at oldColumnHandle.
ColumnHandle originalColumnHandle = assignments.get(groupBySymbol.getName());
ColumnHandle groupByColumnHandle = result.getGroupingColumnMapping().getOrDefault(originalColumnHandle, originalColumnHandle);
assignmentBuilder.put(groupBySymbol, columnHandleToSymbol.get(groupByColumnHandle).toSymbolReference());
});
return Optional.of(new ProjectNode(context.getIdAllocator().getNextId(), new TableScanNode(context.getIdAllocator().getNextId(), result.getHandle(), newScanOutputs.build(), scanAssignments, TupleDomain.all(), deriveTableStatisticsForPushdown(context.getStatsProvider(), context.getSession(), result.isPrecalculateStatistics(), aggregationNode), tableScan.isUpdateTarget(), // table scan partitioning might have changed with new table handle
Optional.empty()), assignmentBuilder.build()));
}
use of io.trino.sql.planner.plan.Assignments in project trino by trinodb.
the class PushDownDereferenceThroughJoin method apply.
@Override
public Result apply(ProjectNode projectNode, Captures captures, Context context) {
JoinNode joinNode = captures.get(CHILD);
// Consider dereferences in projections and join filter for pushdown
ImmutableList.Builder<Expression> expressionsBuilder = ImmutableList.builder();
expressionsBuilder.addAll(projectNode.getAssignments().getExpressions());
joinNode.getFilter().ifPresent(expressionsBuilder::add);
Set<SubscriptExpression> dereferences = extractRowSubscripts(expressionsBuilder.build(), false, context.getSession(), typeAnalyzer, context.getSymbolAllocator().getTypes());
// Exclude criteria symbols
ImmutableSet.Builder<Symbol> criteriaSymbolsBuilder = ImmutableSet.builder();
joinNode.getCriteria().forEach(criteria -> {
criteriaSymbolsBuilder.add(criteria.getLeft());
criteriaSymbolsBuilder.add(criteria.getRight());
});
Set<Symbol> excludeSymbols = criteriaSymbolsBuilder.build();
dereferences = dereferences.stream().filter(expression -> !excludeSymbols.contains(getBase(expression))).collect(toImmutableSet());
if (dereferences.isEmpty()) {
return Result.empty();
}
// Create new symbols for dereference expressions
Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSession(), context.getSymbolAllocator(), typeAnalyzer);
// Rewrite project node assignments using new symbols for dereference expressions
Map<Expression, SymbolReference> mappings = HashBiMap.create(dereferenceAssignments.getMap()).inverse().entrySet().stream().collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference()));
Assignments newAssignments = projectNode.getAssignments().rewrite(expression -> replaceExpression(expression, mappings));
Assignments.Builder leftAssignmentsBuilder = Assignments.builder();
Assignments.Builder rightAssignmentsBuilder = Assignments.builder();
// Separate dereferences coming from left and right nodes
dereferenceAssignments.entrySet().stream().forEach(entry -> {
Symbol baseSymbol = getOnlyElement(extractAll(entry.getValue()));
if (joinNode.getLeft().getOutputSymbols().contains(baseSymbol)) {
leftAssignmentsBuilder.put(entry.getKey(), entry.getValue());
} else if (joinNode.getRight().getOutputSymbols().contains(baseSymbol)) {
rightAssignmentsBuilder.put(entry.getKey(), entry.getValue());
} else {
throw new IllegalArgumentException(format("Unexpected symbol %s in projectNode", baseSymbol));
}
});
Assignments leftAssignments = leftAssignmentsBuilder.build();
Assignments rightAssignments = rightAssignmentsBuilder.build();
PlanNode leftNode = createProjectNodeIfRequired(joinNode.getLeft(), leftAssignments, context.getIdAllocator());
PlanNode rightNode = createProjectNodeIfRequired(joinNode.getRight(), rightAssignments, context.getIdAllocator());
// Prepare new output symbols for join node
List<Symbol> referredSymbolsInAssignments = newAssignments.getExpressions().stream().flatMap(expression -> extractAll(expression).stream()).collect(toList());
List<Symbol> newLeftOutputSymbols = referredSymbolsInAssignments.stream().filter(symbol -> leftNode.getOutputSymbols().contains(symbol)).collect(toList());
List<Symbol> newRightOutputSymbols = referredSymbolsInAssignments.stream().filter(symbol -> rightNode.getOutputSymbols().contains(symbol)).collect(toList());
JoinNode newJoinNode = new JoinNode(context.getIdAllocator().getNextId(), joinNode.getType(), leftNode, rightNode, joinNode.getCriteria(), newLeftOutputSymbols, newRightOutputSymbols, joinNode.isMaySkipOutputDuplicates(), // Use newly created symbols in filter
joinNode.getFilter().map(expression -> replaceExpression(expression, mappings)), joinNode.getLeftHashSymbol(), joinNode.getRightHashSymbol(), joinNode.getDistributionType(), joinNode.isSpillable(), joinNode.getDynamicFilters(), joinNode.getReorderJoinStatsAndCost());
return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), newJoinNode, newAssignments));
}
use of io.trino.sql.planner.plan.Assignments in project trino by trinodb.
the class InlineProjections method inlineProjections.
static Optional<ProjectNode> inlineProjections(PlannerContext plannerContext, ProjectNode parent, ProjectNode child, Session session, TypeAnalyzer typeAnalyzer, TypeProvider types) {
// squash identity projections
if (parent.isIdentity() && child.isIdentity()) {
return Optional.of((ProjectNode) parent.replaceChildren(ImmutableList.of(child.getSource())));
}
Set<Symbol> targets = extractInliningTargets(plannerContext, parent, child, session, typeAnalyzer, types);
if (targets.isEmpty()) {
return Optional.empty();
}
// inline the expressions
Assignments assignments = child.getAssignments().filter(targets::contains);
Map<Symbol, Expression> parentAssignments = parent.getAssignments().entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> inlineReferences(entry.getValue(), assignments)));
// Synthesize identity assignments for the inputs of expressions that were inlined
// to place in the child projection.
Set<Symbol> inputs = child.getAssignments().entrySet().stream().filter(entry -> targets.contains(entry.getKey())).map(Map.Entry::getValue).flatMap(entry -> SymbolsExtractor.extractAll(entry).stream()).collect(toSet());
Assignments.Builder newChildAssignmentsBuilder = Assignments.builder();
for (Map.Entry<Symbol, Expression> assignment : child.getAssignments().entrySet()) {
if (!targets.contains(assignment.getKey())) {
newChildAssignmentsBuilder.put(assignment);
}
}
for (Symbol input : inputs) {
newChildAssignmentsBuilder.putIdentity(input);
}
Assignments newChildAssignments = newChildAssignmentsBuilder.build();
PlanNode newChild;
if (newChildAssignments.isIdentity()) {
newChild = child.getSource();
} else {
newChild = new ProjectNode(child.getId(), child.getSource(), newChildAssignments);
}
return Optional.of(new ProjectNode(parent.getId(), newChild, Assignments.copyOf(parentAssignments)));
}
Aggregations