Search in sources :

Example 1 with AggregateFunction

use of io.trino.spi.connector.AggregateFunction in project trino by trinodb.

the class TestDefaultJdbcMetadata method testAggregationPushdownForTableHandle.

@Test
public void testAggregationPushdownForTableHandle() {
    ConnectorSession session = TestingConnectorSession.builder().setPropertyMetadata(new JdbcMetadataSessionProperties(new JdbcMetadataConfig().setAggregationPushdownEnabled(true), Optional.empty()).getSessionProperties()).build();
    ColumnHandle groupByColumn = metadata.getColumnHandles(session, tableHandle).get("text");
    Function<ConnectorTableHandle, Optional<AggregationApplicationResult<ConnectorTableHandle>>> applyAggregation = handle -> metadata.applyAggregation(session, handle, ImmutableList.of(new AggregateFunction("count", BIGINT, List.of(), List.of(), false, Optional.empty())), ImmutableMap.of(), ImmutableList.of(ImmutableList.of(groupByColumn)));
    ConnectorTableHandle baseTableHandle = metadata.getTableHandle(session, new SchemaTableName("example", "numbers"));
    Optional<AggregationApplicationResult<ConnectorTableHandle>> aggregationResult = applyAggregation.apply(baseTableHandle);
    assertThat(aggregationResult).isPresent();
    SchemaTableName noAggregationPushdownTable = new SchemaTableName("example", "no_aggregation_pushdown");
    metadata.createTable(SESSION, new ConnectorTableMetadata(noAggregationPushdownTable, ImmutableList.of(new ColumnMetadata("text", VARCHAR))), false);
    ConnectorTableHandle noAggregationPushdownTableHandle = metadata.getTableHandle(session, noAggregationPushdownTable);
    aggregationResult = applyAggregation.apply(noAggregationPushdownTableHandle);
    assertThat(aggregationResult).isEmpty();
}
Also used : Constraint(io.trino.spi.connector.Constraint) Assert.assertNull(org.testng.Assert.assertNull) ColumnMetadata(io.trino.spi.connector.ColumnMetadata) AggregateFunction(io.trino.spi.connector.AggregateFunction) Assertions.assertThat(org.assertj.core.api.Assertions.assertThat) Assert.assertEquals(org.testng.Assert.assertEquals) ConnectorTableMetadata(io.trino.spi.connector.ConnectorTableMetadata) Test(org.testng.annotations.Test) AfterMethod(org.testng.annotations.AfterMethod) TrinoExceptionAssert.assertTrinoExceptionThrownBy(io.trino.testing.assertions.TrinoExceptionAssert.assertTrinoExceptionThrownBy) Function(java.util.function.Function) VARCHAR(io.trino.spi.type.VarcharType.VARCHAR) TableNotFoundException(io.trino.spi.connector.TableNotFoundException) ImmutableList(com.google.common.collect.ImmutableList) Assertions.assertThatThrownBy(org.assertj.core.api.Assertions.assertThatThrownBy) ConnectorTableHandle(io.trino.spi.connector.ConnectorTableHandle) Map(java.util.Map) ColumnHandle(io.trino.spi.connector.ColumnHandle) Slices.utf8Slice(io.airlift.slice.Slices.utf8Slice) NOT_FOUND(io.trino.spi.StandardErrorCode.NOT_FOUND) JDBC_VARCHAR(io.trino.plugin.jdbc.TestingJdbcTypeHandle.JDBC_VARCHAR) Collections.emptyMap(java.util.Collections.emptyMap) ImmutableSet(com.google.common.collect.ImmutableSet) ConstraintApplicationResult(io.trino.spi.connector.ConstraintApplicationResult) ImmutableMap(com.google.common.collect.ImmutableMap) Domain(io.trino.spi.predicate.Domain) BeforeMethod(org.testng.annotations.BeforeMethod) JDBC_BIGINT(io.trino.plugin.jdbc.TestingJdbcTypeHandle.JDBC_BIGINT) SESSION(io.trino.testing.TestingConnectorSession.SESSION) ConnectorSession(io.trino.spi.connector.ConnectorSession) TupleDomain(io.trino.spi.predicate.TupleDomain) SchemaTableName(io.trino.spi.connector.SchemaTableName) AggregationApplicationResult(io.trino.spi.connector.AggregationApplicationResult) TestingConnectorSession(io.trino.testing.TestingConnectorSession) List(java.util.List) BIGINT(io.trino.spi.type.BigintType.BIGINT) Optional(java.util.Optional) Assert.assertTrue(org.testng.Assert.assertTrue) VarcharType.createVarcharType(io.trino.spi.type.VarcharType.createVarcharType) ColumnHandle(io.trino.spi.connector.ColumnHandle) ColumnMetadata(io.trino.spi.connector.ColumnMetadata) Optional(java.util.Optional) SchemaTableName(io.trino.spi.connector.SchemaTableName) ConnectorTableHandle(io.trino.spi.connector.ConnectorTableHandle) AggregateFunction(io.trino.spi.connector.AggregateFunction) ConnectorSession(io.trino.spi.connector.ConnectorSession) TestingConnectorSession(io.trino.testing.TestingConnectorSession) AggregationApplicationResult(io.trino.spi.connector.AggregationApplicationResult) ConnectorTableMetadata(io.trino.spi.connector.ConnectorTableMetadata) Test(org.testng.annotations.Test)

