Search in sources :

Example 36 with Assignments

use of io.trino.sql.planner.plan.Assignments in project trino by trinodb.

the class MergePatternRecognitionNodes method extractPrerequisites.

/**
 * Extract project assignments producing symbols used by the PatternRecognitionNode's
 * window functions and measures. Exclude identity assignments.
 */
private static Assignments extractPrerequisites(PatternRecognitionNode node, ProjectNode project) {
    Assignments assignments = project.getAssignments();
    ImmutableSet.Builder<Symbol> inputsBuilder = ImmutableSet.builder();
    node.getWindowFunctions().values().stream().map(SymbolsExtractor::extractAll).forEach(inputsBuilder::addAll);
    node.getMeasures().values().stream().map(Measure::getExpressionAndValuePointers).map(ExpressionAndValuePointers::getInputSymbols).forEach(inputsBuilder::addAll);
    Set<Symbol> inputs = inputsBuilder.build();
    return assignments.filter(symbol -> !assignments.isIdentity(symbol)).filter(inputs::contains);
}
Also used : Measure(io.trino.sql.planner.plan.PatternRecognitionNode.Measure) Capture.newCapture(io.trino.matching.Capture.newCapture) PlanNode(io.trino.sql.planner.plan.PlanNode) ImmutableList(com.google.common.collect.ImmutableList) Map(java.util.Map) Rule(io.trino.sql.planner.iterative.Rule) SymbolsExtractor(io.trino.sql.planner.SymbolsExtractor) ProjectNode(io.trino.sql.planner.plan.ProjectNode) ExpressionAndValuePointers(io.trino.sql.planner.rowpattern.LogicalIndexExtractor.ExpressionAndValuePointers) Symbol(io.trino.sql.planner.Symbol) ImmutableSet(com.google.common.collect.ImmutableSet) ImmutableMap(com.google.common.collect.ImmutableMap) IrLabel(io.trino.sql.planner.rowpattern.ir.IrLabel) ExpressionAndValuePointersEquivalence(io.trino.sql.planner.rowpattern.ExpressionAndValuePointersEquivalence) Collection(java.util.Collection) Assignments(io.trino.sql.planner.plan.Assignments) Function(io.trino.sql.planner.plan.WindowNode.Function) Set(java.util.Set) Streams(com.google.common.collect.Streams) PatternRecognitionNode(io.trino.sql.planner.plan.PatternRecognitionNode) Capture(io.trino.matching.Capture) Patterns.patternRecognition(io.trino.sql.planner.plan.Patterns.patternRecognition) Pattern(io.trino.matching.Pattern) Patterns.source(io.trino.sql.planner.plan.Patterns.source) Captures(io.trino.matching.Captures) Util.restrictOutputs(io.trino.sql.planner.iterative.rule.Util.restrictOutputs) Patterns.project(io.trino.sql.planner.plan.Patterns.project) ImmutableSet(com.google.common.collect.ImmutableSet) Symbol(io.trino.sql.planner.Symbol) Assignments(io.trino.sql.planner.plan.Assignments) Measure(io.trino.sql.planner.plan.PatternRecognitionNode.Measure)

Example 37 with Assignments

use of io.trino.sql.planner.plan.Assignments in project trino by trinodb.

the class MergePatternRecognitionNodes method dependsOnSourceCreatedOutputs.

/**
 * Check if parent node uses output symbols created by the child node (that is, the output symbols
 * of child node's window functions and measures), with an intervening projection between the parent
 * and child nodes. Only searches for dependencies in the window functions and measures of the parent
 * node. Other properties of the parent node, such as specification and frame, are supposed to be
 * identical to corresponding properties of the child node, as checked in the
 * `patternRecognitionSpecificationsMatch` call. As such, they cannot use symbols created by the child.
 */
