Search in sources :

Example 11 with AssignUniqueId

use of io.trino.sql.planner.plan.AssignUniqueId in project trino by trinodb.

the class TransformCorrelatedInPredicateToJoin method buildInPredicateEquivalent.

private PlanNode buildInPredicateEquivalent(ApplyNode apply, InPredicate inPredicate, Symbol inPredicateOutputSymbol, Decorrelated decorrelated, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) {
    Expression correlationCondition = and(decorrelated.getCorrelatedPredicates());
    PlanNode decorrelatedBuildSource = decorrelated.getDecorrelatedNode();
    AssignUniqueId probeSide = new AssignUniqueId(idAllocator.getNextId(), apply.getInput(), symbolAllocator.newSymbol("unique", BIGINT));
    Symbol buildSideKnownNonNull = symbolAllocator.newSymbol("buildSideKnownNonNull", BIGINT);
    ProjectNode buildSide = new ProjectNode(idAllocator.getNextId(), decorrelatedBuildSource, Assignments.builder().putIdentities(decorrelatedBuildSource.getOutputSymbols()).put(buildSideKnownNonNull, bigint(0)).build());
    Symbol probeSideSymbol = Symbol.from(inPredicate.getValue());
    Symbol buildSideSymbol = Symbol.from(inPredicate.getValueList());
    Expression joinExpression = and(or(new IsNullPredicate(probeSideSymbol.toSymbolReference()), new ComparisonExpression(ComparisonExpression.Operator.EQUAL, probeSideSymbol.toSymbolReference(), buildSideSymbol.toSymbolReference()), new IsNullPredicate(buildSideSymbol.toSymbolReference())), correlationCondition);
    JoinNode leftOuterJoin = leftOuterJoin(idAllocator, probeSide, buildSide, joinExpression);
    Symbol matchConditionSymbol = symbolAllocator.newSymbol("matchConditionSymbol", BOOLEAN);
    Expression matchCondition = and(isNotNull(probeSideSymbol), isNotNull(buildSideSymbol));
    Symbol nullMatchConditionSymbol = symbolAllocator.newSymbol("nullMatchConditionSymbol", BOOLEAN);
    Expression nullMatchCondition = and(isNotNull(buildSideKnownNonNull), not(matchCondition));
    ProjectNode preProjection = new ProjectNode(idAllocator.getNextId(), leftOuterJoin, Assignments.builder().putIdentities(leftOuterJoin.getOutputSymbols()).put(matchConditionSymbol, matchCondition).put(nullMatchConditionSymbol, nullMatchCondition).build());
    Symbol countMatchesSymbol = symbolAllocator.newSymbol("countMatches", BIGINT);
    Symbol countNullMatchesSymbol = symbolAllocator.newSymbol("countNullMatches", BIGINT);
    AggregationNode aggregation = new AggregationNode(idAllocator.getNextId(), preProjection, ImmutableMap.<Symbol, AggregationNode.Aggregation>builder().put(countMatchesSymbol, countWithFilter(session, matchConditionSymbol)).put(countNullMatchesSymbol, countWithFilter(session, nullMatchConditionSymbol)).buildOrThrow(), singleGroupingSet(probeSide.getOutputSymbols()), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty());
    // TODO since we care only about "some count > 0", we could have specialized node instead of leftOuterJoin that does the job without materializing join results
    SearchedCaseExpression inPredicateEquivalent = new SearchedCaseExpression(ImmutableList.of(new WhenClause(isGreaterThan(countMatchesSymbol, 0), booleanConstant(true)), new WhenClause(isGreaterThan(countNullMatchesSymbol, 0), booleanConstant(null))), Optional.of(booleanConstant(false)));
    return new ProjectNode(idAllocator.getNextId(), aggregation, Assignments.builder().putIdentities(apply.getInput().getOutputSymbols()).put(inPredicateOutputSymbol, inPredicateEquivalent).build());
}
Also used : ComparisonExpression(io.trino.sql.tree.ComparisonExpression) WhenClause(io.trino.sql.tree.WhenClause) PlanNode(io.trino.sql.planner.plan.PlanNode) AssignUniqueId(io.trino.sql.planner.plan.AssignUniqueId) SearchedCaseExpression(io.trino.sql.tree.SearchedCaseExpression) SearchedCaseExpression(io.trino.sql.tree.SearchedCaseExpression) ComparisonExpression(io.trino.sql.tree.ComparisonExpression) Expression(io.trino.sql.tree.Expression) NotExpression(io.trino.sql.tree.NotExpression) Symbol(io.trino.sql.planner.Symbol) JoinNode(io.trino.sql.planner.plan.JoinNode) IsNullPredicate(io.trino.sql.tree.IsNullPredicate) ProjectNode(io.trino.sql.planner.plan.ProjectNode) AggregationNode(io.trino.sql.planner.plan.AggregationNode)

