Search in sources :

Example 1 with ScalarValuePointer

use of io.trino.sql.planner.rowpattern.ScalarValuePointer in project trino by trinodb.

the class TestPatternRecognitionNodeSerialization method testExpressionAndValuePointersRoundtrip.

@Test
public void testExpressionAndValuePointersRoundtrip() {
    ObjectMapperProvider provider = new ObjectMapperProvider();
    provider.setJsonSerializers(ImmutableMap.of(Expression.class, new ExpressionSerialization.ExpressionSerializer()));
    provider.setJsonDeserializers(ImmutableMap.of(Expression.class, new ExpressionSerialization.ExpressionDeserializer(new SqlParser())));
    JsonCodec<ExpressionAndValuePointers> codec = new JsonCodecFactory(provider).jsonCodec(ExpressionAndValuePointers.class);
    assertJsonRoundTrip(codec, new ExpressionAndValuePointers(new NullLiteral(), ImmutableList.of(), ImmutableList.of(), ImmutableSet.of(), ImmutableSet.of()));
    assertJsonRoundTrip(codec, new ExpressionAndValuePointers(new IfExpression(new ComparisonExpression(GREATER_THAN, new SymbolReference("classifier"), new SymbolReference("x")), new FunctionCall(QualifiedName.of("rand"), ImmutableList.of()), new ArithmeticUnaryExpression(MINUS, new SymbolReference("match_number"))), ImmutableList.of(new Symbol("classifier"), new Symbol("x"), new Symbol("match_number")), ImmutableList.of(new ScalarValuePointer(new LogicalIndexPointer(ImmutableSet.of(new IrLabel("A"), new IrLabel("B")), false, true, 1, -1), new Symbol("input_symbol_a")), new ScalarValuePointer(new LogicalIndexPointer(ImmutableSet.of(new IrLabel("B")), true, false, 2, 1), new Symbol("input_symbol_a")), new ScalarValuePointer(new LogicalIndexPointer(ImmutableSet.of(), true, true, 0, 0), new Symbol("input_symbol_a"))), ImmutableSet.of(new Symbol("classifier")), ImmutableSet.of(new Symbol("match_number"))));
}
Also used : IrLabel(io.trino.sql.planner.rowpattern.ir.IrLabel) IfExpression(io.trino.sql.tree.IfExpression) ScalarValuePointer(io.trino.sql.planner.rowpattern.ScalarValuePointer) SymbolReference(io.trino.sql.tree.SymbolReference) Symbol(io.trino.sql.planner.Symbol) SqlParser(io.trino.sql.parser.SqlParser) ObjectMapperProvider(io.airlift.json.ObjectMapperProvider) ComparisonExpression(io.trino.sql.tree.ComparisonExpression) ArithmeticUnaryExpression(io.trino.sql.tree.ArithmeticUnaryExpression) ComparisonExpression(io.trino.sql.tree.ComparisonExpression) IfExpression(io.trino.sql.tree.IfExpression) Expression(io.trino.sql.tree.Expression) ExpressionAndValuePointers(io.trino.sql.planner.rowpattern.LogicalIndexExtractor.ExpressionAndValuePointers) ArithmeticUnaryExpression(io.trino.sql.tree.ArithmeticUnaryExpression) LogicalIndexPointer(io.trino.sql.planner.rowpattern.LogicalIndexPointer) FunctionCall(io.trino.sql.tree.FunctionCall) JsonCodecFactory(io.airlift.json.JsonCodecFactory) NullLiteral(io.trino.sql.tree.NullLiteral) Test(org.testng.annotations.Test)

Example 2 with ScalarValuePointer

use of io.trino.sql.planner.rowpattern.ScalarValuePointer in project trino by trinodb.

the class TestPatternRecognitionNodeSerialization method testMeasureRoundtrip.

