Search in sources :

Example 1 with AggregationApplicationResult

use of io.trino.spi.connector.AggregationApplicationResult in project trino by trinodb.

the class MetadataManager method applyAggregation.

@Override
public Optional<AggregationApplicationResult<TableHandle>> applyAggregation(Session session, TableHandle table, List<AggregateFunction> aggregations, Map<String, ColumnHandle> assignments, List<List<ColumnHandle>> groupingSets) {
    // Global aggregation is represented by [[]]
    checkArgument(!groupingSets.isEmpty(), "No grouping sets provided");
    CatalogName catalogName = table.getCatalogName();
    ConnectorMetadata metadata = getMetadata(session, catalogName);
    ConnectorSession connectorSession = session.toConnectorSession(catalogName);
    return metadata.applyAggregation(connectorSession, table.getConnectorHandle(), aggregations, assignments, groupingSets).map(result -> {
        verifyProjection(table, result.getProjections(), result.getAssignments(), aggregations.size());
        return new AggregationApplicationResult<>(new TableHandle(catalogName, result.getHandle(), table.getTransaction()), result.getProjections(), result.getAssignments(), result.getGroupingColumnMapping(), result.isPrecalculateStatistics());
    });
}
Also used : CatalogName(io.trino.connector.CatalogName) ConnectorSession(io.trino.spi.connector.ConnectorSession) ConnectorOutputTableHandle(io.trino.spi.connector.ConnectorOutputTableHandle) ConnectorTableHandle(io.trino.spi.connector.ConnectorTableHandle) ConnectorInsertTableHandle(io.trino.spi.connector.ConnectorInsertTableHandle) ConnectorMetadata(io.trino.spi.connector.ConnectorMetadata) AggregationApplicationResult(io.trino.spi.connector.AggregationApplicationResult)

Example 2 with AggregationApplicationResult

use of io.trino.spi.connector.AggregationApplicationResult in project trino by trinodb.

the class TestDefaultJdbcMetadata method testAggregationPushdownForTableHandle.

