use of io.trino.sql.planner.TypeAnalyzer in project trino by trinodb.
the class PushDownDereferencesThroughTopNRanking method apply.
@Override
public Result apply(ProjectNode projectNode, Captures captures, Context context) {
TopNRankingNode topNRankingNode = captures.get(CHILD);
// Extract dereferences from project node assignments for pushdown
Set<SubscriptExpression> dereferences = extractRowSubscripts(projectNode.getAssignments().getExpressions(), false, context.getSession(), typeAnalyzer, context.getSymbolAllocator().getTypes());
// Exclude dereferences on symbols being used in partitionBy and orderBy
WindowNode.Specification specification = topNRankingNode.getSpecification();
dereferences = dereferences.stream().filter(expression -> {
Symbol symbol = getBase(expression);
return !specification.getPartitionBy().contains(symbol) && !specification.getOrderingScheme().map(OrderingScheme::getOrderBy).orElse(ImmutableList.of()).contains(symbol);
}).collect(toImmutableSet());
if (dereferences.isEmpty()) {
return Result.empty();
}
// Create new symbols for dereference expressions
Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSession(), context.getSymbolAllocator(), typeAnalyzer);
// Rewrite project node assignments using new symbols for dereference expressions
Map<Expression, SymbolReference> mappings = HashBiMap.create(dereferenceAssignments.getMap()).inverse().entrySet().stream().collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference()));
Assignments newAssignments = projectNode.getAssignments().rewrite(expression -> replaceExpression(expression, mappings));
return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), topNRankingNode.replaceChildren(ImmutableList.of(new ProjectNode(context.getIdAllocator().getNextId(), topNRankingNode.getSource(), Assignments.builder().putIdentities(topNRankingNode.getSource().getOutputSymbols()).putAll(dereferenceAssignments).build()))), newAssignments));
}
use of io.trino.sql.planner.TypeAnalyzer in project trino by trinodb.
the class LocalQueryRunner method createPlan.
public Plan createPlan(Session session, @Language("SQL") String sql, List<PlanOptimizer> optimizers, LogicalPlanner.Stage stage, WarningCollector warningCollector) {
PreparedQuery preparedQuery = new QueryPreparer(sqlParser).prepareQuery(session, sql);
assertFormattedSql(sqlParser, createParsingOptions(session), preparedQuery.getStatement());
PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator();
AnalyzerFactory analyzerFactory = createAnalyzerFactory(createQueryExplainerFactory(optimizers));
Analyzer analyzer = analyzerFactory.createAnalyzer(session, preparedQuery.getParameters(), parameterExtractor(preparedQuery.getStatement(), preparedQuery.getParameters()), warningCollector);
LogicalPlanner logicalPlanner = new LogicalPlanner(session, optimizers, new PlanSanityChecker(true), idAllocator, getPlannerContext(), new TypeAnalyzer(plannerContext, statementAnalyzerFactory), statsCalculator, costCalculator, warningCollector);
Analysis analysis = analyzer.analyze(preparedQuery.getStatement());
// make LocalQueryRunner always compute plan statistics for test purposes
return logicalPlanner.plan(analysis, stage);
}
use of io.trino.sql.planner.TypeAnalyzer in project trino by trinodb.
the class LocalQueryRunner method createDrivers.
private List<Driver> createDrivers(Session session, Plan plan, OutputFactory outputFactory, TaskContext taskContext) {
if (printPlan) {
System.out.println(PlanPrinter.textLogicalPlan(plan.getRoot(), plan.getTypes(), plannerContext.getMetadata(), plannerContext.getFunctionManager(), plan.getStatsAndCosts(), session, 0, false));
}
SubPlan subplan = createSubPlans(session, plan, true);
if (!subplan.getChildren().isEmpty()) {
throw new AssertionError("Expected subplan to have no children");
}
TableExecuteContextManager tableExecuteContextManager = new TableExecuteContextManager();
tableExecuteContextManager.registerTableExecuteContextForQuery(taskContext.getQueryContext().getQueryId());
LocalExecutionPlanner executionPlanner = new LocalExecutionPlanner(plannerContext, new TypeAnalyzer(plannerContext, statementAnalyzerFactory), Optional.empty(), pageSourceManager, indexManager, nodePartitioningManager, pageSinkManager, null, expressionCompiler, pageFunctionCompiler, joinFilterFunctionCompiler, new IndexJoinLookupStats(), this.taskManagerConfig, spillerFactory, singleStreamSpillerFactory, partitioningSpillerFactory, new PagesIndex.TestingFactory(false), joinCompiler, operatorFactories, new OrderingCompiler(plannerContext.getTypeOperators()), new DynamicFilterConfig(), blockTypeOperators, tableExecuteContextManager, exchangeManagerRegistry);
// plan query
StageExecutionDescriptor stageExecutionDescriptor = subplan.getFragment().getStageExecutionDescriptor();
LocalExecutionPlan localExecutionPlan = executionPlanner.plan(taskContext, stageExecutionDescriptor, subplan.getFragment().getRoot(), subplan.getFragment().getPartitioningScheme().getOutputLayout(), plan.getTypes(), subplan.getFragment().getPartitionedSources(), outputFactory);
// generate splitAssignments
List<SplitAssignment> splitAssignments = new ArrayList<>();
long sequenceId = 0;
for (TableScanNode tableScan : findTableScanNodes(subplan.getFragment().getRoot())) {
TableHandle table = tableScan.getTable();
SplitSource splitSource = splitManager.getSplits(session, table, stageExecutionDescriptor.isScanGroupedExecution(tableScan.getId()) ? GROUPED_SCHEDULING : UNGROUPED_SCHEDULING, EMPTY, alwaysTrue());
ImmutableSet.Builder<ScheduledSplit> scheduledSplits = ImmutableSet.builder();
while (!splitSource.isFinished()) {
for (Split split : getNextBatch(splitSource)) {
scheduledSplits.add(new ScheduledSplit(sequenceId++, tableScan.getId(), split));
}
}
splitAssignments.add(new SplitAssignment(tableScan.getId(), scheduledSplits.build(), true));
}
// create drivers
List<Driver> drivers = new ArrayList<>();
Map<PlanNodeId, DriverFactory> driverFactoriesBySource = new HashMap<>();
for (DriverFactory driverFactory : localExecutionPlan.getDriverFactories()) {
for (int i = 0; i < driverFactory.getDriverInstances().orElse(1); i++) {
if (driverFactory.getSourceId().isPresent()) {
checkState(driverFactoriesBySource.put(driverFactory.getSourceId().get(), driverFactory) == null);
} else {
DriverContext driverContext = taskContext.addPipelineContext(driverFactory.getPipelineId(), driverFactory.isInputDriver(), driverFactory.isOutputDriver(), false).addDriverContext();
Driver driver = driverFactory.createDriver(driverContext);
drivers.add(driver);
}
}
}
// add split assignments to the drivers
ImmutableSet<PlanNodeId> partitionedSources = ImmutableSet.copyOf(subplan.getFragment().getPartitionedSources());
for (SplitAssignment splitAssignment : splitAssignments) {
DriverFactory driverFactory = driverFactoriesBySource.get(splitAssignment.getPlanNodeId());
checkState(driverFactory != null);
boolean partitioned = partitionedSources.contains(driverFactory.getSourceId().get());
for (ScheduledSplit split : splitAssignment.getSplits()) {
DriverContext driverContext = taskContext.addPipelineContext(driverFactory.getPipelineId(), driverFactory.isInputDriver(), driverFactory.isOutputDriver(), partitioned).addDriverContext();
Driver driver = driverFactory.createDriver(driverContext);
driver.updateSplitAssignment(new SplitAssignment(split.getPlanNodeId(), ImmutableSet.of(split), true));
drivers.add(driver);
}
}
for (DriverFactory driverFactory : localExecutionPlan.getDriverFactories()) {
driverFactory.noMoreDrivers();
}
return ImmutableList.copyOf(drivers);
}
use of io.trino.sql.planner.TypeAnalyzer in project trino by trinodb.
the class QueryExplainer method getLogicalPlan.
public Plan getLogicalPlan(Session session, Statement statement, List<Expression> parameters, WarningCollector warningCollector) {
// analyze statement
Analysis analysis = analyze(session, statement, parameters, warningCollector);
PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator();
// plan statement
LogicalPlanner logicalPlanner = new LogicalPlanner(session, planOptimizers, idAllocator, plannerContext, new TypeAnalyzer(plannerContext, statementAnalyzerFactory), statsCalculator, costCalculator, warningCollector);
return logicalPlanner.plan(analysis, OPTIMIZED_AND_VALIDATED, true);
}
use of io.trino.sql.planner.TypeAnalyzer in project trino by trinodb.
the class TestPushProjectionIntoTableScan method testPushProjection.
@Test
public void testPushProjection() {
try (RuleTester ruleTester = defaultRuleTester()) {
// Building context for input
String columnName = "col0";
Type columnType = ROW_TYPE;
Symbol baseColumn = new Symbol(columnName);
ColumnHandle columnHandle = new TpchColumnHandle(columnName, columnType);
// Create catalog with applyProjection enabled
MockConnectorFactory factory = createMockFactory(ImmutableMap.of(columnName, columnHandle), Optional.of(this::mockApplyProjection));
ruleTester.getQueryRunner().createCatalog(MOCK_CATALOG, factory, ImmutableMap.of());
TypeAnalyzer typeAnalyzer = createTestingTypeAnalyzer(ruleTester.getPlannerContext());
// Prepare project node symbols and types
Symbol identity = new Symbol("symbol_identity");
Symbol dereference = new Symbol("symbol_dereference");
Symbol constant = new Symbol("symbol_constant");
Symbol call = new Symbol("symbol_call");
ImmutableMap<Symbol, Type> types = ImmutableMap.of(baseColumn, ROW_TYPE, identity, ROW_TYPE, dereference, BIGINT, constant, BIGINT, call, VARCHAR);
// Prepare project node assignments
ImmutableMap<Symbol, Expression> inputProjections = ImmutableMap.of(identity, baseColumn.toSymbolReference(), dereference, new SubscriptExpression(baseColumn.toSymbolReference(), new LongLiteral("1")), constant, new LongLiteral("5"), call, new FunctionCall(QualifiedName.of("STARTS_WITH"), ImmutableList.of(new StringLiteral("abc"), new StringLiteral("ab"))));
// Compute expected symbols after applyProjection
TransactionId transactionId = ruleTester.getQueryRunner().getTransactionManager().beginTransaction(false);
Session session = MOCK_SESSION.beginTransactionId(transactionId, ruleTester.getQueryRunner().getTransactionManager(), ruleTester.getQueryRunner().getAccessControl());
ImmutableMap<Symbol, String> connectorNames = inputProjections.entrySet().stream().collect(toImmutableMap(Map.Entry::getKey, e -> translate(session, e.getValue(), typeAnalyzer, viewOf(types), ruleTester.getPlannerContext()).get().toString()));
ImmutableMap<Symbol, String> newNames = ImmutableMap.of(identity, "projected_variable_" + connectorNames.get(identity), dereference, "projected_dereference_" + connectorNames.get(dereference), constant, "projected_constant_" + connectorNames.get(constant), call, "projected_call_" + connectorNames.get(call));
Map<String, ColumnHandle> expectedColumns = newNames.entrySet().stream().collect(toImmutableMap(Map.Entry::getValue, e -> column(e.getValue(), types.get(e.getKey()))));
ruleTester.assertThat(createRule(ruleTester)).on(p -> {
// Register symbols
types.forEach((symbol, type) -> p.symbol(symbol.getName(), type));
return p.project(new Assignments(inputProjections), p.tableScan(tableScan -> tableScan.setTableHandle(TEST_TABLE_HANDLE).setSymbols(ImmutableList.copyOf(types.keySet())).setAssignments(types.keySet().stream().collect(Collectors.toMap(Function.identity(), v -> columnHandle))).setStatistics(Optional.of(PlanNodeStatsEstimate.builder().setOutputRowCount(42).addSymbolStatistics(baseColumn, SymbolStatsEstimate.builder().setNullsFraction(0).setDistinctValuesCount(33).build()).build()))));
}).withSession(MOCK_SESSION).matches(project(newNames.entrySet().stream().collect(toImmutableMap(e -> e.getKey().getName(), e -> expression(symbolReference(e.getValue())))), tableScan(new MockConnectorTableHandle(new SchemaTableName(TEST_SCHEMA, "projected_" + TEST_TABLE), TupleDomain.all(), Optional.of(ImmutableList.copyOf(expectedColumns.values())))::equals, TupleDomain.all(), expectedColumns.entrySet().stream().collect(toImmutableMap(Map.Entry::getKey, e -> e.getValue()::equals)), Optional.of(PlanNodeStatsEstimate.builder().setOutputRowCount(42).addSymbolStatistics(new Symbol(newNames.get(constant)), SymbolStatsEstimate.builder().setDistinctValuesCount(1).setNullsFraction(0).setLowValue(5).setHighValue(5).build()).addSymbolStatistics(new Symbol(newNames.get(call).toLowerCase(ENGLISH)), SymbolStatsEstimate.builder().setDistinctValuesCount(1).setNullsFraction(0).build()).addSymbolStatistics(new Symbol(newNames.get(identity)), SymbolStatsEstimate.builder().setDistinctValuesCount(33).setNullsFraction(0).build()).addSymbolStatistics(new Symbol(newNames.get(dereference)), SymbolStatsEstimate.unknown()).build())::equals)));
}
}
Aggregations