Search in sources :

Example 1 with Aggregation

use of io.prestosql.spi.plan.AggregationNode.Aggregation in project hetu-core by openlookeng.

the class TestEffectivePredicateExtractor method testAggregation.

@Test
public void testAggregation() {
    PlanNode node = new AggregationNode(newId(), filter(baseTableScan, and(equals(AE, DE), equals(BE, EE), equals(CE, FE), lessThan(DE, bigintLiteral(10)), lessThan(CE, DE), greaterThan(AE, bigintLiteral(2)), equals(EE, FE))), ImmutableMap.of(C, new Aggregation(new CallExpression("test", new FunctionResolution(metadata.getFunctionAndTypeManager()).countFunction(), BIGINT, ImmutableList.of()), ImmutableList.of(), false, Optional.empty(), Optional.empty(), Optional.empty()), D, new Aggregation(new CallExpression("test", new FunctionResolution(metadata.getFunctionAndTypeManager()).countFunction(), BIGINT, ImmutableList.of()), ImmutableList.of(), false, Optional.empty(), Optional.empty(), Optional.empty())), singleGroupingSet(ImmutableList.of(A, B, C)), ImmutableList.of(), AggregationNode.Step.FINAL, Optional.empty(), Optional.empty(), AggregationNode.AggregationType.HASH, Optional.empty());
    Expression effectivePredicate = effectivePredicateExtractor.extract(SESSION, node, TypeProvider.empty(), typeAnalyzer);
    // Rewrite in terms of group by symbols
    assertEquals(normalizeConjuncts(effectivePredicate), normalizeConjunctsSet(lessThan(AE, bigintLiteral(10)), lessThan(BE, AE), greaterThan(AE, bigintLiteral(2)), equals(BE, CE)));
}
Also used : AggregationNode.globalAggregation(io.prestosql.spi.plan.AggregationNode.globalAggregation) Aggregation(io.prestosql.spi.plan.AggregationNode.Aggregation) PlanNode(io.prestosql.spi.plan.PlanNode) CallExpression(io.prestosql.spi.relation.CallExpression) OriginalExpressionUtils.castToRowExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToRowExpression) InListExpression(io.prestosql.sql.tree.InListExpression) ComparisonExpression(io.prestosql.sql.tree.ComparisonExpression) RowExpression(io.prestosql.spi.relation.RowExpression) Expression(io.prestosql.sql.tree.Expression) AggregationNode(io.prestosql.spi.plan.AggregationNode) CallExpression(io.prestosql.spi.relation.CallExpression) FunctionResolution(io.prestosql.sql.relational.FunctionResolution) Test(org.testng.annotations.Test)

Example 2 with Aggregation

use of io.prestosql.spi.plan.AggregationNode.Aggregation in project hetu-core by openlookeng.

the class AggregationStatsRule method groupBy.