Example 2 with AggregateFunction

use of io.trino.spi.connector.AggregateFunction in project trino by trinodb.

the class TestSqlServerClient method testImplementSum.

@Test
public void testImplementSum() {
    Variable bigintVariable = new Variable("v_bigint", BIGINT);
    Variable doubleVariable = new Variable("v_double", DOUBLE);
    Optional<ConnectorExpression> filter = Optional.of(new Variable("a_filter", BOOLEAN));
    // sum(bigint)
    testImplementAggregation(new AggregateFunction("sum", BIGINT, List.of(bigintVariable), List.of(), false, Optional.empty()), Map.of(bigintVariable.getName(), BIGINT_COLUMN), Optional.of("sum(\"c_bigint\")"));
    // sum(double)
    testImplementAggregation(new AggregateFunction("sum", DOUBLE, List.of(doubleVariable), List.of(), false, Optional.empty()), Map.of(doubleVariable.getName(), DOUBLE_COLUMN), Optional.of("sum(\"c_double\")"));
    // sum(DISTINCT bigint)
    testImplementAggregation(new AggregateFunction("sum", BIGINT, List.of(bigintVariable), List.of(), true, Optional.empty()), Map.of(bigintVariable.getName(), BIGINT_COLUMN), // distinct not supported
    Optional.empty());
    // sum(bigint) FILTER (WHERE ...)
    testImplementAggregation(new AggregateFunction("sum", BIGINT, List.of(bigintVariable), List.of(), false, filter), Map.of(bigintVariable.getName(), BIGINT_COLUMN), // filter not supported
    Optional.empty());
}
Also used : Variable(io.trino.spi.expression.Variable) ConnectorExpression(io.trino.spi.expression.ConnectorExpression) AggregateFunction(io.trino.spi.connector.AggregateFunction) Test(org.testng.annotations.Test)

Example 3 with AggregateFunction

use of io.trino.spi.connector.AggregateFunction in project trino by trinodb.

the class TestPostgreSqlClient method testImplementCount.

@Test
public void testImplementCount() {
    Variable bigintVariable = new Variable("v_bigint", BIGINT);
    Variable doubleVariable = new Variable("v_double", BIGINT);
    Optional<ConnectorExpression> filter = Optional.of(new Variable("a_filter", BOOLEAN));
    // count(*)
    testImplementAggregation(new AggregateFunction("count", BIGINT, List.of(), List.of(), false, Optional.empty()), Map.of(), Optional.of("count(*)"));
    // count(bigint)
    testImplementAggregation(new AggregateFunction("count", BIGINT, List.of(bigintVariable), List.of(), false, Optional.empty()), Map.of(bigintVariable.getName(), BIGINT_COLUMN), Optional.of("count(\"c_bigint\")"));
    // count(double)
    testImplementAggregation(new AggregateFunction("count", BIGINT, List.of(doubleVariable), List.of(), false, Optional.empty()), Map.of(doubleVariable.getName(), DOUBLE_COLUMN), Optional.of("count(\"c_double\")"));
    // count(DISTINCT bigint)
    testImplementAggregation(new AggregateFunction("count", BIGINT, List.of(bigintVariable), List.of(), true, Optional.empty()), Map.of(bigintVariable.getName(), BIGINT_COLUMN), Optional.of("count(DISTINCT \"c_bigint\")"));
    // count() FILTER (WHERE ...)
    testImplementAggregation(new AggregateFunction("count", BIGINT, List.of(), List.of(), false, filter), Map.of(), Optional.empty());
    // count(bigint) FILTER (WHERE ...)
    testImplementAggregation(new AggregateFunction("count", BIGINT, List.of(bigintVariable), List.of(), false, filter), Map.of(bigintVariable.getName(), BIGINT_COLUMN), Optional.empty());
}
Also used : Variable(io.trino.spi.expression.Variable) ConnectorExpression(io.trino.spi.expression.ConnectorExpression) AggregateFunction(io.trino.spi.connector.AggregateFunction) Test(org.testng.annotations.Test)