private static boolean dependsOnSourceCreatedOutputs(PatternRecognitionNode parent, ProjectNode project, PatternRecognitionNode child) {
    Set<Symbol> sourceCreatedOutputs = child.getCreatedSymbols();
    Assignments assignments = project.getAssignments();
    ImmutableSet.Builder<Symbol> parentInputs = ImmutableSet.builder();
    parent.getWindowFunctions().values().stream().map(SymbolsExtractor::extractAll).forEach(parentInputs::addAll);
    parent.getMeasures().values().stream().map(Measure::getExpressionAndValuePointers).map(ExpressionAndValuePointers::getInputSymbols).forEach(parentInputs::addAll);
    return parentInputs.build().stream().map(assignments::get).map(SymbolsExtractor::extractAll).flatMap(Collection::stream).anyMatch(sourceCreatedOutputs::contains);
}
Also used : ImmutableSet(com.google.common.collect.ImmutableSet) Symbol(io.trino.sql.planner.Symbol) Assignments(io.trino.sql.planner.plan.Assignments) Measure(io.trino.sql.planner.plan.PatternRecognitionNode.Measure) SymbolsExtractor(io.trino.sql.planner.SymbolsExtractor)

Example 38 with Assignments

use of io.trino.sql.planner.plan.Assignments in project trino by trinodb.

the class TransformCorrelatedGroupedAggregationWithProjection method apply.

@Override
public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Context context) {
    // if there is another aggregation below the AggregationNode, handle both
    PlanNode source = captures.get(SOURCE);
    AggregationNode distinct = null;
    if (isDistinctOperator(source)) {
        distinct = (AggregationNode) source;
        source = distinct.getSource();
    }
    // decorrelate nested plan
    PlanNodeDecorrelator decorrelator = new PlanNodeDecorrelator(plannerContext, context.getSymbolAllocator(), context.getLookup());
    Optional<PlanNodeDecorrelator.DecorrelatedNode> decorrelatedSource = decorrelator.decorrelateFilters(source, correlatedJoinNode.getCorrelation());
    if (decorrelatedSource.isEmpty()) {
        return Result.empty();
    }
    source = decorrelatedSource.get().getNode();
    // assign unique id on correlated join's input. It will be used to distinguish between original input rows after join
    PlanNode inputWithUniqueId = new AssignUniqueId(context.getIdAllocator().getNextId(), correlatedJoinNode.getInput(), context.getSymbolAllocator().newSymbol("unique", BIGINT));
    JoinNode join = new JoinNode(context.getIdAllocator().getNextId(), JoinNode.Type.INNER, inputWithUniqueId, source, ImmutableList.of(), inputWithUniqueId.getOutputSymbols(), source.getOutputSymbols(), false, decorrelatedSource.get().getCorrelatedPredicates(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), ImmutableMap.of(), Optional.empty());
    // restore distinct aggregation
    if (distinct != null) {
        distinct = new AggregationNode(distinct.getId(), join, distinct.getAggregations(), singleGroupingSet(ImmutableList.<Symbol>builder().addAll(join.getLeftOutputSymbols()).addAll(distinct.getGroupingKeys()).build()), ImmutableList.of(), distinct.getStep(), Optional.empty(), Optional.empty());
    }
    // restore grouped aggregation
    AggregationNode groupedAggregation = captures.get(AGGREGATION);
    groupedAggregation = new AggregationNode(groupedAggregation.getId(), distinct != null ? distinct : join, groupedAggregation.getAggregations(), singleGroupingSet(ImmutableList.<Symbol>builder().addAll(join.getLeftOutputSymbols()).addAll(groupedAggregation.getGroupingKeys()).build()), ImmutableList.of(), groupedAggregation.getStep(), Optional.empty(), Optional.empty());
    // restrict outputs and apply projection
    Set<Symbol> outputSymbols = new HashSet<>(correlatedJoinNode.getOutputSymbols());
    List<Symbol> expectedAggregationOutputs = groupedAggregation.getOutputSymbols().stream().filter(outputSymbols::contains).collect(toImmutableList());
    Assignments assignments = Assignments.builder().putIdentities(expectedAggregationOutputs).putAll(captures.get(PROJECTION).getAssignments()).build();
    return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), groupedAggregation, assignments));
}
Also used : CorrelatedJoinNode(io.trino.sql.planner.plan.CorrelatedJoinNode) JoinNode(io.trino.sql.planner.plan.JoinNode) Symbol(io.trino.sql.planner.Symbol) Assignments(io.trino.sql.planner.plan.Assignments) AggregationNode(io.trino.sql.planner.plan.AggregationNode) PlanNodeDecorrelator(io.trino.sql.planner.optimizations.PlanNodeDecorrelator) PlanNode(io.trino.sql.planner.plan.PlanNode) AssignUniqueId(io.trino.sql.planner.plan.AssignUniqueId) ProjectNode(io.trino.sql.planner.plan.ProjectNode) HashSet(java.util.HashSet)