public static PlanNodeStatsEstimate groupBy(PlanNodeStatsEstimate sourceStats, Collection<Symbol> groupBySymbols, Map<Symbol, Aggregation> aggregations) {
    PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.builder();
    for (Symbol groupBySymbol : groupBySymbols) {
        SymbolStatsEstimate symbolStatistics = sourceStats.getSymbolStatistics(groupBySymbol);
        result.addSymbolStatistics(groupBySymbol, symbolStatistics.mapNullsFraction(nullsFraction -> {
            if (nullsFraction == 0.0) {
                return 0.0;
            }
            return 1.0 / (symbolStatistics.getDistinctValuesCount() + 1);
        }));
    }
    double rowsCount = 1;
    boolean knowsStatsFound = false;
    for (Symbol groupBySymbol : groupBySymbols) {
        SymbolStatsEstimate symbolStatistics = sourceStats.getSymbolStatistics(groupBySymbol);
        if (symbolStatistics.isUnknown()) {
            // so for now we dont consider the same for overall stats calculation.
            continue;
        }
        knowsStatsFound = true;
        int nullRow = (symbolStatistics.getNullsFraction() == 0.0) ? 0 : 1;
        rowsCount *= symbolStatistics.getDistinctValuesCount() + nullRow;
    }
    if (knowsStatsFound) {
        result.setOutputRowCount(min(rowsCount, sourceStats.getOutputRowCount()));
    } else {
        // If there was no proper stats in any of the grouping symbols, we just take fixed % of source.
        result.setOutputRowCount(sourceStats.getOutputRowCount() * UNKNOWN_FILTER_COEFFICIENT);
    }
    for (Map.Entry<Symbol, Aggregation> aggregationEntry : aggregations.entrySet()) {
        result.addSymbolStatistics(aggregationEntry.getKey(), estimateAggregationStats(aggregationEntry.getValue(), sourceStats));
    }
    return result.build();
}
Also used : Symbol(io.prestosql.spi.plan.Symbol) UNKNOWN_FILTER_COEFFICIENT(io.prestosql.cost.FilterStatsCalculator.UNKNOWN_FILTER_COEFFICIENT) Lookup(io.prestosql.sql.planner.iterative.Lookup) Patterns.aggregation(io.prestosql.sql.planner.plan.Patterns.aggregation) Collection(java.util.Collection) TypeProvider(io.prestosql.sql.planner.TypeProvider) Pattern(io.prestosql.matching.Pattern) FINAL(io.prestosql.spi.plan.AggregationNode.Step.FINAL) Math.min(java.lang.Math.min) PARTIAL(io.prestosql.spi.plan.AggregationNode.Step.PARTIAL) AggregationNode(io.prestosql.spi.plan.AggregationNode) Aggregation(io.prestosql.spi.plan.AggregationNode.Aggregation) SINGLE(io.prestosql.spi.plan.AggregationNode.Step.SINGLE) Map(java.util.Map) Objects.requireNonNull(java.util.Objects.requireNonNull) Session(io.prestosql.Session) Optional(java.util.Optional) Aggregation(io.prestosql.spi.plan.AggregationNode.Aggregation) Symbol(io.prestosql.spi.plan.Symbol) Map(java.util.Map)

Example 3 with Aggregation

use of io.prestosql.spi.plan.AggregationNode.Aggregation in project hetu-core by openlookeng.

the class QueryPlanner method aggregate.