@Test
public void testMeasureRoundtrip() {
    ObjectMapperProvider provider = new ObjectMapperProvider();
    provider.setJsonSerializers(ImmutableMap.of(Expression.class, new ExpressionSerialization.ExpressionSerializer()));
    provider.setJsonDeserializers(ImmutableMap.of(Expression.class, new ExpressionSerialization.ExpressionDeserializer(new SqlParser()), Type.class, new TypeDeserializer(TESTING_TYPE_MANAGER)));
    JsonCodec<Measure> codec = new JsonCodecFactory(provider).jsonCodec(Measure.class);
    assertJsonRoundTrip(codec, new Measure(new ExpressionAndValuePointers(new NullLiteral(), ImmutableList.of(), ImmutableList.of(), ImmutableSet.of(), ImmutableSet.of()), BOOLEAN));
    assertJsonRoundTrip(codec, new Measure(new ExpressionAndValuePointers(new IfExpression(new ComparisonExpression(GREATER_THAN, new SymbolReference("match_number"), new SymbolReference("x")), new GenericLiteral("BIGINT", "10"), new ArithmeticUnaryExpression(MINUS, new SymbolReference("y"))), ImmutableList.of(new Symbol("match_number"), new Symbol("x"), new Symbol("y")), ImmutableList.of(new ScalarValuePointer(new LogicalIndexPointer(ImmutableSet.of(), true, true, 0, 0), new Symbol("input_symbol_a")), new ScalarValuePointer(new LogicalIndexPointer(ImmutableSet.of(new IrLabel("A")), false, true, 1, -1), new Symbol("input_symbol_a")), new ScalarValuePointer(new LogicalIndexPointer(ImmutableSet.of(new IrLabel("B")), false, true, 1, -1), new Symbol("input_symbol_b"))), ImmutableSet.of(), ImmutableSet.of(new Symbol("match_number"))), BIGINT));
}
Also used : IrLabel(io.trino.sql.planner.rowpattern.ir.IrLabel) IfExpression(io.trino.sql.tree.IfExpression) ScalarValuePointer(io.trino.sql.planner.rowpattern.ScalarValuePointer) SymbolReference(io.trino.sql.tree.SymbolReference) Symbol(io.trino.sql.planner.Symbol) SqlParser(io.trino.sql.parser.SqlParser) ObjectMapperProvider(io.airlift.json.ObjectMapperProvider) GenericLiteral(io.trino.sql.tree.GenericLiteral) ComparisonExpression(io.trino.sql.tree.ComparisonExpression) Type(io.trino.spi.type.Type) ArithmeticUnaryExpression(io.trino.sql.tree.ArithmeticUnaryExpression) ComparisonExpression(io.trino.sql.tree.ComparisonExpression) IfExpression(io.trino.sql.tree.IfExpression) Expression(io.trino.sql.tree.Expression) ExpressionAndValuePointers(io.trino.sql.planner.rowpattern.LogicalIndexExtractor.ExpressionAndValuePointers) Measure(io.trino.sql.planner.plan.PatternRecognitionNode.Measure) ArithmeticUnaryExpression(io.trino.sql.tree.ArithmeticUnaryExpression) LogicalIndexPointer(io.trino.sql.planner.rowpattern.LogicalIndexPointer) TypeDeserializer(io.trino.type.TypeDeserializer) JsonCodecFactory(io.airlift.json.JsonCodecFactory) NullLiteral(io.trino.sql.tree.NullLiteral) Test(org.testng.annotations.Test)

Example 3 with ScalarValuePointer

use of io.trino.sql.planner.rowpattern.ScalarValuePointer in project trino by trinodb.

the class TestPatternRecognitionNodeSerialization method testScalarValuePointerRoundtrip.