Example 4 with AggregateFunction

use of io.trino.spi.connector.AggregateFunction in project trino by trinodb.

the class PushAggregationIntoTableScan method toAggregateFunction.

private static AggregateFunction toAggregateFunction(Metadata metadata, Context context, AggregationNode.Aggregation aggregation) {
    String canonicalName = metadata.getFunctionMetadata(aggregation.getResolvedFunction()).getCanonicalName();
    BoundSignature signature = aggregation.getResolvedFunction().getSignature();
    ImmutableList.Builder<ConnectorExpression> arguments = ImmutableList.builder();
    for (int i = 0; i < aggregation.getArguments().size(); i++) {
        SymbolReference argument = (SymbolReference) aggregation.getArguments().get(i);
        arguments.add(new Variable(argument.getName(), signature.getArgumentTypes().get(i)));
    }
    Optional<OrderingScheme> orderingScheme = aggregation.getOrderingScheme();
    Optional<List<SortItem>> sortBy = orderingScheme.map(OrderingScheme::toSortItems);
    Optional<ConnectorExpression> filter = aggregation.getFilter().map(symbol -> new Variable(symbol.getName(), context.getSymbolAllocator().getTypes().get(symbol)));
    return new AggregateFunction(canonicalName, signature.getReturnType(), arguments.build(), sortBy.orElse(ImmutableList.of()), aggregation.isDistinct(), filter);
}
Also used : OrderingScheme(io.trino.sql.planner.OrderingScheme) Variable(io.trino.spi.expression.Variable) ImmutableList(com.google.common.collect.ImmutableList) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) ConnectorExpression(io.trino.spi.expression.ConnectorExpression) SymbolReference(io.trino.sql.tree.SymbolReference) BoundSignature(io.trino.metadata.BoundSignature) AggregateFunction(io.trino.spi.connector.AggregateFunction) ImmutableList(com.google.common.collect.ImmutableList) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) List(java.util.List)

Example 5 with AggregateFunction

use of io.trino.spi.connector.AggregateFunction in project trino by trinodb.

the class PushAggregationIntoTableScan method pushAggregationIntoTableScan.