private PlanBuilder aggregate(PlanBuilder inputSubPlan, QuerySpecification node) {
    PlanBuilder subPlan = inputSubPlan;
    if (!analysis.isAggregation(node)) {
        return subPlan;
    }
    // 1. Pre-project all scalar inputs (arguments and non-trivial group by expressions)
    Set<Expression> groupByExpressions = ImmutableSet.copyOf(analysis.getGroupByExpressions(node));
    ImmutableList.Builder<Expression> arguments = ImmutableList.builder();
    analysis.getAggregates(node).stream().map(FunctionCall::getArguments).flatMap(List::stream).filter(// lambda expression is generated at execution time
    exp -> !(exp instanceof LambdaExpression)).forEach(arguments::add);
    analysis.getAggregates(node).stream().map(FunctionCall::getOrderBy).filter(Optional::isPresent).map(Optional::get).map(OrderBy::getSortItems).flatMap(List::stream).map(SortItem::getSortKey).forEach(arguments::add);
    // filter expressions need to be projected first
    analysis.getAggregates(node).stream().map(FunctionCall::getFilter).filter(Optional::isPresent).map(Optional::get).forEach(arguments::add);
    Iterable<Expression> inputs = Iterables.concat(groupByExpressions, arguments.build());
    subPlan = handleSubqueries(subPlan, node, inputs);
    if (!Iterables.isEmpty(inputs)) {
        // avoid an empty projection if the only aggregation is COUNT (which has no arguments)
        subPlan = project(subPlan, inputs);
    }
    // 2. Aggregate
    // 2.a. Rewrite aggregate arguments
    TranslationMap argumentTranslations = new TranslationMap(subPlan.getRelationPlan(), analysis, lambdaDeclarationToSymbolMap);
    ImmutableList.Builder<Symbol> aggregationArgumentsBuilder = ImmutableList.builder();
    for (Expression argument : arguments.build()) {
        Symbol symbol = subPlan.translate(argument);
        argumentTranslations.put(argument, symbol);
        aggregationArgumentsBuilder.add(symbol);
    }
    List<Symbol> aggregationArguments = aggregationArgumentsBuilder.build();
    // 2.b. Rewrite grouping columns
    TranslationMap groupingTranslations = new TranslationMap(subPlan.getRelationPlan(), analysis, lambdaDeclarationToSymbolMap);
    Map<Symbol, Symbol> groupingSetMappings = new LinkedHashMap<>();
    for (Expression expression : groupByExpressions) {
        Symbol input = subPlan.translate(expression);
        Symbol output = planSymbolAllocator.newSymbol(expression, analysis.getTypeWithCoercions(expression), "gid");
        groupingTranslations.put(expression, output);
        groupingSetMappings.put(output, input);
    }
    // This tracks the grouping sets before complex expressions are considered (see comments below)
    // It's also used to compute the descriptors needed to implement grouping()
    List<Set<FieldId>> columnOnlyGroupingSets = ImmutableList.of(ImmutableSet.of());
    List<List<Symbol>> groupingSets = ImmutableList.of(ImmutableList.of());
    if (node.getGroupBy().isPresent()) {
        // For the purpose of "distinct", we need to canonicalize column references that may have varying
        // syntactic forms (e.g., "t.a" vs "a"). Thus we need to enumerate grouping sets based on the underlying
        // fieldId associated with each column reference expression.
        // The catch is that simple group-by expressions can be arbitrary expressions (this is a departure from the SQL specification).
        // But, they don't affect the number of grouping sets or the behavior of "distinct" . We can compute all the candidate
        // grouping sets in terms of fieldId, dedup as appropriate and then cross-join them with the complex expressions.
        Analysis.GroupingSetAnalysis groupingSetAnalysis = analysis.getGroupingSets(node);
        columnOnlyGroupingSets = enumerateGroupingSets(groupingSetAnalysis);
        if (node.getGroupBy().get().isDistinct()) {
            columnOnlyGroupingSets = columnOnlyGroupingSets.stream().distinct().collect(toImmutableList());
        }
        // add in the complex expressions an turn materialize the grouping sets in terms of plan columns
        ImmutableList.Builder<List<Symbol>> groupingSetBuilder = ImmutableList.builder();
        for (Set<FieldId> groupingSet : columnOnlyGroupingSets) {
            ImmutableList.Builder<Symbol> columns = ImmutableList.builder();
            groupingSetAnalysis.getComplexExpressions().stream().map(groupingTranslations::get).forEach(columns::add);
            groupingSet.stream().map(field -> groupingTranslations.get(new FieldReference(field.getFieldIndex()))).forEach(columns::add);
            groupingSetBuilder.add(columns.build());
        }
        groupingSets = groupingSetBuilder.build();
    }
    // 2.c. Generate GroupIdNode (multiple grouping sets) or ProjectNode (single grouping set)
    Optional<Symbol> groupIdSymbol = Optional.empty();
    if (groupingSets.size() > 1) {
        groupIdSymbol = Optional.of(planSymbolAllocator.newSymbol("groupId", BIGINT));
        GroupIdNode groupId = new GroupIdNode(idAllocator.getNextId(), subPlan.getRoot(), groupingSets, groupingSetMappings, aggregationArguments, groupIdSymbol.get());
        subPlan = new PlanBuilder(groupingTranslations, groupId);
    } else {
        Assignments.Builder assignments = Assignments.builder();
        aggregationArguments.forEach(symbol -> assignments.put(symbol, castToRowExpression(toSymbolReference(symbol))));
        groupingSetMappings.forEach((key, value) -> assignments.put(key, castToRowExpression(toSymbolReference(value))));
        ProjectNode project = new ProjectNode(idAllocator.getNextId(), subPlan.getRoot(), assignments.build());
        subPlan = new PlanBuilder(groupingTranslations, project);
    }
    TranslationMap aggregationTranslations = new TranslationMap(subPlan.getRelationPlan(), analysis, lambdaDeclarationToSymbolMap);
    aggregationTranslations.copyMappingsFrom(groupingTranslations);
    // 2.d. Rewrite aggregates
    ImmutableMap.Builder<Symbol, Aggregation> aggregationsBuilder = ImmutableMap.builder();
    boolean needPostProjectionCoercion = false;
    for (FunctionCall aggregate : analysis.getAggregates(node)) {
        Expression rewritten = argumentTranslations.rewrite(aggregate);
        Symbol newSymbol = planSymbolAllocator.newSymbol(rewritten, analysis.getType(aggregate));
        // Therefore we can end up with this implicit cast, and have to move it into a post-projection
        if (rewritten instanceof Cast) {
            rewritten = ((Cast) rewritten).getExpression();
            needPostProjectionCoercion = true;
        }
        aggregationTranslations.put(aggregate, newSymbol);
        FunctionCall functionCall = (FunctionCall) rewritten;
        aggregationsBuilder.put(newSymbol, new Aggregation(call(aggregate.getName().getSuffix(), analysis.getFunctionHandle(aggregate), analysis.getType(aggregate), functionCall.getArguments().stream().map(OriginalExpressionUtils::castToRowExpression).collect(toImmutableList())), functionCall.getArguments().stream().map(OriginalExpressionUtils::castToRowExpression).collect(toImmutableList()), functionCall.isDistinct(), functionCall.getFilter().map(SymbolUtils::from), functionCall.getOrderBy().map(OrderingSchemeUtils::fromOrderBy), Optional.empty()));
    }
    Map<Symbol, Aggregation> aggregations = aggregationsBuilder.build();
    ImmutableSet.Builder<Integer> globalGroupingSets = ImmutableSet.builder();
    for (int i = 0; i < groupingSets.size(); i++) {
        if (groupingSets.get(i).isEmpty()) {
            globalGroupingSets.add(i);
        }
    }
    ImmutableList.Builder<Symbol> groupingKeys = ImmutableList.builder();
    groupingSets.stream().flatMap(List::stream).distinct().forEach(groupingKeys::add);
    groupIdSymbol.ifPresent(groupingKeys::add);
    AggregationNode aggregationNode = new AggregationNode(idAllocator.getNextId(), subPlan.getRoot(), aggregations, groupingSets(groupingKeys.build(), groupingSets.size(), globalGroupingSets.build()), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), groupIdSymbol, AggregationNode.AggregationType.HASH, Optional.empty());
    subPlan = new PlanBuilder(aggregationTranslations, aggregationNode);
    // TODO: this is a hack, we should change type coercions to coerce the inputs to functions/operators instead of coercing the output
    if (needPostProjectionCoercion) {
        ImmutableList.Builder<Expression> alreadyCoerced = ImmutableList.builder();
        alreadyCoerced.addAll(groupByExpressions);
        groupIdSymbol.map(SymbolUtils::toSymbolReference).ifPresent(alreadyCoerced::add);
        subPlan = explicitCoercionFields(subPlan, alreadyCoerced.build(), analysis.getAggregates(node));
    }
    // 4. Project and re-write all grouping functions
    return handleGroupingOperations(subPlan, node, groupIdSymbol, columnOnlyGroupingSets);
}
Also used : Table(io.prestosql.sql.tree.Table) SortNode(io.prestosql.sql.planner.plan.SortNode) Property(io.prestosql.sql.tree.Property) SortOrder(io.prestosql.spi.block.SortOrder) AggregationNode(io.prestosql.spi.plan.AggregationNode) Cast(io.prestosql.sql.tree.Cast) HeuristicIndexUtils(io.prestosql.utils.HeuristicIndexUtils) Statement(io.prestosql.sql.tree.Statement) HetuConstant(io.prestosql.spi.HetuConstant) Map(java.util.Map) OriginalExpressionUtils.castToRowExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToRowExpression) FetchFirst(io.prestosql.sql.tree.FetchFirst) Identifier(io.prestosql.sql.tree.Identifier) PlanNodeId(io.prestosql.spi.plan.PlanNodeId) CreateIndexNode(io.prestosql.sql.planner.plan.CreateIndexNode) AssignmentItem(io.prestosql.sql.tree.AssignmentItem) Delete(io.prestosql.sql.tree.Delete) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) TableScanNode(io.prestosql.spi.plan.TableScanNode) Set(java.util.Set) PlanNode(io.prestosql.spi.plan.PlanNode) ProjectNode(io.prestosql.spi.plan.ProjectNode) Metadata(io.prestosql.metadata.Metadata) NodeRef(io.prestosql.sql.tree.NodeRef) LambdaExpression(io.prestosql.sql.tree.LambdaExpression) SymbolUtils.toSymbolReference(io.prestosql.sql.planner.SymbolUtils.toSymbolReference) ReuseExchangeOperator(io.prestosql.spi.operator.ReuseExchangeOperator) ImmutableMap.toImmutableMap(com.google.common.collect.ImmutableMap.toImmutableMap) SymbolReference(io.prestosql.sql.tree.SymbolReference) StringLiteral(io.prestosql.sql.tree.StringLiteral) LEVEL_PROP_KEY(io.prestosql.spi.connector.CreateIndexMetadata.LEVEL_PROP_KEY) SymbolUtils.from(io.prestosql.sql.planner.SymbolUtils.from) Field(io.prestosql.sql.analyzer.Field) OriginalExpressionUtils(io.prestosql.sql.relational.OriginalExpressionUtils) GroupIdNode(io.prestosql.spi.plan.GroupIdNode) Iterables(com.google.common.collect.Iterables) TableMetadata(io.prestosql.metadata.TableMetadata) NodeUtils.getSortItemsFromOrderBy(io.prestosql.sql.NodeUtils.getSortItemsFromOrderBy) Node(io.prestosql.sql.tree.Node) ArrayList(java.util.ArrayList) LinkedHashMap(java.util.LinkedHashMap) OrderingSchemeUtils.sortItemToSortOrder(io.prestosql.sql.planner.OrderingSchemeUtils.sortItemToSortOrder) MetadataUtil(io.prestosql.metadata.MetadataUtil) Session(io.prestosql.Session) ImmutableSet.toImmutableSet(com.google.common.collect.ImmutableSet.toImmutableSet) FrameBoundType(io.prestosql.spi.sql.expression.Types.FrameBoundType) DeleteNode(io.prestosql.sql.planner.plan.DeleteNode) AssignmentUtils(io.prestosql.sql.planner.plan.AssignmentUtils) Properties(java.util.Properties) Query(io.prestosql.sql.tree.Query) Assignments(io.prestosql.spi.plan.Assignments) Streams.stream(com.google.common.collect.Streams.stream) DeleteTarget(io.prestosql.sql.planner.plan.TableWriterNode.DeleteTarget) VARBINARY(io.prestosql.spi.type.VarbinaryType.VARBINARY) ValuesNode(io.prestosql.spi.plan.ValuesNode) ColumnHandle(io.prestosql.spi.connector.ColumnHandle) CreateIndexMetadata(io.prestosql.spi.connector.CreateIndexMetadata) WindowNode(io.prestosql.spi.plan.WindowNode) LimitNode(io.prestosql.spi.plan.LimitNode) Expression(io.prestosql.sql.tree.Expression) OffsetNode(io.prestosql.sql.planner.plan.OffsetNode) CURRENT_ROW(io.prestosql.spi.sql.expression.Types.FrameBoundType.CURRENT_ROW) UpdateIndexNode(io.prestosql.sql.planner.plan.UpdateIndexNode) UNBOUNDED_PRECEDING(io.prestosql.spi.sql.expression.Types.FrameBoundType.UNBOUNDED_PRECEDING) QualifiedName(io.prestosql.sql.tree.QualifiedName) FieldReference(io.prestosql.sql.tree.FieldReference) ExchangeNode(io.prestosql.sql.planner.plan.ExchangeNode) Preconditions.checkArgument(com.google.common.base.Preconditions.checkArgument) FilterNode(io.prestosql.spi.plan.FilterNode) Expressions.call(io.prestosql.sql.relational.Expressions.call) WindowFrame(io.prestosql.sql.tree.WindowFrame) Locale(java.util.Locale) Window(io.prestosql.sql.tree.Window) Type(io.prestosql.spi.type.Type) QuerySpecification(io.prestosql.sql.tree.QuerySpecification) TypeCoercion(io.prestosql.type.TypeCoercion) BIGINT(io.prestosql.spi.type.BigintType.BIGINT) AggregationNode.singleGroupingSet(io.prestosql.spi.plan.AggregationNode.singleGroupingSet) ImmutableSet(com.google.common.collect.ImmutableSet) WindowFrameType(io.prestosql.spi.sql.expression.Types.WindowFrameType) ImmutableMap(com.google.common.collect.ImmutableMap) GroupingOperation(io.prestosql.sql.tree.GroupingOperation) SystemSessionProperties.isSkipRedundantSort(io.prestosql.SystemSessionProperties.isSkipRedundantSort) RANGE(io.prestosql.spi.sql.expression.Types.WindowFrameType.RANGE) CreateIndex(io.prestosql.sql.tree.CreateIndex) PropertyService(io.prestosql.spi.service.PropertyService) UUID(java.util.UUID) RelationType(io.prestosql.sql.analyzer.RelationType) Collectors(java.util.stream.Collectors) Sets(com.google.common.collect.Sets) Pair(io.prestosql.spi.heuristicindex.Pair) String.format(java.lang.String.format) Scope(io.prestosql.sql.analyzer.Scope) List(java.util.List) AggregationNode.groupingSets(io.prestosql.spi.plan.AggregationNode.groupingSets) Entry(java.util.Map.Entry) Optional(java.util.Optional) Analysis(io.prestosql.sql.analyzer.Analysis) MoreObjects.firstNonNull(com.google.common.base.MoreObjects.firstNonNull) FieldId(io.prestosql.sql.analyzer.FieldId) IntStream(java.util.stream.IntStream) UpdateIndex(io.prestosql.sql.tree.UpdateIndex) UpdateIndexMetadata(io.prestosql.spi.connector.UpdateIndexMetadata) HashMap(java.util.HashMap) RelationId(io.prestosql.sql.analyzer.RelationId) TableHandle(io.prestosql.spi.metadata.TableHandle) DecimalLiteral(io.prestosql.sql.tree.DecimalLiteral) INDEX_SUPPORTED_TYPES(io.prestosql.spi.connector.CreateIndexMetadata.INDEX_SUPPORTED_TYPES) ImmutableList(com.google.common.collect.ImmutableList) FunctionCall(io.prestosql.sql.tree.FunctionCall) OrderingScheme(io.prestosql.spi.plan.OrderingScheme) Objects.requireNonNull(java.util.Objects.requireNonNull) SortItem(io.prestosql.sql.tree.SortItem) Symbol(io.prestosql.spi.plan.Symbol) TableWriterNode(io.prestosql.sql.planner.plan.TableWriterNode) Iterator(java.util.Iterator) ColumnMetadata(io.prestosql.spi.connector.ColumnMetadata) OrderBy(io.prestosql.sql.tree.OrderBy) UpdateNode(io.prestosql.sql.planner.plan.UpdateNode) LambdaArgumentDeclaration(io.prestosql.sql.tree.LambdaArgumentDeclaration) Update(io.prestosql.sql.tree.Update) Offset(io.prestosql.sql.tree.Offset) AUTOLOAD_PROP_KEY(io.prestosql.spi.connector.CreateIndexMetadata.AUTOLOAD_PROP_KEY) Aggregation(io.prestosql.spi.plan.AggregationNode.Aggregation) PlanNodeIdAllocator(io.prestosql.spi.plan.PlanNodeIdAllocator) RowExpression(io.prestosql.spi.relation.RowExpression) Assignments(io.prestosql.spi.plan.Assignments) LinkedHashMap(java.util.LinkedHashMap) ImmutableSet.toImmutableSet(com.google.common.collect.ImmutableSet.toImmutableSet) ImmutableSet(com.google.common.collect.ImmutableSet) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) ArrayList(java.util.ArrayList) List(java.util.List) ImmutableList(com.google.common.collect.ImmutableList) Optional(java.util.Optional) ImmutableMap.toImmutableMap(com.google.common.collect.ImmutableMap.toImmutableMap) ImmutableMap(com.google.common.collect.ImmutableMap) Analysis(io.prestosql.sql.analyzer.Analysis) OriginalExpressionUtils(io.prestosql.sql.relational.OriginalExpressionUtils) LambdaExpression(io.prestosql.sql.tree.LambdaExpression) Cast(io.prestosql.sql.tree.Cast) Set(java.util.Set) ImmutableSet.toImmutableSet(com.google.common.collect.ImmutableSet.toImmutableSet) AggregationNode.singleGroupingSet(io.prestosql.spi.plan.AggregationNode.singleGroupingSet) ImmutableSet(com.google.common.collect.ImmutableSet) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) ImmutableList(com.google.common.collect.ImmutableList) Symbol(io.prestosql.spi.plan.Symbol) Aggregation(io.prestosql.spi.plan.AggregationNode.Aggregation) GroupIdNode(io.prestosql.spi.plan.GroupIdNode) FunctionCall(io.prestosql.sql.tree.FunctionCall) FieldReference(io.prestosql.sql.tree.FieldReference) AggregationNode(io.prestosql.spi.plan.AggregationNode) OriginalExpressionUtils.castToRowExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToRowExpression) LambdaExpression(io.prestosql.sql.tree.LambdaExpression) Expression(io.prestosql.sql.tree.Expression) RowExpression(io.prestosql.spi.relation.RowExpression) FieldId(io.prestosql.sql.analyzer.FieldId) ProjectNode(io.prestosql.spi.plan.ProjectNode)