@Test
public void testAggregationPushdownForTableHandle() {
    ConnectorSession session = TestingConnectorSession.builder().setPropertyMetadata(new JdbcMetadataSessionProperties(new JdbcMetadataConfig().setAggregationPushdownEnabled(true), Optional.empty()).getSessionProperties()).build();
    ColumnHandle groupByColumn = metadata.getColumnHandles(session, tableHandle).get("text");
    Function<ConnectorTableHandle, Optional<AggregationApplicationResult<ConnectorTableHandle>>> applyAggregation = handle -> metadata.applyAggregation(session, handle, ImmutableList.of(new AggregateFunction("count", BIGINT, List.of(), List.of(), false, Optional.empty())), ImmutableMap.of(), ImmutableList.of(ImmutableList.of(groupByColumn)));
    ConnectorTableHandle baseTableHandle = metadata.getTableHandle(session, new SchemaTableName("example", "numbers"));
    Optional<AggregationApplicationResult<ConnectorTableHandle>> aggregationResult = applyAggregation.apply(baseTableHandle);
    assertThat(aggregationResult).isPresent();
    SchemaTableName noAggregationPushdownTable = new SchemaTableName("example", "no_aggregation_pushdown");
    metadata.createTable(SESSION, new ConnectorTableMetadata(noAggregationPushdownTable, ImmutableList.of(new ColumnMetadata("text", VARCHAR))), false);
    ConnectorTableHandle noAggregationPushdownTableHandle = metadata.getTableHandle(session, noAggregationPushdownTable);
    aggregationResult = applyAggregation.apply(noAggregationPushdownTableHandle);
    assertThat(aggregationResult).isEmpty();
}
Also used : Constraint(io.trino.spi.connector.Constraint) Assert.assertNull(org.testng.Assert.assertNull) ColumnMetadata(io.trino.spi.connector.ColumnMetadata) AggregateFunction(io.trino.spi.connector.AggregateFunction) Assertions.assertThat(org.assertj.core.api.Assertions.assertThat) Assert.assertEquals(org.testng.Assert.assertEquals) ConnectorTableMetadata(io.trino.spi.connector.ConnectorTableMetadata) Test(org.testng.annotations.Test) AfterMethod(org.testng.annotations.AfterMethod) TrinoExceptionAssert.assertTrinoExceptionThrownBy(io.trino.testing.assertions.TrinoExceptionAssert.assertTrinoExceptionThrownBy) Function(java.util.function.Function) VARCHAR(io.trino.spi.type.VarcharType.VARCHAR) TableNotFoundException(io.trino.spi.connector.TableNotFoundException) ImmutableList(com.google.common.collect.ImmutableList) Assertions.assertThatThrownBy(org.assertj.core.api.Assertions.assertThatThrownBy) ConnectorTableHandle(io.trino.spi.connector.ConnectorTableHandle) Map(java.util.Map) ColumnHandle(io.trino.spi.connector.ColumnHandle) Slices.utf8Slice(io.airlift.slice.Slices.utf8Slice) NOT_FOUND(io.trino.spi.StandardErrorCode.NOT_FOUND) JDBC_VARCHAR(io.trino.plugin.jdbc.TestingJdbcTypeHandle.JDBC_VARCHAR) Collections.emptyMap(java.util.Collections.emptyMap) ImmutableSet(com.google.common.collect.ImmutableSet) ConstraintApplicationResult(io.trino.spi.connector.ConstraintApplicationResult) ImmutableMap(com.google.common.collect.ImmutableMap) Domain(io.trino.spi.predicate.Domain) BeforeMethod(org.testng.annotations.BeforeMethod) JDBC_BIGINT(io.trino.plugin.jdbc.TestingJdbcTypeHandle.JDBC_BIGINT) SESSION(io.trino.testing.TestingConnectorSession.SESSION) ConnectorSession(io.trino.spi.connector.ConnectorSession) TupleDomain(io.trino.spi.predicate.TupleDomain) SchemaTableName(io.trino.spi.connector.SchemaTableName) AggregationApplicationResult(io.trino.spi.connector.AggregationApplicationResult) TestingConnectorSession(io.trino.testing.TestingConnectorSession) List(java.util.List) BIGINT(io.trino.spi.type.BigintType.BIGINT) Optional(java.util.Optional) Assert.assertTrue(org.testng.Assert.assertTrue) VarcharType.createVarcharType(io.trino.spi.type.VarcharType.createVarcharType) ColumnHandle(io.trino.spi.connector.ColumnHandle) ColumnMetadata(io.trino.spi.connector.ColumnMetadata) Optional(java.util.Optional) SchemaTableName(io.trino.spi.connector.SchemaTableName) ConnectorTableHandle(io.trino.spi.connector.ConnectorTableHandle) AggregateFunction(io.trino.spi.connector.AggregateFunction) ConnectorSession(io.trino.spi.connector.ConnectorSession) TestingConnectorSession(io.trino.testing.TestingConnectorSession) AggregationApplicationResult(io.trino.spi.connector.AggregationApplicationResult) ConnectorTableMetadata(io.trino.spi.connector.ConnectorTableMetadata) Test(org.testng.annotations.Test)

Example 3 with AggregationApplicationResult

use of io.trino.spi.connector.AggregationApplicationResult in project trino by trinodb.

the class PushAggregationIntoTableScan method pushAggregationIntoTableScan.

