use of io.trino.sql.planner.plan.StatisticAggregations in project trino by trinodb.
the class PruneTableWriterSourceColumns method apply.
@Override
public Result apply(TableWriterNode tableWriterNode, Captures captures, Context context) {
ImmutableSet.Builder<Symbol> requiredInputs = ImmutableSet.<Symbol>builder().addAll(tableWriterNode.getColumns());
if (tableWriterNode.getPartitioningScheme().isPresent()) {
PartitioningScheme partitioningScheme = tableWriterNode.getPartitioningScheme().get();
partitioningScheme.getPartitioning().getColumns().forEach(requiredInputs::add);
partitioningScheme.getHashColumn().ifPresent(requiredInputs::add);
}
if (tableWriterNode.getStatisticsAggregation().isPresent()) {
StatisticAggregations aggregations = tableWriterNode.getStatisticsAggregation().get();
requiredInputs.addAll(aggregations.getGroupingSymbols());
aggregations.getAggregations().values().stream().map(SymbolsExtractor::extractUnique).forEach(requiredInputs::addAll);
}
return restrictChildOutputs(context.getIdAllocator(), tableWriterNode, requiredInputs.build()).map(Result::ofPlanNode).orElse(Result.empty());
}
use of io.trino.sql.planner.plan.StatisticAggregations in project trino by trinodb.
the class LogicalPlanner method createAnalyzePlan.
private RelationPlan createAnalyzePlan(Analysis analysis, Analyze analyzeStatement) {
TableHandle targetTable = analysis.getAnalyzeTarget().orElseThrow();
// Plan table scan
Map<String, ColumnHandle> columnHandles = metadata.getColumnHandles(session, targetTable);
ImmutableList.Builder<Symbol> tableScanOutputs = ImmutableList.builder();
ImmutableMap.Builder<Symbol, ColumnHandle> symbolToColumnHandle = ImmutableMap.builder();
ImmutableMap.Builder<String, Symbol> columnNameToSymbol = ImmutableMap.builder();
TableMetadata tableMetadata = metadata.getTableMetadata(session, targetTable);
for (ColumnMetadata column : tableMetadata.getColumns()) {
Symbol symbol = symbolAllocator.newSymbol(column.getName(), column.getType());
tableScanOutputs.add(symbol);
symbolToColumnHandle.put(symbol, columnHandles.get(column.getName()));
columnNameToSymbol.put(column.getName(), symbol);
}
TableStatisticsMetadata tableStatisticsMetadata = metadata.getStatisticsCollectionMetadata(session, targetTable.getCatalogName().getCatalogName(), tableMetadata.getMetadata());
TableStatisticAggregation tableStatisticAggregation = statisticsAggregationPlanner.createStatisticsAggregation(tableStatisticsMetadata, columnNameToSymbol.buildOrThrow());
StatisticAggregations statisticAggregations = tableStatisticAggregation.getAggregations();
List<Symbol> groupingSymbols = statisticAggregations.getGroupingSymbols();
PlanNode planNode = new StatisticsWriterNode(idAllocator.getNextId(), new AggregationNode(idAllocator.getNextId(), TableScanNode.newInstance(idAllocator.getNextId(), targetTable, tableScanOutputs.build(), symbolToColumnHandle.buildOrThrow(), false, Optional.empty()), statisticAggregations.getAggregations(), singleGroupingSet(groupingSymbols), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty()), new StatisticsWriterNode.WriteStatisticsReference(targetTable), symbolAllocator.newSymbol("rows", BIGINT), tableStatisticsMetadata.getTableStatistics().contains(ROW_COUNT), tableStatisticAggregation.getDescriptor());
return new RelationPlan(planNode, analysis.getScope(analyzeStatement), planNode.getOutputSymbols(), Optional.empty());
}
use of io.trino.sql.planner.plan.StatisticAggregations in project trino by trinodb.
the class LogicalPlanner method createTableWriterPlan.
private RelationPlan createTableWriterPlan(Analysis analysis, PlanNode source, List<Symbol> symbols, WriterTarget target, List<String> columnNames, List<ColumnMetadata> columnMetadataList, Optional<TableLayout> writeTableLayout, TableStatisticsMetadata statisticsMetadata) {
Optional<PartitioningScheme> partitioningScheme = Optional.empty();
Optional<PartitioningScheme> preferredPartitioningScheme = Optional.empty();
if (writeTableLayout.isPresent()) {
List<Symbol> partitionFunctionArguments = new ArrayList<>();
writeTableLayout.get().getPartitionColumns().stream().mapToInt(columnNames::indexOf).mapToObj(symbols::get).forEach(partitionFunctionArguments::add);
List<Symbol> outputLayout = new ArrayList<>(symbols);
Optional<PartitioningHandle> partitioningHandle = writeTableLayout.get().getPartitioning();
if (partitioningHandle.isPresent()) {
partitioningScheme = Optional.of(new PartitioningScheme(Partitioning.create(partitioningHandle.get(), partitionFunctionArguments), outputLayout));
} else {
// empty connector partitioning handle means evenly partitioning on partitioning columns
preferredPartitioningScheme = Optional.of(new PartitioningScheme(Partitioning.create(FIXED_HASH_DISTRIBUTION, partitionFunctionArguments), outputLayout));
}
}
verify(columnNames.size() == symbols.size(), "columnNames.size() != symbols.size(): %s and %s", columnNames, symbols);
Map<String, Symbol> columnToSymbolMap = zip(columnNames.stream(), symbols.stream(), SimpleImmutableEntry::new).collect(toImmutableMap(Entry::getKey, Entry::getValue));
Set<Symbol> notNullColumnSymbols = columnMetadataList.stream().filter(column -> !column.isNullable()).map(ColumnMetadata::getName).map(columnToSymbolMap::get).collect(toImmutableSet());
if (!statisticsMetadata.isEmpty()) {
TableStatisticAggregation result = statisticsAggregationPlanner.createStatisticsAggregation(statisticsMetadata, columnToSymbolMap);
StatisticAggregations.Parts aggregations = result.getAggregations().createPartialAggregations(symbolAllocator, plannerContext);
// partial aggregation is run within the TableWriteOperator to calculate the statistics for
// the data consumed by the TableWriteOperator
// final aggregation is run within the TableFinishOperator to summarize collected statistics
// by the partial aggregation from all of the writer nodes
StatisticAggregations partialAggregation = aggregations.getPartialAggregation();
TableFinishNode commitNode = new TableFinishNode(idAllocator.getNextId(), new TableWriterNode(idAllocator.getNextId(), source, target, symbolAllocator.newSymbol("partialrows", BIGINT), symbolAllocator.newSymbol("fragment", VARBINARY), symbols, columnNames, notNullColumnSymbols, partitioningScheme, preferredPartitioningScheme, Optional.of(partialAggregation), Optional.of(result.getDescriptor().map(aggregations.getMappings()::get))), target, symbolAllocator.newSymbol("rows", BIGINT), Optional.of(aggregations.getFinalAggregation()), Optional.of(result.getDescriptor()));
return new RelationPlan(commitNode, analysis.getRootScope(), commitNode.getOutputSymbols(), Optional.empty());
}
TableFinishNode commitNode = new TableFinishNode(idAllocator.getNextId(), new TableWriterNode(idAllocator.getNextId(), source, target, symbolAllocator.newSymbol("partialrows", BIGINT), symbolAllocator.newSymbol("fragment", VARBINARY), symbols, columnNames, notNullColumnSymbols, partitioningScheme, preferredPartitioningScheme, Optional.empty(), Optional.empty()), target, symbolAllocator.newSymbol("rows", BIGINT), Optional.empty(), Optional.empty());
return new RelationPlan(commitNode, analysis.getRootScope(), commitNode.getOutputSymbols(), Optional.empty());
}
use of io.trino.sql.planner.plan.StatisticAggregations in project trino by trinodb.
the class StatisticsAggregationPlanner method createStatisticsAggregation.
public TableStatisticAggregation createStatisticsAggregation(TableStatisticsMetadata statisticsMetadata, Map<String, Symbol> columnToSymbolMap) {
StatisticAggregationsDescriptor.Builder<Symbol> descriptor = StatisticAggregationsDescriptor.builder();
List<String> groupingColumns = statisticsMetadata.getGroupingColumns();
List<Symbol> groupingSymbols = groupingColumns.stream().map(columnToSymbolMap::get).collect(toImmutableList());
for (int i = 0; i < groupingSymbols.size(); i++) {
descriptor.addGrouping(groupingColumns.get(i), groupingSymbols.get(i));
}
ImmutableMap.Builder<Symbol, AggregationNode.Aggregation> aggregations = ImmutableMap.builder();
for (TableStatisticType type : statisticsMetadata.getTableStatistics()) {
if (type != ROW_COUNT) {
throw new TrinoException(NOT_SUPPORTED, "Table-wide statistic type not supported: " + type);
}
AggregationNode.Aggregation aggregation = new AggregationNode.Aggregation(metadata.resolveFunction(session, QualifiedName.of("count"), ImmutableList.of()), ImmutableList.of(), false, Optional.empty(), Optional.empty(), Optional.empty());
Symbol symbol = symbolAllocator.newSymbol("rowCount", BIGINT);
aggregations.put(symbol, aggregation);
descriptor.addTableStatistic(ROW_COUNT, symbol);
}
for (ColumnStatisticMetadata columnStatisticMetadata : statisticsMetadata.getColumnStatistics()) {
String columnName = columnStatisticMetadata.getColumnName();
ColumnStatisticType statisticType = columnStatisticMetadata.getStatisticType();
Symbol inputSymbol = columnToSymbolMap.get(columnName);
verifyNotNull(inputSymbol, "inputSymbol is null");
Type inputType = symbolAllocator.getTypes().get(inputSymbol);
verifyNotNull(inputType, "inputType is null for symbol: %s", inputSymbol);
ColumnStatisticsAggregation aggregation = createColumnAggregation(statisticType, inputSymbol, inputType);
Symbol symbol = symbolAllocator.newSymbol(statisticType + ":" + columnName, aggregation.getOutputType());
aggregations.put(symbol, aggregation.getAggregation());
descriptor.addColumnStatistic(columnStatisticMetadata, symbol);
}
StatisticAggregations aggregation = new StatisticAggregations(aggregations.buildOrThrow(), groupingSymbols);
return new TableStatisticAggregation(aggregation, descriptor.build());
}
Aggregations