Search in sources :

Example 1 with Expression

use of io.trino.sql.tree.Expression in project trino by trinodb.

the class FilterStatsCalculator method extractCorrelatedGroups.

private static List<List<Expression>> extractCorrelatedGroups(List<Expression> terms, double filterConjunctionIndependenceFactor) {
    if (filterConjunctionIndependenceFactor == 1) {
        // Allows the filters to be estimated as if there is no correlation between any of the terms
        return ImmutableList.of(terms);
    }
    ListMultimap<Expression, Symbol> expressionUniqueSymbols = ArrayListMultimap.create();
    terms.forEach(expression -> expressionUniqueSymbols.putAll(expression, extractUnique(expression)));
    // Partition symbols into disjoint sets such that the symbols belonging to different disjoint sets
    // do not appear together in any expression.
    DisjointSet<Symbol> symbolsPartitioner = new DisjointSet<>();
    for (Expression term : terms) {
        List<Symbol> expressionSymbols = expressionUniqueSymbols.get(term);
        if (expressionSymbols.isEmpty()) {
            continue;
        }
        // Ensure that symbol is added to DisjointSet when there is only one symbol in the list
        symbolsPartitioner.find(expressionSymbols.get(0));
        for (int i = 1; i < expressionSymbols.size(); i++) {
            symbolsPartitioner.findAndUnion(expressionSymbols.get(0), expressionSymbols.get(i));
        }
    }
    // Use disjoint sets of symbols to partition the given list of expressions
    List<Set<Symbol>> symbolPartitions = ImmutableList.copyOf(symbolsPartitioner.getEquivalentClasses());
    checkState(symbolPartitions.size() <= terms.size(), "symbolPartitions size exceeds number of expressions");
    ListMultimap<Integer, Expression> expressionPartitions = ArrayListMultimap.create();
    for (Expression term : terms) {
        List<Symbol> expressionSymbols = expressionUniqueSymbols.get(term);
        int expressionPartitionId;
        if (expressionSymbols.isEmpty()) {
            // For expressions with no symbols
            expressionPartitionId = symbolPartitions.size();
        } else {
            // Lookup any symbol to find the partition id
            Symbol symbol = expressionSymbols.get(0);
            expressionPartitionId = IntStream.range(0, symbolPartitions.size()).filter(partition -> symbolPartitions.get(partition).contains(symbol)).findFirst().orElseThrow();
        }
        expressionPartitions.put(expressionPartitionId, term);
    }
    return expressionPartitions.keySet().stream().map(expressionPartitions::get).collect(toImmutableList());
}
Also used : ArrayListMultimap(com.google.common.collect.ArrayListMultimap) IsNullPredicate(io.trino.sql.tree.IsNullPredicate) ListMultimap(com.google.common.collect.ListMultimap) SymbolsExtractor.extractUnique(io.trino.sql.planner.SymbolsExtractor.extractUnique) SystemSessionProperties.getFilterConjunctionIndependenceFactor(io.trino.SystemSessionProperties.getFilterConjunctionIndependenceFactor) Scope(io.trino.sql.analyzer.Scope) Double.min(java.lang.Double.min) StatsUtil.toStatsRepresentation(io.trino.spi.statistics.StatsUtil.toStatsRepresentation) Node(io.trino.sql.tree.Node) Preconditions.checkArgument(com.google.common.base.Preconditions.checkArgument) BooleanLiteral(io.trino.sql.tree.BooleanLiteral) NaN(java.lang.Double.NaN) Map(java.util.Map) DisjointSet(io.trino.util.DisjointSet) FunctionCall(io.trino.sql.tree.FunctionCall) ExpressionUtils.getExpressionTypes(io.trino.sql.ExpressionUtils.getExpressionTypes) ComparisonStatsCalculator.estimateExpressionToExpressionComparison(io.trino.cost.ComparisonStatsCalculator.estimateExpressionToExpressionComparison) BetweenPredicate(io.trino.sql.tree.BetweenPredicate) Double.isInfinite(java.lang.Double.isInfinite) PlanNodeStatsEstimateMath.capStats(io.trino.cost.PlanNodeStatsEstimateMath.capStats) ImmutableMap(com.google.common.collect.ImmutableMap) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) ExpressionInterpreter(io.trino.sql.planner.ExpressionInterpreter) Set(java.util.Set) ComparisonStatsCalculator.estimateExpressionToLiteralComparison(io.trino.cost.ComparisonStatsCalculator.estimateExpressionToLiteralComparison) ComparisonExpression(io.trino.sql.tree.ComparisonExpression) String.format(java.lang.String.format) InPredicate(io.trino.sql.tree.InPredicate) Preconditions.checkState(com.google.common.base.Preconditions.checkState) LESS_THAN_OR_EQUAL(io.trino.sql.tree.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL) EQUAL(io.trino.sql.tree.ComparisonExpression.Operator.EQUAL) List(java.util.List) SymbolReference(io.trino.sql.tree.SymbolReference) Optional(java.util.Optional) Expression(io.trino.sql.tree.Expression) InListExpression(io.trino.sql.tree.InListExpression) PlanNodeStatsEstimateMath.subtractSubsetStats(io.trino.cost.PlanNodeStatsEstimateMath.subtractSubsetStats) Session(io.trino.Session) PlannerContext(io.trino.sql.PlannerContext) IntStream(java.util.stream.IntStream) PlanNodeStatsEstimateMath.addStatsAndSumDistinctValues(io.trino.cost.PlanNodeStatsEstimateMath.addStatsAndSumDistinctValues) AllowAllAccessControl(io.trino.security.AllowAllAccessControl) LiteralEncoder(io.trino.sql.planner.LiteralEncoder) ExpressionInterpreter.evaluateConstantExpression(io.trino.sql.planner.ExpressionInterpreter.evaluateConstantExpression) Type(io.trino.spi.type.Type) NoOpSymbolResolver(io.trino.sql.planner.NoOpSymbolResolver) OptionalDouble(java.util.OptionalDouble) BOOLEAN(io.trino.spi.type.BooleanType.BOOLEAN) Inject(javax.inject.Inject) ImmutableList(com.google.common.collect.ImmutableList) Verify.verify(com.google.common.base.Verify.verify) PlanNodeStatsEstimateMath.estimateCorrelatedConjunctionRowCount(io.trino.cost.PlanNodeStatsEstimateMath.estimateCorrelatedConjunctionRowCount) IsNotNullPredicate(io.trino.sql.tree.IsNotNullPredicate) ExpressionAnalyzer(io.trino.sql.analyzer.ExpressionAnalyzer) NodeRef(io.trino.sql.tree.NodeRef) NotExpression(io.trino.sql.tree.NotExpression) Objects.requireNonNull(java.util.Objects.requireNonNull) GREATER_THAN_OR_EQUAL(io.trino.sql.tree.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL) Nullable(javax.annotation.Nullable) AstVisitor(io.trino.sql.tree.AstVisitor) VerifyException(com.google.common.base.VerifyException) ExpressionUtils(io.trino.sql.ExpressionUtils) Symbol(io.trino.sql.planner.Symbol) PlanNodeStatsEstimateMath.intersectCorrelatedStats(io.trino.cost.PlanNodeStatsEstimateMath.intersectCorrelatedStats) ExpressionUtils.and(io.trino.sql.ExpressionUtils.and) DynamicFilters.isDynamicFilter(io.trino.sql.DynamicFilters.isDynamicFilter) Double.isNaN(java.lang.Double.isNaN) WarningCollector(io.trino.execution.warnings.WarningCollector) LogicalExpression(io.trino.sql.tree.LogicalExpression) TypeProvider(io.trino.sql.planner.TypeProvider) DisjointSet(io.trino.util.DisjointSet) Set(java.util.Set) ComparisonExpression(io.trino.sql.tree.ComparisonExpression) Expression(io.trino.sql.tree.Expression) InListExpression(io.trino.sql.tree.InListExpression) ExpressionInterpreter.evaluateConstantExpression(io.trino.sql.planner.ExpressionInterpreter.evaluateConstantExpression) NotExpression(io.trino.sql.tree.NotExpression) LogicalExpression(io.trino.sql.tree.LogicalExpression) Symbol(io.trino.sql.planner.Symbol) DisjointSet(io.trino.util.DisjointSet)

