Search in sources :

Example 6 with AggregationNode

use of io.prestosql.spi.plan.AggregationNode in project hetu-core by openlookeng.

the class PruneCountAggregationOverScalar method apply.

@Override
public Result apply(AggregationNode parent, Captures captures, Context context) {
    if (!parent.hasDefaultOutput() || parent.getOutputSymbols().size() != 1) {
        return Result.empty();
    }
    Map<Symbol, AggregationNode.Aggregation> assignments = parent.getAggregations();
    for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : assignments.entrySet()) {
        AggregationNode.Aggregation aggregation = entry.getValue();
        requireNonNull(aggregation, "aggregation is null");
        if (!functionResolution.isCountFunction(aggregation.getFunctionHandle()) || !aggregation.getArguments().isEmpty()) {
            return Result.empty();
        }
    }
    if (!assignments.isEmpty() && isScalar(parent.getSource(), context.getLookup())) {
        return Result.ofPlanNode(new ValuesNode(parent.getId(), parent.getOutputSymbols(), ImmutableList.of(ImmutableList.of(castToRowExpression(new LongLiteral("1"))))));
    }
    return Result.empty();
}
Also used : ValuesNode(io.prestosql.spi.plan.ValuesNode) LongLiteral(io.prestosql.sql.tree.LongLiteral) Symbol(io.prestosql.spi.plan.Symbol) AggregationNode(io.prestosql.spi.plan.AggregationNode) Map(java.util.Map)

Example 7 with AggregationNode

use of io.prestosql.spi.plan.AggregationNode in project hetu-core by openlookeng.

the class TransformCorrelatedScalarAggregationToJoin method apply.

@Override
public Result apply(LateralJoinNode lateralJoinNode, Captures captures, Context context) {
    PlanNode subquery = lateralJoinNode.getSubquery();
    if (!isScalar(subquery, context.getLookup())) {
        return Result.empty();
    }
    Optional<AggregationNode> aggregation = findAggregation(subquery, context.getLookup());
    if (!(aggregation.isPresent() && aggregation.get().getGroupingKeys().isEmpty())) {
        return Result.empty();
    }
    ScalarAggregationToJoinRewriter rewriter = new ScalarAggregationToJoinRewriter(metadata, context.getSymbolAllocator(), context.getIdAllocator(), context.getLookup());
    PlanNode rewrittenNode = rewriter.rewriteScalarAggregation(lateralJoinNode, aggregation.get());
    if (rewrittenNode instanceof LateralJoinNode) {
        return Result.empty();
    }
    return Result.ofPlanNode(rewrittenNode);
}
Also used : PlanNode(io.prestosql.spi.plan.PlanNode) LateralJoinNode(io.prestosql.sql.planner.plan.LateralJoinNode) ScalarAggregationToJoinRewriter(io.prestosql.sql.planner.optimizations.ScalarAggregationToJoinRewriter) AggregationNode(io.prestosql.spi.plan.AggregationNode)

Example 8 with AggregationNode

use of io.prestosql.spi.plan.AggregationNode in project hetu-core by openlookeng.

the class TransformUnCorrelatedInPredicateSubQuerySelfJoinToAggregate method transformProjectNode.

