use of io.trino.sql.tree.ExpressionTreeRewriter in project trino by trinodb.
the class SymbolMapper method map.
private ExpressionAndValuePointers map(ExpressionAndValuePointers expressionAndValuePointers) {
// Map only the input symbols of ValuePointers. These are the symbols produced by the source node.
// Other symbols present in the ExpressionAndValuePointers structure are synthetic unique symbols
// with no outer usage or dependencies.
ImmutableList.Builder<ValuePointer> newValuePointers = ImmutableList.builder();
for (ValuePointer valuePointer : expressionAndValuePointers.getValuePointers()) {
if (valuePointer instanceof ScalarValuePointer) {
ScalarValuePointer scalarValuePointer = (ScalarValuePointer) valuePointer;
Symbol inputSymbol = scalarValuePointer.getInputSymbol();
if (expressionAndValuePointers.getClassifierSymbols().contains(inputSymbol) || expressionAndValuePointers.getMatchNumberSymbols().contains(inputSymbol)) {
newValuePointers.add(scalarValuePointer);
} else {
newValuePointers.add(new ScalarValuePointer(scalarValuePointer.getLogicalIndexPointer(), map(inputSymbol)));
}
} else {
AggregationValuePointer aggregationValuePointer = (AggregationValuePointer) valuePointer;
List<Expression> newArguments = aggregationValuePointer.getArguments().stream().map(expression -> ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter<Void>() {
@Override
public Expression rewriteSymbolReference(SymbolReference node, Void context, ExpressionTreeRewriter<Void> treeRewriter) {
if (Symbol.from(node).equals(aggregationValuePointer.getClassifierSymbol()) || Symbol.from(node).equals(aggregationValuePointer.getMatchNumberSymbol())) {
return node;
}
return map(node);
}
}, expression)).collect(toImmutableList());
newValuePointers.add(new AggregationValuePointer(aggregationValuePointer.getFunction(), aggregationValuePointer.getSetDescriptor(), newArguments, aggregationValuePointer.getClassifierSymbol(), aggregationValuePointer.getMatchNumberSymbol()));
}
}
return new ExpressionAndValuePointers(expressionAndValuePointers.getExpression(), expressionAndValuePointers.getLayout(), newValuePointers.build(), expressionAndValuePointers.getClassifierSymbols(), expressionAndValuePointers.getMatchNumberSymbols());
}
Aggregations