Example 2 with Expression

use of io.trino.sql.tree.Expression in project trino by trinodb.

the class ExtractSpatialJoins method tryCreateSpatialJoin.

private static Result tryCreateSpatialJoin(Context context, JoinNode joinNode, Expression filter, PlanNodeId nodeId, List<Symbol> outputSymbols, FunctionCall spatialFunction, Optional<Expression> radius, PlannerContext plannerContext, SplitManager splitManager, PageSourceManager pageSourceManager, TypeAnalyzer typeAnalyzer) {
    // TODO Add support for distributed left spatial joins
    Optional<String> spatialPartitioningTableName = joinNode.getType() == INNER ? getSpatialPartitioningTableName(context.getSession()) : Optional.empty();
    Optional<KdbTree> kdbTree = spatialPartitioningTableName.map(tableName -> loadKdbTree(tableName, context.getSession(), plannerContext.getMetadata(), splitManager, pageSourceManager));
    List<Expression> arguments = spatialFunction.getArguments();
    verify(arguments.size() == 2);
    Expression firstArgument = arguments.get(0);
    Expression secondArgument = arguments.get(1);
    Type sphericalGeographyType = plannerContext.getTypeManager().getType(SPHERICAL_GEOGRAPHY_TYPE_SIGNATURE);
    if (typeAnalyzer.getType(context.getSession(), context.getSymbolAllocator().getTypes(), firstArgument).equals(sphericalGeographyType) || typeAnalyzer.getType(context.getSession(), context.getSymbolAllocator().getTypes(), secondArgument).equals(sphericalGeographyType)) {
        return Result.empty();
    }
    Set<Symbol> firstSymbols = extractUnique(firstArgument);
    Set<Symbol> secondSymbols = extractUnique(secondArgument);
    if (firstSymbols.isEmpty() || secondSymbols.isEmpty()) {
        return Result.empty();
    }
    Optional<Symbol> newFirstSymbol = newGeometrySymbol(context, firstArgument, plannerContext.getTypeManager());
    Optional<Symbol> newSecondSymbol = newGeometrySymbol(context, secondArgument, plannerContext.getTypeManager());
    PlanNode leftNode = joinNode.getLeft();
    PlanNode rightNode = joinNode.getRight();
    PlanNode newLeftNode;
    PlanNode newRightNode;
    // Check if the order of arguments of the spatial function matches the order of join sides
    int alignment = checkAlignment(joinNode, firstSymbols, secondSymbols);
    if (alignment > 0) {
        newLeftNode = newFirstSymbol.map(symbol -> addProjection(context, leftNode, symbol, firstArgument)).orElse(leftNode);
        newRightNode = newSecondSymbol.map(symbol -> addProjection(context, rightNode, symbol, secondArgument)).orElse(rightNode);
    } else if (alignment < 0) {
        newLeftNode = newSecondSymbol.map(symbol -> addProjection(context, leftNode, symbol, secondArgument)).orElse(leftNode);
        newRightNode = newFirstSymbol.map(symbol -> addProjection(context, rightNode, symbol, firstArgument)).orElse(rightNode);
    } else {
        return Result.empty();
    }
    Expression newFirstArgument = toExpression(newFirstSymbol, firstArgument);
    Expression newSecondArgument = toExpression(newSecondSymbol, secondArgument);
    Optional<Symbol> leftPartitionSymbol = Optional.empty();
    Optional<Symbol> rightPartitionSymbol = Optional.empty();
    if (kdbTree.isPresent()) {
        leftPartitionSymbol = Optional.of(context.getSymbolAllocator().newSymbol("pid", INTEGER));
        rightPartitionSymbol = Optional.of(context.getSymbolAllocator().newSymbol("pid", INTEGER));
        if (alignment > 0) {
            newLeftNode = addPartitioningNodes(plannerContext, context, newLeftNode, leftPartitionSymbol.get(), kdbTree.get(), newFirstArgument, Optional.empty());
            newRightNode = addPartitioningNodes(plannerContext, context, newRightNode, rightPartitionSymbol.get(), kdbTree.get(), newSecondArgument, radius);
        } else {
            newLeftNode = addPartitioningNodes(plannerContext, context, newLeftNode, leftPartitionSymbol.get(), kdbTree.get(), newSecondArgument, Optional.empty());
            newRightNode = addPartitioningNodes(plannerContext, context, newRightNode, rightPartitionSymbol.get(), kdbTree.get(), newFirstArgument, radius);
        }
    }
    Expression newSpatialFunction = FunctionCallBuilder.resolve(context.getSession(), plannerContext.getMetadata()).setName(spatialFunction.getName()).addArgument(GEOMETRY_TYPE_SIGNATURE, newFirstArgument).addArgument(GEOMETRY_TYPE_SIGNATURE, newSecondArgument).build();
    Expression newFilter = replaceExpression(filter, ImmutableMap.of(spatialFunction, newSpatialFunction));
    return Result.ofPlanNode(new SpatialJoinNode(nodeId, SpatialJoinNode.Type.fromJoinNodeType(joinNode.getType()), newLeftNode, newRightNode, outputSymbols, newFilter, leftPartitionSymbol, rightPartitionSymbol, kdbTree.map(KdbTreeUtils::toJson)));
}
Also used : EMPTY(io.trino.spi.connector.DynamicFilter.EMPTY) SpatialJoinUtils.extractSupportedSpatialComparisons(io.trino.util.SpatialJoinUtils.extractSupportedSpatialComparisons) SymbolsExtractor.extractUnique(io.trino.sql.planner.SymbolsExtractor.extractUnique) SplitBatch(io.trino.split.SplitSource.SplitBatch) SplitManager(io.trino.split.SplitManager) SystemSessionProperties.getSpatialPartitioningTableName(io.trino.SystemSessionProperties.getSpatialPartitioningTableName) FilterNode(io.trino.sql.planner.plan.FilterNode) PlanNode(io.trino.sql.planner.plan.PlanNode) LEFT(io.trino.sql.planner.plan.JoinNode.Type.LEFT) PlanNodeId(io.trino.sql.planner.plan.PlanNodeId) Map(java.util.Map) SpatialJoinNode(io.trino.sql.planner.plan.SpatialJoinNode) ConnectorPageSource(io.trino.spi.connector.ConnectorPageSource) JoinNode(io.trino.sql.planner.plan.JoinNode) INTEGER(io.trino.spi.type.IntegerType.INTEGER) Splitter(com.google.common.base.Splitter) FunctionCall(io.trino.sql.tree.FunctionCall) Patterns.join(io.trino.sql.planner.plan.Patterns.join) TypeSignature(io.trino.spi.type.TypeSignature) ImmutableSet(com.google.common.collect.ImmutableSet) ImmutableMap(com.google.common.collect.ImmutableMap) Collection(java.util.Collection) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) TypeSignatureTranslator.toSqlType(io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType) KdbTree(io.trino.geospatial.KdbTree) Assignments(io.trino.sql.planner.plan.Assignments) Set(java.util.Set) TrinoException(io.trino.spi.TrinoException) ArrayType(io.trino.spi.type.ArrayType) SplitSource(io.trino.split.SplitSource) Context(io.trino.sql.planner.iterative.Rule.Context) ComparisonExpression(io.trino.sql.tree.ComparisonExpression) String.format(java.lang.String.format) Constraint.alwaysTrue(io.trino.spi.connector.Constraint.alwaysTrue) LESS_THAN_OR_EQUAL(io.trino.sql.tree.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL) UncheckedIOException(java.io.UncheckedIOException) List(java.util.List) INVALID_SPATIAL_PARTITIONING(io.trino.spi.StandardErrorCode.INVALID_SPATIAL_PARTITIONING) NOT_PARTITIONED(io.trino.spi.connector.NotPartitionedPartitionHandle.NOT_PARTITIONED) Pattern(io.trino.matching.Pattern) SymbolReference(io.trino.sql.tree.SymbolReference) Split(io.trino.metadata.Split) DynamicFilter(io.trino.spi.connector.DynamicFilter) Optional(java.util.Optional) ExpressionNodeInliner.replaceExpression(io.trino.sql.planner.ExpressionNodeInliner.replaceExpression) Expression(io.trino.sql.tree.Expression) Session(io.trino.Session) PlannerContext(io.trino.sql.PlannerContext) Iterables(com.google.common.collect.Iterables) INNER(io.trino.sql.planner.plan.JoinNode.Type.INNER) Type(io.trino.spi.type.Type) Patterns.filter(io.trino.sql.planner.plan.Patterns.filter) Page(io.trino.spi.Page) Capture.newCapture(io.trino.matching.Capture.newCapture) Cast(io.trino.sql.tree.Cast) KdbTreeUtils(io.trino.geospatial.KdbTreeUtils) VARCHAR(io.trino.spi.type.VarcharType.VARCHAR) FunctionCallBuilder(io.trino.sql.planner.FunctionCallBuilder) ImmutableList(com.google.common.collect.ImmutableList) Verify.verify(com.google.common.base.Verify.verify) UNGROUPED_SCHEDULING(io.trino.spi.connector.ConnectorSplitManager.SplitSchedulingStrategy.UNGROUPED_SCHEDULING) Objects.requireNonNull(java.util.Objects.requireNonNull) Result(io.trino.sql.planner.iterative.Rule.Result) ColumnHandle(io.trino.spi.connector.ColumnHandle) Rule(io.trino.sql.planner.iterative.Rule) Lifespan(io.trino.execution.Lifespan) ProjectNode(io.trino.sql.planner.plan.ProjectNode) Symbol(io.trino.sql.planner.Symbol) StringLiteral(io.trino.sql.tree.StringLiteral) SystemSessionProperties.isSpatialJoinEnabled(io.trino.SystemSessionProperties.isSpatialJoinEnabled) IOException(java.io.IOException) PageSourceManager(io.trino.split.PageSourceManager) LESS_THAN(io.trino.sql.tree.ComparisonExpression.Operator.LESS_THAN) MoreFutures.getFutureValue(io.airlift.concurrent.MoreFutures.getFutureValue) UnnestNode(io.trino.sql.planner.plan.UnnestNode) Capture(io.trino.matching.Capture) QualifiedName(io.trino.sql.tree.QualifiedName) DOUBLE(io.trino.spi.type.DoubleType.DOUBLE) TableHandle(io.trino.metadata.TableHandle) TypeAnalyzer(io.trino.sql.planner.TypeAnalyzer) QualifiedObjectName(io.trino.metadata.QualifiedObjectName) Patterns.source(io.trino.sql.planner.plan.Patterns.source) Captures(io.trino.matching.Captures) Metadata(io.trino.metadata.Metadata) VisibleForTesting(com.google.common.annotations.VisibleForTesting) TypeManager(io.trino.spi.type.TypeManager) SpatialJoinUtils.extractSupportedSpatialFunctions(io.trino.util.SpatialJoinUtils.extractSupportedSpatialFunctions) KdbTreeUtils(io.trino.geospatial.KdbTreeUtils) KdbTree(io.trino.geospatial.KdbTree) Symbol(io.trino.sql.planner.Symbol) SpatialJoinNode(io.trino.sql.planner.plan.SpatialJoinNode) TypeSignatureTranslator.toSqlType(io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType) ArrayType(io.trino.spi.type.ArrayType) Type(io.trino.spi.type.Type) PlanNode(io.trino.sql.planner.plan.PlanNode) ComparisonExpression(io.trino.sql.tree.ComparisonExpression) ExpressionNodeInliner.replaceExpression(io.trino.sql.planner.ExpressionNodeInliner.replaceExpression) Expression(io.trino.sql.tree.Expression)