Example 4 with Aggregation

use of io.prestosql.spi.plan.AggregationNode.Aggregation in project hetu-core by openlookeng.

the class ImplementFilteredAggregations method apply.

@Override
public Result apply(AggregationNode aggregationNode, Captures captures, Context context) {
    Assignments.Builder newAssignments = Assignments.builder();
    ImmutableMap.Builder<Symbol, Aggregation> aggregations = ImmutableMap.builder();
    ImmutableList.Builder<Expression> maskSymbols = ImmutableList.builder();
    boolean aggregateWithoutFilterPresent = false;
    for (Map.Entry<Symbol, Aggregation> entry : aggregationNode.getAggregations().entrySet()) {
        Symbol output = entry.getKey();
        // strip the filters
        Aggregation aggregation = entry.getValue();
        Optional<Symbol> mask = aggregation.getMask();
        if (aggregation.getFilter().isPresent()) {
            Symbol filter = aggregation.getFilter().get();
            Symbol symbol = context.getSymbolAllocator().newSymbol(filter.getName(), BOOLEAN);
            verify(!mask.isPresent(), "Expected aggregation without mask symbols, see Rule pattern");
            newAssignments.put(symbol, castToRowExpression(toSymbolReference(filter)));
            mask = Optional.of(symbol);
            maskSymbols.add(toSymbolReference(symbol));
        } else {
            aggregateWithoutFilterPresent = true;
        }
        aggregations.put(output, new Aggregation(aggregation.getFunctionCall(), aggregation.getArguments(), aggregation.isDistinct(), Optional.empty(), aggregation.getOrderingScheme(), mask));
    }
    Expression predicate = TRUE_LITERAL;
    if (!aggregationNode.hasNonEmptyGroupingSet() && !aggregateWithoutFilterPresent) {
        predicate = combineDisjunctsWithDefault(maskSymbols.build(), TRUE_LITERAL);
    }
    // identity projection for all existing inputs
    newAssignments.putAll(AssignmentUtils.identityAsSymbolReferences(aggregationNode.getSource().getOutputSymbols()));
    return Result.ofPlanNode(new AggregationNode(context.getIdAllocator().getNextId(), new FilterNode(context.getIdAllocator().getNextId(), new ProjectNode(context.getIdAllocator().getNextId(), aggregationNode.getSource(), newAssignments.build()), castToRowExpression(predicate)), aggregations.build(), aggregationNode.getGroupingSets(), ImmutableList.of(), aggregationNode.getStep(), aggregationNode.getHashSymbol(), aggregationNode.getGroupIdSymbol(), aggregationNode.getAggregationType(), aggregationNode.getFinalizeSymbol()));
}
Also used : Symbol(io.prestosql.spi.plan.Symbol) ImmutableList(com.google.common.collect.ImmutableList) FilterNode(io.prestosql.spi.plan.FilterNode) Assignments(io.prestosql.spi.plan.Assignments) AggregationNode(io.prestosql.spi.plan.AggregationNode) ImmutableMap(com.google.common.collect.ImmutableMap) Aggregation(io.prestosql.spi.plan.AggregationNode.Aggregation) OriginalExpressionUtils.castToRowExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToRowExpression) Expression(io.prestosql.sql.tree.Expression) ProjectNode(io.prestosql.spi.plan.ProjectNode) ImmutableMap(com.google.common.collect.ImmutableMap) Map(java.util.Map)

