use of io.trino.sql.planner.plan.WindowNode.Function in project trino by trinodb.
the class WindowFunctionMatcher method getAssignedSymbol.
@Override
public Optional<Symbol> getAssignedSymbol(PlanNode node, Session session, Metadata metadata, SymbolAliases symbolAliases) {
Optional<Symbol> result = Optional.empty();
Map<Symbol, Function> assignments;
if (node instanceof WindowNode) {
assignments = ((WindowNode) node).getWindowFunctions();
} else if (node instanceof PatternRecognitionNode) {
assignments = ((PatternRecognitionNode) node).getWindowFunctions();
} else {
return result;
}
FunctionCall expectedCall = callMaker.getExpectedValue(symbolAliases);
Optional<WindowNode.Frame> expectedFrame = frameMaker.map(maker -> maker.getExpectedValue(symbolAliases));
for (Map.Entry<Symbol, Function> assignment : assignments.entrySet()) {
Function function = assignment.getValue();
boolean signatureMatches = resolvedFunction.map(assignment.getValue().getResolvedFunction()::equals).orElse(true);
if (signatureMatches && windowFunctionMatches(function, expectedCall, expectedFrame)) {
checkState(result.isEmpty(), "Ambiguous function calls in %s", node);
result = Optional.of(assignment.getKey());
}
}
return result;
}
use of io.trino.sql.planner.plan.WindowNode.Function in project trino by trinodb.
the class TestPatternRecognitionNodeSerialization method testPatternRecognitionNodeRoundtrip.
@Test
public void testPatternRecognitionNodeRoundtrip() {
ObjectMapperProvider provider = new ObjectMapperProvider();
provider.setJsonSerializers(ImmutableMap.of(Expression.class, new ExpressionSerialization.ExpressionSerializer()));
provider.setJsonDeserializers(ImmutableMap.of(Expression.class, new ExpressionSerialization.ExpressionDeserializer(new SqlParser()), Type.class, new TypeDeserializer(TESTING_TYPE_MANAGER)));
provider.setKeyDeserializers(ImmutableMap.of(TypeSignature.class, new TypeSignatureKeyDeserializer()));
JsonCodec<PatternRecognitionNode> codec = new JsonCodecFactory(provider).jsonCodec(PatternRecognitionNode.class);
ResolvedFunction rankFunction = createTestMetadataManager().resolveFunction(TEST_SESSION, QualifiedName.of("rank"), ImmutableList.of());
// test remaining fields inside PatternRecognitionNode specific to pattern recognition:
// windowFunctions, measures, commonBaseFrame, rowsPerMatch, skipToLabel, skipToPosition, initial, pattern, subsets, variableDefinitions
PatternRecognitionNode node = new PatternRecognitionNode(new PlanNodeId("0"), new ValuesNode(new PlanNodeId("1"), 1), new Specification(ImmutableList.of(), Optional.empty()), Optional.empty(), ImmutableSet.of(), 0, ImmutableMap.of(new Symbol("rank"), new Function(rankFunction, ImmutableList.of(), new Frame(ROWS, CURRENT_ROW, Optional.empty(), Optional.empty(), UNBOUNDED_FOLLOWING, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()), false)), ImmutableMap.of(new Symbol("measure"), new Measure(new ExpressionAndValuePointers(new NullLiteral(), ImmutableList.of(), ImmutableList.of(), ImmutableSet.of(), ImmutableSet.of()), BOOLEAN)), Optional.of(new Frame(ROWS, CURRENT_ROW, Optional.empty(), Optional.empty(), UNBOUNDED_FOLLOWING, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty())), WINDOW, Optional.of(new IrLabel("B")), LAST, true, new IrConcatenation(ImmutableList.of(new IrLabel("A"), new IrLabel("B"), new IrLabel("C"))), ImmutableMap.of(new IrLabel("U"), ImmutableSet.of(new IrLabel("A"), new IrLabel("B")), new IrLabel("V"), ImmutableSet.of(new IrLabel("B"), new IrLabel("C"))), ImmutableMap.of(new IrLabel("B"), new ExpressionAndValuePointers(new NullLiteral(), ImmutableList.of(), ImmutableList.of(), ImmutableSet.of(), ImmutableSet.of()), new IrLabel("C"), new ExpressionAndValuePointers(new NullLiteral(), ImmutableList.of(), ImmutableList.of(), ImmutableSet.of(), ImmutableSet.of())));
PatternRecognitionNode roundtripNode = codec.fromJson(codec.toJson(node));
assertEquals(roundtripNode.getMeasures(), node.getMeasures());
assertEquals(roundtripNode.getRowsPerMatch(), node.getRowsPerMatch());
assertEquals(roundtripNode.getSkipToLabel(), node.getSkipToLabel());
assertEquals(roundtripNode.getSkipToPosition(), node.getSkipToPosition());
assertEquals(roundtripNode.isInitial(), node.isInitial());
assertEquals(roundtripNode.getPattern(), node.getPattern());
assertEquals(roundtripNode.getSubsets(), node.getSubsets());
assertEquals(roundtripNode.getVariableDefinitions(), node.getVariableDefinitions());
}
use of io.trino.sql.planner.plan.WindowNode.Function in project trino by trinodb.
the class PrunePattenRecognitionColumns method pushDownProjectOff.
@Override
protected Optional<PlanNode> pushDownProjectOff(Context context, PatternRecognitionNode patternRecognitionNode, Set<Symbol> referencedOutputs) {
Map<Symbol, Function> referencedFunctions = filterKeys(patternRecognitionNode.getWindowFunctions(), referencedOutputs::contains);
Map<Symbol, Measure> referencedMeasures = filterKeys(patternRecognitionNode.getMeasures(), referencedOutputs::contains);
if (referencedFunctions.isEmpty() && referencedMeasures.isEmpty() && passesAllRows(patternRecognitionNode)) {
return Optional.of(patternRecognitionNode.getSource());
}
// TODO prune unreferenced variable definitions when we have pattern transformations that delete parts of the pattern
ImmutableSet.Builder<Symbol> referencedInputs = ImmutableSet.builder();
patternRecognitionNode.getSource().getOutputSymbols().stream().filter(referencedOutputs::contains).forEach(referencedInputs::add);
referencedInputs.addAll(patternRecognitionNode.getPartitionBy());
patternRecognitionNode.getOrderingScheme().ifPresent(orderingScheme -> referencedInputs.addAll(orderingScheme.getOrderBy()));
patternRecognitionNode.getHashSymbol().ifPresent(referencedInputs::add);
referencedFunctions.values().stream().map(SymbolsExtractor::extractUnique).forEach(referencedInputs::addAll);
referencedMeasures.values().stream().map(Measure::getExpressionAndValuePointers).map(ExpressionAndValuePointers::getInputSymbols).forEach(referencedInputs::addAll);
patternRecognitionNode.getCommonBaseFrame().flatMap(Frame::getEndValue).ifPresent(referencedInputs::add);
patternRecognitionNode.getVariableDefinitions().values().stream().map(ExpressionAndValuePointers::getInputSymbols).forEach(referencedInputs::addAll);
Optional<PlanNode> prunedSource = restrictOutputs(context.getIdAllocator(), patternRecognitionNode.getSource(), referencedInputs.build());
if (prunedSource.isEmpty() && referencedFunctions.size() == patternRecognitionNode.getWindowFunctions().size() && referencedMeasures.size() == patternRecognitionNode.getMeasures().size()) {
return Optional.empty();
}
return Optional.of(new PatternRecognitionNode(patternRecognitionNode.getId(), prunedSource.orElse(patternRecognitionNode.getSource()), patternRecognitionNode.getSpecification(), patternRecognitionNode.getHashSymbol(), patternRecognitionNode.getPrePartitionedInputs(), patternRecognitionNode.getPreSortedOrderPrefix(), referencedFunctions, referencedMeasures, patternRecognitionNode.getCommonBaseFrame(), patternRecognitionNode.getRowsPerMatch(), patternRecognitionNode.getSkipToLabel(), patternRecognitionNode.getSkipToPosition(), patternRecognitionNode.isInitial(), patternRecognitionNode.getPattern(), patternRecognitionNode.getSubsets(), patternRecognitionNode.getVariableDefinitions()));
}
Aggregations