Example 3 with Expression

use of io.trino.sql.tree.Expression in project trino by trinodb.

the class ImplementFilteredAggregations method apply.

@Override
public Result apply(AggregationNode aggregationNode, Captures captures, Context context) {
    Assignments.Builder newAssignments = Assignments.builder();
    ImmutableMap.Builder<Symbol, Aggregation> aggregations = ImmutableMap.builder();
    ImmutableList.Builder<Expression> maskSymbols = ImmutableList.builder();
    boolean aggregateWithoutFilterOrMaskPresent = false;
    for (Map.Entry<Symbol, Aggregation> entry : aggregationNode.getAggregations().entrySet()) {
        Symbol output = entry.getKey();
        // strip the filters
        Aggregation aggregation = entry.getValue();
        Optional<Symbol> mask = aggregation.getMask();
        if (aggregation.getFilter().isPresent()) {
            Symbol filter = aggregation.getFilter().get();
            if (mask.isPresent()) {
                Symbol newMask = context.getSymbolAllocator().newSymbol("mask", BOOLEAN);
                Expression expression = and(mask.get().toSymbolReference(), filter.toSymbolReference());
                newAssignments.put(newMask, expression);
                mask = Optional.of(newMask);
                maskSymbols.add(newMask.toSymbolReference());
            } else {
                mask = Optional.of(filter);
                maskSymbols.add(filter.toSymbolReference());
            }
        } else if (mask.isPresent()) {
            maskSymbols.add(mask.get().toSymbolReference());
        } else {
            aggregateWithoutFilterOrMaskPresent = true;
        }
        aggregations.put(output, new Aggregation(aggregation.getResolvedFunction(), aggregation.getArguments(), aggregation.isDistinct(), Optional.empty(), aggregation.getOrderingScheme(), mask));
    }
    Expression predicate = TRUE_LITERAL;
    if (!aggregationNode.hasNonEmptyGroupingSet() && !aggregateWithoutFilterOrMaskPresent) {
        predicate = combineDisjunctsWithDefault(metadata, maskSymbols.build(), TRUE_LITERAL);
    }
    // identity projection for all existing inputs
    newAssignments.putIdentities(aggregationNode.getSource().getOutputSymbols());
    return Result.ofPlanNode(new AggregationNode(context.getIdAllocator().getNextId(), new FilterNode(context.getIdAllocator().getNextId(), new ProjectNode(context.getIdAllocator().getNextId(), aggregationNode.getSource(), newAssignments.build()), predicate), aggregations.buildOrThrow(), aggregationNode.getGroupingSets(), ImmutableList.of(), aggregationNode.getStep(), aggregationNode.getHashSymbol(), aggregationNode.getGroupIdSymbol()));
}
Also used : Symbol(io.trino.sql.planner.Symbol) ImmutableList(com.google.common.collect.ImmutableList) FilterNode(io.trino.sql.planner.plan.FilterNode) Assignments(io.trino.sql.planner.plan.Assignments) AggregationNode(io.trino.sql.planner.plan.AggregationNode) ImmutableMap(com.google.common.collect.ImmutableMap) Aggregation(io.trino.sql.planner.plan.AggregationNode.Aggregation) Expression(io.trino.sql.tree.Expression) ProjectNode(io.trino.sql.planner.plan.ProjectNode) ImmutableMap(com.google.common.collect.ImmutableMap) Map(java.util.Map)

