use of io.trino.sql.planner.plan.AssignUniqueId in project trino by trinodb.
the class TransformCorrelatedInPredicateToJoin method buildInPredicateEquivalent.
private PlanNode buildInPredicateEquivalent(ApplyNode apply, InPredicate inPredicate, Symbol inPredicateOutputSymbol, Decorrelated decorrelated, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) {
Expression correlationCondition = and(decorrelated.getCorrelatedPredicates());
PlanNode decorrelatedBuildSource = decorrelated.getDecorrelatedNode();
AssignUniqueId probeSide = new AssignUniqueId(idAllocator.getNextId(), apply.getInput(), symbolAllocator.newSymbol("unique", BIGINT));
Symbol buildSideKnownNonNull = symbolAllocator.newSymbol("buildSideKnownNonNull", BIGINT);
ProjectNode buildSide = new ProjectNode(idAllocator.getNextId(), decorrelatedBuildSource, Assignments.builder().putIdentities(decorrelatedBuildSource.getOutputSymbols()).put(buildSideKnownNonNull, bigint(0)).build());
Symbol probeSideSymbol = Symbol.from(inPredicate.getValue());
Symbol buildSideSymbol = Symbol.from(inPredicate.getValueList());
Expression joinExpression = and(or(new IsNullPredicate(probeSideSymbol.toSymbolReference()), new ComparisonExpression(ComparisonExpression.Operator.EQUAL, probeSideSymbol.toSymbolReference(), buildSideSymbol.toSymbolReference()), new IsNullPredicate(buildSideSymbol.toSymbolReference())), correlationCondition);
JoinNode leftOuterJoin = leftOuterJoin(idAllocator, probeSide, buildSide, joinExpression);
Symbol matchConditionSymbol = symbolAllocator.newSymbol("matchConditionSymbol", BOOLEAN);
Expression matchCondition = and(isNotNull(probeSideSymbol), isNotNull(buildSideSymbol));
Symbol nullMatchConditionSymbol = symbolAllocator.newSymbol("nullMatchConditionSymbol", BOOLEAN);
Expression nullMatchCondition = and(isNotNull(buildSideKnownNonNull), not(matchCondition));
ProjectNode preProjection = new ProjectNode(idAllocator.getNextId(), leftOuterJoin, Assignments.builder().putIdentities(leftOuterJoin.getOutputSymbols()).put(matchConditionSymbol, matchCondition).put(nullMatchConditionSymbol, nullMatchCondition).build());
Symbol countMatchesSymbol = symbolAllocator.newSymbol("countMatches", BIGINT);
Symbol countNullMatchesSymbol = symbolAllocator.newSymbol("countNullMatches", BIGINT);
AggregationNode aggregation = new AggregationNode(idAllocator.getNextId(), preProjection, ImmutableMap.<Symbol, AggregationNode.Aggregation>builder().put(countMatchesSymbol, countWithFilter(session, matchConditionSymbol)).put(countNullMatchesSymbol, countWithFilter(session, nullMatchConditionSymbol)).buildOrThrow(), singleGroupingSet(probeSide.getOutputSymbols()), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty());
// TODO since we care only about "some count > 0", we could have specialized node instead of leftOuterJoin that does the job without materializing join results
SearchedCaseExpression inPredicateEquivalent = new SearchedCaseExpression(ImmutableList.of(new WhenClause(isGreaterThan(countMatchesSymbol, 0), booleanConstant(true)), new WhenClause(isGreaterThan(countNullMatchesSymbol, 0), booleanConstant(null))), Optional.of(booleanConstant(false)));
return new ProjectNode(idAllocator.getNextId(), aggregation, Assignments.builder().putIdentities(apply.getInput().getOutputSymbols()).put(inPredicateOutputSymbol, inPredicateEquivalent).build());
}
use of io.trino.sql.planner.plan.AssignUniqueId in project trino by trinodb.
the class TransformCorrelatedGlobalAggregationWithoutProjection method apply.
@Override
public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Context context) {
checkArgument(correlatedJoinNode.getType() == INNER || correlatedJoinNode.getType() == LEFT, "unexpected correlated join type: " + correlatedJoinNode.getType());
// if there is another aggregation below the AggregationNode, handle both
PlanNode source = captures.get(SOURCE);
AggregationNode distinct = null;
if (isDistinctOperator(source)) {
distinct = (AggregationNode) source;
source = distinct.getSource();
}
// decorrelate nested plan
PlanNodeDecorrelator decorrelator = new PlanNodeDecorrelator(plannerContext, context.getSymbolAllocator(), context.getLookup());
Optional<PlanNodeDecorrelator.DecorrelatedNode> decorrelatedSource = decorrelator.decorrelateFilters(source, correlatedJoinNode.getCorrelation());
if (decorrelatedSource.isEmpty()) {
return Result.empty();
}
source = decorrelatedSource.get().getNode();
// append non-null symbol on nested plan. It will be used to restore semantics of null-sensitive aggregations after LEFT join
Symbol nonNull = context.getSymbolAllocator().newSymbol("non_null", BOOLEAN);
source = new ProjectNode(context.getIdAllocator().getNextId(), source, Assignments.builder().putIdentities(source.getOutputSymbols()).put(nonNull, TRUE_LITERAL).build());
// assign unique id on correlated join's input. It will be used to distinguish between original input rows after join
PlanNode inputWithUniqueId = new AssignUniqueId(context.getIdAllocator().getNextId(), correlatedJoinNode.getInput(), context.getSymbolAllocator().newSymbol("unique", BIGINT));
JoinNode join = new JoinNode(context.getIdAllocator().getNextId(), JoinNode.Type.LEFT, inputWithUniqueId, source, ImmutableList.of(), inputWithUniqueId.getOutputSymbols(), source.getOutputSymbols(), false, decorrelatedSource.get().getCorrelatedPredicates(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), ImmutableMap.of(), Optional.empty());
PlanNode root = join;
// restore distinct aggregation
if (distinct != null) {
root = new AggregationNode(distinct.getId(), join, distinct.getAggregations(), singleGroupingSet(ImmutableList.<Symbol>builder().addAll(join.getLeftOutputSymbols()).add(nonNull).addAll(distinct.getGroupingKeys()).build()), ImmutableList.of(), distinct.getStep(), Optional.empty(), Optional.empty());
}
// prepare mask symbols for aggregations
// Every original aggregation agg() will be rewritten to agg() mask(non_null). If the aggregation
// already has a mask, it will be replaced with conjunction of the existing mask and non_null.
// This is necessary to restore the original aggregation result in case when:
// - the nested lateral subquery returned empty result for some input row,
// - aggregation is null-sensitive, which means that its result over a single null row is different
// than result for empty input (with global grouping)
// It applies to the following aggregate functions: count(*), checksum(), array_agg().
AggregationNode globalAggregation = captures.get(AGGREGATION);
ImmutableMap.Builder<Symbol, Symbol> masks = ImmutableMap.builder();
Assignments.Builder assignmentsBuilder = Assignments.builder();
for (Map.Entry<Symbol, Aggregation> entry : globalAggregation.getAggregations().entrySet()) {
Aggregation aggregation = entry.getValue();
if (aggregation.getMask().isPresent()) {
Symbol newMask = context.getSymbolAllocator().newSymbol("mask", BOOLEAN);
Expression expression = and(aggregation.getMask().get().toSymbolReference(), nonNull.toSymbolReference());
assignmentsBuilder.put(newMask, expression);
masks.put(entry.getKey(), newMask);
} else {
masks.put(entry.getKey(), nonNull);
}
}
Assignments maskAssignments = assignmentsBuilder.build();
if (!maskAssignments.isEmpty()) {
root = new ProjectNode(context.getIdAllocator().getNextId(), root, Assignments.builder().putIdentities(root.getOutputSymbols()).putAll(maskAssignments).build());
}
// restore global aggregation
globalAggregation = new AggregationNode(globalAggregation.getId(), root, rewriteWithMasks(globalAggregation.getAggregations(), masks.buildOrThrow()), singleGroupingSet(ImmutableList.<Symbol>builder().addAll(join.getLeftOutputSymbols()).addAll(globalAggregation.getGroupingKeys()).build()), ImmutableList.of(), globalAggregation.getStep(), Optional.empty(), Optional.empty());
// restrict outputs
Optional<PlanNode> project = restrictOutputs(context.getIdAllocator(), globalAggregation, ImmutableSet.copyOf(correlatedJoinNode.getOutputSymbols()));
return Result.ofPlanNode(project.orElse(globalAggregation));
}
use of io.trino.sql.planner.plan.AssignUniqueId in project trino by trinodb.
the class PushRemoteExchangeThroughAssignUniqueId method apply.
@Override
public Result apply(ExchangeNode node, Captures captures, Context context) {
checkArgument(node.getOrderingScheme().isEmpty(), "Merge exchange over AssignUniqueId not supported");
AssignUniqueId assignUniqueId = captures.get(ASSIGN_UNIQUE_ID);
PartitioningScheme partitioningScheme = node.getPartitioningScheme();
if (partitioningScheme.getPartitioning().getColumns().contains(assignUniqueId.getIdColumn())) {
// Hence, AssignUniqueId node has to stay below the exchange node.
return Result.empty();
}
return Result.ofPlanNode(new AssignUniqueId(assignUniqueId.getId(), new ExchangeNode(node.getId(), node.getType(), node.getScope(), new PartitioningScheme(partitioningScheme.getPartitioning(), removeSymbol(partitioningScheme.getOutputLayout(), assignUniqueId.getIdColumn()), partitioningScheme.getHashColumn(), partitioningScheme.isReplicateNullsAndAny(), partitioningScheme.getBucketToPartition()), ImmutableList.of(assignUniqueId.getSource()), ImmutableList.of(removeSymbol(getOnlyElement(node.getInputs()), assignUniqueId.getIdColumn())), Optional.empty()), assignUniqueId.getIdColumn()));
}
Aggregations