public static Optional<PlanNode> pushAggregationIntoTableScan(PlannerContext plannerContext, Context context, PlanNode aggregationNode, TableScanNode tableScan, Map<Symbol, AggregationNode.Aggregation> aggregations, List<Symbol> groupingKeys) {
    Map<String, ColumnHandle> assignments = tableScan.getAssignments().entrySet().stream().collect(toImmutableMap(entry -> entry.getKey().getName(), Entry::getValue));
    List<Entry<Symbol, AggregationNode.Aggregation>> aggregationsList = aggregations.entrySet().stream().collect(toImmutableList());
    List<AggregateFunction> aggregateFunctions = aggregationsList.stream().map(Entry::getValue).map(aggregation -> toAggregateFunction(plannerContext.getMetadata(), context, aggregation)).collect(toImmutableList());
    List<Symbol> aggregationOutputSymbols = aggregationsList.stream().map(Entry::getKey).collect(toImmutableList());
    List<ColumnHandle> groupByColumns = groupingKeys.stream().map(groupByColumn -> assignments.get(groupByColumn.getName())).collect(toImmutableList());
    Optional<AggregationApplicationResult<TableHandle>> aggregationPushdownResult = plannerContext.getMetadata().applyAggregation(context.getSession(), tableScan.getTable(), aggregateFunctions, assignments, ImmutableList.of(groupByColumns));
    if (aggregationPushdownResult.isEmpty()) {
        return Optional.empty();
    }
    AggregationApplicationResult<TableHandle> result = aggregationPushdownResult.get();
    // The new scan outputs should be the symbols associated with grouping columns plus the symbols associated with aggregations.
    ImmutableList.Builder<Symbol> newScanOutputs = ImmutableList.builder();
    newScanOutputs.addAll(tableScan.getOutputSymbols());
    ImmutableBiMap.Builder<Symbol, ColumnHandle> newScanAssignments = ImmutableBiMap.builder();
    newScanAssignments.putAll(tableScan.getAssignments());
    Map<String, Symbol> variableMappings = new HashMap<>();
    for (Assignment assignment : result.getAssignments()) {
        Symbol symbol = context.getSymbolAllocator().newSymbol(assignment.getVariable(), assignment.getType());
        newScanOutputs.add(symbol);
        newScanAssignments.put(symbol, assignment.getColumn());
        variableMappings.put(assignment.getVariable(), symbol);
    }
    List<Expression> newProjections = result.getProjections().stream().map(expression -> ConnectorExpressionTranslator.translate(context.getSession(), expression, plannerContext, variableMappings, new LiteralEncoder(plannerContext))).collect(toImmutableList());
    verify(aggregationOutputSymbols.size() == newProjections.size());
    Assignments.Builder assignmentBuilder = Assignments.builder();
    IntStream.range(0, aggregationOutputSymbols.size()).forEach(index -> assignmentBuilder.put(aggregationOutputSymbols.get(index), newProjections.get(index)));
    ImmutableBiMap<Symbol, ColumnHandle> scanAssignments = newScanAssignments.build();
    ImmutableBiMap<ColumnHandle, Symbol> columnHandleToSymbol = scanAssignments.inverse();
    // projections assignmentBuilder should have both agg and group by so we add all the group bys as symbol references
    groupingKeys.forEach(groupBySymbol -> {
        // if the connector returned a new mapping from oldColumnHandle to newColumnHandle, groupBy needs to point to
        // new columnHandle's symbol reference, otherwise it will continue pointing at oldColumnHandle.
        ColumnHandle originalColumnHandle = assignments.get(groupBySymbol.getName());
        ColumnHandle groupByColumnHandle = result.getGroupingColumnMapping().getOrDefault(originalColumnHandle, originalColumnHandle);
        assignmentBuilder.put(groupBySymbol, columnHandleToSymbol.get(groupByColumnHandle).toSymbolReference());
    });
    return Optional.of(new ProjectNode(context.getIdAllocator().getNextId(), new TableScanNode(context.getIdAllocator().getNextId(), result.getHandle(), newScanOutputs.build(), scanAssignments, TupleDomain.all(), deriveTableStatisticsForPushdown(context.getStatsProvider(), context.getSession(), result.isPrecalculateStatistics(), aggregationNode), tableScan.isUpdateTarget(), // table scan partitioning might have changed with new table handle
    Optional.empty()), assignmentBuilder.build()));
}
Also used : IntStream(java.util.stream.IntStream) SortItem(io.trino.spi.connector.SortItem) LiteralEncoder(io.trino.sql.planner.LiteralEncoder) Aggregation.step(io.trino.sql.planner.plan.Patterns.Aggregation.step) AggregateFunction(io.trino.spi.connector.AggregateFunction) HashMap(java.util.HashMap) Variable(io.trino.spi.expression.Variable) Capture.newCapture(io.trino.matching.Capture.newCapture) PlanNode(io.trino.sql.planner.plan.PlanNode) ImmutableBiMap(com.google.common.collect.ImmutableBiMap) Patterns.aggregation(io.trino.sql.planner.plan.Patterns.aggregation) ImmutableList(com.google.common.collect.ImmutableList) Verify.verify(com.google.common.base.Verify.verify) Map(java.util.Map) Objects.requireNonNull(java.util.Objects.requireNonNull) ColumnHandle(io.trino.spi.connector.ColumnHandle) AggregationNode(io.trino.sql.planner.plan.AggregationNode) Rule(io.trino.sql.planner.iterative.Rule) ProjectNode(io.trino.sql.planner.plan.ProjectNode) TableScanNode(io.trino.sql.planner.plan.TableScanNode) Symbol(io.trino.sql.planner.Symbol) Rules.deriveTableStatisticsForPushdown(io.trino.sql.planner.iterative.rule.Rules.deriveTableStatisticsForPushdown) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) ConnectorExpressionTranslator(io.trino.sql.planner.ConnectorExpressionTranslator) Assignments(io.trino.sql.planner.plan.Assignments) TupleDomain(io.trino.spi.predicate.TupleDomain) OrderingScheme(io.trino.sql.planner.OrderingScheme) Patterns.tableScan(io.trino.sql.planner.plan.Patterns.tableScan) Capture(io.trino.matching.Capture) AggregationApplicationResult(io.trino.spi.connector.AggregationApplicationResult) List(java.util.List) ImmutableMap.toImmutableMap(com.google.common.collect.ImmutableMap.toImmutableMap) Pattern(io.trino.matching.Pattern) TableHandle(io.trino.metadata.TableHandle) ConnectorExpression(io.trino.spi.expression.ConnectorExpression) SystemSessionProperties.isAllowPushdownIntoConnectors(io.trino.SystemSessionProperties.isAllowPushdownIntoConnectors) Patterns.source(io.trino.sql.planner.plan.Patterns.source) SymbolReference(io.trino.sql.tree.SymbolReference) Captures(io.trino.matching.Captures) BoundSignature(io.trino.metadata.BoundSignature) Assignment(io.trino.spi.connector.Assignment) Entry(java.util.Map.Entry) Metadata(io.trino.metadata.Metadata) Optional(java.util.Optional) Expression(io.trino.sql.tree.Expression) Session(io.trino.Session) PlannerContext(io.trino.sql.PlannerContext) HashMap(java.util.HashMap) Symbol(io.trino.sql.planner.Symbol) ImmutableList(com.google.common.collect.ImmutableList) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) Assignments(io.trino.sql.planner.plan.Assignments) Assignment(io.trino.spi.connector.Assignment) Entry(java.util.Map.Entry) LiteralEncoder(io.trino.sql.planner.LiteralEncoder) ColumnHandle(io.trino.spi.connector.ColumnHandle) AggregationNode(io.trino.sql.planner.plan.AggregationNode) ImmutableBiMap(com.google.common.collect.ImmutableBiMap) TableScanNode(io.trino.sql.planner.plan.TableScanNode) ConnectorExpression(io.trino.spi.expression.ConnectorExpression) Expression(io.trino.sql.tree.Expression) AggregateFunction(io.trino.spi.connector.AggregateFunction) TableHandle(io.trino.metadata.TableHandle) ProjectNode(io.trino.sql.planner.plan.ProjectNode) AggregationApplicationResult(io.trino.spi.connector.AggregationApplicationResult)

