use of io.trino.sql.planner.rowpattern.ScalarValuePointer in project trino by trinodb.
the class TestPatternRecognitionNodeSerialization method testExpressionAndValuePointersRoundtrip.
@Test
public void testExpressionAndValuePointersRoundtrip() {
ObjectMapperProvider provider = new ObjectMapperProvider();
provider.setJsonSerializers(ImmutableMap.of(Expression.class, new ExpressionSerialization.ExpressionSerializer()));
provider.setJsonDeserializers(ImmutableMap.of(Expression.class, new ExpressionSerialization.ExpressionDeserializer(new SqlParser())));
JsonCodec<ExpressionAndValuePointers> codec = new JsonCodecFactory(provider).jsonCodec(ExpressionAndValuePointers.class);
assertJsonRoundTrip(codec, new ExpressionAndValuePointers(new NullLiteral(), ImmutableList.of(), ImmutableList.of(), ImmutableSet.of(), ImmutableSet.of()));
assertJsonRoundTrip(codec, new ExpressionAndValuePointers(new IfExpression(new ComparisonExpression(GREATER_THAN, new SymbolReference("classifier"), new SymbolReference("x")), new FunctionCall(QualifiedName.of("rand"), ImmutableList.of()), new ArithmeticUnaryExpression(MINUS, new SymbolReference("match_number"))), ImmutableList.of(new Symbol("classifier"), new Symbol("x"), new Symbol("match_number")), ImmutableList.of(new ScalarValuePointer(new LogicalIndexPointer(ImmutableSet.of(new IrLabel("A"), new IrLabel("B")), false, true, 1, -1), new Symbol("input_symbol_a")), new ScalarValuePointer(new LogicalIndexPointer(ImmutableSet.of(new IrLabel("B")), true, false, 2, 1), new Symbol("input_symbol_a")), new ScalarValuePointer(new LogicalIndexPointer(ImmutableSet.of(), true, true, 0, 0), new Symbol("input_symbol_a"))), ImmutableSet.of(new Symbol("classifier")), ImmutableSet.of(new Symbol("match_number"))));
}
use of io.trino.sql.planner.rowpattern.ScalarValuePointer in project trino by trinodb.
the class TestPatternRecognitionNodeSerialization method testMeasureRoundtrip.
@Test
public void testMeasureRoundtrip() {
ObjectMapperProvider provider = new ObjectMapperProvider();
provider.setJsonSerializers(ImmutableMap.of(Expression.class, new ExpressionSerialization.ExpressionSerializer()));
provider.setJsonDeserializers(ImmutableMap.of(Expression.class, new ExpressionSerialization.ExpressionDeserializer(new SqlParser()), Type.class, new TypeDeserializer(TESTING_TYPE_MANAGER)));
JsonCodec<Measure> codec = new JsonCodecFactory(provider).jsonCodec(Measure.class);
assertJsonRoundTrip(codec, new Measure(new ExpressionAndValuePointers(new NullLiteral(), ImmutableList.of(), ImmutableList.of(), ImmutableSet.of(), ImmutableSet.of()), BOOLEAN));
assertJsonRoundTrip(codec, new Measure(new ExpressionAndValuePointers(new IfExpression(new ComparisonExpression(GREATER_THAN, new SymbolReference("match_number"), new SymbolReference("x")), new GenericLiteral("BIGINT", "10"), new ArithmeticUnaryExpression(MINUS, new SymbolReference("y"))), ImmutableList.of(new Symbol("match_number"), new Symbol("x"), new Symbol("y")), ImmutableList.of(new ScalarValuePointer(new LogicalIndexPointer(ImmutableSet.of(), true, true, 0, 0), new Symbol("input_symbol_a")), new ScalarValuePointer(new LogicalIndexPointer(ImmutableSet.of(new IrLabel("A")), false, true, 1, -1), new Symbol("input_symbol_a")), new ScalarValuePointer(new LogicalIndexPointer(ImmutableSet.of(new IrLabel("B")), false, true, 1, -1), new Symbol("input_symbol_b"))), ImmutableSet.of(), ImmutableSet.of(new Symbol("match_number"))), BIGINT));
}
use of io.trino.sql.planner.rowpattern.ScalarValuePointer in project trino by trinodb.
the class TestPatternRecognitionNodeSerialization method testScalarValuePointerRoundtrip.
@Test
public void testScalarValuePointerRoundtrip() {
JsonCodec<ValuePointer> codec = new JsonCodecFactory(new ObjectMapperProvider()).jsonCodec(ValuePointer.class);
assertJsonRoundTrip(codec, new ScalarValuePointer(new LogicalIndexPointer(ImmutableSet.of(), false, false, 5, 5), new Symbol("input_symbol")));
assertJsonRoundTrip(codec, new ScalarValuePointer(new LogicalIndexPointer(ImmutableSet.of(new IrLabel("A"), new IrLabel("B")), true, true, 1, -1), new Symbol("input_symbol")));
}
use of io.trino.sql.planner.rowpattern.ScalarValuePointer in project trino by trinodb.
the class SymbolMapper method map.
private ExpressionAndValuePointers map(ExpressionAndValuePointers expressionAndValuePointers) {
// Map only the input symbols of ValuePointers. These are the symbols produced by the source node.
// Other symbols present in the ExpressionAndValuePointers structure are synthetic unique symbols
// with no outer usage or dependencies.
ImmutableList.Builder<ValuePointer> newValuePointers = ImmutableList.builder();
for (ValuePointer valuePointer : expressionAndValuePointers.getValuePointers()) {
if (valuePointer instanceof ScalarValuePointer) {
ScalarValuePointer scalarValuePointer = (ScalarValuePointer) valuePointer;
Symbol inputSymbol = scalarValuePointer.getInputSymbol();
if (expressionAndValuePointers.getClassifierSymbols().contains(inputSymbol) || expressionAndValuePointers.getMatchNumberSymbols().contains(inputSymbol)) {
newValuePointers.add(scalarValuePointer);
} else {
newValuePointers.add(new ScalarValuePointer(scalarValuePointer.getLogicalIndexPointer(), map(inputSymbol)));
}
} else {
AggregationValuePointer aggregationValuePointer = (AggregationValuePointer) valuePointer;
List<Expression> newArguments = aggregationValuePointer.getArguments().stream().map(expression -> ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter<Void>() {
@Override
public Expression rewriteSymbolReference(SymbolReference node, Void context, ExpressionTreeRewriter<Void> treeRewriter) {
if (Symbol.from(node).equals(aggregationValuePointer.getClassifierSymbol()) || Symbol.from(node).equals(aggregationValuePointer.getMatchNumberSymbol())) {
return node;
}
return map(node);
}
}, expression)).collect(toImmutableList());
newValuePointers.add(new AggregationValuePointer(aggregationValuePointer.getFunction(), aggregationValuePointer.getSetDescriptor(), newArguments, aggregationValuePointer.getClassifierSymbol(), aggregationValuePointer.getMatchNumberSymbol()));
}
}
return new ExpressionAndValuePointers(expressionAndValuePointers.getExpression(), expressionAndValuePointers.getLayout(), newValuePointers.build(), expressionAndValuePointers.getClassifierSymbols(), expressionAndValuePointers.getMatchNumberSymbols());
}
use of io.trino.sql.planner.rowpattern.ScalarValuePointer in project trino by trinodb.
the class PushDownProjectionsFromPatternRecognition method rewrite.
private static ExpressionAndValuePointers rewrite(ExpressionAndValuePointers expression, Assignments.Builder assignments, Context context) {
ImmutableList.Builder<ValuePointer> rewrittenPointers = ImmutableList.builder();
for (ValuePointer valuePointer : expression.getValuePointers()) {
if (valuePointer instanceof ScalarValuePointer) {
rewrittenPointers.add(valuePointer);
} else {
AggregationValuePointer aggregationPointer = (AggregationValuePointer) valuePointer;
Set<Symbol> runtimeEvaluatedSymbols = ImmutableSet.of(aggregationPointer.getClassifierSymbol(), aggregationPointer.getMatchNumberSymbol());
List<Type> argumentTypes = aggregationPointer.getFunction().getSignature().getArgumentTypes();
ImmutableList.Builder<Expression> rewrittenArguments = ImmutableList.builder();
for (int i = 0; i < aggregationPointer.getArguments().size(); i++) {
Expression argument = aggregationPointer.getArguments().get(i);
if (argument instanceof SymbolReference || SymbolsExtractor.extractUnique(argument).stream().anyMatch(runtimeEvaluatedSymbols::contains)) {
rewrittenArguments.add(argument);
} else {
Symbol symbol = context.getSymbolAllocator().newSymbol(argument, argumentTypes.get(i));
assignments.put(symbol, argument);
rewrittenArguments.add(symbol.toSymbolReference());
}
}
rewrittenPointers.add(new AggregationValuePointer(aggregationPointer.getFunction(), aggregationPointer.getSetDescriptor(), rewrittenArguments.build(), aggregationPointer.getClassifierSymbol(), aggregationPointer.getMatchNumberSymbol()));
}
}
return new ExpressionAndValuePointers(expression.getExpression(), expression.getLayout(), rewrittenPointers.build(), expression.getClassifierSymbols(), expression.getMatchNumberSymbols());
}
Aggregations