Example 4 with Expression

use of io.trino.sql.tree.Expression in project trino by trinodb.

the class ImplementIntersectDistinctAsUnion method apply.

@Override
public Result apply(IntersectNode node, Captures captures, Context context) {
    SetOperationNodeTranslator translator = new SetOperationNodeTranslator(context.getSession(), metadata, context.getSymbolAllocator(), context.getIdAllocator());
    SetOperationNodeTranslator.TranslationResult result = translator.makeSetContainmentPlanForDistinct(node);
    // intersect predicate: the row must be present in every source
    Expression predicate = and(result.getCountSymbols().stream().map(symbol -> new ComparisonExpression(GREATER_THAN_OR_EQUAL, symbol.toSymbolReference(), new GenericLiteral("BIGINT", "1"))).collect(toImmutableList()));
    return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), new FilterNode(context.getIdAllocator().getNextId(), result.getPlanNode(), predicate), Assignments.identity(node.getOutputSymbols())));
}
Also used : ComparisonExpression(io.trino.sql.tree.ComparisonExpression) ComparisonExpression(io.trino.sql.tree.ComparisonExpression) Expression(io.trino.sql.tree.Expression) FilterNode(io.trino.sql.planner.plan.FilterNode) ProjectNode(io.trino.sql.planner.plan.ProjectNode) GenericLiteral(io.trino.sql.tree.GenericLiteral)