Example 4 with AggregationApplicationResult

use of io.trino.spi.connector.AggregationApplicationResult in project trino by trinodb.

the class DefaultJdbcMetadata method applyAggregation.

@Override
public Optional<AggregationApplicationResult<ConnectorTableHandle>> applyAggregation(ConnectorSession session, ConnectorTableHandle table, List<AggregateFunction> aggregates, Map<String, ColumnHandle> assignments, List<List<ColumnHandle>> groupingSets) {
    if (!isAggregationPushdownEnabled(session)) {
        return Optional.empty();
    }
    JdbcTableHandle handle = (JdbcTableHandle) table;
    // Global aggregation is represented by [[]]
    verify(!groupingSets.isEmpty(), "No grouping sets provided");
    if (!jdbcClient.supportsAggregationPushdown(session, handle, aggregates, assignments, groupingSets)) {
        // JDBC client implementation prevents pushdown for the given table
        return Optional.empty();
    }
    if (handle.getLimit().isPresent()) {
        handle = flushAttributesAsQuery(session, handle);
    }
    int nextSyntheticColumnId = handle.getNextSyntheticColumnId();
    ImmutableList.Builder<JdbcColumnHandle> newColumns = ImmutableList.builder();
    ImmutableList.Builder<ConnectorExpression> projections = ImmutableList.builder();
    ImmutableList.Builder<Assignment> resultAssignments = ImmutableList.builder();
    ImmutableMap.Builder<String, String> expressions = ImmutableMap.builder();
    List<List<JdbcColumnHandle>> groupingSetsAsJdbcColumnHandles = groupingSets.stream().map(groupingSet -> groupingSet.stream().map(JdbcColumnHandle.class::cast).collect(toImmutableList())).collect(toImmutableList());
    Optional<List<JdbcColumnHandle>> tableColumns = handle.getColumns();
    groupingSetsAsJdbcColumnHandles.stream().flatMap(List::stream).distinct().peek(handle.getColumns().<Consumer<JdbcColumnHandle>>map(columns -> groupKey -> verify(columns.contains(groupKey), "applyAggregation called with a grouping column %s which was not included in the table columns: %s", groupKey, tableColumns)).orElse(groupKey -> {
    })).forEach(newColumns::add);
    for (AggregateFunction aggregate : aggregates) {
        Optional<JdbcExpression> expression = jdbcClient.implementAggregation(session, aggregate, assignments);
        if (expression.isEmpty()) {
            return Optional.empty();
        }
        String columnName = SYNTHETIC_COLUMN_NAME_PREFIX + nextSyntheticColumnId;
        nextSyntheticColumnId++;
        JdbcColumnHandle newColumn = JdbcColumnHandle.builder().setColumnName(columnName).setJdbcTypeHandle(expression.get().getJdbcTypeHandle()).setColumnType(aggregate.getOutputType()).setComment(Optional.of("synthetic")).build();
        newColumns.add(newColumn);
        projections.add(new Variable(newColumn.getColumnName(), aggregate.getOutputType()));
        resultAssignments.add(new Assignment(newColumn.getColumnName(), newColumn, aggregate.getOutputType()));
        expressions.put(columnName, expression.get().getExpression());
    }
    List<JdbcColumnHandle> newColumnsList = newColumns.build();
    // We need to have matching column handles in JdbcTableHandle constructed below, as columns read via JDBC must match column handles list.
    // For more context see assertion in JdbcRecordSetProvider.getRecordSet
    PreparedQuery preparedQuery = jdbcClient.prepareQuery(session, handle, Optional.of(groupingSetsAsJdbcColumnHandles), newColumnsList, expressions.buildOrThrow());
    handle = new JdbcTableHandle(new JdbcQueryRelationHandle(preparedQuery), TupleDomain.all(), ImmutableList.of(), Optional.empty(), OptionalLong.empty(), Optional.of(newColumnsList), handle.getAllReferencedTables(), nextSyntheticColumnId);
    return Optional.of(new AggregationApplicationResult<>(handle, projections.build(), resultAssignments.build(), ImmutableMap.of(), false));
}
Also used : SortItem(io.trino.spi.connector.SortItem) AggregateFunction(io.trino.spi.connector.AggregateFunction) TopNApplicationResult(io.trino.spi.connector.TopNApplicationResult) NOT_SUPPORTED(io.trino.spi.StandardErrorCode.NOT_SUPPORTED) Preconditions.checkArgument(com.google.common.base.Preconditions.checkArgument) JdbcMetadataSessionProperties.isAggregationPushdownEnabled(io.trino.plugin.jdbc.JdbcMetadataSessionProperties.isAggregationPushdownEnabled) TableNotFoundException(io.trino.spi.connector.TableNotFoundException) ConnectorOutputTableHandle(io.trino.spi.connector.ConnectorOutputTableHandle) ConnectorTableHandle(io.trino.spi.connector.ConnectorTableHandle) Map(java.util.Map) ProjectionApplicationResult(io.trino.spi.connector.ProjectionApplicationResult) Functions.identity(com.google.common.base.Functions.identity) ImmutableSet(com.google.common.collect.ImmutableSet) ImmutableMap(com.google.common.collect.ImmutableMap) Domain(io.trino.spi.predicate.Domain) Collection(java.util.Collection) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) ComputedStatistics(io.trino.spi.statistics.ComputedStatistics) Set(java.util.Set) ConnectorExpressions.extractConjuncts(io.trino.plugin.base.expression.ConnectorExpressions.extractConjuncts) TrinoException(io.trino.spi.TrinoException) LimitApplicationResult(io.trino.spi.connector.LimitApplicationResult) ConnectorOutputMetadata(io.trino.spi.connector.ConnectorOutputMetadata) SchemaTableName(io.trino.spi.connector.SchemaTableName) ConnectorTableSchema(io.trino.spi.connector.ConnectorTableSchema) Preconditions.checkState(com.google.common.base.Preconditions.checkState) List(java.util.List) ImmutableMap.toImmutableMap(com.google.common.collect.ImmutableMap.toImmutableMap) TrinoPrincipal(io.trino.spi.security.TrinoPrincipal) BIGINT(io.trino.spi.type.BigintType.BIGINT) SchemaTablePrefix(io.trino.spi.connector.SchemaTablePrefix) Assignment(io.trino.spi.connector.Assignment) Optional(java.util.Optional) Math.max(java.lang.Math.max) TableScanRedirectApplicationResult(io.trino.spi.connector.TableScanRedirectApplicationResult) SystemTable(io.trino.spi.connector.SystemTable) Types(java.sql.Types) Verify.verifyNotNull(com.google.common.base.Verify.verifyNotNull) AccessDeniedException(io.trino.spi.security.AccessDeniedException) ConnectorTableLayout(io.trino.spi.connector.ConnectorTableLayout) Constraint(io.trino.spi.connector.Constraint) ConnectorInsertTableHandle(io.trino.spi.connector.ConnectorInsertTableHandle) JoinCondition(io.trino.spi.connector.JoinCondition) Slice(io.airlift.slice.Slice) ColumnMetadata(io.trino.spi.connector.ColumnMetadata) ConnectorTableMetadata(io.trino.spi.connector.ConnectorTableMetadata) HashMap(java.util.HashMap) Variable(io.trino.spi.expression.Variable) DomainPushdownResult(io.trino.plugin.jdbc.PredicatePushdownController.DomainPushdownResult) AtomicReference(java.util.concurrent.atomic.AtomicReference) ArrayList(java.util.ArrayList) OptionalLong(java.util.OptionalLong) JoinStatistics(io.trino.spi.connector.JoinStatistics) ImmutableList(com.google.common.collect.ImmutableList) JoinType(io.trino.spi.connector.JoinType) Verify.verify(com.google.common.base.Verify.verify) Objects.requireNonNull(java.util.Objects.requireNonNull) ColumnHandle(io.trino.spi.connector.ColumnHandle) TableStatistics(io.trino.spi.statistics.TableStatistics) JdbcMetadataSessionProperties.isJoinPushdownEnabled(io.trino.plugin.jdbc.JdbcMetadataSessionProperties.isJoinPushdownEnabled) JdbcMetadataSessionProperties.isTopNPushdownEnabled(io.trino.plugin.jdbc.JdbcMetadataSessionProperties.isTopNPushdownEnabled) Constant(io.trino.spi.expression.Constant) ConstraintApplicationResult(io.trino.spi.connector.ConstraintApplicationResult) JdbcMetadataSessionProperties.isComplexExpressionPushdown(io.trino.plugin.jdbc.JdbcMetadataSessionProperties.isComplexExpressionPushdown) ConnectorSession(io.trino.spi.connector.ConnectorSession) TupleDomain(io.trino.spi.predicate.TupleDomain) AggregationApplicationResult(io.trino.spi.connector.AggregationApplicationResult) Consumer(java.util.function.Consumer) JoinApplicationResult(io.trino.spi.connector.JoinApplicationResult) ConnectorExpressions.and(io.trino.plugin.base.expression.ConnectorExpressions.and) ConnectorTableProperties(io.trino.spi.connector.ConnectorTableProperties) ConnectorExpression(io.trino.spi.expression.ConnectorExpression) Variable(io.trino.spi.expression.Variable) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) ImmutableList(com.google.common.collect.ImmutableList) ConnectorExpression(io.trino.spi.expression.ConnectorExpression) Assignment(io.trino.spi.connector.Assignment) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) List(java.util.List) ArrayList(java.util.ArrayList) ImmutableList(com.google.common.collect.ImmutableList) Constraint(io.trino.spi.connector.Constraint) ImmutableMap(com.google.common.collect.ImmutableMap) ImmutableMap.toImmutableMap(com.google.common.collect.ImmutableMap.toImmutableMap) AggregateFunction(io.trino.spi.connector.AggregateFunction)

