use of io.trino.spi.connector.AggregateFunction in project trino by trinodb.
the class DefaultJdbcMetadata method applyAggregation.
@Override
public Optional<AggregationApplicationResult<ConnectorTableHandle>> applyAggregation(ConnectorSession session, ConnectorTableHandle table, List<AggregateFunction> aggregates, Map<String, ColumnHandle> assignments, List<List<ColumnHandle>> groupingSets) {
if (!isAggregationPushdownEnabled(session)) {
return Optional.empty();
}
JdbcTableHandle handle = (JdbcTableHandle) table;
// Global aggregation is represented by [[]]
verify(!groupingSets.isEmpty(), "No grouping sets provided");
if (!jdbcClient.supportsAggregationPushdown(session, handle, aggregates, assignments, groupingSets)) {
// JDBC client implementation prevents pushdown for the given table
return Optional.empty();
}
if (handle.getLimit().isPresent()) {
handle = flushAttributesAsQuery(session, handle);
}
int nextSyntheticColumnId = handle.getNextSyntheticColumnId();
ImmutableList.Builder<JdbcColumnHandle> newColumns = ImmutableList.builder();
ImmutableList.Builder<ConnectorExpression> projections = ImmutableList.builder();
ImmutableList.Builder<Assignment> resultAssignments = ImmutableList.builder();
ImmutableMap.Builder<String, String> expressions = ImmutableMap.builder();
List<List<JdbcColumnHandle>> groupingSetsAsJdbcColumnHandles = groupingSets.stream().map(groupingSet -> groupingSet.stream().map(JdbcColumnHandle.class::cast).collect(toImmutableList())).collect(toImmutableList());
Optional<List<JdbcColumnHandle>> tableColumns = handle.getColumns();
groupingSetsAsJdbcColumnHandles.stream().flatMap(List::stream).distinct().peek(handle.getColumns().<Consumer<JdbcColumnHandle>>map(columns -> groupKey -> verify(columns.contains(groupKey), "applyAggregation called with a grouping column %s which was not included in the table columns: %s", groupKey, tableColumns)).orElse(groupKey -> {
})).forEach(newColumns::add);
for (AggregateFunction aggregate : aggregates) {
Optional<JdbcExpression> expression = jdbcClient.implementAggregation(session, aggregate, assignments);
if (expression.isEmpty()) {
return Optional.empty();
}
String columnName = SYNTHETIC_COLUMN_NAME_PREFIX + nextSyntheticColumnId;
nextSyntheticColumnId++;
JdbcColumnHandle newColumn = JdbcColumnHandle.builder().setColumnName(columnName).setJdbcTypeHandle(expression.get().getJdbcTypeHandle()).setColumnType(aggregate.getOutputType()).setComment(Optional.of("synthetic")).build();
newColumns.add(newColumn);
projections.add(new Variable(newColumn.getColumnName(), aggregate.getOutputType()));
resultAssignments.add(new Assignment(newColumn.getColumnName(), newColumn, aggregate.getOutputType()));
expressions.put(columnName, expression.get().getExpression());
}
List<JdbcColumnHandle> newColumnsList = newColumns.build();
// We need to have matching column handles in JdbcTableHandle constructed below, as columns read via JDBC must match column handles list.
// For more context see assertion in JdbcRecordSetProvider.getRecordSet
PreparedQuery preparedQuery = jdbcClient.prepareQuery(session, handle, Optional.of(groupingSetsAsJdbcColumnHandles), newColumnsList, expressions.buildOrThrow());
handle = new JdbcTableHandle(new JdbcQueryRelationHandle(preparedQuery), TupleDomain.all(), ImmutableList.of(), Optional.empty(), OptionalLong.empty(), Optional.of(newColumnsList), handle.getAllReferencedTables(), nextSyntheticColumnId);
return Optional.of(new AggregationApplicationResult<>(handle, projections.build(), resultAssignments.build(), ImmutableMap.of(), false));
}
use of io.trino.spi.connector.AggregateFunction in project trino by trinodb.
the class TestPushDistinctLimitIntoTableScan method testPushDistinct.
@Test
public void testPushDistinct() {
AtomicInteger applyCallCounter = new AtomicInteger();
AtomicReference<List<AggregateFunction>> applyAggregates = new AtomicReference<>();
AtomicReference<Map<String, ColumnHandle>> applyAssignments = new AtomicReference<>();
AtomicReference<List<List<ColumnHandle>>> applyGroupingSets = new AtomicReference<>();
testApplyAggregation = (session, handle, aggregates, assignments, groupingSets) -> {
applyCallCounter.incrementAndGet();
applyAggregates.set(List.copyOf(aggregates));
applyAssignments.set(Map.copyOf(assignments));
applyGroupingSets.set(groupingSets.stream().map(List::copyOf).collect(toUnmodifiableList()));
return Optional.of(new AggregationApplicationResult<>(new MockConnectorTableHandle(new SchemaTableName("mock_schema", "mock_nation_aggregated")), List.of(), List.of(), Map.of(), false));
};
MockConnectorColumnHandle regionkeyHandle = new MockConnectorColumnHandle("regionkey", BIGINT);
tester().assertThat(rule).on(p -> {
Symbol regionkeySymbol = p.symbol("regionkey_symbol");
return p.distinctLimit(43, List.of(regionkeySymbol), p.tableScan(tableHandle, ImmutableList.of(regionkeySymbol), ImmutableMap.of(regionkeySymbol, regionkeyHandle)));
}).matches(limit(43, project(tableScan("mock_nation_aggregated"))));
assertThat(applyCallCounter).as("applyCallCounter").hasValue(1);
assertThat(applyAggregates).as("applyAggregates").hasValue(List.of());
assertThat(applyAssignments).as("applyAssignments").hasValue(Map.of("regionkey_symbol", regionkeyHandle));
assertThat(applyGroupingSets).as("applyGroupingSets").hasValue(List.of(List.of(regionkeyHandle)));
}
use of io.trino.spi.connector.AggregateFunction in project trino by trinodb.
the class TestMySqlClient method testImplementSum.
@Test
public void testImplementSum() {
Variable bigintVariable = new Variable("v_bigint", BIGINT);
Variable doubleVariable = new Variable("v_double", DOUBLE);
Optional<ConnectorExpression> filter = Optional.of(new Variable("a_filter", BOOLEAN));
// sum(bigint)
testImplementAggregation(new AggregateFunction("sum", BIGINT, List.of(bigintVariable), List.of(), false, Optional.empty()), Map.of(bigintVariable.getName(), BIGINT_COLUMN), Optional.of("sum(`c_bigint`)"));
// sum(double)
testImplementAggregation(new AggregateFunction("sum", DOUBLE, List.of(doubleVariable), List.of(), false, Optional.empty()), Map.of(doubleVariable.getName(), DOUBLE_COLUMN), Optional.of("sum(`c_double`)"));
// sum(DISTINCT bigint)
testImplementAggregation(new AggregateFunction("sum", BIGINT, List.of(bigintVariable), List.of(), true, Optional.empty()), Map.of(bigintVariable.getName(), BIGINT_COLUMN), // distinct not supported
Optional.empty());
// sum(bigint) FILTER (WHERE ...)
testImplementAggregation(new AggregateFunction("sum", BIGINT, List.of(bigintVariable), List.of(), false, filter), Map.of(bigintVariable.getName(), BIGINT_COLUMN), // filter not supported
Optional.empty());
}
use of io.trino.spi.connector.AggregateFunction in project trino by trinodb.
the class TestMySqlClient method testImplementCount.
@Test
public void testImplementCount() {
Variable bigintVariable = new Variable("v_bigint", BIGINT);
Variable doubleVariable = new Variable("v_double", BIGINT);
Optional<ConnectorExpression> filter = Optional.of(new Variable("a_filter", BOOLEAN));
// count(*)
testImplementAggregation(new AggregateFunction("count", BIGINT, List.of(), List.of(), false, Optional.empty()), Map.of(), Optional.of("count(*)"));
// count(bigint)
testImplementAggregation(new AggregateFunction("count", BIGINT, List.of(bigintVariable), List.of(), false, Optional.empty()), Map.of(bigintVariable.getName(), BIGINT_COLUMN), Optional.of("count(`c_bigint`)"));
// count(double)
testImplementAggregation(new AggregateFunction("count", BIGINT, List.of(doubleVariable), List.of(), false, Optional.empty()), Map.of(doubleVariable.getName(), DOUBLE_COLUMN), Optional.of("count(`c_double`)"));
// count(DISTINCT bigint)
testImplementAggregation(new AggregateFunction("count", BIGINT, List.of(bigintVariable), List.of(), true, Optional.empty()), Map.of(bigintVariable.getName(), BIGINT_COLUMN), Optional.empty());
// count() FILTER (WHERE ...)
testImplementAggregation(new AggregateFunction("count", BIGINT, List.of(), List.of(), false, filter), Map.of(), Optional.empty());
// count(bigint) FILTER (WHERE ...)
testImplementAggregation(new AggregateFunction("count", BIGINT, List.of(bigintVariable), List.of(), false, filter), Map.of(bigintVariable.getName(), BIGINT_COLUMN), Optional.empty());
}
use of io.trino.spi.connector.AggregateFunction in project trino by trinodb.
the class PinotMetadata method applyAggregation.
@Override
public Optional<AggregationApplicationResult<ConnectorTableHandle>> applyAggregation(ConnectorSession session, ConnectorTableHandle handle, List<AggregateFunction> aggregates, Map<String, ColumnHandle> assignments, List<List<ColumnHandle>> groupingSets) {
if (!isAggregationPushdownEnabled(session)) {
return Optional.empty();
}
// Global aggregation is represented by [[]]
verify(!groupingSets.isEmpty(), "No grouping sets provided");
// Pinot currently only supports simple GROUP BY clauses with a single grouping set
if (groupingSets.size() != 1) {
return Optional.empty();
}
// See https://github.com/apache/pinot/issues/8353 for more details.
if (getOnlyElement(groupingSets).stream().filter(columnHandle -> ((PinotColumnHandle) columnHandle).getDataType() instanceof ArrayType).findFirst().isPresent()) {
return Optional.empty();
}
PinotTableHandle tableHandle = (PinotTableHandle) handle;
// If there is an offset then do not push the aggregation down as the results will not be correct
if (tableHandle.getQuery().isPresent() && (!tableHandle.getQuery().get().getAggregateColumns().isEmpty() || tableHandle.getQuery().get().isAggregateInProjections() || tableHandle.getQuery().get().getOffset().isPresent())) {
return Optional.empty();
}
ImmutableList.Builder<ConnectorExpression> projections = ImmutableList.builder();
ImmutableList.Builder<Assignment> resultAssignments = ImmutableList.builder();
ImmutableList.Builder<PinotColumnHandle> aggregateColumnsBuilder = ImmutableList.builder();
for (AggregateFunction aggregate : aggregates) {
Optional<AggregateExpression> rewriteResult = aggregateFunctionRewriter.rewrite(session, aggregate, assignments);
rewriteResult = applyCountDistinct(session, aggregate, assignments, tableHandle, rewriteResult);
if (rewriteResult.isEmpty()) {
return Optional.empty();
}
AggregateExpression aggregateExpression = rewriteResult.get();
PinotColumnHandle pinotColumnHandle = new PinotColumnHandle(aggregateExpression.toFieldName(), aggregate.getOutputType(), aggregateExpression.toExpression(), false, true, aggregateExpression.isReturnNullOnEmptyGroup(), Optional.of(aggregateExpression.getFunction()), Optional.of(aggregateExpression.getArgument()));
aggregateColumnsBuilder.add(pinotColumnHandle);
projections.add(new Variable(pinotColumnHandle.getColumnName(), pinotColumnHandle.getDataType()));
resultAssignments.add(new Assignment(pinotColumnHandle.getColumnName(), pinotColumnHandle, pinotColumnHandle.getDataType()));
}
List<PinotColumnHandle> groupingColumns = getOnlyElement(groupingSets).stream().map(PinotColumnHandle.class::cast).map(PinotColumnHandle::fromNonAggregateColumnHandle).collect(toImmutableList());
OptionalLong limitForDynamicTable = OptionalLong.empty();
// know when the limit was exceeded and throw an error
if (tableHandle.getLimit().isEmpty() && !groupingColumns.isEmpty()) {
limitForDynamicTable = OptionalLong.of(maxRowsPerBrokerQuery + 1);
}
List<PinotColumnHandle> aggregationColumns = aggregateColumnsBuilder.build();
String newQuery = "";
List<PinotColumnHandle> newSelections = groupingColumns;
if (tableHandle.getQuery().isPresent()) {
newQuery = tableHandle.getQuery().get().getQuery();
Map<String, PinotColumnHandle> projectionsMap = tableHandle.getQuery().get().getProjections().stream().collect(toImmutableMap(PinotColumnHandle::getColumnName, identity()));
groupingColumns = groupingColumns.stream().map(groupIngColumn -> projectionsMap.getOrDefault(groupIngColumn.getColumnName(), groupIngColumn)).collect(toImmutableList());
ImmutableList.Builder<PinotColumnHandle> newSelectionsBuilder = ImmutableList.<PinotColumnHandle>builder().addAll(groupingColumns);
aggregationColumns = aggregationColumns.stream().map(aggregateExpression -> resolveAggregateExpressionWithAlias(aggregateExpression, projectionsMap)).collect(toImmutableList());
newSelections = newSelectionsBuilder.build();
}
DynamicTable dynamicTable = new DynamicTable(tableHandle.getTableName(), Optional.empty(), newSelections, tableHandle.getQuery().flatMap(DynamicTable::getFilter), groupingColumns, aggregationColumns, ImmutableList.of(), limitForDynamicTable, OptionalLong.empty(), newQuery);
tableHandle = new PinotTableHandle(tableHandle.getSchemaName(), tableHandle.getTableName(), tableHandle.getConstraint(), tableHandle.getLimit(), Optional.of(dynamicTable));
return Optional.of(new AggregationApplicationResult<>(tableHandle, projections.build(), resultAssignments.build(), ImmutableMap.of(), false));
}
Aggregations