use of io.trino.spi.connector.AggregateFunction in project trino by trinodb.
the class TestDefaultJdbcMetadata method testAggregationPushdownForTableHandle.
@Test
public void testAggregationPushdownForTableHandle() {
ConnectorSession session = TestingConnectorSession.builder().setPropertyMetadata(new JdbcMetadataSessionProperties(new JdbcMetadataConfig().setAggregationPushdownEnabled(true), Optional.empty()).getSessionProperties()).build();
ColumnHandle groupByColumn = metadata.getColumnHandles(session, tableHandle).get("text");
Function<ConnectorTableHandle, Optional<AggregationApplicationResult<ConnectorTableHandle>>> applyAggregation = handle -> metadata.applyAggregation(session, handle, ImmutableList.of(new AggregateFunction("count", BIGINT, List.of(), List.of(), false, Optional.empty())), ImmutableMap.of(), ImmutableList.of(ImmutableList.of(groupByColumn)));
ConnectorTableHandle baseTableHandle = metadata.getTableHandle(session, new SchemaTableName("example", "numbers"));
Optional<AggregationApplicationResult<ConnectorTableHandle>> aggregationResult = applyAggregation.apply(baseTableHandle);
assertThat(aggregationResult).isPresent();
SchemaTableName noAggregationPushdownTable = new SchemaTableName("example", "no_aggregation_pushdown");
metadata.createTable(SESSION, new ConnectorTableMetadata(noAggregationPushdownTable, ImmutableList.of(new ColumnMetadata("text", VARCHAR))), false);
ConnectorTableHandle noAggregationPushdownTableHandle = metadata.getTableHandle(session, noAggregationPushdownTable);
aggregationResult = applyAggregation.apply(noAggregationPushdownTableHandle);
assertThat(aggregationResult).isEmpty();
}
use of io.trino.spi.connector.AggregateFunction in project trino by trinodb.
the class TestSqlServerClient method testImplementSum.
@Test
public void testImplementSum() {
Variable bigintVariable = new Variable("v_bigint", BIGINT);
Variable doubleVariable = new Variable("v_double", DOUBLE);
Optional<ConnectorExpression> filter = Optional.of(new Variable("a_filter", BOOLEAN));
// sum(bigint)
testImplementAggregation(new AggregateFunction("sum", BIGINT, List.of(bigintVariable), List.of(), false, Optional.empty()), Map.of(bigintVariable.getName(), BIGINT_COLUMN), Optional.of("sum(\"c_bigint\")"));
// sum(double)
testImplementAggregation(new AggregateFunction("sum", DOUBLE, List.of(doubleVariable), List.of(), false, Optional.empty()), Map.of(doubleVariable.getName(), DOUBLE_COLUMN), Optional.of("sum(\"c_double\")"));
// sum(DISTINCT bigint)
testImplementAggregation(new AggregateFunction("sum", BIGINT, List.of(bigintVariable), List.of(), true, Optional.empty()), Map.of(bigintVariable.getName(), BIGINT_COLUMN), // distinct not supported
Optional.empty());
// sum(bigint) FILTER (WHERE ...)
testImplementAggregation(new AggregateFunction("sum", BIGINT, List.of(bigintVariable), List.of(), false, filter), Map.of(bigintVariable.getName(), BIGINT_COLUMN), // filter not supported
Optional.empty());
}
use of io.trino.spi.connector.AggregateFunction in project trino by trinodb.
the class TestPostgreSqlClient method testImplementCount.
@Test
public void testImplementCount() {
Variable bigintVariable = new Variable("v_bigint", BIGINT);
Variable doubleVariable = new Variable("v_double", BIGINT);
Optional<ConnectorExpression> filter = Optional.of(new Variable("a_filter", BOOLEAN));
// count(*)
testImplementAggregation(new AggregateFunction("count", BIGINT, List.of(), List.of(), false, Optional.empty()), Map.of(), Optional.of("count(*)"));
// count(bigint)
testImplementAggregation(new AggregateFunction("count", BIGINT, List.of(bigintVariable), List.of(), false, Optional.empty()), Map.of(bigintVariable.getName(), BIGINT_COLUMN), Optional.of("count(\"c_bigint\")"));
// count(double)
testImplementAggregation(new AggregateFunction("count", BIGINT, List.of(doubleVariable), List.of(), false, Optional.empty()), Map.of(doubleVariable.getName(), DOUBLE_COLUMN), Optional.of("count(\"c_double\")"));
// count(DISTINCT bigint)
testImplementAggregation(new AggregateFunction("count", BIGINT, List.of(bigintVariable), List.of(), true, Optional.empty()), Map.of(bigintVariable.getName(), BIGINT_COLUMN), Optional.of("count(DISTINCT \"c_bigint\")"));
// count() FILTER (WHERE ...)
testImplementAggregation(new AggregateFunction("count", BIGINT, List.of(), List.of(), false, filter), Map.of(), Optional.empty());
// count(bigint) FILTER (WHERE ...)
testImplementAggregation(new AggregateFunction("count", BIGINT, List.of(bigintVariable), List.of(), false, filter), Map.of(bigintVariable.getName(), BIGINT_COLUMN), Optional.empty());
}
use of io.trino.spi.connector.AggregateFunction in project trino by trinodb.
the class PushAggregationIntoTableScan method toAggregateFunction.
private static AggregateFunction toAggregateFunction(Metadata metadata, Context context, AggregationNode.Aggregation aggregation) {
String canonicalName = metadata.getFunctionMetadata(aggregation.getResolvedFunction()).getCanonicalName();
BoundSignature signature = aggregation.getResolvedFunction().getSignature();
ImmutableList.Builder<ConnectorExpression> arguments = ImmutableList.builder();
for (int i = 0; i < aggregation.getArguments().size(); i++) {
SymbolReference argument = (SymbolReference) aggregation.getArguments().get(i);
arguments.add(new Variable(argument.getName(), signature.getArgumentTypes().get(i)));
}
Optional<OrderingScheme> orderingScheme = aggregation.getOrderingScheme();
Optional<List<SortItem>> sortBy = orderingScheme.map(OrderingScheme::toSortItems);
Optional<ConnectorExpression> filter = aggregation.getFilter().map(symbol -> new Variable(symbol.getName(), context.getSymbolAllocator().getTypes().get(symbol)));
return new AggregateFunction(canonicalName, signature.getReturnType(), arguments.build(), sortBy.orElse(ImmutableList.of()), aggregation.isDistinct(), filter);
}
use of io.trino.spi.connector.AggregateFunction in project trino by trinodb.
the class PushAggregationIntoTableScan method pushAggregationIntoTableScan.
public static Optional<PlanNode> pushAggregationIntoTableScan(PlannerContext plannerContext, Context context, PlanNode aggregationNode, TableScanNode tableScan, Map<Symbol, AggregationNode.Aggregation> aggregations, List<Symbol> groupingKeys) {
Map<String, ColumnHandle> assignments = tableScan.getAssignments().entrySet().stream().collect(toImmutableMap(entry -> entry.getKey().getName(), Entry::getValue));
List<Entry<Symbol, AggregationNode.Aggregation>> aggregationsList = aggregations.entrySet().stream().collect(toImmutableList());
List<AggregateFunction> aggregateFunctions = aggregationsList.stream().map(Entry::getValue).map(aggregation -> toAggregateFunction(plannerContext.getMetadata(), context, aggregation)).collect(toImmutableList());
List<Symbol> aggregationOutputSymbols = aggregationsList.stream().map(Entry::getKey).collect(toImmutableList());
List<ColumnHandle> groupByColumns = groupingKeys.stream().map(groupByColumn -> assignments.get(groupByColumn.getName())).collect(toImmutableList());
Optional<AggregationApplicationResult<TableHandle>> aggregationPushdownResult = plannerContext.getMetadata().applyAggregation(context.getSession(), tableScan.getTable(), aggregateFunctions, assignments, ImmutableList.of(groupByColumns));
if (aggregationPushdownResult.isEmpty()) {
return Optional.empty();
}
AggregationApplicationResult<TableHandle> result = aggregationPushdownResult.get();
// The new scan outputs should be the symbols associated with grouping columns plus the symbols associated with aggregations.
ImmutableList.Builder<Symbol> newScanOutputs = ImmutableList.builder();
newScanOutputs.addAll(tableScan.getOutputSymbols());
ImmutableBiMap.Builder<Symbol, ColumnHandle> newScanAssignments = ImmutableBiMap.builder();
newScanAssignments.putAll(tableScan.getAssignments());
Map<String, Symbol> variableMappings = new HashMap<>();
for (Assignment assignment : result.getAssignments()) {
Symbol symbol = context.getSymbolAllocator().newSymbol(assignment.getVariable(), assignment.getType());
newScanOutputs.add(symbol);
newScanAssignments.put(symbol, assignment.getColumn());
variableMappings.put(assignment.getVariable(), symbol);
}
List<Expression> newProjections = result.getProjections().stream().map(expression -> ConnectorExpressionTranslator.translate(context.getSession(), expression, plannerContext, variableMappings, new LiteralEncoder(plannerContext))).collect(toImmutableList());
verify(aggregationOutputSymbols.size() == newProjections.size());
Assignments.Builder assignmentBuilder = Assignments.builder();
IntStream.range(0, aggregationOutputSymbols.size()).forEach(index -> assignmentBuilder.put(aggregationOutputSymbols.get(index), newProjections.get(index)));
ImmutableBiMap<Symbol, ColumnHandle> scanAssignments = newScanAssignments.build();
ImmutableBiMap<ColumnHandle, Symbol> columnHandleToSymbol = scanAssignments.inverse();
// projections assignmentBuilder should have both agg and group by so we add all the group bys as symbol references
groupingKeys.forEach(groupBySymbol -> {
// if the connector returned a new mapping from oldColumnHandle to newColumnHandle, groupBy needs to point to
// new columnHandle's symbol reference, otherwise it will continue pointing at oldColumnHandle.
ColumnHandle originalColumnHandle = assignments.get(groupBySymbol.getName());
ColumnHandle groupByColumnHandle = result.getGroupingColumnMapping().getOrDefault(originalColumnHandle, originalColumnHandle);
assignmentBuilder.put(groupBySymbol, columnHandleToSymbol.get(groupByColumnHandle).toSymbolReference());
});
return Optional.of(new ProjectNode(context.getIdAllocator().getNextId(), new TableScanNode(context.getIdAllocator().getNextId(), result.getHandle(), newScanOutputs.build(), scanAssignments, TupleDomain.all(), deriveTableStatisticsForPushdown(context.getStatsProvider(), context.getSession(), result.isPrecalculateStatistics(), aggregationNode), tableScan.isUpdateTarget(), // table scan partitioning might have changed with new table handle
Optional.empty()), assignmentBuilder.build()));
}
Aggregations