Example 39 with Assignments

use of io.trino.sql.planner.plan.Assignments in project trino by trinodb.

the class TransformCorrelatedInPredicateToJoin method apply.

@Override
public Result apply(ApplyNode apply, Captures captures, Context context) {
    Assignments subqueryAssignments = apply.getSubqueryAssignments();
    if (subqueryAssignments.size() != 1) {
        return Result.empty();
    }
    Expression assignmentExpression = getOnlyElement(subqueryAssignments.getExpressions());
    if (!(assignmentExpression instanceof InPredicate)) {
        return Result.empty();
    }
    InPredicate inPredicate = (InPredicate) assignmentExpression;
    Symbol inPredicateOutputSymbol = getOnlyElement(subqueryAssignments.getSymbols());
    return apply(apply, inPredicate, inPredicateOutputSymbol, context.getLookup(), context.getIdAllocator(), context.getSymbolAllocator(), context.getSession());
}
Also used : SearchedCaseExpression(io.trino.sql.tree.SearchedCaseExpression) ComparisonExpression(io.trino.sql.tree.ComparisonExpression) Expression(io.trino.sql.tree.Expression) NotExpression(io.trino.sql.tree.NotExpression) Symbol(io.trino.sql.planner.Symbol) Assignments(io.trino.sql.planner.plan.Assignments) InPredicate(io.trino.sql.tree.InPredicate)

Example 40 with Assignments

use of io.trino.sql.planner.plan.Assignments in project trino by trinodb.

the class TransformCorrelatedGlobalAggregationWithoutProjection method apply.