private Optional<ProjectNode> transformProjectNode(Context context, ProjectNode projectNode) {
    // IN predicate requires only one projection
    if (projectNode.getOutputSymbols().size() > 1) {
        return Optional.empty();
    }
    PlanNode source = context.getLookup().resolve(projectNode.getSource());
    if (source instanceof CTEScanNode) {
        source = getChildFilterNode(context, context.getLookup().resolve(((CTEScanNode) source).getSource()));
    }
    if (!(source instanceof FilterNode && context.getLookup().resolve(((FilterNode) source).getSource()) instanceof JoinNode)) {
        return Optional.empty();
    }
    FilterNode filter = (FilterNode) source;
    Expression predicate = OriginalExpressionUtils.castToExpression(filter.getPredicate());
    List<SymbolReference> allPredicateSymbols = new ArrayList<>();
    getAllSymbols(predicate, allPredicateSymbols);
    JoinNode joinNode = (JoinNode) context.getLookup().resolve(((FilterNode) source).getSource());
    if (!isSelfJoin(projectNode, predicate, joinNode, context.getLookup())) {
        // Check next level for Self Join
        PlanNode left = context.getLookup().resolve(joinNode.getLeft());
        boolean changed = false;
        if (left instanceof ProjectNode) {
            Optional<ProjectNode> transformResult = transformProjectNode(context, (ProjectNode) left);
            if (transformResult.isPresent()) {
                joinNode = new JoinNode(joinNode.getId(), joinNode.getType(), transformResult.get(), joinNode.getRight(), joinNode.getCriteria(), joinNode.getOutputSymbols(), joinNode.getFilter(), joinNode.getLeftHashSymbol(), joinNode.getRightHashSymbol(), joinNode.getDistributionType(), joinNode.isSpillable(), joinNode.getDynamicFilters());
                changed = true;
            }
        }
        PlanNode right = context.getLookup().resolve(joinNode.getRight());
        if (right instanceof ProjectNode) {
            Optional<ProjectNode> transformResult = transformProjectNode(context, (ProjectNode) right);
            if (transformResult.isPresent()) {
                joinNode = new JoinNode(joinNode.getId(), joinNode.getType(), joinNode.getLeft(), transformResult.get(), joinNode.getCriteria(), joinNode.getOutputSymbols(), joinNode.getFilter(), joinNode.getLeftHashSymbol(), joinNode.getRightHashSymbol(), joinNode.getDistributionType(), joinNode.isSpillable(), joinNode.getDynamicFilters());
                changed = true;
            }
        }
        if (changed) {
            FilterNode transformedFilter = new FilterNode(filter.getId(), joinNode, filter.getPredicate());
            ProjectNode transformedProject = new ProjectNode(projectNode.getId(), transformedFilter, projectNode.getAssignments());
            return Optional.of(transformedProject);
        }
        return Optional.empty();
    }
    // Choose the table to use based on projected output.
    TableScanNode leftTable = (TableScanNode) context.getLookup().resolve(joinNode.getLeft());
    TableScanNode rightTable = (TableScanNode) context.getLookup().resolve(joinNode.getRight());
    TableScanNode tableToUse;
    List<RowExpression> aggregationSymbols;
    AggregationNode.GroupingSetDescriptor groupingSetDescriptor;
    Assignments projectionsForCTE = null;
    Assignments projectionsFromFilter = null;
    // Use non-projected column for aggregation
    if (context.getLookup().resolve(projectNode.getSource()) instanceof CTEScanNode) {
        CTEScanNode cteScanNode = (CTEScanNode) context.getLookup().resolve(projectNode.getSource());
        ProjectNode childProjectOfCte = (ProjectNode) context.getLookup().resolve(cteScanNode.getSource());
        List<Symbol> completeOutputSymbols = new ArrayList<>();
        rightTable.getOutputSymbols().forEach((s) -> completeOutputSymbols.add(s));
        leftTable.getOutputSymbols().forEach((s) -> completeOutputSymbols.add(s));
        List<Symbol> outputSymbols = new ArrayList<>();
        for (int i = 0; i < completeOutputSymbols.size(); i++) {
            Symbol outputSymbol = completeOutputSymbols.get(i);
            for (Symbol symbol : projectNode.getOutputSymbols()) {
                if (childProjectOfCte.getAssignments().getMap().containsKey(symbol)) {
                    if (((SymbolReference) OriginalExpressionUtils.castToExpression(childProjectOfCte.getAssignments().getMap().get(symbol))).getName().equals(outputSymbol.getName())) {
                        outputSymbols.add(outputSymbol);
                    }
                }
            }
        }
        Map<Symbol, RowExpression> projectionsForCTEMap = new HashMap<>();
        Map<Symbol, RowExpression> projectionsFromFilterMap = new HashMap<>();
        for (Map.Entry entry : childProjectOfCte.getAssignments().getMap().entrySet()) {
            if (entry.getKey().equals(getOnlyElement(projectNode.getOutputSymbols()))) {
                projectionsForCTEMap.put((Symbol) entry.getKey(), (RowExpression) entry.getValue());
            }
            if (entry.getKey().equals(getOnlyElement(projectNode.getOutputSymbols()))) {
                projectionsFromFilterMap.put(getOnlyElement(outputSymbols), (RowExpression) entry.getValue());
            }
        }
        projectionsForCTE = new Assignments(projectionsForCTEMap);
        projectionsFromFilter = new Assignments(projectionsFromFilterMap);
        tableToUse = leftTable.getOutputSymbols().contains(getOnlyElement(outputSymbols)) ? leftTable : rightTable;
        aggregationSymbols = allPredicateSymbols.stream().filter(s -> tableToUse.getOutputSymbols().contains(SymbolUtils.from(s))).filter(s -> !outputSymbols.contains(SymbolUtils.from(s))).map(OriginalExpressionUtils::castToRowExpression).collect(Collectors.toList());
        // Create aggregation
        groupingSetDescriptor = new AggregationNode.GroupingSetDescriptor(ImmutableList.copyOf(outputSymbols), 1, ImmutableSet.of());
    } else {
        tableToUse = leftTable.getOutputSymbols().contains(getOnlyElement(projectNode.getOutputSymbols())) ? leftTable : rightTable;
        aggregationSymbols = allPredicateSymbols.stream().filter(s -> tableToUse.getOutputSymbols().contains(SymbolUtils.from(s))).filter(s -> !projectNode.getOutputSymbols().contains(SymbolUtils.from(s))).map(OriginalExpressionUtils::castToRowExpression).collect(Collectors.toList());
        // Create aggregation
        groupingSetDescriptor = new AggregationNode.GroupingSetDescriptor(ImmutableList.copyOf(projectNode.getOutputSymbols()), 1, ImmutableSet.of());
    }
    AggregationNode.Aggregation aggregation = new AggregationNode.Aggregation(new CallExpression("count", functionResolution.countFunction(), BIGINT, aggregationSymbols), aggregationSymbols, // mark DISTINCT since NOT_EQUALS predicate
    true, Optional.empty(), Optional.empty(), Optional.empty());
    ImmutableMap.Builder<Symbol, AggregationNode.Aggregation> aggregationsBuilder = ImmutableMap.builder();
    Symbol countSymbol = context.getSymbolAllocator().newSymbol(aggregation.getFunctionCall().getDisplayName(), BIGINT);
    aggregationsBuilder.put(countSymbol, aggregation);
    AggregationNode aggregationNode = new AggregationNode(context.getIdAllocator().getNextId(), tableToUse, aggregationsBuilder.build(), groupingSetDescriptor, ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty(), AggregationNode.AggregationType.HASH, Optional.empty());
    // Filter rows with count < 1 from aggregation results to match the NOT_EQUALS clause in original query.
    FilterNode filterNode = new FilterNode(context.getIdAllocator().getNextId(), aggregationNode, OriginalExpressionUtils.castToRowExpression(new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN, SymbolUtils.toSymbolReference(countSymbol), new GenericLiteral("BIGINT", "1"))));
    // Project the aggregated+filtered rows.
    ProjectNode transformedSubquery = new ProjectNode(projectNode.getId(), filterNode, projectNode.getAssignments());
    if (context.getLookup().resolve(projectNode.getSource()) instanceof CTEScanNode) {
        CTEScanNode cteScanNode = (CTEScanNode) context.getLookup().resolve(projectNode.getSource());
        PlanNode projectNodeForCTE = context.getLookup().resolve(cteScanNode.getSource());
        PlanNode projectNodeFromFilter = context.getLookup().resolve(projectNodeForCTE);
        projectNodeFromFilter = new ProjectNode(projectNodeFromFilter.getId(), filterNode, projectionsFromFilter);
        projectNodeForCTE = new ProjectNode(projectNodeForCTE.getId(), projectNodeFromFilter, projectionsForCTE);
        cteScanNode = (CTEScanNode) cteScanNode.replaceChildren(ImmutableList.of(projectNodeForCTE));
        cteScanNode.setOutputSymbols(projectNode.getOutputSymbols());
        transformedSubquery = new ProjectNode(projectNode.getId(), cteScanNode, projectNode.getAssignments());
    }
    return Optional.of(transformedSubquery);
}
Also used : Apply.subQuery(io.prestosql.sql.planner.plan.Patterns.Apply.subQuery) Lookup(io.prestosql.sql.planner.iterative.Lookup) Patterns.applyNode(io.prestosql.sql.planner.plan.Patterns.applyNode) HashMap(java.util.HashMap) Pattern(io.prestosql.matching.Pattern) StandardFunctionResolution(io.prestosql.spi.function.StandardFunctionResolution) CTEScanNode(io.prestosql.spi.plan.CTEScanNode) ArrayList(java.util.ArrayList) AggregationNode(io.prestosql.spi.plan.AggregationNode) CallExpression(io.prestosql.spi.relation.CallExpression) Capture.newCapture(io.prestosql.matching.Capture.newCapture) ImmutableList(com.google.common.collect.ImmutableList) FilterNode(io.prestosql.spi.plan.FilterNode) Map(java.util.Map) Objects.requireNonNull(java.util.Objects.requireNonNull) Session(io.prestosql.Session) BIGINT(io.prestosql.spi.type.BigintType.BIGINT) Patterns.project(io.prestosql.sql.planner.plan.Patterns.project) JoinNode(io.prestosql.spi.plan.JoinNode) Symbol(io.prestosql.spi.plan.Symbol) SymbolUtils(io.prestosql.sql.planner.SymbolUtils) LogicalBinaryExpression(io.prestosql.sql.tree.LogicalBinaryExpression) ApplyNode(io.prestosql.sql.planner.plan.ApplyNode) ImmutableSet(com.google.common.collect.ImmutableSet) ImmutableMap(com.google.common.collect.ImmutableMap) Assignments(io.prestosql.spi.plan.Assignments) Rule(io.prestosql.sql.planner.iterative.Rule) Pattern.empty(io.prestosql.matching.Pattern.empty) TableScanNode(io.prestosql.spi.plan.TableScanNode) ComparisonExpression(io.prestosql.sql.tree.ComparisonExpression) Iterables.getOnlyElement(com.google.common.collect.Iterables.getOnlyElement) PlanNode(io.prestosql.spi.plan.PlanNode) ProjectNode(io.prestosql.spi.plan.ProjectNode) Collectors(java.util.stream.Collectors) Metadata(io.prestosql.metadata.Metadata) Captures(io.prestosql.matching.Captures) List(java.util.List) FunctionResolution(io.prestosql.sql.relational.FunctionResolution) GenericLiteral(io.prestosql.sql.tree.GenericLiteral) SymbolReference(io.prestosql.sql.tree.SymbolReference) TRANSFORM_SELF_JOIN_TO_GROUPBY(io.prestosql.SystemSessionProperties.TRANSFORM_SELF_JOIN_TO_GROUPBY) Apply.correlation(io.prestosql.sql.planner.plan.Patterns.Apply.correlation) Capture(io.prestosql.matching.Capture) RowExpression(io.prestosql.spi.relation.RowExpression) Optional(java.util.Optional) Expression(io.prestosql.sql.tree.Expression) OriginalExpressionUtils(io.prestosql.sql.relational.OriginalExpressionUtils) InPredicate(io.prestosql.sql.tree.InPredicate) HashMap(java.util.HashMap) SymbolReference(io.prestosql.sql.tree.SymbolReference) Symbol(io.prestosql.spi.plan.Symbol) FilterNode(io.prestosql.spi.plan.FilterNode) ArrayList(java.util.ArrayList) Assignments(io.prestosql.spi.plan.Assignments) GenericLiteral(io.prestosql.sql.tree.GenericLiteral) PlanNode(io.prestosql.spi.plan.PlanNode) CTEScanNode(io.prestosql.spi.plan.CTEScanNode) CallExpression(io.prestosql.spi.relation.CallExpression) JoinNode(io.prestosql.spi.plan.JoinNode) RowExpression(io.prestosql.spi.relation.RowExpression) AggregationNode(io.prestosql.spi.plan.AggregationNode) ImmutableMap(com.google.common.collect.ImmutableMap) ComparisonExpression(io.prestosql.sql.tree.ComparisonExpression) TableScanNode(io.prestosql.spi.plan.TableScanNode) CallExpression(io.prestosql.spi.relation.CallExpression) LogicalBinaryExpression(io.prestosql.sql.tree.LogicalBinaryExpression) ComparisonExpression(io.prestosql.sql.tree.ComparisonExpression) RowExpression(io.prestosql.spi.relation.RowExpression) Expression(io.prestosql.sql.tree.Expression) ProjectNode(io.prestosql.spi.plan.ProjectNode) OriginalExpressionUtils(io.prestosql.sql.relational.OriginalExpressionUtils) HashMap(java.util.HashMap) Map(java.util.Map) ImmutableMap(com.google.common.collect.ImmutableMap)