public static Optional<PlanNode> pushAggregationIntoTableScan(PlannerContext plannerContext, Context context, PlanNode aggregationNode, TableScanNode tableScan, Map<Symbol, AggregationNode.Aggregation> aggregations, List<Symbol> groupingKeys) {
    Map<String, ColumnHandle> assignments = tableScan.getAssignments().entrySet().stream().collect(toImmutableMap(entry -> entry.getKey().getName(), Entry::getValue));
    List<Entry<Symbol, AggregationNode.Aggregation>> aggregationsList = aggregations.entrySet().stream().collect(toImmutableList());
    List<AggregateFunction> aggregateFunctions = aggregationsList.stream().map(Entry::getValue).map(aggregation -> toAggregateFunction(plannerContext.getMetadata(), context, aggregation)).collect(toImmutableList());
    List<Symbol> aggregationOutputSymbols = aggregationsList.stream().map(Entry::getKey).collect(toImmutableList());
    List<ColumnHandle> groupByColumns = groupingKeys.stream().map(groupByColumn -> assignments.get(groupByColumn.getName())).collect(toImmutableList());
    Optional<AggregationApplicationResult<TableHandle>> aggregationPushdownResult = plannerContext.getMetadata().applyAggregation(context.getSession(), tableScan.getTable(), aggregateFunctions, assignments, ImmutableList.of(groupByColumns));
    if (aggregationPushdownResult.isEmpty()) {
        return Optional.empty();
    }
    AggregationApplicationResult<TableHandle> result = aggregationPushdownResult.get();
    // The new scan outputs should be the symbols associated with grouping columns plus the symbols associated with aggregations.
    ImmutableList.Builder<Symbol> newScanOutputs = ImmutableList.builder();
    newScanOutputs.addAll(tableScan.getOutputSymbols());
    ImmutableBiMap.Builder<Symbol, ColumnHandle> newScanAssignments = ImmutableBiMap.builder();
    newScanAssignments.putAll(tableScan.getAssignments());
    Map<String, Symbol> variableMappings = new HashMap<>();
    for (Assignment assignment : result.getAssignments()) {
        Symbol symbol = context.getSymbolAllocator().newSymbol(assignment.getVariable(), assignment.getType());
        newScanOutputs.add(symbol);
        newScanAssignments.put(symbol, assignment.getColumn());
        variableMappings.put(assignment.getVariable(), symbol);
    }
    List<Expression> newProjections = result.getProjections().stream().map(expression -> ConnectorExpressionTranslator.translate(context.getSession(), expression, plannerContext, variableMappings, new LiteralEncoder(plannerContext))).collect(toImmutableList());
    verify(aggregationOutputSymbols.size() == newProjections.size());
    Assignments.Builder assignmentBuilder = Assignments.builder();
    IntStream.range(0, aggregationOutputSymbols.size()).forEach(index -> assignmentBuilder.put(aggregationOutputSymbols.get(index), newProjections.get(index)));
    ImmutableBiMap<Symbol, ColumnHandle> scanAssignments = newScanAssignments.build();
    ImmutableBiMap<ColumnHandle, Symbol> columnHandleToSymbol = scanAssignments.inverse();
    // projections assignmentBuilder should have both agg and group by so we add all the group bys as symbol references
    groupingKeys.forEach(groupBySymbol -> {
        // if the connector returned a new mapping from oldColumnHandle to newColumnHandle, groupBy needs to point to
        // new columnHandle's symbol reference, otherwise it will continue pointing at oldColumnHandle.
        ColumnHandle originalColumnHandle = assignments.get(groupBySymbol.getName());
        ColumnHandle groupByColumnHandle = result.getGroupingColumnMapping().getOrDefault(originalColumnHandle, originalColumnHandle);
        assignmentBuilder.put(groupBySymbol, columnHandleToSymbol.get(groupByColumnHandle).toSymbolReference());
    });
    return Optional.of(new ProjectNode(context.getIdAllocator().getNextId(), new TableScanNode(context.getIdAllocator().getNextId(), result.getHandle(), newScanOutputs.build(), scanAssignments, TupleDomain.all(), deriveTableStatisticsForPushdown(context.getStatsProvider(), context.getSession(), result.isPrecalculateStatistics(), aggregationNode), tableScan.isUpdateTarget(), // table scan partitioning might have changed with new table handle
    Optional.empty()), assignmentBuilder.build()));
}
Also used : IntStream(java.util.stream.IntStream) SortItem(io.trino.spi.connector.SortItem) LiteralEncoder(io.trino.sql.planner.LiteralEncoder) Aggregation.step(io.trino.sql.planner.plan.Patterns.Aggregation.step) AggregateFunction(io.trino.spi.connector.AggregateFunction) HashMap(java.util.HashMap) Variable(io.trino.spi.expression.Variable) Capture.newCapture(io.trino.matching.Capture.newCapture) PlanNode(io.trino.sql.planner.plan.PlanNode) ImmutableBiMap(com.google.common.collect.ImmutableBiMap) Patterns.aggregation(io.trino.sql.planner.plan.Patterns.aggregation) ImmutableList(com.google.common.collect.ImmutableList) Verify.verify(com.google.common.base.Verify.verify) Map(java.util.Map) Objects.requireNonNull(java.util.Objects.requireNonNull) ColumnHandle(io.trino.spi.connector.ColumnHandle) AggregationNode(io.trino.sql.planner.plan.AggregationNode) Rule(io.trino.sql.planner.iterative.Rule) ProjectNode(io.trino.sql.planner.plan.ProjectNode) TableScanNode(io.trino.sql.planner.plan.TableScanNode) Symbol(io.trino.sql.planner.Symbol) Rules.deriveTableStatisticsForPushdown(io.trino.sql.planner.iterative.rule.Rules.deriveTableStatisticsForPushdown) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) ConnectorExpressionTranslator(io.trino.sql.planner.ConnectorExpressionTranslator) Assignments(io.trino.sql.planner.plan.Assignments) TupleDomain(io.trino.spi.predicate.TupleDomain) OrderingScheme(io.trino.sql.planner.OrderingScheme) Patterns.tableScan(io.trino.sql.planner.plan.Patterns.tableScan) Capture(io.trino.matching.Capture) AggregationApplicationResult(io.trino.spi.connector.AggregationApplicationResult) List(java.util.List) ImmutableMap.toImmutableMap(com.google.common.collect.ImmutableMap.toImmutableMap) Pattern(io.trino.matching.Pattern) TableHandle(io.trino.metadata.TableHandle) ConnectorExpression(io.trino.spi.expression.ConnectorExpression) SystemSessionProperties.isAllowPushdownIntoConnectors(io.trino.SystemSessionProperties.isAllowPushdownIntoConnectors) Patterns.source(io.trino.sql.planner.plan.Patterns.source) SymbolReference(io.trino.sql.tree.SymbolReference) Captures(io.trino.matching.Captures) BoundSignature(io.trino.metadata.BoundSignature) Assignment(io.trino.spi.connector.Assignment) Entry(java.util.Map.Entry) Metadata(io.trino.metadata.Metadata) Optional(java.util.Optional) Expression(io.trino.sql.tree.Expression) Session(io.trino.Session) PlannerContext(io.trino.sql.PlannerContext) HashMap(java.util.HashMap) Symbol(io.trino.sql.planner.Symbol) ImmutableList(com.google.common.collect.ImmutableList) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) Assignments(io.trino.sql.planner.plan.Assignments) Assignment(io.trino.spi.connector.Assignment) Entry(java.util.Map.Entry) LiteralEncoder(io.trino.sql.planner.LiteralEncoder) ColumnHandle(io.trino.spi.connector.ColumnHandle) AggregationNode(io.trino.sql.planner.plan.AggregationNode) ImmutableBiMap(com.google.common.collect.ImmutableBiMap) TableScanNode(io.trino.sql.planner.plan.TableScanNode) ConnectorExpression(io.trino.spi.expression.ConnectorExpression) Expression(io.trino.sql.tree.Expression) AggregateFunction(io.trino.spi.connector.AggregateFunction) TableHandle(io.trino.metadata.TableHandle) ProjectNode(io.trino.sql.planner.plan.ProjectNode) AggregationApplicationResult(io.trino.spi.connector.AggregationApplicationResult)

Aggregations

AggregateFunction (io.trino.spi.connector.AggregateFunction)12 ConnectorExpression (io.trino.spi.expression.ConnectorExpression)10 Variable (io.trino.spi.expression.Variable)10 Test (org.testng.annotations.Test)8 ImmutableList (com.google.common.collect.ImmutableList)6 List (java.util.List)5 ImmutableList.toImmutableList (com.google.common.collect.ImmutableList.toImmutableList)4 AggregationApplicationResult (io.trino.spi.connector.AggregationApplicationResult)4 ColumnHandle (io.trino.spi.connector.ColumnHandle)4 Map (java.util.Map)4 Optional (java.util.Optional)4 ImmutableMap (com.google.common.collect.ImmutableMap)3 Assignment (io.trino.spi.connector.Assignment)3 SchemaTableName (io.trino.spi.connector.SchemaTableName)3 TupleDomain (io.trino.spi.predicate.TupleDomain)3 BIGINT (io.trino.spi.type.BigintType.BIGINT)3 Verify.verify (com.google.common.base.Verify.verify)2 ImmutableMap.toImmutableMap (com.google.common.collect.ImmutableMap.toImmutableMap)2 ImmutableSet (com.google.common.collect.ImmutableSet)2 Session (io.trino.Session)2