Search in sources :

Example 1 with ExpressionRewriter

use of io.trino.sql.tree.ExpressionRewriter in project trino by trinodb.

the class SymbolMapper method map.

private ExpressionAndValuePointers map(ExpressionAndValuePointers expressionAndValuePointers) {
    // Map only the input symbols of ValuePointers. These are the symbols produced by the source node.
    // Other symbols present in the ExpressionAndValuePointers structure are synthetic unique symbols
    // with no outer usage or dependencies.
    ImmutableList.Builder<ValuePointer> newValuePointers = ImmutableList.builder();
    for (ValuePointer valuePointer : expressionAndValuePointers.getValuePointers()) {
        if (valuePointer instanceof ScalarValuePointer) {
            ScalarValuePointer scalarValuePointer = (ScalarValuePointer) valuePointer;
            Symbol inputSymbol = scalarValuePointer.getInputSymbol();
            if (expressionAndValuePointers.getClassifierSymbols().contains(inputSymbol) || expressionAndValuePointers.getMatchNumberSymbols().contains(inputSymbol)) {
                newValuePointers.add(scalarValuePointer);
            } else {
                newValuePointers.add(new ScalarValuePointer(scalarValuePointer.getLogicalIndexPointer(), map(inputSymbol)));
            }
        } else {
            AggregationValuePointer aggregationValuePointer = (AggregationValuePointer) valuePointer;
            List<Expression> newArguments = aggregationValuePointer.getArguments().stream().map(expression -> ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter<Void>() {

                @Override
                public Expression rewriteSymbolReference(SymbolReference node, Void context, ExpressionTreeRewriter<Void> treeRewriter) {
                    if (Symbol.from(node).equals(aggregationValuePointer.getClassifierSymbol()) || Symbol.from(node).equals(aggregationValuePointer.getMatchNumberSymbol())) {
                        return node;
                    }
                    return map(node);
                }
            }, expression)).collect(toImmutableList());
            newValuePointers.add(new AggregationValuePointer(aggregationValuePointer.getFunction(), aggregationValuePointer.getSetDescriptor(), newArguments, aggregationValuePointer.getClassifierSymbol(), aggregationValuePointer.getMatchNumberSymbol()));
        }
    }
    return new ExpressionAndValuePointers(expressionAndValuePointers.getExpression(), expressionAndValuePointers.getLayout(), newValuePointers.build(), expressionAndValuePointers.getClassifierSymbols(), expressionAndValuePointers.getMatchNumberSymbols());
}
Also used : TableExecuteNode(io.trino.sql.planner.plan.TableExecuteNode) Aggregation(io.trino.sql.planner.plan.AggregationNode.Aggregation) SymbolAllocator(io.trino.sql.planner.SymbolAllocator) LimitNode(io.trino.sql.planner.plan.LimitNode) TableWriterNode(io.trino.sql.planner.plan.TableWriterNode) Measure(io.trino.sql.planner.plan.PatternRecognitionNode.Measure) HashMap(java.util.HashMap) StatisticAggregations(io.trino.sql.planner.plan.StatisticAggregations) AggregationValuePointer(io.trino.sql.planner.rowpattern.AggregationValuePointer) PartitioningScheme(io.trino.sql.planner.PartitioningScheme) Function(java.util.function.Function) PlanNode(io.trino.sql.planner.plan.PlanNode) HashSet(java.util.HashSet) ImmutableList(com.google.common.collect.ImmutableList) PlanNodeId(io.trino.sql.planner.plan.PlanNodeId) RowNumberNode(io.trino.sql.planner.plan.RowNumberNode) Map(java.util.Map) Objects.requireNonNull(java.util.Objects.requireNonNull) AggregationNode(io.trino.sql.planner.plan.AggregationNode) ImmutableSet.toImmutableSet(com.google.common.collect.ImmutableSet.toImmutableSet) ExpressionRewriter(io.trino.sql.tree.ExpressionRewriter) ExpressionAndValuePointers(io.trino.sql.planner.rowpattern.LogicalIndexExtractor.ExpressionAndValuePointers) Symbol(io.trino.sql.planner.Symbol) ImmutableMap(com.google.common.collect.ImmutableMap) IrLabel(io.trino.sql.planner.rowpattern.ir.IrLabel) StatisticsWriterNode(io.trino.sql.planner.plan.StatisticsWriterNode) ExpressionTreeRewriter(io.trino.sql.tree.ExpressionTreeRewriter) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) TopNNode(io.trino.sql.planner.plan.TopNNode) Set(java.util.Set) ScalarValuePointer(io.trino.sql.planner.rowpattern.ScalarValuePointer) OrderingScheme(io.trino.sql.planner.OrderingScheme) PatternRecognitionNode(io.trino.sql.planner.plan.PatternRecognitionNode) SortOrder(io.trino.spi.connector.SortOrder) List(java.util.List) ImmutableMap.toImmutableMap(com.google.common.collect.ImmutableMap.toImmutableMap) AggregationNode.groupingSets(io.trino.sql.planner.plan.AggregationNode.groupingSets) SymbolReference(io.trino.sql.tree.SymbolReference) GroupIdNode(io.trino.sql.planner.plan.GroupIdNode) DistinctLimitNode(io.trino.sql.planner.plan.DistinctLimitNode) ValuePointer(io.trino.sql.planner.rowpattern.ValuePointer) Entry(java.util.Map.Entry) TableFinishNode(io.trino.sql.planner.plan.TableFinishNode) TopNRankingNode(io.trino.sql.planner.plan.TopNRankingNode) Expression(io.trino.sql.tree.Expression) WindowNode(io.trino.sql.planner.plan.WindowNode) ScalarValuePointer(io.trino.sql.planner.rowpattern.ScalarValuePointer) ImmutableList(com.google.common.collect.ImmutableList) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) Symbol(io.trino.sql.planner.Symbol) SymbolReference(io.trino.sql.tree.SymbolReference) AggregationValuePointer(io.trino.sql.planner.rowpattern.AggregationValuePointer) ScalarValuePointer(io.trino.sql.planner.rowpattern.ScalarValuePointer) ValuePointer(io.trino.sql.planner.rowpattern.ValuePointer) AggregationValuePointer(io.trino.sql.planner.rowpattern.AggregationValuePointer) Expression(io.trino.sql.tree.Expression) ExpressionAndValuePointers(io.trino.sql.planner.rowpattern.LogicalIndexExtractor.ExpressionAndValuePointers)