Example 5 with Aggregation

use of io.prestosql.spi.plan.AggregationNode.Aggregation in project hetu-core by openlookeng.

the class SingleDistinctAggregationToGroupBy method apply.

@Override
public Result apply(AggregationNode aggregation, Captures captures, Context context) {
    List<Set<Expression>> argumentSets = extractArgumentSets(aggregation).collect(Collectors.toList());
    Set<Symbol> symbols = Iterables.getOnlyElement(argumentSets).stream().map(SymbolUtils::from).collect(Collectors.toSet());
    return Result.ofPlanNode(new AggregationNode(aggregation.getId(), new AggregationNode(context.getIdAllocator().getNextId(), aggregation.getSource(), ImmutableMap.of(), singleGroupingSet(ImmutableList.<Symbol>builder().addAll(aggregation.getGroupingKeys()).addAll(symbols).build()), ImmutableList.of(), SINGLE, Optional.empty(), Optional.empty(), aggregation.getAggregationType(), aggregation.getFinalizeSymbol()), // remove DISTINCT flag from function calls
    aggregation.getAggregations().entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> removeDistinct(e.getValue()))), aggregation.getGroupingSets(), emptyList(), aggregation.getStep(), aggregation.getHashSymbol(), aggregation.getGroupIdSymbol(), aggregation.getAggregationType(), aggregation.getFinalizeSymbol()));
}
Also used : Iterables(com.google.common.collect.Iterables) Patterns.aggregation(io.prestosql.sql.planner.plan.Patterns.aggregation) Pattern(io.prestosql.matching.Pattern) AggregationNode(io.prestosql.spi.plan.AggregationNode) HashSet(java.util.HashSet) Preconditions.checkArgument(com.google.common.base.Preconditions.checkArgument) ImmutableList(com.google.common.collect.ImmutableList) SINGLE(io.prestosql.spi.plan.AggregationNode.Step.SINGLE) Map(java.util.Map) AggregationNode.singleGroupingSet(io.prestosql.spi.plan.AggregationNode.singleGroupingSet) Symbol(io.prestosql.spi.plan.Symbol) SymbolUtils(io.prestosql.sql.planner.SymbolUtils) ImmutableMap(com.google.common.collect.ImmutableMap) Rule(io.prestosql.sql.planner.iterative.Rule) Collections.emptyList(java.util.Collections.emptyList) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) Set(java.util.Set) Collectors(java.util.stream.Collectors) Captures(io.prestosql.matching.Captures) List(java.util.List) Stream(java.util.stream.Stream) Aggregation(io.prestosql.spi.plan.AggregationNode.Aggregation) Optional(java.util.Optional) Expression(io.prestosql.sql.tree.Expression) OriginalExpressionUtils(io.prestosql.sql.relational.OriginalExpressionUtils) HashSet(java.util.HashSet) AggregationNode.singleGroupingSet(io.prestosql.spi.plan.AggregationNode.singleGroupingSet) Set(java.util.Set) Symbol(io.prestosql.spi.plan.Symbol) AggregationNode(io.prestosql.spi.plan.AggregationNode) Map(java.util.Map) ImmutableMap(com.google.common.collect.ImmutableMap)