@Test
public void testScalarValuePointerRoundtrip() {
    JsonCodec<ValuePointer> codec = new JsonCodecFactory(new ObjectMapperProvider()).jsonCodec(ValuePointer.class);
    assertJsonRoundTrip(codec, new ScalarValuePointer(new LogicalIndexPointer(ImmutableSet.of(), false, false, 5, 5), new Symbol("input_symbol")));
    assertJsonRoundTrip(codec, new ScalarValuePointer(new LogicalIndexPointer(ImmutableSet.of(new IrLabel("A"), new IrLabel("B")), true, true, 1, -1), new Symbol("input_symbol")));
}
Also used : IrLabel(io.trino.sql.planner.rowpattern.ir.IrLabel) ScalarValuePointer(io.trino.sql.planner.rowpattern.ScalarValuePointer) Symbol(io.trino.sql.planner.Symbol) AggregationValuePointer(io.trino.sql.planner.rowpattern.AggregationValuePointer) ScalarValuePointer(io.trino.sql.planner.rowpattern.ScalarValuePointer) ValuePointer(io.trino.sql.planner.rowpattern.ValuePointer) LogicalIndexPointer(io.trino.sql.planner.rowpattern.LogicalIndexPointer) JsonCodecFactory(io.airlift.json.JsonCodecFactory) ObjectMapperProvider(io.airlift.json.ObjectMapperProvider) Test(org.testng.annotations.Test)

Example 4 with ScalarValuePointer

use of io.trino.sql.planner.rowpattern.ScalarValuePointer in project trino by trinodb.

the class SymbolMapper method map.