Example 5 with AggregationApplicationResult

use of io.trino.spi.connector.AggregationApplicationResult in project trino by trinodb.

the class TestPushDistinctLimitIntoTableScan method testPushDistinct.

@Test
public void testPushDistinct() {
    AtomicInteger applyCallCounter = new AtomicInteger();
    AtomicReference<List<AggregateFunction>> applyAggregates = new AtomicReference<>();
    AtomicReference<Map<String, ColumnHandle>> applyAssignments = new AtomicReference<>();
    AtomicReference<List<List<ColumnHandle>>> applyGroupingSets = new AtomicReference<>();
    testApplyAggregation = (session, handle, aggregates, assignments, groupingSets) -> {
        applyCallCounter.incrementAndGet();
        applyAggregates.set(List.copyOf(aggregates));
        applyAssignments.set(Map.copyOf(assignments));
        applyGroupingSets.set(groupingSets.stream().map(List::copyOf).collect(toUnmodifiableList()));
        return Optional.of(new AggregationApplicationResult<>(new MockConnectorTableHandle(new SchemaTableName("mock_schema", "mock_nation_aggregated")), List.of(), List.of(), Map.of(), false));
    };
    MockConnectorColumnHandle regionkeyHandle = new MockConnectorColumnHandle("regionkey", BIGINT);
    tester().assertThat(rule).on(p -> {
        Symbol regionkeySymbol = p.symbol("regionkey_symbol");
        return p.distinctLimit(43, List.of(regionkeySymbol), p.tableScan(tableHandle, ImmutableList.of(regionkeySymbol), ImmutableMap.of(regionkeySymbol, regionkeyHandle)));
    }).matches(limit(43, project(tableScan("mock_nation_aggregated"))));
    assertThat(applyCallCounter).as("applyCallCounter").hasValue(1);
    assertThat(applyAggregates).as("applyAggregates").hasValue(List.of());
    assertThat(applyAssignments).as("applyAssignments").hasValue(Map.of("regionkey_symbol", regionkeyHandle));
    assertThat(applyGroupingSets).as("applyGroupingSets").hasValue(List.of(List.of(regionkeyHandle)));
}
Also used : AggregateFunction(io.trino.spi.connector.AggregateFunction) Assertions.assertThat(org.assertj.core.api.Assertions.assertThat) Test(org.testng.annotations.Test) AtomicReference(java.util.concurrent.atomic.AtomicReference) Collectors.toUnmodifiableList(java.util.stream.Collectors.toUnmodifiableList) PlanMatchPattern.limit(io.trino.sql.planner.assertions.PlanMatchPattern.limit) CatalogName(io.trino.connector.CatalogName) ImmutableList(com.google.common.collect.ImmutableList) MockConnectorFactory(io.trino.connector.MockConnectorFactory) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) LocalQueryRunner(io.trino.testing.LocalQueryRunner) Map(java.util.Map) ColumnHandle(io.trino.spi.connector.ColumnHandle) Symbol(io.trino.sql.planner.Symbol) ApplyAggregation(io.trino.connector.MockConnectorFactory.ApplyAggregation) ImmutableMap(com.google.common.collect.ImmutableMap) BaseRuleTest(io.trino.sql.planner.iterative.rule.test.BaseRuleTest) MockConnectorColumnHandle(io.trino.connector.MockConnectorColumnHandle) BeforeClass(org.testng.annotations.BeforeClass) BeforeMethod(org.testng.annotations.BeforeMethod) SchemaTableName(io.trino.spi.connector.SchemaTableName) AggregationApplicationResult(io.trino.spi.connector.AggregationApplicationResult) MockConnectorTableHandle(io.trino.connector.MockConnectorTableHandle) MockConnectorTransactionHandle(io.trino.connector.MockConnectorTransactionHandle) List(java.util.List) TableHandle(io.trino.metadata.TableHandle) BIGINT(io.trino.spi.type.BigintType.BIGINT) PlanMatchPattern.project(io.trino.sql.planner.assertions.PlanMatchPattern.project) TestingSession(io.trino.testing.TestingSession) Optional(java.util.Optional) PlanMatchPattern.tableScan(io.trino.sql.planner.assertions.PlanMatchPattern.tableScan) Session(io.trino.Session) ColumnHandle(io.trino.spi.connector.ColumnHandle) MockConnectorColumnHandle(io.trino.connector.MockConnectorColumnHandle) Symbol(io.trino.sql.planner.Symbol) AtomicReference(java.util.concurrent.atomic.AtomicReference) SchemaTableName(io.trino.spi.connector.SchemaTableName) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) MockConnectorTableHandle(io.trino.connector.MockConnectorTableHandle) Collectors.toUnmodifiableList(java.util.stream.Collectors.toUnmodifiableList) ImmutableList(com.google.common.collect.ImmutableList) List(java.util.List) Map(java.util.Map) ImmutableMap(com.google.common.collect.ImmutableMap) MockConnectorColumnHandle(io.trino.connector.MockConnectorColumnHandle) Test(org.testng.annotations.Test) BaseRuleTest(io.trino.sql.planner.iterative.rule.test.BaseRuleTest)