Aggregations

Aggregation (io.prestosql.spi.plan.AggregationNode.Aggregation)17 AggregationNode (io.prestosql.spi.plan.AggregationNode)16 Symbol (io.prestosql.spi.plan.Symbol)16 Map (java.util.Map)12 CallExpression (io.prestosql.spi.relation.CallExpression)9 ImmutableMap (com.google.common.collect.ImmutableMap)8 PlanNode (io.prestosql.spi.plan.PlanNode)8 ProjectNode (io.prestosql.spi.plan.ProjectNode)6 ImmutableList (com.google.common.collect.ImmutableList)5 RowExpression (io.prestosql.spi.relation.RowExpression)5 Optional (java.util.Optional)5 Set (java.util.Set)5 ImmutableList.toImmutableList (com.google.common.collect.ImmutableList.toImmutableList)4 Session (io.prestosql.Session)4 Assignments (io.prestosql.spi.plan.Assignments)4 OriginalExpressionUtils.castToRowExpression (io.prestosql.sql.relational.OriginalExpressionUtils.castToRowExpression)4 Expression (io.prestosql.sql.tree.Expression)4 List (java.util.List)4 Preconditions.checkArgument (com.google.common.base.Preconditions.checkArgument)3 ImmutableSet (com.google.common.collect.ImmutableSet)3