Aggregations

ImmutableList (com.google.common.collect.ImmutableList)1 ImmutableList.toImmutableList (com.google.common.collect.ImmutableList.toImmutableList)1 ImmutableMap (com.google.common.collect.ImmutableMap)1 ImmutableMap.toImmutableMap (com.google.common.collect.ImmutableMap.toImmutableMap)1 ImmutableSet.toImmutableSet (com.google.common.collect.ImmutableSet.toImmutableSet)1 SortOrder (io.trino.spi.connector.SortOrder)1 OrderingScheme (io.trino.sql.planner.OrderingScheme)1 PartitioningScheme (io.trino.sql.planner.PartitioningScheme)1 Symbol (io.trino.sql.planner.Symbol)1 SymbolAllocator (io.trino.sql.planner.SymbolAllocator)1 AggregationNode (io.trino.sql.planner.plan.AggregationNode)1 Aggregation (io.trino.sql.planner.plan.AggregationNode.Aggregation)1 AggregationNode.groupingSets (io.trino.sql.planner.plan.AggregationNode.groupingSets)1 DistinctLimitNode (io.trino.sql.planner.plan.DistinctLimitNode)1 GroupIdNode (io.trino.sql.planner.plan.GroupIdNode)1 LimitNode (io.trino.sql.planner.plan.LimitNode)1 PatternRecognitionNode (io.trino.sql.planner.plan.PatternRecognitionNode)1 Measure (io.trino.sql.planner.plan.PatternRecognitionNode.Measure)1 PlanNode (io.trino.sql.planner.plan.PlanNode)1 PlanNodeId (io.trino.sql.planner.plan.PlanNodeId)1