Aggregations

AggregationApplicationResult (io.trino.spi.connector.AggregationApplicationResult)5 ImmutableList (com.google.common.collect.ImmutableList)4 AggregateFunction (io.trino.spi.connector.AggregateFunction)4 ColumnHandle (io.trino.spi.connector.ColumnHandle)4 List (java.util.List)4 Map (java.util.Map)4 Optional (java.util.Optional)4 ImmutableMap (com.google.common.collect.ImmutableMap)3 SchemaTableName (io.trino.spi.connector.SchemaTableName)3 TupleDomain (io.trino.spi.predicate.TupleDomain)3 BIGINT (io.trino.spi.type.BigintType.BIGINT)3 Verify.verify (com.google.common.base.Verify.verify)2 ImmutableList.toImmutableList (com.google.common.collect.ImmutableList.toImmutableList)2 ImmutableMap.toImmutableMap (com.google.common.collect.ImmutableMap.toImmutableMap)2 ImmutableSet (com.google.common.collect.ImmutableSet)2 Session (io.trino.Session)2 CatalogName (io.trino.connector.CatalogName)2 TableHandle (io.trino.metadata.TableHandle)2 Assignment (io.trino.spi.connector.Assignment)2 ConnectorSession (io.trino.spi.connector.ConnectorSession)2