Example 9 with AggregationNode

use of io.prestosql.spi.plan.AggregationNode in project hetu-core by openlookeng.

the class AggregationRewriteWithCube method createScanNode.

public CubeRewriteResult createScanNode(AggregationNode originalAggregationNode, PlanNode filterNode, TableHandle cubeTableHandle, Map<String, ColumnHandle> cubeColumnsMap, List<ColumnMetadata> cubeColumnsMetadata, boolean exactGroupsMatch) {
    Set<Symbol> cubeScanSymbols = new HashSet<>();
    Map<Symbol, ColumnHandle> symbolAssignments = new HashMap<>();
    Set<CubeRewriteResult.DimensionSource> dimensionSymbols = new HashSet<>();
    Set<CubeRewriteResult.AggregatorSource> aggregationColumns = new HashSet<>();
    Set<CubeRewriteResult.AverageAggregatorSource> averageAggregationColumns = new HashSet<>();
    Map<Symbol, ColumnMetadata> symbolMetadataMap = new HashMap<>();
    Map<String, ColumnMetadata> columnMetadataMap = cubeColumnsMetadata.stream().collect(Collectors.toMap(ColumnMetadata::getName, Function.identity()));
    boolean computeAvgDividingSumByCount = true;
    Set<Symbol> filterSymbols = new HashSet<>();
    if (filterNode != null) {
        filterSymbols.addAll(SymbolsExtractor.extractUnique(((FilterNode) filterNode).getPredicate()));
    }
    for (Symbol filterSymbol : filterSymbols) {
        if (symbolMappings.containsKey(filterSymbol.getName()) && symbolMappings.get(filterSymbol.getName()) instanceof ColumnHandle) {
            // output symbol references of the columns in original table
            ColumnHandle originalColumn = (ColumnHandle) symbolMappings.get(filterSymbol.getName());
            ColumnHandle cubeScanColumn = cubeColumnsMap.get(originalColumn.getColumnName());
            ColumnMetadata columnMetadata = columnMetadataMap.get(cubeScanColumn.getColumnName());
            Symbol cubeScanSymbol = symbolAllocator.newSymbol(cubeScanColumn.getColumnName(), columnMetadata.getType());
            cubeScanSymbols.add(cubeScanSymbol);
            symbolAssignments.put(cubeScanSymbol, cubeScanColumn);
            symbolMetadataMap.put(cubeScanSymbol, columnMetadata);
            rewrittenMappings.put(filterSymbol.getName(), cubeScanSymbol);
            dimensionSymbols.add(new CubeRewriteResult.DimensionSource(filterSymbol, cubeScanSymbol));
        }
    }
    for (Symbol originalAggOutputSymbol : originalAggregationNode.getOutputSymbols()) {
        if (symbolMappings.containsKey(originalAggOutputSymbol.getName()) && symbolMappings.get(originalAggOutputSymbol.getName()) instanceof ColumnHandle) {
            // output symbol references of the columns in original table - column part of group by clause
            ColumnHandle originalColumn = (ColumnHandle) symbolMappings.get(originalAggOutputSymbol.getName());
            ColumnHandle cubeScanColumn = cubeColumnsMap.get(originalColumn.getColumnName());
            ColumnMetadata columnMetadata = columnMetadataMap.get(cubeScanColumn.getColumnName());
            if (!symbolAssignments.containsValue(cubeScanColumn)) {
                Symbol cubeScanSymbol = symbolAllocator.newSymbol(cubeScanColumn.getColumnName(), columnMetadata.getType());
                cubeScanSymbols.add(cubeScanSymbol);
                symbolAssignments.put(cubeScanSymbol, cubeScanColumn);
                symbolMetadataMap.put(cubeScanSymbol, columnMetadata);
                rewrittenMappings.put(originalAggOutputSymbol.getName(), cubeScanSymbol);
                dimensionSymbols.add(new CubeRewriteResult.DimensionSource(originalAggOutputSymbol, cubeScanSymbol));
            } else {
                Symbol cubeScanSymbol = symbolAssignments.keySet().stream().filter(key -> cubeScanColumn.equals(symbolAssignments.get(key))).findFirst().get();
                rewrittenMappings.put(originalAggOutputSymbol.getName(), cubeScanSymbol);
                dimensionSymbols.add(new CubeRewriteResult.DimensionSource(originalAggOutputSymbol, cubeScanSymbol));
            }
        } else if (originalAggregationNode.getAggregations().containsKey(originalAggOutputSymbol)) {
            // output symbol is mapped to an aggregation
            AggregationNode.Aggregation aggregation = originalAggregationNode.getAggregations().get(originalAggOutputSymbol);
            String aggFunction = aggregation.getFunctionCall().getDisplayName();
            List<Expression> arguments = aggregation.getArguments() == null ? null : aggregation.getArguments().stream().map(OriginalExpressionUtils::castToExpression).collect(Collectors.toList());
            if (arguments != null && !arguments.isEmpty() && (!(arguments.get(0) instanceof SymbolReference))) {
                log.info("Not a symbol reference in aggregation function. Agg Function = %s, Arguments = %s", aggFunction, arguments);
                continue;
            }
            Object mappedValue = arguments == null || arguments.isEmpty() ? null : symbolMappings.get(((SymbolReference) arguments.get(0)).getName());
            if (mappedValue == null || (mappedValue instanceof LongLiteral && ((LongLiteral) mappedValue).getValue() == 1)) {
                // COUNT aggregation
                if (CubeAggregateFunction.COUNT.getName().equals(aggFunction) && !aggregation.isDistinct()) {
                    // COUNT 1
                    AggregationSignature aggregationSignature = AggregationSignature.count();
                    String cubeColumnName = cubeMetadata.getColumn(aggregationSignature).orElseThrow(() -> new PrestoException(CUBE_ERROR, "Cannot find column associated with aggregation " + aggregationSignature));
                    ColumnHandle cubeColHandle = cubeColumnsMap.get(cubeColumnName);
                    if (!symbolAssignments.containsValue(cubeColHandle)) {
                        ColumnMetadata columnMetadata = columnMetadataMap.get(cubeColHandle.getColumnName());
                        Symbol cubeScanSymbol = symbolAllocator.newSymbol(cubeColHandle.getColumnName(), columnMetadata.getType());
                        cubeScanSymbols.add(cubeScanSymbol);
                        symbolMetadataMap.put(cubeScanSymbol, columnMetadata);
                        symbolAssignments.put(cubeScanSymbol, cubeColHandle);
                        rewrittenMappings.put(originalAggOutputSymbol.getName(), cubeScanSymbol);
                        aggregationColumns.add(new CubeRewriteResult.AggregatorSource(originalAggOutputSymbol, cubeScanSymbol));
                    }
                }
            } else if (mappedValue instanceof ColumnHandle) {
                String originalColumnName = ((ColumnHandle) mappedValue).getColumnName();
                boolean distinct = originalAggregationNode.getAggregations().get(originalAggOutputSymbol).isDistinct();
                switch(aggFunction) {
                    case "min":
                    case "max":
                    case "sum":
                    case "count":
                        AggregationSignature aggregationSignature = new AggregationSignature(aggFunction, originalColumnName, distinct);
                        String cubeColumnName = cubeMetadata.getColumn(aggregationSignature).orElseThrow(() -> new PrestoException(CUBE_ERROR, "Cannot find column associated with aggregation " + aggregationSignature));
                        ColumnHandle cubeColHandle = cubeColumnsMap.get(cubeColumnName);
                        if (!symbolAssignments.containsValue(cubeColHandle)) {
                            ColumnMetadata columnMetadata = columnMetadataMap.get(cubeColHandle.getColumnName());
                            Symbol cubeScanSymbol = symbolAllocator.newSymbol(cubeColHandle.getColumnName(), columnMetadata.getType());
                            cubeScanSymbols.add(cubeScanSymbol);
                            symbolAssignments.put(cubeScanSymbol, cubeColHandle);
                            symbolMetadataMap.put(cubeScanSymbol, columnMetadata);
                            rewrittenMappings.put(originalAggOutputSymbol.getName(), cubeScanSymbol);
                            aggregationColumns.add(new CubeRewriteResult.AggregatorSource(originalAggOutputSymbol, cubeScanSymbol));
                        } else {
                            ColumnMetadata columnMetadata = columnMetadataMap.get(cubeColHandle.getColumnName());
                            Symbol cubeScanSymbol = symbolAssignments.keySet().stream().filter(key -> cubeColHandle.equals(symbolAssignments.get(key))).findFirst().get();
                            cubeScanSymbols.add(cubeScanSymbol);
                            symbolAssignments.put(cubeScanSymbol, cubeColHandle);
                            symbolMetadataMap.put(cubeScanSymbol, columnMetadata);
                            rewrittenMappings.put(originalAggOutputSymbol.getName(), cubeScanSymbol);
                            aggregationColumns.add(new CubeRewriteResult.AggregatorSource(originalAggOutputSymbol, cubeScanSymbol));
                        }
                        break;
                    case "avg":
                        AggregationSignature avgAggregationSignature = new AggregationSignature(aggFunction, originalColumnName, distinct);
                        if (exactGroupsMatch && cubeMetadata.getColumn(avgAggregationSignature).isPresent()) {
                            computeAvgDividingSumByCount = false;
                            String avgCubeColumnName = cubeMetadata.getColumn(avgAggregationSignature).orElseThrow(() -> new PrestoException(CUBE_ERROR, "Cannot find column associated with aggregation " + avgAggregationSignature));
                            ColumnHandle avgCubeColHandle = cubeColumnsMap.get(avgCubeColumnName);
                            if (!symbolAssignments.containsValue(avgCubeColHandle)) {
                                ColumnMetadata columnMetadata = columnMetadataMap.get(avgCubeColHandle.getColumnName());
                                Symbol cubeScanSymbol = symbolAllocator.newSymbol(avgCubeColHandle.getColumnName(), columnMetadata.getType());
                                cubeScanSymbols.add(cubeScanSymbol);
                                symbolAssignments.put(cubeScanSymbol, avgCubeColHandle);
                                symbolMetadataMap.put(cubeScanSymbol, columnMetadata);
                                rewrittenMappings.put(originalAggOutputSymbol.getName(), cubeScanSymbol);
                                aggregationColumns.add(new CubeRewriteResult.AggregatorSource(originalAggOutputSymbol, cubeScanSymbol));
                            }
                        } else {
                            AggregationSignature sumSignature = new AggregationSignature(SUM.getName(), originalColumnName, distinct);
                            String sumColumnName = cubeMetadata.getColumn(sumSignature).orElseThrow(() -> new PrestoException(CUBE_ERROR, "Cannot find column associated with aggregation " + sumSignature));
                            ColumnHandle sumColumnHandle = cubeColumnsMap.get(sumColumnName);
                            Symbol sumSymbol = null;
                            if (!symbolAssignments.containsValue(sumColumnHandle)) {
                                ColumnMetadata columnMetadata = columnMetadataMap.get(sumColumnHandle.getColumnName());
                                sumSymbol = symbolAllocator.newSymbol("sum_" + originalColumnName + "_" + originalAggOutputSymbol.getName(), columnMetadata.getType());
                                cubeScanSymbols.add(sumSymbol);
                                symbolAssignments.put(sumSymbol, sumColumnHandle);
                                symbolMetadataMap.put(sumSymbol, columnMetadata);
                                rewrittenMappings.put(sumSymbol.getName(), sumSymbol);
                                aggregationColumns.add(new CubeRewriteResult.AggregatorSource(sumSymbol, sumSymbol));
                            } else {
                                for (Map.Entry<Symbol, ColumnHandle> assignment : symbolAssignments.entrySet()) {
                                    if (assignment.getValue().equals(sumColumnHandle)) {
                                        sumSymbol = assignment.getKey();
                                        break;
                                    }
                                }
                            }
                            AggregationSignature countSignature = new AggregationSignature(COUNT.getName(), originalColumnName, distinct);
                            String countColumnName = cubeMetadata.getColumn(countSignature).orElseThrow(() -> new PrestoException(CUBE_ERROR, "Cannot find column associated with aggregation " + countSignature));
                            ColumnHandle countColumnHandle = cubeColumnsMap.get(countColumnName);
                            Symbol countSymbol = null;
                            if (!symbolAssignments.containsValue(countColumnHandle)) {
                                ColumnMetadata columnMetadata = columnMetadataMap.get(countColumnHandle.getColumnName());
                                countSymbol = symbolAllocator.newSymbol("count_" + originalColumnName + "_" + originalAggOutputSymbol.getName(), columnMetadata.getType());
                                cubeScanSymbols.add(countSymbol);
                                symbolAssignments.put(countSymbol, countColumnHandle);
                                symbolMetadataMap.put(countSymbol, columnMetadata);
                                rewrittenMappings.put(countSymbol.getName(), countSymbol);
                                aggregationColumns.add(new CubeRewriteResult.AggregatorSource(countSymbol, countSymbol));
                            } else {
                                for (Map.Entry<Symbol, ColumnHandle> assignment : symbolAssignments.entrySet()) {
                                    if (assignment.getValue().equals(countColumnHandle)) {
                                        countSymbol = assignment.getKey();
                                        break;
                                    }
                                }
                            }
                            averageAggregationColumns.add(new CubeRewriteResult.AverageAggregatorSource(originalAggOutputSymbol, sumSymbol, countSymbol));
                        }
                        break;
                    default:
                        throw new PrestoException(StandardErrorCode.GENERIC_INTERNAL_ERROR, "Unsupported aggregation function " + aggFunction);
                }
            } else {
                log.info("Aggregation function argument is not a Column Handle. Agg Function = %s, Arguments = %s", aggFunction, arguments);
            }
        }
    }
    // Scan output order is important for partitioned cubes. Otherwise, incorrect results may be produced.
    // Refer: https://gitee.com/openlookeng/hetu-core/issues/I4LAYC
    List<Symbol> scanOutput = new ArrayList<>(cubeScanSymbols);
    scanOutput.sort(Comparator.comparingInt(outSymbol -> cubeColumnsMetadata.indexOf(symbolMetadataMap.get(outSymbol))));
    TableScanNode tableScanNode = TableScanNode.newInstance(idAllocator.getNextId(), cubeTableHandle, scanOutput, symbolAssignments, ReuseExchangeOperator.STRATEGY.REUSE_STRATEGY_DEFAULT, new UUID(0, 0), 0, false);
    return new CubeRewriteResult(tableScanNode, symbolMetadataMap, dimensionSymbols, aggregationColumns, averageAggregationColumns, computeAvgDividingSumByCount);
}
Also used : TypeProvider(io.prestosql.sql.planner.TypeProvider) ColumnNotFoundException(io.prestosql.spi.connector.ColumnNotFoundException) AggregationNode(io.prestosql.spi.plan.AggregationNode) CallExpression(io.prestosql.spi.relation.CallExpression) Cast(io.prestosql.sql.tree.Cast) FilterNode(io.prestosql.spi.plan.FilterNode) Map(java.util.Map) OriginalExpressionUtils.castToRowExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToRowExpression) Type(io.prestosql.spi.type.Type) OriginalExpressionUtils.castToExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToExpression) AggregationNode.singleGroupingSet(io.prestosql.spi.plan.AggregationNode.singleGroupingSet) PrestoException(io.prestosql.spi.PrestoException) SymbolsExtractor(io.prestosql.sql.planner.SymbolsExtractor) SUM(io.hetu.core.spi.cube.CubeAggregateFunction.SUM) ImmutableMap(com.google.common.collect.ImmutableMap) CubeAggregateFunction(io.hetu.core.spi.cube.CubeAggregateFunction) TableScanNode(io.prestosql.spi.plan.TableScanNode) Set(java.util.Set) PlanNode(io.prestosql.spi.plan.PlanNode) UUID(java.util.UUID) ProjectNode(io.prestosql.spi.plan.ProjectNode) Collectors(java.util.stream.Collectors) Metadata(io.prestosql.metadata.Metadata) FunctionHandle(io.prestosql.spi.function.FunctionHandle) PlanSymbolAllocator(io.prestosql.sql.planner.PlanSymbolAllocator) SymbolUtils.toSymbolReference(io.prestosql.sql.planner.SymbolUtils.toSymbolReference) ReuseExchangeOperator(io.prestosql.spi.operator.ReuseExchangeOperator) List(java.util.List) LongLiteral(io.prestosql.sql.tree.LongLiteral) SymbolReference(io.prestosql.sql.tree.SymbolReference) AggregationSignature(io.hetu.core.spi.cube.aggregator.AggregationSignature) Optional(java.util.Optional) StandardErrorCode(io.prestosql.spi.StandardErrorCode) OriginalExpressionUtils(io.prestosql.sql.relational.OriginalExpressionUtils) COUNT(io.hetu.core.spi.cube.CubeAggregateFunction.COUNT) TableMetadata(io.prestosql.metadata.TableMetadata) Logger(io.airlift.log.Logger) TypeSignatureProvider(io.prestosql.sql.analyzer.TypeSignatureProvider) HashMap(java.util.HashMap) TableHandle(io.prestosql.spi.metadata.TableHandle) Function(java.util.function.Function) ExpressionTreeRewriter(io.prestosql.sql.tree.ExpressionTreeRewriter) QualifiedObjectName(io.prestosql.spi.connector.QualifiedObjectName) ArrayList(java.util.ArrayList) HashSet(java.util.HashSet) SchemaTableName(io.prestosql.spi.connector.SchemaTableName) ArithmeticBinaryExpression(io.prestosql.sql.tree.ArithmeticBinaryExpression) ImmutableList(com.google.common.collect.ImmutableList) Session(io.prestosql.Session) ExpressionRewriter(io.prestosql.sql.tree.ExpressionRewriter) Symbol(io.prestosql.spi.plan.Symbol) Assignments(io.prestosql.spi.plan.Assignments) ColumnMetadata(io.prestosql.spi.connector.ColumnMetadata) SymbolAllocator(io.prestosql.spi.SymbolAllocator) ColumnHandle(io.prestosql.spi.connector.ColumnHandle) CUBE_ERROR(io.prestosql.spi.StandardErrorCode.CUBE_ERROR) PlanNodeIdAllocator(io.prestosql.spi.plan.PlanNodeIdAllocator) CubeMetadata(io.hetu.core.spi.cube.CubeMetadata) CubeNotFoundException(io.prestosql.spi.connector.CubeNotFoundException) Comparator(java.util.Comparator) Expression(io.prestosql.sql.tree.Expression) ColumnMetadata(io.prestosql.spi.connector.ColumnMetadata) HashMap(java.util.HashMap) Symbol(io.prestosql.spi.plan.Symbol) SymbolUtils.toSymbolReference(io.prestosql.sql.planner.SymbolUtils.toSymbolReference) SymbolReference(io.prestosql.sql.tree.SymbolReference) FilterNode(io.prestosql.spi.plan.FilterNode) ArrayList(java.util.ArrayList) PrestoException(io.prestosql.spi.PrestoException) List(java.util.List) ArrayList(java.util.ArrayList) ImmutableList(com.google.common.collect.ImmutableList) UUID(java.util.UUID) HashSet(java.util.HashSet) ColumnHandle(io.prestosql.spi.connector.ColumnHandle) LongLiteral(io.prestosql.sql.tree.LongLiteral) AggregationSignature(io.hetu.core.spi.cube.aggregator.AggregationSignature) TableScanNode(io.prestosql.spi.plan.TableScanNode) Map(java.util.Map) ImmutableMap(com.google.common.collect.ImmutableMap) HashMap(java.util.HashMap)

