use of io.trino.sql.planner.plan.PatternRecognitionNode in project trino by trinodb.
the class MeasureMatcher method getAssignedSymbol.
@Override
public Optional<Symbol> getAssignedSymbol(PlanNode node, Session session, Metadata metadata, SymbolAliases symbolAliases) {
Optional<Symbol> result = Optional.empty();
if (!(node instanceof PatternRecognitionNode)) {
return result;
}
PatternRecognitionNode patternRecognitionNode = (PatternRecognitionNode) node;
Measure expectedMeasure = new Measure(rewrite(expression, subsets), type);
for (Map.Entry<Symbol, Measure> assignment : patternRecognitionNode.getMeasures().entrySet()) {
Measure actualMeasure = assignment.getValue();
if (measuresEquivalent(actualMeasure, expectedMeasure, symbolAliases)) {
checkState(result.isEmpty(), "Ambiguous measures in %s", patternRecognitionNode);
result = Optional.of(assignment.getKey());
}
}
return result;
}
use of io.trino.sql.planner.plan.PatternRecognitionNode in project trino by trinodb.
the class PatternRecognitionMatcher method detailMatches.
@Override
public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session session, Metadata metadata, SymbolAliases symbolAliases) {
checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName());
PatternRecognitionNode patternRecognitionNode = (PatternRecognitionNode) node;
boolean specificationMatches = specification.map(expected -> expected.getExpectedValue(symbolAliases).equals(patternRecognitionNode.getSpecification())).orElse(true);
if (!specificationMatches) {
return NO_MATCH;
}
if (frame.isPresent()) {
if (patternRecognitionNode.getCommonBaseFrame().isEmpty()) {
return NO_MATCH;
}
if (!frame.get().getExpectedValue(symbolAliases).equals(patternRecognitionNode.getCommonBaseFrame().get())) {
return NO_MATCH;
}
}
if (rowsPerMatch != patternRecognitionNode.getRowsPerMatch()) {
return NO_MATCH;
}
if (!skipToLabel.equals(patternRecognitionNode.getSkipToLabel())) {
return NO_MATCH;
}
if (skipToPosition != patternRecognitionNode.getSkipToPosition()) {
return NO_MATCH;
}
if (initial != patternRecognitionNode.isInitial()) {
return NO_MATCH;
}
if (!pattern.equals(patternRecognitionNode.getPattern())) {
return NO_MATCH;
}
if (!subsets.equals(patternRecognitionNode.getSubsets())) {
return NO_MATCH;
}
if (variableDefinitions.size() != patternRecognitionNode.getVariableDefinitions().size()) {
return NO_MATCH;
}
for (Map.Entry<IrLabel, ExpressionAndValuePointers> entry : variableDefinitions.entrySet()) {
IrLabel name = entry.getKey();
ExpressionAndValuePointers actual = patternRecognitionNode.getVariableDefinitions().get(name);
if (actual == null) {
return NO_MATCH;
}
ExpressionAndValuePointers expected = entry.getValue();
ExpressionVerifier verifier = new ExpressionVerifier(symbolAliases);
if (!ExpressionAndValuePointersEquivalence.equivalent(actual, expected, (actualSymbol, expectedSymbol) -> verifier.process(actualSymbol.toSymbolReference(), expectedSymbol.toSymbolReference()))) {
return NO_MATCH;
}
}
return match();
}
use of io.trino.sql.planner.plan.PatternRecognitionNode in project trino by trinodb.
the class MergePatternRecognitionNodes method extractPrerequisites.
/**
* Extract project assignments producing symbols used by the PatternRecognitionNode's
* window functions and measures. Exclude identity assignments.
*/
private static Assignments extractPrerequisites(PatternRecognitionNode node, ProjectNode project) {
Assignments assignments = project.getAssignments();
ImmutableSet.Builder<Symbol> inputsBuilder = ImmutableSet.builder();
node.getWindowFunctions().values().stream().map(SymbolsExtractor::extractAll).forEach(inputsBuilder::addAll);
node.getMeasures().values().stream().map(Measure::getExpressionAndValuePointers).map(ExpressionAndValuePointers::getInputSymbols).forEach(inputsBuilder::addAll);
Set<Symbol> inputs = inputsBuilder.build();
return assignments.filter(symbol -> !assignments.isIdentity(symbol)).filter(inputs::contains);
}
use of io.trino.sql.planner.plan.PatternRecognitionNode in project trino by trinodb.
the class PrunePattenRecognitionColumns method pushDownProjectOff.
@Override
protected Optional<PlanNode> pushDownProjectOff(Context context, PatternRecognitionNode patternRecognitionNode, Set<Symbol> referencedOutputs) {
Map<Symbol, Function> referencedFunctions = filterKeys(patternRecognitionNode.getWindowFunctions(), referencedOutputs::contains);
Map<Symbol, Measure> referencedMeasures = filterKeys(patternRecognitionNode.getMeasures(), referencedOutputs::contains);
if (referencedFunctions.isEmpty() && referencedMeasures.isEmpty() && passesAllRows(patternRecognitionNode)) {
return Optional.of(patternRecognitionNode.getSource());
}
// TODO prune unreferenced variable definitions when we have pattern transformations that delete parts of the pattern
ImmutableSet.Builder<Symbol> referencedInputs = ImmutableSet.builder();
patternRecognitionNode.getSource().getOutputSymbols().stream().filter(referencedOutputs::contains).forEach(referencedInputs::add);
referencedInputs.addAll(patternRecognitionNode.getPartitionBy());
patternRecognitionNode.getOrderingScheme().ifPresent(orderingScheme -> referencedInputs.addAll(orderingScheme.getOrderBy()));
patternRecognitionNode.getHashSymbol().ifPresent(referencedInputs::add);
referencedFunctions.values().stream().map(SymbolsExtractor::extractUnique).forEach(referencedInputs::addAll);
referencedMeasures.values().stream().map(Measure::getExpressionAndValuePointers).map(ExpressionAndValuePointers::getInputSymbols).forEach(referencedInputs::addAll);
patternRecognitionNode.getCommonBaseFrame().flatMap(Frame::getEndValue).ifPresent(referencedInputs::add);
patternRecognitionNode.getVariableDefinitions().values().stream().map(ExpressionAndValuePointers::getInputSymbols).forEach(referencedInputs::addAll);
Optional<PlanNode> prunedSource = restrictOutputs(context.getIdAllocator(), patternRecognitionNode.getSource(), referencedInputs.build());
if (prunedSource.isEmpty() && referencedFunctions.size() == patternRecognitionNode.getWindowFunctions().size() && referencedMeasures.size() == patternRecognitionNode.getMeasures().size()) {
return Optional.empty();
}
return Optional.of(new PatternRecognitionNode(patternRecognitionNode.getId(), prunedSource.orElse(patternRecognitionNode.getSource()), patternRecognitionNode.getSpecification(), patternRecognitionNode.getHashSymbol(), patternRecognitionNode.getPrePartitionedInputs(), patternRecognitionNode.getPreSortedOrderPrefix(), referencedFunctions, referencedMeasures, patternRecognitionNode.getCommonBaseFrame(), patternRecognitionNode.getRowsPerMatch(), patternRecognitionNode.getSkipToLabel(), patternRecognitionNode.getSkipToPosition(), patternRecognitionNode.isInitial(), patternRecognitionNode.getPattern(), patternRecognitionNode.getSubsets(), patternRecognitionNode.getVariableDefinitions()));
}
use of io.trino.sql.planner.plan.PatternRecognitionNode in project trino by trinodb.
the class RelationPlanner method visitPatternRecognitionRelation.
@Override
protected RelationPlan visitPatternRecognitionRelation(PatternRecognitionRelation node, Void context) {
RelationPlan subPlan = process(node.getInput(), context);
// Pre-project inputs for PARTITION BY and ORDER BY
List<Expression> inputs = ImmutableList.<Expression>builder().addAll(node.getPartitionBy()).addAll(getSortItemsFromOrderBy(node.getOrderBy()).stream().map(SortItem::getSortKey).collect(toImmutableList())).build();
PlanBuilder planBuilder = newPlanBuilder(subPlan, analysis, lambdaDeclarationToSymbolMap);
// no handleSubqueries because subqueries are not allowed here
planBuilder = planBuilder.appendProjections(inputs, symbolAllocator, idAllocator);
ImmutableList.Builder<Symbol> outputLayout = ImmutableList.builder();
boolean oneRowOutput = node.getRowsPerMatch().isEmpty() || node.getRowsPerMatch().get().isOneRow();
WindowNode.Specification specification = planWindowSpecification(node.getPartitionBy(), node.getOrderBy(), planBuilder::translate);
outputLayout.addAll(specification.getPartitionBy());
if (!oneRowOutput) {
getSortItemsFromOrderBy(node.getOrderBy()).stream().map(SortItem::getSortKey).map(planBuilder::translate).forEach(outputLayout::add);
}
planBuilder = subqueryPlanner.handleSubqueries(planBuilder, extractPatternRecognitionExpressions(node.getVariableDefinitions(), node.getMeasures()), analysis.getSubqueries(node));
PatternRecognitionComponents components = planPatternRecognitionComponents(planBuilder::rewrite, node.getSubsets(), node.getMeasures(), node.getAfterMatchSkipTo(), node.getPatternSearchMode(), node.getPattern(), node.getVariableDefinitions());
outputLayout.addAll(components.getMeasureOutputs());
if (!oneRowOutput) {
Set<Symbol> inputSymbolsOnOutput = ImmutableSet.copyOf(outputLayout.build());
subPlan.getFieldMappings().stream().filter(symbol -> !inputSymbolsOnOutput.contains(symbol)).forEach(outputLayout::add);
}
PatternRecognitionNode planNode = new PatternRecognitionNode(idAllocator.getNextId(), planBuilder.getRoot(), specification, Optional.empty(), ImmutableSet.of(), 0, ImmutableMap.of(), components.getMeasures(), Optional.empty(), node.getRowsPerMatch().orElse(ONE), components.getSkipToLabel(), components.getSkipToPosition(), components.isInitial(), components.getPattern(), components.getSubsets(), components.getVariableDefinitions());
return new RelationPlan(planNode, analysis.getScope(node), outputLayout.build(), outerContext);
}
Aggregations