Example 12 with AssignUniqueId

use of io.trino.sql.planner.plan.AssignUniqueId in project trino by trinodb.

the class TransformCorrelatedGlobalAggregationWithoutProjection method apply.

@Override
public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Context context) {
    checkArgument(correlatedJoinNode.getType() == INNER || correlatedJoinNode.getType() == LEFT, "unexpected correlated join type: " + correlatedJoinNode.getType());
    // if there is another aggregation below the AggregationNode, handle both
    PlanNode source = captures.get(SOURCE);
    AggregationNode distinct = null;
    if (isDistinctOperator(source)) {
        distinct = (AggregationNode) source;
        source = distinct.getSource();
    }
    // decorrelate nested plan
    PlanNodeDecorrelator decorrelator = new PlanNodeDecorrelator(plannerContext, context.getSymbolAllocator(), context.getLookup());
    Optional<PlanNodeDecorrelator.DecorrelatedNode> decorrelatedSource = decorrelator.decorrelateFilters(source, correlatedJoinNode.getCorrelation());
    if (decorrelatedSource.isEmpty()) {
        return Result.empty();
    }
    source = decorrelatedSource.get().getNode();
    // append non-null symbol on nested plan. It will be used to restore semantics of null-sensitive aggregations after LEFT join
    Symbol nonNull = context.getSymbolAllocator().newSymbol("non_null", BOOLEAN);
    source = new ProjectNode(context.getIdAllocator().getNextId(), source, Assignments.builder().putIdentities(source.getOutputSymbols()).put(nonNull, TRUE_LITERAL).build());
    // assign unique id on correlated join's input. It will be used to distinguish between original input rows after join
    PlanNode inputWithUniqueId = new AssignUniqueId(context.getIdAllocator().getNextId(), correlatedJoinNode.getInput(), context.getSymbolAllocator().newSymbol("unique", BIGINT));
    JoinNode join = new JoinNode(context.getIdAllocator().getNextId(), JoinNode.Type.LEFT, inputWithUniqueId, source, ImmutableList.of(), inputWithUniqueId.getOutputSymbols(), source.getOutputSymbols(), false, decorrelatedSource.get().getCorrelatedPredicates(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), ImmutableMap.of(), Optional.empty());
    PlanNode root = join;
    // restore distinct aggregation
    if (distinct != null) {
        root = new AggregationNode(distinct.getId(), join, distinct.getAggregations(), singleGroupingSet(ImmutableList.<Symbol>builder().addAll(join.getLeftOutputSymbols()).add(nonNull).addAll(distinct.getGroupingKeys()).build()), ImmutableList.of(), distinct.getStep(), Optional.empty(), Optional.empty());
    }
    // prepare mask symbols for aggregations
    // Every original aggregation agg() will be rewritten to agg() mask(non_null). If the aggregation
    // already has a mask, it will be replaced with conjunction of the existing mask and non_null.
    // This is necessary to restore the original aggregation result in case when:
    // - the nested lateral subquery returned empty result for some input row,
    // - aggregation is null-sensitive, which means that its result over a single null row is different
    // than result for empty input (with global grouping)
    // It applies to the following aggregate functions: count(*), checksum(), array_agg().
    AggregationNode globalAggregation = captures.get(AGGREGATION);
    ImmutableMap.Builder<Symbol, Symbol> masks = ImmutableMap.builder();
    Assignments.Builder assignmentsBuilder = Assignments.builder();
    for (Map.Entry<Symbol, Aggregation> entry : globalAggregation.getAggregations().entrySet()) {
        Aggregation aggregation = entry.getValue();
        if (aggregation.getMask().isPresent()) {
            Symbol newMask = context.getSymbolAllocator().newSymbol("mask", BOOLEAN);
            Expression expression = and(aggregation.getMask().get().toSymbolReference(), nonNull.toSymbolReference());
            assignmentsBuilder.put(newMask, expression);
            masks.put(entry.getKey(), newMask);
        } else {
            masks.put(entry.getKey(), nonNull);
        }
    }
    Assignments maskAssignments = assignmentsBuilder.build();
    if (!maskAssignments.isEmpty()) {
        root = new ProjectNode(context.getIdAllocator().getNextId(), root, Assignments.builder().putIdentities(root.getOutputSymbols()).putAll(maskAssignments).build());
    }
    // restore global aggregation
    globalAggregation = new AggregationNode(globalAggregation.getId(), root, rewriteWithMasks(globalAggregation.getAggregations(), masks.buildOrThrow()), singleGroupingSet(ImmutableList.<Symbol>builder().addAll(join.getLeftOutputSymbols()).addAll(globalAggregation.getGroupingKeys()).build()), ImmutableList.of(), globalAggregation.getStep(), Optional.empty(), Optional.empty());
    // restrict outputs
    Optional<PlanNode> project = restrictOutputs(context.getIdAllocator(), globalAggregation, ImmutableSet.copyOf(correlatedJoinNode.getOutputSymbols()));
    return Result.ofPlanNode(project.orElse(globalAggregation));
}
Also used : Symbol(io.trino.sql.planner.Symbol) CorrelatedJoinNode(io.trino.sql.planner.plan.CorrelatedJoinNode) JoinNode(io.trino.sql.planner.plan.JoinNode) Assignments(io.trino.sql.planner.plan.Assignments) AggregationNode(io.trino.sql.planner.plan.AggregationNode) ImmutableMap(com.google.common.collect.ImmutableMap) Aggregation(io.trino.sql.planner.plan.AggregationNode.Aggregation) PlanNodeDecorrelator(io.trino.sql.planner.optimizations.PlanNodeDecorrelator) PlanNode(io.trino.sql.planner.plan.PlanNode) AssignUniqueId(io.trino.sql.planner.plan.AssignUniqueId) Expression(io.trino.sql.tree.Expression) ProjectNode(io.trino.sql.planner.plan.ProjectNode) Map(java.util.Map) ImmutableMap(com.google.common.collect.ImmutableMap)