private ExpressionAndValuePointers map(ExpressionAndValuePointers expressionAndValuePointers) {
    // Map only the input symbols of ValuePointers. These are the symbols produced by the source node.
    // Other symbols present in the ExpressionAndValuePointers structure are synthetic unique symbols
    // with no outer usage or dependencies.
    ImmutableList.Builder<ValuePointer> newValuePointers = ImmutableList.builder();
    for (ValuePointer valuePointer : expressionAndValuePointers.getValuePointers()) {
        if (valuePointer instanceof ScalarValuePointer) {
            ScalarValuePointer scalarValuePointer = (ScalarValuePointer) valuePointer;
            Symbol inputSymbol = scalarValuePointer.getInputSymbol();
            if (expressionAndValuePointers.getClassifierSymbols().contains(inputSymbol) || expressionAndValuePointers.getMatchNumberSymbols().contains(inputSymbol)) {
                newValuePointers.add(scalarValuePointer);
            } else {
                newValuePointers.add(new ScalarValuePointer(scalarValuePointer.getLogicalIndexPointer(), map(inputSymbol)));
            }
        } else {
            AggregationValuePointer aggregationValuePointer = (AggregationValuePointer) valuePointer;
            List<Expression> newArguments = aggregationValuePointer.getArguments().stream().map(expression -> ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter<Void>() {

                @Override
                public Expression rewriteSymbolReference(SymbolReference node, Void context, ExpressionTreeRewriter<Void> treeRewriter) {
                    if (Symbol.from(node).equals(aggregationValuePointer.getClassifierSymbol()) || Symbol.from(node).equals(aggregationValuePointer.getMatchNumberSymbol())) {
                        return node;
                    }
                    return map(node);
                }
            }, expression)).collect(toImmutableList());
            newValuePointers.add(new AggregationValuePointer(aggregationValuePointer.getFunction(), aggregationValuePointer.getSetDescriptor(), newArguments, aggregationValuePointer.getClassifierSymbol(), aggregationValuePointer.getMatchNumberSymbol()));
        }
    }
    return new ExpressionAndValuePointers(expressionAndValuePointers.getExpression(), expressionAndValuePointers.getLayout(), newValuePointers.build(), expressionAndValuePointers.getClassifierSymbols(), expressionAndValuePointers.getMatchNumberSymbols());
}
Also used : TableExecuteNode(io.trino.sql.planner.plan.TableExecuteNode) Aggregation(io.trino.sql.planner.plan.AggregationNode.Aggregation) SymbolAllocator(io.trino.sql.planner.SymbolAllocator) LimitNode(io.trino.sql.planner.plan.LimitNode) TableWriterNode(io.trino.sql.planner.plan.TableWriterNode) Measure(io.trino.sql.planner.plan.PatternRecognitionNode.Measure) HashMap(java.util.HashMap) StatisticAggregations(io.trino.sql.planner.plan.StatisticAggregations) AggregationValuePointer(io.trino.sql.planner.rowpattern.AggregationValuePointer) PartitioningScheme(io.trino.sql.planner.PartitioningScheme) Function(java.util.function.Function) PlanNode(io.trino.sql.planner.plan.PlanNode) HashSet(java.util.HashSet) ImmutableList(com.google.common.collect.ImmutableList) PlanNodeId(io.trino.sql.planner.plan.PlanNodeId) RowNumberNode(io.trino.sql.planner.plan.RowNumberNode) Map(java.util.Map) Objects.requireNonNull(java.util.Objects.requireNonNull) AggregationNode(io.trino.sql.planner.plan.AggregationNode) ImmutableSet.toImmutableSet(com.google.common.collect.ImmutableSet.toImmutableSet) ExpressionRewriter(io.trino.sql.tree.ExpressionRewriter) ExpressionAndValuePointers(io.trino.sql.planner.rowpattern.LogicalIndexExtractor.ExpressionAndValuePointers) Symbol(io.trino.sql.planner.Symbol) ImmutableMap(com.google.common.collect.ImmutableMap) IrLabel(io.trino.sql.planner.rowpattern.ir.IrLabel) StatisticsWriterNode(io.trino.sql.planner.plan.StatisticsWriterNode) ExpressionTreeRewriter(io.trino.sql.tree.ExpressionTreeRewriter) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) TopNNode(io.trino.sql.planner.plan.TopNNode) Set(java.util.Set) ScalarValuePointer(io.trino.sql.planner.rowpattern.ScalarValuePointer) OrderingScheme(io.trino.sql.planner.OrderingScheme) PatternRecognitionNode(io.trino.sql.planner.plan.PatternRecognitionNode) SortOrder(io.trino.spi.connector.SortOrder) List(java.util.List) ImmutableMap.toImmutableMap(com.google.common.collect.ImmutableMap.toImmutableMap) AggregationNode.groupingSets(io.trino.sql.planner.plan.AggregationNode.groupingSets) SymbolReference(io.trino.sql.tree.SymbolReference) GroupIdNode(io.trino.sql.planner.plan.GroupIdNode) DistinctLimitNode(io.trino.sql.planner.plan.DistinctLimitNode) ValuePointer(io.trino.sql.planner.rowpattern.ValuePointer) Entry(java.util.Map.Entry) TableFinishNode(io.trino.sql.planner.plan.TableFinishNode) TopNRankingNode(io.trino.sql.planner.plan.TopNRankingNode) Expression(io.trino.sql.tree.Expression) WindowNode(io.trino.sql.planner.plan.WindowNode) ScalarValuePointer(io.trino.sql.planner.rowpattern.ScalarValuePointer) ImmutableList(com.google.common.collect.ImmutableList) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) Symbol(io.trino.sql.planner.Symbol) SymbolReference(io.trino.sql.tree.SymbolReference) AggregationValuePointer(io.trino.sql.planner.rowpattern.AggregationValuePointer) ScalarValuePointer(io.trino.sql.planner.rowpattern.ScalarValuePointer) ValuePointer(io.trino.sql.planner.rowpattern.ValuePointer) AggregationValuePointer(io.trino.sql.planner.rowpattern.AggregationValuePointer) Expression(io.trino.sql.tree.Expression) ExpressionAndValuePointers(io.trino.sql.planner.rowpattern.LogicalIndexExtractor.ExpressionAndValuePointers)

Example 5 with ScalarValuePointer

use of io.trino.sql.planner.rowpattern.ScalarValuePointer in project trino by trinodb.

the class PushDownProjectionsFromPatternRecognition method rewrite.