Example 10 with AggregationNode

use of io.prestosql.spi.plan.AggregationNode in project hetu-core by openlookeng.

the class AggregationRewriteWithCube method rewrite.

public PlanNode rewrite(AggregationNode originalAggregationNode, PlanNode filterNode) {
    QualifiedObjectName starTreeTableName = QualifiedObjectName.valueOf(cubeMetadata.getCubeName());
    TableHandle cubeTableHandle = metadata.getTableHandle(session, starTreeTableName).orElseThrow(() -> new CubeNotFoundException(starTreeTableName.toString()));
    Map<String, ColumnHandle> cubeColumnsMap = metadata.getColumnHandles(session, cubeTableHandle);
    TableMetadata cubeTableMetadata = metadata.getTableMetadata(session, cubeTableHandle);
    List<ColumnMetadata> cubeColumnMetadataList = cubeTableMetadata.getColumns();
    // Add group by
    List<Symbol> groupings = new ArrayList<>(originalAggregationNode.getGroupingKeys().size());
    for (Symbol symbol : originalAggregationNode.getGroupingKeys()) {
        Object column = symbolMappings.get(symbol.getName());
        if (column instanceof ColumnHandle) {
            groupings.add(new Symbol(((ColumnHandle) column).getColumnName()));
        }
    }
    Set<String> cubeGroups = cubeMetadata.getGroup();
    boolean exactGroupsMatch = false;
    if (groupings.size() == cubeGroups.size()) {
        exactGroupsMatch = groupings.stream().map(Symbol::getName).map(String::toLowerCase).allMatch(cubeGroups::contains);
    }
    CubeRewriteResult cubeRewriteResult = createScanNode(originalAggregationNode, filterNode, cubeTableHandle, cubeColumnsMap, cubeColumnMetadataList, exactGroupsMatch);
    PlanNode planNode = cubeRewriteResult.getTableScanNode();
    // Add filter node
    if (filterNode != null) {
        Expression expression = castToExpression(((FilterNode) filterNode).getPredicate());
        expression = rewriteExpression(expression, rewrittenMappings);
        planNode = new FilterNode(idAllocator.getNextId(), planNode, castToRowExpression(expression));
    }
    if (!exactGroupsMatch) {
        Map<Symbol, Symbol> cubeScanToAggOutputMap = new HashMap<>();
        // Rewrite AggregationNode using Cube table
        ImmutableMap.Builder<Symbol, AggregationNode.Aggregation> aggregationsBuilder = ImmutableMap.builder();
        for (CubeRewriteResult.AggregatorSource aggregatorSource : cubeRewriteResult.getAggregationColumns()) {
            ColumnHandle cubeColHandle = cubeRewriteResult.getTableScanNode().getAssignments().get(aggregatorSource.getScanSymbol());
            ColumnMetadata cubeColumnMetadata = cubeRewriteResult.getSymbolMetadataMap().get(aggregatorSource.getScanSymbol());
            Type type = cubeColumnMetadata.getType();
            AggregationSignature aggregationSignature = cubeMetadata.getAggregationSignature(cubeColumnMetadata.getName()).orElseThrow(() -> new ColumnNotFoundException(new SchemaTableName(starTreeTableName.getSchemaName(), starTreeTableName.getObjectName()), cubeColHandle.getColumnName()));
            String aggFunction = COUNT.getName().equals(aggregationSignature.getFunction()) ? "sum" : aggregationSignature.getFunction();
            SymbolReference argument = toSymbolReference(aggregatorSource.getScanSymbol());
            FunctionHandle functionHandle = metadata.getFunctionAndTypeManager().lookupFunction(aggFunction, TypeSignatureProvider.fromTypeSignatures(type.getTypeSignature()));
            cubeScanToAggOutputMap.put(aggregatorSource.getScanSymbol(), aggregatorSource.getOriginalAggSymbol());
            aggregationsBuilder.put(aggregatorSource.getOriginalAggSymbol(), new AggregationNode.Aggregation(new CallExpression(aggFunction, functionHandle, type, ImmutableList.of(OriginalExpressionUtils.castToRowExpression(argument))), ImmutableList.of(OriginalExpressionUtils.castToRowExpression(argument)), false, Optional.empty(), Optional.empty(), Optional.empty()));
        }
        List<Symbol> groupingKeys = originalAggregationNode.getGroupingKeys().stream().map(Symbol::getName).map(rewrittenMappings::get).collect(Collectors.toList());
        planNode = new AggregationNode(idAllocator.getNextId(), planNode, aggregationsBuilder.build(), singleGroupingSet(groupingKeys), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty(), AggregationNode.AggregationType.HASH, Optional.empty());
        AggregationNode aggNode = (AggregationNode) planNode;
        if (!cubeRewriteResult.getAvgAggregationColumns().isEmpty()) {
            if (!cubeRewriteResult.getComputeAvgDividingSumByCount()) {
                Map<Symbol, Expression> aggregateAssignments = new HashMap<>();
                for (CubeRewriteResult.AggregatorSource aggregatorSource : cubeRewriteResult.getAggregationColumns()) {
                    aggregateAssignments.put(aggregatorSource.getOriginalAggSymbol(), toSymbolReference(aggregatorSource.getScanSymbol()));
                }
                planNode = new ProjectNode(idAllocator.getNextId(), aggNode, new Assignments(aggregateAssignments.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> castToRowExpression(entry.getValue())))));
            } else {
                // If there was an AVG aggregation, map it to AVG = SUM/COUNT
                Map<Symbol, Expression> projections = new HashMap<>();
                aggNode.getGroupingKeys().forEach(symbol -> projections.put(symbol, toSymbolReference(symbol)));
                aggNode.getAggregations().keySet().stream().filter(symbol -> symbolMappings.containsValue(symbol.getName())).forEach(aggSymbol -> projections.put(aggSymbol, toSymbolReference(aggSymbol)));
                // Add AVG = SUM / COUNT
                for (CubeRewriteResult.AverageAggregatorSource avgAggSource : cubeRewriteResult.getAvgAggregationColumns()) {
                    Symbol sumSymbol = cubeScanToAggOutputMap.get(avgAggSource.getSum());
                    Symbol countSymbol = cubeScanToAggOutputMap.get(avgAggSource.getCount());
                    Type avgResultType = typeProvider.get(avgAggSource.getOriginalAggSymbol());
                    ArithmeticBinaryExpression division = new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.DIVIDE, new Cast(toSymbolReference(sumSymbol), avgResultType.getTypeSignature().toString()), new Cast(toSymbolReference(countSymbol), avgResultType.getTypeSignature().toString()));
                    projections.put(avgAggSource.getOriginalAggSymbol(), division);
                }
                planNode = new ProjectNode(idAllocator.getNextId(), aggNode, new Assignments(projections.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> castToRowExpression(entry.getValue())))));
            }
        }
    }
    // Safety check to remove redundant symbols and rename original column names to intermediate names
    if (!planNode.getOutputSymbols().equals(originalAggregationNode.getOutputSymbols())) {
        // Map new symbol names to the old symbols
        Map<Symbol, Expression> assignments = new HashMap<>();
        Set<Symbol> planNodeOutput = new HashSet<>(planNode.getOutputSymbols());
        for (Symbol originalAggOutputSymbol : originalAggregationNode.getOutputSymbols()) {
            if (!planNodeOutput.contains(originalAggOutputSymbol)) {
                // Must be grouping key
                assignments.put(originalAggOutputSymbol, toSymbolReference(rewrittenMappings.get(originalAggOutputSymbol.getName())));
            } else {
                // Should be an expression and must have the same name in the new plan node
                assignments.put(originalAggOutputSymbol, toSymbolReference(originalAggOutputSymbol));
            }
        }
        planNode = new ProjectNode(idAllocator.getNextId(), planNode, new Assignments(assignments.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> castToRowExpression(entry.getValue())))));
    }
    return planNode;
}
Also used : TypeProvider(io.prestosql.sql.planner.TypeProvider) ColumnNotFoundException(io.prestosql.spi.connector.ColumnNotFoundException) AggregationNode(io.prestosql.spi.plan.AggregationNode) CallExpression(io.prestosql.spi.relation.CallExpression) Cast(io.prestosql.sql.tree.Cast) FilterNode(io.prestosql.spi.plan.FilterNode) Map(java.util.Map) OriginalExpressionUtils.castToRowExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToRowExpression) Type(io.prestosql.spi.type.Type) OriginalExpressionUtils.castToExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToExpression) AggregationNode.singleGroupingSet(io.prestosql.spi.plan.AggregationNode.singleGroupingSet) PrestoException(io.prestosql.spi.PrestoException) SymbolsExtractor(io.prestosql.sql.planner.SymbolsExtractor) SUM(io.hetu.core.spi.cube.CubeAggregateFunction.SUM) ImmutableMap(com.google.common.collect.ImmutableMap) CubeAggregateFunction(io.hetu.core.spi.cube.CubeAggregateFunction) TableScanNode(io.prestosql.spi.plan.TableScanNode) Set(java.util.Set) PlanNode(io.prestosql.spi.plan.PlanNode) UUID(java.util.UUID) ProjectNode(io.prestosql.spi.plan.ProjectNode) Collectors(java.util.stream.Collectors) Metadata(io.prestosql.metadata.Metadata) FunctionHandle(io.prestosql.spi.function.FunctionHandle) PlanSymbolAllocator(io.prestosql.sql.planner.PlanSymbolAllocator) SymbolUtils.toSymbolReference(io.prestosql.sql.planner.SymbolUtils.toSymbolReference) ReuseExchangeOperator(io.prestosql.spi.operator.ReuseExchangeOperator) List(java.util.List) LongLiteral(io.prestosql.sql.tree.LongLiteral) SymbolReference(io.prestosql.sql.tree.SymbolReference) AggregationSignature(io.hetu.core.spi.cube.aggregator.AggregationSignature) Optional(java.util.Optional) StandardErrorCode(io.prestosql.spi.StandardErrorCode) OriginalExpressionUtils(io.prestosql.sql.relational.OriginalExpressionUtils) COUNT(io.hetu.core.spi.cube.CubeAggregateFunction.COUNT) TableMetadata(io.prestosql.metadata.TableMetadata) Logger(io.airlift.log.Logger) TypeSignatureProvider(io.prestosql.sql.analyzer.TypeSignatureProvider) HashMap(java.util.HashMap) TableHandle(io.prestosql.spi.metadata.TableHandle) Function(java.util.function.Function) ExpressionTreeRewriter(io.prestosql.sql.tree.ExpressionTreeRewriter) QualifiedObjectName(io.prestosql.spi.connector.QualifiedObjectName) ArrayList(java.util.ArrayList) HashSet(java.util.HashSet) SchemaTableName(io.prestosql.spi.connector.SchemaTableName) ArithmeticBinaryExpression(io.prestosql.sql.tree.ArithmeticBinaryExpression) ImmutableList(com.google.common.collect.ImmutableList) Session(io.prestosql.Session) ExpressionRewriter(io.prestosql.sql.tree.ExpressionRewriter) Symbol(io.prestosql.spi.plan.Symbol) Assignments(io.prestosql.spi.plan.Assignments) ColumnMetadata(io.prestosql.spi.connector.ColumnMetadata) SymbolAllocator(io.prestosql.spi.SymbolAllocator) ColumnHandle(io.prestosql.spi.connector.ColumnHandle) CUBE_ERROR(io.prestosql.spi.StandardErrorCode.CUBE_ERROR) PlanNodeIdAllocator(io.prestosql.spi.plan.PlanNodeIdAllocator) CubeMetadata(io.hetu.core.spi.cube.CubeMetadata) CubeNotFoundException(io.prestosql.spi.connector.CubeNotFoundException) Comparator(java.util.Comparator) Expression(io.prestosql.sql.tree.Expression) ArithmeticBinaryExpression(io.prestosql.sql.tree.ArithmeticBinaryExpression) Cast(io.prestosql.sql.tree.Cast) ColumnMetadata(io.prestosql.spi.connector.ColumnMetadata) HashMap(java.util.HashMap) Symbol(io.prestosql.spi.plan.Symbol) SymbolUtils.toSymbolReference(io.prestosql.sql.planner.SymbolUtils.toSymbolReference) SymbolReference(io.prestosql.sql.tree.SymbolReference) FilterNode(io.prestosql.spi.plan.FilterNode) ArrayList(java.util.ArrayList) Assignments(io.prestosql.spi.plan.Assignments) PlanNode(io.prestosql.spi.plan.PlanNode) CubeNotFoundException(io.prestosql.spi.connector.CubeNotFoundException) FunctionHandle(io.prestosql.spi.function.FunctionHandle) CallExpression(io.prestosql.spi.relation.CallExpression) HashSet(java.util.HashSet) TableMetadata(io.prestosql.metadata.TableMetadata) ColumnHandle(io.prestosql.spi.connector.ColumnHandle) AggregationSignature(io.hetu.core.spi.cube.aggregator.AggregationSignature) AggregationNode(io.prestosql.spi.plan.AggregationNode) SchemaTableName(io.prestosql.spi.connector.SchemaTableName) QualifiedObjectName(io.prestosql.spi.connector.QualifiedObjectName) ImmutableMap(com.google.common.collect.ImmutableMap) Type(io.prestosql.spi.type.Type) ColumnNotFoundException(io.prestosql.spi.connector.ColumnNotFoundException) CallExpression(io.prestosql.spi.relation.CallExpression) OriginalExpressionUtils.castToRowExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToRowExpression) OriginalExpressionUtils.castToExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToExpression) ArithmeticBinaryExpression(io.prestosql.sql.tree.ArithmeticBinaryExpression) Expression(io.prestosql.sql.tree.Expression) TableHandle(io.prestosql.spi.metadata.TableHandle) ProjectNode(io.prestosql.spi.plan.ProjectNode)

Aggregations

AggregationNode (io.prestosql.spi.plan.AggregationNode)52 Symbol (io.prestosql.spi.plan.Symbol)40 PlanNode (io.prestosql.spi.plan.PlanNode)30 ProjectNode (io.prestosql.spi.plan.ProjectNode)24 Map (java.util.Map)23 CallExpression (io.prestosql.spi.relation.CallExpression)20 ImmutableMap (com.google.common.collect.ImmutableMap)17 Assignments (io.prestosql.spi.plan.Assignments)16 Aggregation (io.prestosql.spi.plan.AggregationNode.Aggregation)14 RowExpression (io.prestosql.spi.relation.RowExpression)14 Expression (io.prestosql.sql.tree.Expression)14 List (java.util.List)14 ImmutableList (com.google.common.collect.ImmutableList)13 TableScanNode (io.prestosql.spi.plan.TableScanNode)13 Optional (java.util.Optional)13 OriginalExpressionUtils.castToRowExpression (io.prestosql.sql.relational.OriginalExpressionUtils.castToRowExpression)12 Test (org.testng.annotations.Test)12 Session (io.prestosql.Session)11 FilterNode (io.prestosql.spi.plan.FilterNode)11 Metadata (io.prestosql.metadata.Metadata)9