Example 5 with Expression

use of io.trino.sql.tree.Expression in project trino by trinodb.

the class InlineProjectIntoFilter method apply.

@Override
public Result apply(FilterNode node, Captures captures, Context context) {
    ProjectNode projectNode = captures.get(PROJECTION);
    List<Expression> filterConjuncts = extractConjuncts(node.getPredicate());
    Map<Boolean, List<Expression>> conjuncts = filterConjuncts.stream().collect(partitioningBy(SymbolReference.class::isInstance));
    List<Expression> simpleConjuncts = conjuncts.get(true);
    List<Expression> complexConjuncts = conjuncts.get(false);
    // Do not inline expression if the symbol is used multiple times in simple conjuncts.
    Set<Expression> simpleUniqueConjuncts = simpleConjuncts.stream().collect(Collectors.groupingBy(identity(), counting())).entrySet().stream().filter(entry -> entry.getValue() == 1).map(Map.Entry::getKey).collect(toImmutableSet());
    Set<Expression> complexConjunctSymbols = SymbolsExtractor.extractUnique(complexConjuncts).stream().map(Symbol::toSymbolReference).collect(toImmutableSet());
    // Do not inline expression if the symbol is used in complex conjuncts.
    Set<Expression> simpleConjunctsToInline = Sets.difference(simpleUniqueConjuncts, complexConjunctSymbols);
    if (simpleConjunctsToInline.isEmpty()) {
        return Result.empty();
    }
    ImmutableList.Builder<Expression> newConjuncts = ImmutableList.builder();
    Assignments.Builder newAssignments = Assignments.builder();
    Assignments.Builder postFilterAssignmentsBuilder = Assignments.builder();
    for (Expression conjunct : filterConjuncts) {
        if (simpleConjunctsToInline.contains(conjunct)) {
            Expression expression = projectNode.getAssignments().get(Symbol.from(conjunct));
            if (expression == null || expression instanceof SymbolReference) {
                // expression == null -> The symbol is not produced by the underlying projection (i.e. it is a correlation symbol).
                // expression instanceof SymbolReference -> Do not inline trivial projections.
                newConjuncts.add(conjunct);
            } else {
                newConjuncts.add(expression);
                newAssignments.putIdentities(SymbolsExtractor.extractUnique(expression));
                postFilterAssignmentsBuilder.put(Symbol.from(conjunct), TRUE_LITERAL);
            }
        } else {
            newConjuncts.add(conjunct);
        }
    }
    Assignments postFilterAssignments = postFilterAssignmentsBuilder.build();
    if (postFilterAssignments.isEmpty()) {
        return Result.empty();
    }
    Set<Symbol> postFilterSymbols = postFilterAssignments.getSymbols();
    // Remove inlined expressions from the underlying projection.
    newAssignments.putAll(projectNode.getAssignments().filter(symbol -> !postFilterSymbols.contains(symbol)));
    Map<Symbol, Expression> outputAssignments = new HashMap<>();
    outputAssignments.putAll(Assignments.identity(node.getOutputSymbols()).getMap());
    // Restore inlined symbols.
    outputAssignments.putAll(postFilterAssignments.getMap());
    return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), new FilterNode(node.getId(), new ProjectNode(projectNode.getId(), projectNode.getSource(), newAssignments.build()), combineConjuncts(metadata, newConjuncts.build())), Assignments.builder().putAll(outputAssignments).build()));
}
Also used : Collectors.partitioningBy(java.util.stream.Collectors.partitioningBy) Patterns.filter(io.trino.sql.planner.plan.Patterns.filter) Collectors.counting(java.util.stream.Collectors.counting) HashMap(java.util.HashMap) Capture.newCapture(io.trino.matching.Capture.newCapture) FilterNode(io.trino.sql.planner.plan.FilterNode) ImmutableList(com.google.common.collect.ImmutableList) Map(java.util.Map) ImmutableSet.toImmutableSet(com.google.common.collect.ImmutableSet.toImmutableSet) Rule(io.trino.sql.planner.iterative.Rule) SymbolsExtractor(io.trino.sql.planner.SymbolsExtractor) ProjectNode(io.trino.sql.planner.plan.ProjectNode) Symbol(io.trino.sql.planner.Symbol) Assignments(io.trino.sql.planner.plan.Assignments) Set(java.util.Set) TRUE_LITERAL(io.trino.sql.tree.BooleanLiteral.TRUE_LITERAL) Collectors(java.util.stream.Collectors) Sets(com.google.common.collect.Sets) Capture(io.trino.matching.Capture) List(java.util.List) Pattern(io.trino.matching.Pattern) Patterns.source(io.trino.sql.planner.plan.Patterns.source) SymbolReference(io.trino.sql.tree.SymbolReference) Captures(io.trino.matching.Captures) Function.identity(java.util.function.Function.identity) Metadata(io.trino.metadata.Metadata) Expression(io.trino.sql.tree.Expression) ExpressionUtils.extractConjuncts(io.trino.sql.ExpressionUtils.extractConjuncts) Patterns.project(io.trino.sql.planner.plan.Patterns.project) ExpressionUtils.combineConjuncts(io.trino.sql.ExpressionUtils.combineConjuncts) HashMap(java.util.HashMap) ImmutableList(com.google.common.collect.ImmutableList) SymbolReference(io.trino.sql.tree.SymbolReference) Symbol(io.trino.sql.planner.Symbol) FilterNode(io.trino.sql.planner.plan.FilterNode) Assignments(io.trino.sql.planner.plan.Assignments) Expression(io.trino.sql.tree.Expression) ProjectNode(io.trino.sql.planner.plan.ProjectNode) ImmutableList(com.google.common.collect.ImmutableList) List(java.util.List) HashMap(java.util.HashMap) Map(java.util.Map)

Aggregations

Expression (io.trino.sql.tree.Expression)257 ComparisonExpression (io.trino.sql.tree.ComparisonExpression)111 ImmutableList (com.google.common.collect.ImmutableList)88 Symbol (io.trino.sql.planner.Symbol)77 ProjectNode (io.trino.sql.planner.plan.ProjectNode)65 PlanNode (io.trino.sql.planner.plan.PlanNode)63 Map (java.util.Map)63 NotExpression (io.trino.sql.tree.NotExpression)62 Test (org.testng.annotations.Test)58 Objects.requireNonNull (java.util.Objects.requireNonNull)57 Assignments (io.trino.sql.planner.plan.Assignments)55 InListExpression (io.trino.sql.tree.InListExpression)55 ImmutableList.toImmutableList (com.google.common.collect.ImmutableList.toImmutableList)54 SymbolReference (io.trino.sql.tree.SymbolReference)53 Type (io.trino.spi.type.Type)52 List (java.util.List)52 Set (java.util.Set)52 SubscriptExpression (io.trino.sql.tree.SubscriptExpression)48 ImmutableMap (com.google.common.collect.ImmutableMap)47 Optional (java.util.Optional)47