Example 13 with AssignUniqueId

use of io.trino.sql.planner.plan.AssignUniqueId in project trino by trinodb.

the class PushRemoteExchangeThroughAssignUniqueId method apply.

@Override
public Result apply(ExchangeNode node, Captures captures, Context context) {
    checkArgument(node.getOrderingScheme().isEmpty(), "Merge exchange over AssignUniqueId not supported");
    AssignUniqueId assignUniqueId = captures.get(ASSIGN_UNIQUE_ID);
    PartitioningScheme partitioningScheme = node.getPartitioningScheme();
    if (partitioningScheme.getPartitioning().getColumns().contains(assignUniqueId.getIdColumn())) {
        // Hence, AssignUniqueId node has to stay below the exchange node.
        return Result.empty();
    }
    return Result.ofPlanNode(new AssignUniqueId(assignUniqueId.getId(), new ExchangeNode(node.getId(), node.getType(), node.getScope(), new PartitioningScheme(partitioningScheme.getPartitioning(), removeSymbol(partitioningScheme.getOutputLayout(), assignUniqueId.getIdColumn()), partitioningScheme.getHashColumn(), partitioningScheme.isReplicateNullsAndAny(), partitioningScheme.getBucketToPartition()), ImmutableList.of(assignUniqueId.getSource()), ImmutableList.of(removeSymbol(getOnlyElement(node.getInputs()), assignUniqueId.getIdColumn())), Optional.empty()), assignUniqueId.getIdColumn()));
}
Also used : AssignUniqueId(io.trino.sql.planner.plan.AssignUniqueId) ExchangeNode(io.trino.sql.planner.plan.ExchangeNode) PartitioningScheme(io.trino.sql.planner.PartitioningScheme)

Aggregations

AssignUniqueId (io.trino.sql.planner.plan.AssignUniqueId)13 Symbol (io.trino.sql.planner.Symbol)11 PlanNode (io.trino.sql.planner.plan.PlanNode)11 CorrelatedJoinNode (io.trino.sql.planner.plan.CorrelatedJoinNode)10 ProjectNode (io.trino.sql.planner.plan.ProjectNode)10 AggregationNode (io.trino.sql.planner.plan.AggregationNode)9 Assignments (io.trino.sql.planner.plan.Assignments)8 JoinNode (io.trino.sql.planner.plan.JoinNode)7 PlanNodeDecorrelator (io.trino.sql.planner.optimizations.PlanNodeDecorrelator)6 Expression (io.trino.sql.tree.Expression)6 ImmutableList (com.google.common.collect.ImmutableList)4 ImmutableMap (com.google.common.collect.ImmutableMap)4 Captures (io.trino.matching.Captures)4 Pattern (io.trino.matching.Pattern)4 Rule (io.trino.sql.planner.iterative.Rule)4 ImmutableList.toImmutableList (com.google.common.collect.ImmutableList.toImmutableList)3 ImmutableSet (com.google.common.collect.ImmutableSet)3 Pattern.nonEmpty (io.trino.matching.Pattern.nonEmpty)3 BIGINT (io.trino.spi.type.BigintType.BIGINT)3 SymbolsExtractor (io.trino.sql.planner.SymbolsExtractor)3