use of io.trino.sql.planner.plan.MarkDistinctNode in project trino by trinodb.
the class MultipleDistinctAggregationToMarkDistinct method apply.
@Override
public Result apply(AggregationNode parent, Captures captures, Context context) {
if (!SystemSessionProperties.useMarkDistinct(context.getSession())) {
return Result.empty();
}
// the distinct marker for the given set of input columns
Map<Set<Symbol>, Symbol> markers = new HashMap<>();
Map<Symbol, Aggregation> newAggregations = new HashMap<>();
PlanNode subPlan = parent.getSource();
for (Map.Entry<Symbol, Aggregation> entry : parent.getAggregations().entrySet()) {
Aggregation aggregation = entry.getValue();
if (aggregation.isDistinct() && aggregation.getFilter().isEmpty() && aggregation.getMask().isEmpty()) {
Set<Symbol> inputs = aggregation.getArguments().stream().map(Symbol::from).collect(toSet());
Symbol marker = markers.get(inputs);
if (marker == null) {
marker = context.getSymbolAllocator().newSymbol(Iterables.getLast(inputs).getName(), BOOLEAN, "distinct");
markers.put(inputs, marker);
ImmutableSet.Builder<Symbol> distinctSymbols = ImmutableSet.<Symbol>builder().addAll(parent.getGroupingKeys()).addAll(inputs);
parent.getGroupIdSymbol().ifPresent(distinctSymbols::add);
subPlan = new MarkDistinctNode(context.getIdAllocator().getNextId(), subPlan, marker, ImmutableList.copyOf(distinctSymbols.build()), Optional.empty());
}
// remove the distinct flag and set the distinct marker
newAggregations.put(entry.getKey(), new Aggregation(aggregation.getResolvedFunction(), aggregation.getArguments(), false, aggregation.getFilter(), aggregation.getOrderingScheme(), Optional.of(marker)));
} else {
newAggregations.put(entry.getKey(), aggregation);
}
}
return Result.ofPlanNode(new AggregationNode(parent.getId(), subPlan, newAggregations, parent.getGroupingSets(), ImmutableList.of(), parent.getStep(), parent.getHashSymbol(), parent.getGroupIdSymbol()));
}
use of io.trino.sql.planner.plan.MarkDistinctNode in project trino by trinodb.
the class TransformCorrelatedScalarSubquery method apply.
@Override
public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Context context) {
PlanNode subquery = context.getLookup().resolve(correlatedJoinNode.getSubquery());
if (!searchFrom(subquery, context.getLookup()).where(EnforceSingleRowNode.class::isInstance).recurseOnlyWhen(ProjectNode.class::isInstance).matches()) {
return Result.empty();
}
PlanNode rewrittenSubquery = searchFrom(subquery, context.getLookup()).where(EnforceSingleRowNode.class::isInstance).recurseOnlyWhen(ProjectNode.class::isInstance).removeFirst();
Range<Long> subqueryCardinality = extractCardinality(rewrittenSubquery, context.getLookup());
boolean producesAtMostOneRow = Range.closed(0L, 1L).encloses(subqueryCardinality);
if (producesAtMostOneRow) {
boolean producesSingleRow = Range.singleton(1L).encloses(subqueryCardinality);
return Result.ofPlanNode(new CorrelatedJoinNode(context.getIdAllocator().getNextId(), correlatedJoinNode.getInput(), rewrittenSubquery, correlatedJoinNode.getCorrelation(), producesSingleRow ? correlatedJoinNode.getType() : LEFT, correlatedJoinNode.getFilter(), correlatedJoinNode.getOriginSubquery()));
}
Symbol unique = context.getSymbolAllocator().newSymbol("unique", BigintType.BIGINT);
CorrelatedJoinNode rewrittenCorrelatedJoinNode = new CorrelatedJoinNode(context.getIdAllocator().getNextId(), new AssignUniqueId(context.getIdAllocator().getNextId(), correlatedJoinNode.getInput(), unique), rewrittenSubquery, correlatedJoinNode.getCorrelation(), LEFT, correlatedJoinNode.getFilter(), correlatedJoinNode.getOriginSubquery());
Symbol isDistinct = context.getSymbolAllocator().newSymbol("is_distinct", BOOLEAN);
MarkDistinctNode markDistinctNode = new MarkDistinctNode(context.getIdAllocator().getNextId(), rewrittenCorrelatedJoinNode, isDistinct, rewrittenCorrelatedJoinNode.getInput().getOutputSymbols(), Optional.empty());
FilterNode filterNode = new FilterNode(context.getIdAllocator().getNextId(), markDistinctNode, new SimpleCaseExpression(isDistinct.toSymbolReference(), ImmutableList.of(new WhenClause(TRUE_LITERAL, TRUE_LITERAL)), Optional.of(new Cast(FunctionCallBuilder.resolve(context.getSession(), metadata).setName(QualifiedName.of("fail")).addArgument(INTEGER, new LongLiteral(Integer.toString(SUBQUERY_MULTIPLE_ROWS.toErrorCode().getCode()))).addArgument(VARCHAR, new StringLiteral("Scalar sub-query has returned multiple rows")).build(), toSqlType(BOOLEAN)))));
return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), filterNode, Assignments.identity(correlatedJoinNode.getOutputSymbols())));
}
use of io.trino.sql.planner.plan.MarkDistinctNode in project trino by trinodb.
the class MarkDistinctMatcher method detailMatches.
@Override
public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session session, Metadata metadata, SymbolAliases symbolAliases) {
checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName());
MarkDistinctNode markDistinctNode = (MarkDistinctNode) node;
if (!markDistinctNode.getHashSymbol().equals(hashSymbol.map(alias -> alias.toSymbol(symbolAliases)))) {
return NO_MATCH;
}
if (!ImmutableSet.copyOf(markDistinctNode.getDistinctSymbols()).equals(distinctSymbols.stream().map(alias -> alias.toSymbol(symbolAliases)).collect(toImmutableSet()))) {
return NO_MATCH;
}
return match(markerSymbol.toString(), markDistinctNode.getMarkerSymbol().toSymbolReference());
}
use of io.trino.sql.planner.plan.MarkDistinctNode in project trino by trinodb.
the class PushDownDereferencesThroughMarkDistinct method apply.
@Override
public Result apply(ProjectNode projectNode, Captures captures, Context context) {
MarkDistinctNode markDistinctNode = captures.get(CHILD);
// Extract dereferences from project node assignments for pushdown
Set<SubscriptExpression> dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false, context.getSession(), typeAnalyzer, context.getSymbolAllocator().getTypes());
// Exclude dereferences on distinct symbols being used in markDistinctNode. We do not need to filter
// dereferences on markerSymbol since it is supposed to be of boolean type.
dereferences = dereferences.stream().filter(expression -> !markDistinctNode.getDistinctSymbols().contains(getBase(expression))).collect(toImmutableSet());
if (dereferences.isEmpty()) {
return Result.empty();
}
// Create new symbols for dereference expressions
Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSession(), context.getSymbolAllocator(), typeAnalyzer);
// Rewrite project node assignments using new symbols for dereference expressions
Map<Expression, SymbolReference> mappings = HashBiMap.create(dereferenceAssignments.getMap()).inverse().entrySet().stream().collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference()));
Assignments newAssignments = projectNode.getAssignments().rewrite(expression -> replaceExpression(expression, mappings));
return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), markDistinctNode.replaceChildren(ImmutableList.of(new ProjectNode(context.getIdAllocator().getNextId(), markDistinctNode.getSource(), Assignments.builder().putIdentities(markDistinctNode.getSource().getOutputSymbols()).putAll(dereferenceAssignments).build()))), newAssignments));
}
Aggregations