@Override
public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Context context) {
    checkArgument(correlatedJoinNode.getType() == INNER || correlatedJoinNode.getType() == LEFT, "unexpected correlated join type: " + correlatedJoinNode.getType());
    // if there is another aggregation below the AggregationNode, handle both
    PlanNode source = captures.get(SOURCE);
    AggregationNode distinct = null;
    if (isDistinctOperator(source)) {
        distinct = (AggregationNode) source;
        source = distinct.getSource();
    }
    // decorrelate nested plan
    PlanNodeDecorrelator decorrelator = new PlanNodeDecorrelator(plannerContext, context.getSymbolAllocator(), context.getLookup());
    Optional<PlanNodeDecorrelator.DecorrelatedNode> decorrelatedSource = decorrelator.decorrelateFilters(source, correlatedJoinNode.getCorrelation());
    if (decorrelatedSource.isEmpty()) {
        return Result.empty();
    }
    source = decorrelatedSource.get().getNode();
    // append non-null symbol on nested plan. It will be used to restore semantics of null-sensitive aggregations after LEFT join
    Symbol nonNull = context.getSymbolAllocator().newSymbol("non_null", BOOLEAN);
    source = new ProjectNode(context.getIdAllocator().getNextId(), source, Assignments.builder().putIdentities(source.getOutputSymbols()).put(nonNull, TRUE_LITERAL).build());
    // assign unique id on correlated join's input. It will be used to distinguish between original input rows after join
    PlanNode inputWithUniqueId = new AssignUniqueId(context.getIdAllocator().getNextId(), correlatedJoinNode.getInput(), context.getSymbolAllocator().newSymbol("unique", BIGINT));
    JoinNode join = new JoinNode(context.getIdAllocator().getNextId(), JoinNode.Type.LEFT, inputWithUniqueId, source, ImmutableList.of(), inputWithUniqueId.getOutputSymbols(), source.getOutputSymbols(), false, decorrelatedSource.get().getCorrelatedPredicates(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), ImmutableMap.of(), Optional.empty());
    PlanNode root = join;
    // restore distinct aggregation
    if (distinct != null) {
        root = new AggregationNode(distinct.getId(), join, distinct.getAggregations(), singleGroupingSet(ImmutableList.<Symbol>builder().addAll(join.getLeftOutputSymbols()).add(nonNull).addAll(distinct.getGroupingKeys()).build()), ImmutableList.of(), distinct.getStep(), Optional.empty(), Optional.empty());
    }
    // prepare mask symbols for aggregations
    // Every original aggregation agg() will be rewritten to agg() mask(non_null). If the aggregation
    // already has a mask, it will be replaced with conjunction of the existing mask and non_null.
    // This is necessary to restore the original aggregation result in case when:
    // - the nested lateral subquery returned empty result for some input row,
    // - aggregation is null-sensitive, which means that its result over a single null row is different
    // than result for empty input (with global grouping)
    // It applies to the following aggregate functions: count(*), checksum(), array_agg().
    AggregationNode globalAggregation = captures.get(AGGREGATION);
    ImmutableMap.Builder<Symbol, Symbol> masks = ImmutableMap.builder();
    Assignments.Builder assignmentsBuilder = Assignments.builder();
    for (Map.Entry<Symbol, Aggregation> entry : globalAggregation.getAggregations().entrySet()) {
        Aggregation aggregation = entry.getValue();
        if (aggregation.getMask().isPresent()) {
            Symbol newMask = context.getSymbolAllocator().newSymbol("mask", BOOLEAN);
            Expression expression = and(aggregation.getMask().get().toSymbolReference(), nonNull.toSymbolReference());
            assignmentsBuilder.put(newMask, expression);
            masks.put(entry.getKey(), newMask);
        } else {
            masks.put(entry.getKey(), nonNull);
        }
    }
    Assignments maskAssignments = assignmentsBuilder.build();
    if (!maskAssignments.isEmpty()) {
        root = new ProjectNode(context.getIdAllocator().getNextId(), root, Assignments.builder().putIdentities(root.getOutputSymbols()).putAll(maskAssignments).build());
    }
    // restore global aggregation
    globalAggregation = new AggregationNode(globalAggregation.getId(), root, rewriteWithMasks(globalAggregation.getAggregations(), masks.buildOrThrow()), singleGroupingSet(ImmutableList.<Symbol>builder().addAll(join.getLeftOutputSymbols()).addAll(globalAggregation.getGroupingKeys()).build()), ImmutableList.of(), globalAggregation.getStep(), Optional.empty(), Optional.empty());
    // restrict outputs
    Optional<PlanNode> project = restrictOutputs(context.getIdAllocator(), globalAggregation, ImmutableSet.copyOf(correlatedJoinNode.getOutputSymbols()));
    return Result.ofPlanNode(project.orElse(globalAggregation));
}
Also used : Symbol(io.trino.sql.planner.Symbol) CorrelatedJoinNode(io.trino.sql.planner.plan.CorrelatedJoinNode) JoinNode(io.trino.sql.planner.plan.JoinNode) Assignments(io.trino.sql.planner.plan.Assignments) AggregationNode(io.trino.sql.planner.plan.AggregationNode) ImmutableMap(com.google.common.collect.ImmutableMap) Aggregation(io.trino.sql.planner.plan.AggregationNode.Aggregation) PlanNodeDecorrelator(io.trino.sql.planner.optimizations.PlanNodeDecorrelator) PlanNode(io.trino.sql.planner.plan.PlanNode) AssignUniqueId(io.trino.sql.planner.plan.AssignUniqueId) Expression(io.trino.sql.tree.Expression) ProjectNode(io.trino.sql.planner.plan.ProjectNode) Map(java.util.Map) ImmutableMap(com.google.common.collect.ImmutableMap)

Aggregations

Assignments (io.trino.sql.planner.plan.Assignments)61 ProjectNode (io.trino.sql.planner.plan.ProjectNode)50 Symbol (io.trino.sql.planner.Symbol)39 Expression (io.trino.sql.tree.Expression)39 ImmutableList (com.google.common.collect.ImmutableList)36 Map (java.util.Map)33 PlanNode (io.trino.sql.planner.plan.PlanNode)26 SymbolReference (io.trino.sql.tree.SymbolReference)26 Captures (io.trino.matching.Captures)23 Pattern (io.trino.matching.Pattern)23 Rule (io.trino.sql.planner.iterative.Rule)23 ImmutableMap.toImmutableMap (com.google.common.collect.ImmutableMap.toImmutableMap)21 Objects.requireNonNull (java.util.Objects.requireNonNull)21 Set (java.util.Set)21 Capture (io.trino.matching.Capture)20 Capture.newCapture (io.trino.matching.Capture.newCapture)20 TypeAnalyzer (io.trino.sql.planner.TypeAnalyzer)19 Patterns.source (io.trino.sql.planner.plan.Patterns.source)19 SubscriptExpression (io.trino.sql.tree.SubscriptExpression)18 ImmutableMap (com.google.common.collect.ImmutableMap)17