use of io.trino.spi.connector.AggregationApplicationResult in project trino by trinodb.
the class MetadataManager method applyAggregation.
@Override
public Optional<AggregationApplicationResult<TableHandle>> applyAggregation(Session session, TableHandle table, List<AggregateFunction> aggregations, Map<String, ColumnHandle> assignments, List<List<ColumnHandle>> groupingSets) {
// Global aggregation is represented by [[]]
checkArgument(!groupingSets.isEmpty(), "No grouping sets provided");
CatalogName catalogName = table.getCatalogName();
ConnectorMetadata metadata = getMetadata(session, catalogName);
ConnectorSession connectorSession = session.toConnectorSession(catalogName);
return metadata.applyAggregation(connectorSession, table.getConnectorHandle(), aggregations, assignments, groupingSets).map(result -> {
verifyProjection(table, result.getProjections(), result.getAssignments(), aggregations.size());
return new AggregationApplicationResult<>(new TableHandle(catalogName, result.getHandle(), table.getTransaction()), result.getProjections(), result.getAssignments(), result.getGroupingColumnMapping(), result.isPrecalculateStatistics());
});
}
use of io.trino.spi.connector.AggregationApplicationResult in project trino by trinodb.
the class TestDefaultJdbcMetadata method testAggregationPushdownForTableHandle.
@Test
public void testAggregationPushdownForTableHandle() {
ConnectorSession session = TestingConnectorSession.builder().setPropertyMetadata(new JdbcMetadataSessionProperties(new JdbcMetadataConfig().setAggregationPushdownEnabled(true), Optional.empty()).getSessionProperties()).build();
ColumnHandle groupByColumn = metadata.getColumnHandles(session, tableHandle).get("text");
Function<ConnectorTableHandle, Optional<AggregationApplicationResult<ConnectorTableHandle>>> applyAggregation = handle -> metadata.applyAggregation(session, handle, ImmutableList.of(new AggregateFunction("count", BIGINT, List.of(), List.of(), false, Optional.empty())), ImmutableMap.of(), ImmutableList.of(ImmutableList.of(groupByColumn)));
ConnectorTableHandle baseTableHandle = metadata.getTableHandle(session, new SchemaTableName("example", "numbers"));
Optional<AggregationApplicationResult<ConnectorTableHandle>> aggregationResult = applyAggregation.apply(baseTableHandle);
assertThat(aggregationResult).isPresent();
SchemaTableName noAggregationPushdownTable = new SchemaTableName("example", "no_aggregation_pushdown");
metadata.createTable(SESSION, new ConnectorTableMetadata(noAggregationPushdownTable, ImmutableList.of(new ColumnMetadata("text", VARCHAR))), false);
ConnectorTableHandle noAggregationPushdownTableHandle = metadata.getTableHandle(session, noAggregationPushdownTable);
aggregationResult = applyAggregation.apply(noAggregationPushdownTableHandle);
assertThat(aggregationResult).isEmpty();
}
use of io.trino.spi.connector.AggregationApplicationResult in project trino by trinodb.
the class PushAggregationIntoTableScan method pushAggregationIntoTableScan.
public static Optional<PlanNode> pushAggregationIntoTableScan(PlannerContext plannerContext, Context context, PlanNode aggregationNode, TableScanNode tableScan, Map<Symbol, AggregationNode.Aggregation> aggregations, List<Symbol> groupingKeys) {
Map<String, ColumnHandle> assignments = tableScan.getAssignments().entrySet().stream().collect(toImmutableMap(entry -> entry.getKey().getName(), Entry::getValue));
List<Entry<Symbol, AggregationNode.Aggregation>> aggregationsList = aggregations.entrySet().stream().collect(toImmutableList());
List<AggregateFunction> aggregateFunctions = aggregationsList.stream().map(Entry::getValue).map(aggregation -> toAggregateFunction(plannerContext.getMetadata(), context, aggregation)).collect(toImmutableList());
List<Symbol> aggregationOutputSymbols = aggregationsList.stream().map(Entry::getKey).collect(toImmutableList());
List<ColumnHandle> groupByColumns = groupingKeys.stream().map(groupByColumn -> assignments.get(groupByColumn.getName())).collect(toImmutableList());
Optional<AggregationApplicationResult<TableHandle>> aggregationPushdownResult = plannerContext.getMetadata().applyAggregation(context.getSession(), tableScan.getTable(), aggregateFunctions, assignments, ImmutableList.of(groupByColumns));
if (aggregationPushdownResult.isEmpty()) {
return Optional.empty();
}
AggregationApplicationResult<TableHandle> result = aggregationPushdownResult.get();
// The new scan outputs should be the symbols associated with grouping columns plus the symbols associated with aggregations.
ImmutableList.Builder<Symbol> newScanOutputs = ImmutableList.builder();
newScanOutputs.addAll(tableScan.getOutputSymbols());
ImmutableBiMap.Builder<Symbol, ColumnHandle> newScanAssignments = ImmutableBiMap.builder();
newScanAssignments.putAll(tableScan.getAssignments());
Map<String, Symbol> variableMappings = new HashMap<>();
for (Assignment assignment : result.getAssignments()) {
Symbol symbol = context.getSymbolAllocator().newSymbol(assignment.getVariable(), assignment.getType());
newScanOutputs.add(symbol);
newScanAssignments.put(symbol, assignment.getColumn());
variableMappings.put(assignment.getVariable(), symbol);
}
List<Expression> newProjections = result.getProjections().stream().map(expression -> ConnectorExpressionTranslator.translate(context.getSession(), expression, plannerContext, variableMappings, new LiteralEncoder(plannerContext))).collect(toImmutableList());
verify(aggregationOutputSymbols.size() == newProjections.size());
Assignments.Builder assignmentBuilder = Assignments.builder();
IntStream.range(0, aggregationOutputSymbols.size()).forEach(index -> assignmentBuilder.put(aggregationOutputSymbols.get(index), newProjections.get(index)));
ImmutableBiMap<Symbol, ColumnHandle> scanAssignments = newScanAssignments.build();
ImmutableBiMap<ColumnHandle, Symbol> columnHandleToSymbol = scanAssignments.inverse();
// projections assignmentBuilder should have both agg and group by so we add all the group bys as symbol references
groupingKeys.forEach(groupBySymbol -> {
// if the connector returned a new mapping from oldColumnHandle to newColumnHandle, groupBy needs to point to
// new columnHandle's symbol reference, otherwise it will continue pointing at oldColumnHandle.
ColumnHandle originalColumnHandle = assignments.get(groupBySymbol.getName());
ColumnHandle groupByColumnHandle = result.getGroupingColumnMapping().getOrDefault(originalColumnHandle, originalColumnHandle);
assignmentBuilder.put(groupBySymbol, columnHandleToSymbol.get(groupByColumnHandle).toSymbolReference());
});
return Optional.of(new ProjectNode(context.getIdAllocator().getNextId(), new TableScanNode(context.getIdAllocator().getNextId(), result.getHandle(), newScanOutputs.build(), scanAssignments, TupleDomain.all(), deriveTableStatisticsForPushdown(context.getStatsProvider(), context.getSession(), result.isPrecalculateStatistics(), aggregationNode), tableScan.isUpdateTarget(), // table scan partitioning might have changed with new table handle
Optional.empty()), assignmentBuilder.build()));
}
use of io.trino.spi.connector.AggregationApplicationResult in project trino by trinodb.
the class DefaultJdbcMetadata method applyAggregation.
@Override
public Optional<AggregationApplicationResult<ConnectorTableHandle>> applyAggregation(ConnectorSession session, ConnectorTableHandle table, List<AggregateFunction> aggregates, Map<String, ColumnHandle> assignments, List<List<ColumnHandle>> groupingSets) {
if (!isAggregationPushdownEnabled(session)) {
return Optional.empty();
}
JdbcTableHandle handle = (JdbcTableHandle) table;
// Global aggregation is represented by [[]]
verify(!groupingSets.isEmpty(), "No grouping sets provided");
if (!jdbcClient.supportsAggregationPushdown(session, handle, aggregates, assignments, groupingSets)) {
// JDBC client implementation prevents pushdown for the given table
return Optional.empty();
}
if (handle.getLimit().isPresent()) {
handle = flushAttributesAsQuery(session, handle);
}
int nextSyntheticColumnId = handle.getNextSyntheticColumnId();
ImmutableList.Builder<JdbcColumnHandle> newColumns = ImmutableList.builder();
ImmutableList.Builder<ConnectorExpression> projections = ImmutableList.builder();
ImmutableList.Builder<Assignment> resultAssignments = ImmutableList.builder();
ImmutableMap.Builder<String, String> expressions = ImmutableMap.builder();
List<List<JdbcColumnHandle>> groupingSetsAsJdbcColumnHandles = groupingSets.stream().map(groupingSet -> groupingSet.stream().map(JdbcColumnHandle.class::cast).collect(toImmutableList())).collect(toImmutableList());
Optional<List<JdbcColumnHandle>> tableColumns = handle.getColumns();
groupingSetsAsJdbcColumnHandles.stream().flatMap(List::stream).distinct().peek(handle.getColumns().<Consumer<JdbcColumnHandle>>map(columns -> groupKey -> verify(columns.contains(groupKey), "applyAggregation called with a grouping column %s which was not included in the table columns: %s", groupKey, tableColumns)).orElse(groupKey -> {
})).forEach(newColumns::add);
for (AggregateFunction aggregate : aggregates) {
Optional<JdbcExpression> expression = jdbcClient.implementAggregation(session, aggregate, assignments);
if (expression.isEmpty()) {
return Optional.empty();
}
String columnName = SYNTHETIC_COLUMN_NAME_PREFIX + nextSyntheticColumnId;
nextSyntheticColumnId++;
JdbcColumnHandle newColumn = JdbcColumnHandle.builder().setColumnName(columnName).setJdbcTypeHandle(expression.get().getJdbcTypeHandle()).setColumnType(aggregate.getOutputType()).setComment(Optional.of("synthetic")).build();
newColumns.add(newColumn);
projections.add(new Variable(newColumn.getColumnName(), aggregate.getOutputType()));
resultAssignments.add(new Assignment(newColumn.getColumnName(), newColumn, aggregate.getOutputType()));
expressions.put(columnName, expression.get().getExpression());
}
List<JdbcColumnHandle> newColumnsList = newColumns.build();
// We need to have matching column handles in JdbcTableHandle constructed below, as columns read via JDBC must match column handles list.
// For more context see assertion in JdbcRecordSetProvider.getRecordSet
PreparedQuery preparedQuery = jdbcClient.prepareQuery(session, handle, Optional.of(groupingSetsAsJdbcColumnHandles), newColumnsList, expressions.buildOrThrow());
handle = new JdbcTableHandle(new JdbcQueryRelationHandle(preparedQuery), TupleDomain.all(), ImmutableList.of(), Optional.empty(), OptionalLong.empty(), Optional.of(newColumnsList), handle.getAllReferencedTables(), nextSyntheticColumnId);
return Optional.of(new AggregationApplicationResult<>(handle, projections.build(), resultAssignments.build(), ImmutableMap.of(), false));
}
use of io.trino.spi.connector.AggregationApplicationResult in project trino by trinodb.
the class TestPushDistinctLimitIntoTableScan method testPushDistinct.
@Test
public void testPushDistinct() {
AtomicInteger applyCallCounter = new AtomicInteger();
AtomicReference<List<AggregateFunction>> applyAggregates = new AtomicReference<>();
AtomicReference<Map<String, ColumnHandle>> applyAssignments = new AtomicReference<>();
AtomicReference<List<List<ColumnHandle>>> applyGroupingSets = new AtomicReference<>();
testApplyAggregation = (session, handle, aggregates, assignments, groupingSets) -> {
applyCallCounter.incrementAndGet();
applyAggregates.set(List.copyOf(aggregates));
applyAssignments.set(Map.copyOf(assignments));
applyGroupingSets.set(groupingSets.stream().map(List::copyOf).collect(toUnmodifiableList()));
return Optional.of(new AggregationApplicationResult<>(new MockConnectorTableHandle(new SchemaTableName("mock_schema", "mock_nation_aggregated")), List.of(), List.of(), Map.of(), false));
};
MockConnectorColumnHandle regionkeyHandle = new MockConnectorColumnHandle("regionkey", BIGINT);
tester().assertThat(rule).on(p -> {
Symbol regionkeySymbol = p.symbol("regionkey_symbol");
return p.distinctLimit(43, List.of(regionkeySymbol), p.tableScan(tableHandle, ImmutableList.of(regionkeySymbol), ImmutableMap.of(regionkeySymbol, regionkeyHandle)));
}).matches(limit(43, project(tableScan("mock_nation_aggregated"))));
assertThat(applyCallCounter).as("applyCallCounter").hasValue(1);
assertThat(applyAggregates).as("applyAggregates").hasValue(List.of());
assertThat(applyAssignments).as("applyAssignments").hasValue(Map.of("regionkey_symbol", regionkeyHandle));
assertThat(applyGroupingSets).as("applyGroupingSets").hasValue(List.of(List.of(regionkeyHandle)));
}
Aggregations