private static ExpressionAndValuePointers rewrite(ExpressionAndValuePointers expression, Assignments.Builder assignments, Context context) {
    ImmutableList.Builder<ValuePointer> rewrittenPointers = ImmutableList.builder();
    for (ValuePointer valuePointer : expression.getValuePointers()) {
        if (valuePointer instanceof ScalarValuePointer) {
            rewrittenPointers.add(valuePointer);
        } else {
            AggregationValuePointer aggregationPointer = (AggregationValuePointer) valuePointer;
            Set<Symbol> runtimeEvaluatedSymbols = ImmutableSet.of(aggregationPointer.getClassifierSymbol(), aggregationPointer.getMatchNumberSymbol());
            List<Type> argumentTypes = aggregationPointer.getFunction().getSignature().getArgumentTypes();
            ImmutableList.Builder<Expression> rewrittenArguments = ImmutableList.builder();
            for (int i = 0; i < aggregationPointer.getArguments().size(); i++) {
                Expression argument = aggregationPointer.getArguments().get(i);
                if (argument instanceof SymbolReference || SymbolsExtractor.extractUnique(argument).stream().anyMatch(runtimeEvaluatedSymbols::contains)) {
                    rewrittenArguments.add(argument);
                } else {
                    Symbol symbol = context.getSymbolAllocator().newSymbol(argument, argumentTypes.get(i));
                    assignments.put(symbol, argument);
                    rewrittenArguments.add(symbol.toSymbolReference());
                }
            }
            rewrittenPointers.add(new AggregationValuePointer(aggregationPointer.getFunction(), aggregationPointer.getSetDescriptor(), rewrittenArguments.build(), aggregationPointer.getClassifierSymbol(), aggregationPointer.getMatchNumberSymbol()));
        }
    }
    return new ExpressionAndValuePointers(expression.getExpression(), expression.getLayout(), rewrittenPointers.build(), expression.getClassifierSymbols(), expression.getMatchNumberSymbols());
}
Also used : ScalarValuePointer(io.trino.sql.planner.rowpattern.ScalarValuePointer) ImmutableList(com.google.common.collect.ImmutableList) Symbol(io.trino.sql.planner.Symbol) SymbolReference(io.trino.sql.tree.SymbolReference) AggregationValuePointer(io.trino.sql.planner.rowpattern.AggregationValuePointer) ScalarValuePointer(io.trino.sql.planner.rowpattern.ScalarValuePointer) ValuePointer(io.trino.sql.planner.rowpattern.ValuePointer) AggregationValuePointer(io.trino.sql.planner.rowpattern.AggregationValuePointer) Type(io.trino.spi.type.Type) Expression(io.trino.sql.tree.Expression) ExpressionAndValuePointers(io.trino.sql.planner.rowpattern.LogicalIndexExtractor.ExpressionAndValuePointers)

Aggregations

Symbol (io.trino.sql.planner.Symbol)5 ScalarValuePointer (io.trino.sql.planner.rowpattern.ScalarValuePointer)5 ExpressionAndValuePointers (io.trino.sql.planner.rowpattern.LogicalIndexExtractor.ExpressionAndValuePointers)4 IrLabel (io.trino.sql.planner.rowpattern.ir.IrLabel)4 Expression (io.trino.sql.tree.Expression)4 SymbolReference (io.trino.sql.tree.SymbolReference)4 JsonCodecFactory (io.airlift.json.JsonCodecFactory)3 ObjectMapperProvider (io.airlift.json.ObjectMapperProvider)3 AggregationValuePointer (io.trino.sql.planner.rowpattern.AggregationValuePointer)3 LogicalIndexPointer (io.trino.sql.planner.rowpattern.LogicalIndexPointer)3 ValuePointer (io.trino.sql.planner.rowpattern.ValuePointer)3 Test (org.testng.annotations.Test)3 ImmutableList (com.google.common.collect.ImmutableList)2 Type (io.trino.spi.type.Type)2 SqlParser (io.trino.sql.parser.SqlParser)2 Measure (io.trino.sql.planner.plan.PatternRecognitionNode.Measure)2 ArithmeticUnaryExpression (io.trino.sql.tree.ArithmeticUnaryExpression)2 ComparisonExpression (io.trino.sql.tree.ComparisonExpression)2 IfExpression (io.trino.sql.tree.IfExpression)2 NullLiteral (io.trino.sql.tree.NullLiteral)2