Search in sources :

Example 1 with AggregationSignature

use of io.hetu.core.spi.cube.aggregator.AggregationSignature in project hetu-core by openlookeng.

the class AggregationRewriteWithCube method createScanNode.

public CubeRewriteResult createScanNode(AggregationNode originalAggregationNode, PlanNode filterNode, TableHandle cubeTableHandle, Map<String, ColumnHandle> cubeColumnsMap, List<ColumnMetadata> cubeColumnsMetadata, boolean exactGroupsMatch) {
    Set<Symbol> cubeScanSymbols = new HashSet<>();
    Map<Symbol, ColumnHandle> symbolAssignments = new HashMap<>();
    Set<CubeRewriteResult.DimensionSource> dimensionSymbols = new HashSet<>();
    Set<CubeRewriteResult.AggregatorSource> aggregationColumns = new HashSet<>();
    Set<CubeRewriteResult.AverageAggregatorSource> averageAggregationColumns = new HashSet<>();
    Map<Symbol, ColumnMetadata> symbolMetadataMap = new HashMap<>();
    Map<String, ColumnMetadata> columnMetadataMap = cubeColumnsMetadata.stream().collect(Collectors.toMap(ColumnMetadata::getName, Function.identity()));
    boolean computeAvgDividingSumByCount = true;
    Set<Symbol> filterSymbols = new HashSet<>();
    if (filterNode != null) {
        filterSymbols.addAll(SymbolsExtractor.extractUnique(((FilterNode) filterNode).getPredicate()));
    }
    for (Symbol filterSymbol : filterSymbols) {
        if (symbolMappings.containsKey(filterSymbol.getName()) && symbolMappings.get(filterSymbol.getName()) instanceof ColumnHandle) {
            // output symbol references of the columns in original table
            ColumnHandle originalColumn = (ColumnHandle) symbolMappings.get(filterSymbol.getName());
            ColumnHandle cubeScanColumn = cubeColumnsMap.get(originalColumn.getColumnName());
            ColumnMetadata columnMetadata = columnMetadataMap.get(cubeScanColumn.getColumnName());
            Symbol cubeScanSymbol = symbolAllocator.newSymbol(cubeScanColumn.getColumnName(), columnMetadata.getType());
            cubeScanSymbols.add(cubeScanSymbol);
            symbolAssignments.put(cubeScanSymbol, cubeScanColumn);
            symbolMetadataMap.put(cubeScanSymbol, columnMetadata);
            rewrittenMappings.put(filterSymbol.getName(), cubeScanSymbol);
            dimensionSymbols.add(new CubeRewriteResult.DimensionSource(filterSymbol, cubeScanSymbol));
        }
    }
    for (Symbol originalAggOutputSymbol : originalAggregationNode.getOutputSymbols()) {
        if (symbolMappings.containsKey(originalAggOutputSymbol.getName()) && symbolMappings.get(originalAggOutputSymbol.getName()) instanceof ColumnHandle) {
            // output symbol references of the columns in original table - column part of group by clause
            ColumnHandle originalColumn = (ColumnHandle) symbolMappings.get(originalAggOutputSymbol.getName());
            ColumnHandle cubeScanColumn = cubeColumnsMap.get(originalColumn.getColumnName());
            ColumnMetadata columnMetadata = columnMetadataMap.get(cubeScanColumn.getColumnName());
            if (!symbolAssignments.containsValue(cubeScanColumn)) {
                Symbol cubeScanSymbol = symbolAllocator.newSymbol(cubeScanColumn.getColumnName(), columnMetadata.getType());
                cubeScanSymbols.add(cubeScanSymbol);
                symbolAssignments.put(cubeScanSymbol, cubeScanColumn);
                symbolMetadataMap.put(cubeScanSymbol, columnMetadata);
                rewrittenMappings.put(originalAggOutputSymbol.getName(), cubeScanSymbol);
                dimensionSymbols.add(new CubeRewriteResult.DimensionSource(originalAggOutputSymbol, cubeScanSymbol));
            } else {
                Symbol cubeScanSymbol = symbolAssignments.keySet().stream().filter(key -> cubeScanColumn.equals(symbolAssignments.get(key))).findFirst().get();
                rewrittenMappings.put(originalAggOutputSymbol.getName(), cubeScanSymbol);
                dimensionSymbols.add(new CubeRewriteResult.DimensionSource(originalAggOutputSymbol, cubeScanSymbol));
            }
        } else if (originalAggregationNode.getAggregations().containsKey(originalAggOutputSymbol)) {
            // output symbol is mapped to an aggregation
            AggregationNode.Aggregation aggregation = originalAggregationNode.getAggregations().get(originalAggOutputSymbol);
            String aggFunction = aggregation.getFunctionCall().getDisplayName();
            List<Expression> arguments = aggregation.getArguments() == null ? null : aggregation.getArguments().stream().map(OriginalExpressionUtils::castToExpression).collect(Collectors.toList());
            if (arguments != null && !arguments.isEmpty() && (!(arguments.get(0) instanceof SymbolReference))) {
                log.info("Not a symbol reference in aggregation function. Agg Function = %s, Arguments = %s", aggFunction, arguments);
                continue;
            }
            Object mappedValue = arguments == null || arguments.isEmpty() ? null : symbolMappings.get(((SymbolReference) arguments.get(0)).getName());
            if (mappedValue == null || (mappedValue instanceof LongLiteral && ((LongLiteral) mappedValue).getValue() == 1)) {
                // COUNT aggregation
                if (CubeAggregateFunction.COUNT.getName().equals(aggFunction) && !aggregation.isDistinct()) {
                    // COUNT 1
                    AggregationSignature aggregationSignature = AggregationSignature.count();
                    String cubeColumnName = cubeMetadata.getColumn(aggregationSignature).orElseThrow(() -> new PrestoException(CUBE_ERROR, "Cannot find column associated with aggregation " + aggregationSignature));
                    ColumnHandle cubeColHandle = cubeColumnsMap.get(cubeColumnName);
                    if (!symbolAssignments.containsValue(cubeColHandle)) {
                        ColumnMetadata columnMetadata = columnMetadataMap.get(cubeColHandle.getColumnName());
                        Symbol cubeScanSymbol = symbolAllocator.newSymbol(cubeColHandle.getColumnName(), columnMetadata.getType());
                        cubeScanSymbols.add(cubeScanSymbol);
                        symbolMetadataMap.put(cubeScanSymbol, columnMetadata);
                        symbolAssignments.put(cubeScanSymbol, cubeColHandle);
                        rewrittenMappings.put(originalAggOutputSymbol.getName(), cubeScanSymbol);
                        aggregationColumns.add(new CubeRewriteResult.AggregatorSource(originalAggOutputSymbol, cubeScanSymbol));
                    }
                }
            } else if (mappedValue instanceof ColumnHandle) {
                String originalColumnName = ((ColumnHandle) mappedValue).getColumnName();
                boolean distinct = originalAggregationNode.getAggregations().get(originalAggOutputSymbol).isDistinct();
                switch(aggFunction) {
                    case "min":
                    case "max":
                    case "sum":
                    case "count":
                        AggregationSignature aggregationSignature = new AggregationSignature(aggFunction, originalColumnName, distinct);
                        String cubeColumnName = cubeMetadata.getColumn(aggregationSignature).orElseThrow(() -> new PrestoException(CUBE_ERROR, "Cannot find column associated with aggregation " + aggregationSignature));
                        ColumnHandle cubeColHandle = cubeColumnsMap.get(cubeColumnName);
                        if (!symbolAssignments.containsValue(cubeColHandle)) {
                            ColumnMetadata columnMetadata = columnMetadataMap.get(cubeColHandle.getColumnName());
                            Symbol cubeScanSymbol = symbolAllocator.newSymbol(cubeColHandle.getColumnName(), columnMetadata.getType());
                            cubeScanSymbols.add(cubeScanSymbol);
                            symbolAssignments.put(cubeScanSymbol, cubeColHandle);
                            symbolMetadataMap.put(cubeScanSymbol, columnMetadata);
                            rewrittenMappings.put(originalAggOutputSymbol.getName(), cubeScanSymbol);
                            aggregationColumns.add(new CubeRewriteResult.AggregatorSource(originalAggOutputSymbol, cubeScanSymbol));
                        } else {
                            ColumnMetadata columnMetadata = columnMetadataMap.get(cubeColHandle.getColumnName());
                            Symbol cubeScanSymbol = symbolAssignments.keySet().stream().filter(key -> cubeColHandle.equals(symbolAssignments.get(key))).findFirst().get();
                            cubeScanSymbols.add(cubeScanSymbol);
                            symbolAssignments.put(cubeScanSymbol, cubeColHandle);
                            symbolMetadataMap.put(cubeScanSymbol, columnMetadata);
                            rewrittenMappings.put(originalAggOutputSymbol.getName(), cubeScanSymbol);
                            aggregationColumns.add(new CubeRewriteResult.AggregatorSource(originalAggOutputSymbol, cubeScanSymbol));
                        }
                        break;
                    case "avg":
                        AggregationSignature avgAggregationSignature = new AggregationSignature(aggFunction, originalColumnName, distinct);
                        if (exactGroupsMatch && cubeMetadata.getColumn(avgAggregationSignature).isPresent()) {
                            computeAvgDividingSumByCount = false;
                            String avgCubeColumnName = cubeMetadata.getColumn(avgAggregationSignature).orElseThrow(() -> new PrestoException(CUBE_ERROR, "Cannot find column associated with aggregation " + avgAggregationSignature));
                            ColumnHandle avgCubeColHandle = cubeColumnsMap.get(avgCubeColumnName);
                            if (!symbolAssignments.containsValue(avgCubeColHandle)) {
                                ColumnMetadata columnMetadata = columnMetadataMap.get(avgCubeColHandle.getColumnName());
                                Symbol cubeScanSymbol = symbolAllocator.newSymbol(avgCubeColHandle.getColumnName(), columnMetadata.getType());
                                cubeScanSymbols.add(cubeScanSymbol);
                                symbolAssignments.put(cubeScanSymbol, avgCubeColHandle);
                                symbolMetadataMap.put(cubeScanSymbol, columnMetadata);
                                rewrittenMappings.put(originalAggOutputSymbol.getName(), cubeScanSymbol);
                                aggregationColumns.add(new CubeRewriteResult.AggregatorSource(originalAggOutputSymbol, cubeScanSymbol));
                            }
                        } else {
                            AggregationSignature sumSignature = new AggregationSignature(SUM.getName(), originalColumnName, distinct);
                            String sumColumnName = cubeMetadata.getColumn(sumSignature).orElseThrow(() -> new PrestoException(CUBE_ERROR, "Cannot find column associated with aggregation " + sumSignature));
                            ColumnHandle sumColumnHandle = cubeColumnsMap.get(sumColumnName);
                            Symbol sumSymbol = null;
                            if (!symbolAssignments.containsValue(sumColumnHandle)) {
                                ColumnMetadata columnMetadata = columnMetadataMap.get(sumColumnHandle.getColumnName());
                                sumSymbol = symbolAllocator.newSymbol("sum_" + originalColumnName + "_" + originalAggOutputSymbol.getName(), columnMetadata.getType());
                                cubeScanSymbols.add(sumSymbol);
                                symbolAssignments.put(sumSymbol, sumColumnHandle);
                                symbolMetadataMap.put(sumSymbol, columnMetadata);
                                rewrittenMappings.put(sumSymbol.getName(), sumSymbol);
                                aggregationColumns.add(new CubeRewriteResult.AggregatorSource(sumSymbol, sumSymbol));
                            } else {
                                for (Map.Entry<Symbol, ColumnHandle> assignment : symbolAssignments.entrySet()) {
                                    if (assignment.getValue().equals(sumColumnHandle)) {
                                        sumSymbol = assignment.getKey();
                                        break;
                                    }
                                }
                            }
                            AggregationSignature countSignature = new AggregationSignature(COUNT.getName(), originalColumnName, distinct);
                            String countColumnName = cubeMetadata.getColumn(countSignature).orElseThrow(() -> new PrestoException(CUBE_ERROR, "Cannot find column associated with aggregation " + countSignature));
                            ColumnHandle countColumnHandle = cubeColumnsMap.get(countColumnName);
                            Symbol countSymbol = null;
                            if (!symbolAssignments.containsValue(countColumnHandle)) {
                                ColumnMetadata columnMetadata = columnMetadataMap.get(countColumnHandle.getColumnName());
                                countSymbol = symbolAllocator.newSymbol("count_" + originalColumnName + "_" + originalAggOutputSymbol.getName(), columnMetadata.getType());
                                cubeScanSymbols.add(countSymbol);
                                symbolAssignments.put(countSymbol, countColumnHandle);
                                symbolMetadataMap.put(countSymbol, columnMetadata);
                                rewrittenMappings.put(countSymbol.getName(), countSymbol);
                                aggregationColumns.add(new CubeRewriteResult.AggregatorSource(countSymbol, countSymbol));
                            } else {
                                for (Map.Entry<Symbol, ColumnHandle> assignment : symbolAssignments.entrySet()) {
                                    if (assignment.getValue().equals(countColumnHandle)) {
                                        countSymbol = assignment.getKey();
                                        break;
                                    }
                                }
                            }
                            averageAggregationColumns.add(new CubeRewriteResult.AverageAggregatorSource(originalAggOutputSymbol, sumSymbol, countSymbol));
                        }
                        break;
                    default:
                        throw new PrestoException(StandardErrorCode.GENERIC_INTERNAL_ERROR, "Unsupported aggregation function " + aggFunction);
                }
            } else {
                log.info("Aggregation function argument is not a Column Handle. Agg Function = %s, Arguments = %s", aggFunction, arguments);
            }
        }
    }
    // Scan output order is important for partitioned cubes. Otherwise, incorrect results may be produced.
    // Refer: https://gitee.com/openlookeng/hetu-core/issues/I4LAYC
    List<Symbol> scanOutput = new ArrayList<>(cubeScanSymbols);
    scanOutput.sort(Comparator.comparingInt(outSymbol -> cubeColumnsMetadata.indexOf(symbolMetadataMap.get(outSymbol))));
    TableScanNode tableScanNode = TableScanNode.newInstance(idAllocator.getNextId(), cubeTableHandle, scanOutput, symbolAssignments, ReuseExchangeOperator.STRATEGY.REUSE_STRATEGY_DEFAULT, new UUID(0, 0), 0, false);
    return new CubeRewriteResult(tableScanNode, symbolMetadataMap, dimensionSymbols, aggregationColumns, averageAggregationColumns, computeAvgDividingSumByCount);
}
Also used : TypeProvider(io.prestosql.sql.planner.TypeProvider) ColumnNotFoundException(io.prestosql.spi.connector.ColumnNotFoundException) AggregationNode(io.prestosql.spi.plan.AggregationNode) CallExpression(io.prestosql.spi.relation.CallExpression) Cast(io.prestosql.sql.tree.Cast) FilterNode(io.prestosql.spi.plan.FilterNode) Map(java.util.Map) OriginalExpressionUtils.castToRowExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToRowExpression) Type(io.prestosql.spi.type.Type) OriginalExpressionUtils.castToExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToExpression) AggregationNode.singleGroupingSet(io.prestosql.spi.plan.AggregationNode.singleGroupingSet) PrestoException(io.prestosql.spi.PrestoException) SymbolsExtractor(io.prestosql.sql.planner.SymbolsExtractor) SUM(io.hetu.core.spi.cube.CubeAggregateFunction.SUM) ImmutableMap(com.google.common.collect.ImmutableMap) CubeAggregateFunction(io.hetu.core.spi.cube.CubeAggregateFunction) TableScanNode(io.prestosql.spi.plan.TableScanNode) Set(java.util.Set) PlanNode(io.prestosql.spi.plan.PlanNode) UUID(java.util.UUID) ProjectNode(io.prestosql.spi.plan.ProjectNode) Collectors(java.util.stream.Collectors) Metadata(io.prestosql.metadata.Metadata) FunctionHandle(io.prestosql.spi.function.FunctionHandle) PlanSymbolAllocator(io.prestosql.sql.planner.PlanSymbolAllocator) SymbolUtils.toSymbolReference(io.prestosql.sql.planner.SymbolUtils.toSymbolReference) ReuseExchangeOperator(io.prestosql.spi.operator.ReuseExchangeOperator) List(java.util.List) LongLiteral(io.prestosql.sql.tree.LongLiteral) SymbolReference(io.prestosql.sql.tree.SymbolReference) AggregationSignature(io.hetu.core.spi.cube.aggregator.AggregationSignature) Optional(java.util.Optional) StandardErrorCode(io.prestosql.spi.StandardErrorCode) OriginalExpressionUtils(io.prestosql.sql.relational.OriginalExpressionUtils) COUNT(io.hetu.core.spi.cube.CubeAggregateFunction.COUNT) TableMetadata(io.prestosql.metadata.TableMetadata) Logger(io.airlift.log.Logger) TypeSignatureProvider(io.prestosql.sql.analyzer.TypeSignatureProvider) HashMap(java.util.HashMap) TableHandle(io.prestosql.spi.metadata.TableHandle) Function(java.util.function.Function) ExpressionTreeRewriter(io.prestosql.sql.tree.ExpressionTreeRewriter) QualifiedObjectName(io.prestosql.spi.connector.QualifiedObjectName) ArrayList(java.util.ArrayList) HashSet(java.util.HashSet) SchemaTableName(io.prestosql.spi.connector.SchemaTableName) ArithmeticBinaryExpression(io.prestosql.sql.tree.ArithmeticBinaryExpression) ImmutableList(com.google.common.collect.ImmutableList) Session(io.prestosql.Session) ExpressionRewriter(io.prestosql.sql.tree.ExpressionRewriter) Symbol(io.prestosql.spi.plan.Symbol) Assignments(io.prestosql.spi.plan.Assignments) ColumnMetadata(io.prestosql.spi.connector.ColumnMetadata) SymbolAllocator(io.prestosql.spi.SymbolAllocator) ColumnHandle(io.prestosql.spi.connector.ColumnHandle) CUBE_ERROR(io.prestosql.spi.StandardErrorCode.CUBE_ERROR) PlanNodeIdAllocator(io.prestosql.spi.plan.PlanNodeIdAllocator) CubeMetadata(io.hetu.core.spi.cube.CubeMetadata) CubeNotFoundException(io.prestosql.spi.connector.CubeNotFoundException) Comparator(java.util.Comparator) Expression(io.prestosql.sql.tree.Expression) ColumnMetadata(io.prestosql.spi.connector.ColumnMetadata) HashMap(java.util.HashMap) Symbol(io.prestosql.spi.plan.Symbol) SymbolUtils.toSymbolReference(io.prestosql.sql.planner.SymbolUtils.toSymbolReference) SymbolReference(io.prestosql.sql.tree.SymbolReference) FilterNode(io.prestosql.spi.plan.FilterNode) ArrayList(java.util.ArrayList) PrestoException(io.prestosql.spi.PrestoException) List(java.util.List) ArrayList(java.util.ArrayList) ImmutableList(com.google.common.collect.ImmutableList) UUID(java.util.UUID) HashSet(java.util.HashSet) ColumnHandle(io.prestosql.spi.connector.ColumnHandle) LongLiteral(io.prestosql.sql.tree.LongLiteral) AggregationSignature(io.hetu.core.spi.cube.aggregator.AggregationSignature) TableScanNode(io.prestosql.spi.plan.TableScanNode) Map(java.util.Map) ImmutableMap(com.google.common.collect.ImmutableMap) HashMap(java.util.HashMap)

Example 2 with AggregationSignature

use of io.hetu.core.spi.cube.aggregator.AggregationSignature in project hetu-core by openlookeng.

the class AggregationRewriteWithCube method rewrite.

public PlanNode rewrite(AggregationNode originalAggregationNode, PlanNode filterNode) {
    QualifiedObjectName starTreeTableName = QualifiedObjectName.valueOf(cubeMetadata.getCubeName());
    TableHandle cubeTableHandle = metadata.getTableHandle(session, starTreeTableName).orElseThrow(() -> new CubeNotFoundException(starTreeTableName.toString()));
    Map<String, ColumnHandle> cubeColumnsMap = metadata.getColumnHandles(session, cubeTableHandle);
    TableMetadata cubeTableMetadata = metadata.getTableMetadata(session, cubeTableHandle);
    List<ColumnMetadata> cubeColumnMetadataList = cubeTableMetadata.getColumns();
    // Add group by
    List<Symbol> groupings = new ArrayList<>(originalAggregationNode.getGroupingKeys().size());
    for (Symbol symbol : originalAggregationNode.getGroupingKeys()) {
        Object column = symbolMappings.get(symbol.getName());
        if (column instanceof ColumnHandle) {
            groupings.add(new Symbol(((ColumnHandle) column).getColumnName()));
        }
    }
    Set<String> cubeGroups = cubeMetadata.getGroup();
    boolean exactGroupsMatch = false;
    if (groupings.size() == cubeGroups.size()) {
        exactGroupsMatch = groupings.stream().map(Symbol::getName).map(String::toLowerCase).allMatch(cubeGroups::contains);
    }
    CubeRewriteResult cubeRewriteResult = createScanNode(originalAggregationNode, filterNode, cubeTableHandle, cubeColumnsMap, cubeColumnMetadataList, exactGroupsMatch);
    PlanNode planNode = cubeRewriteResult.getTableScanNode();
    // Add filter node
    if (filterNode != null) {
        Expression expression = castToExpression(((FilterNode) filterNode).getPredicate());
        expression = rewriteExpression(expression, rewrittenMappings);
        planNode = new FilterNode(idAllocator.getNextId(), planNode, castToRowExpression(expression));
    }
    if (!exactGroupsMatch) {
        Map<Symbol, Symbol> cubeScanToAggOutputMap = new HashMap<>();
        // Rewrite AggregationNode using Cube table
        ImmutableMap.Builder<Symbol, AggregationNode.Aggregation> aggregationsBuilder = ImmutableMap.builder();
        for (CubeRewriteResult.AggregatorSource aggregatorSource : cubeRewriteResult.getAggregationColumns()) {
            ColumnHandle cubeColHandle = cubeRewriteResult.getTableScanNode().getAssignments().get(aggregatorSource.getScanSymbol());
            ColumnMetadata cubeColumnMetadata = cubeRewriteResult.getSymbolMetadataMap().get(aggregatorSource.getScanSymbol());
            Type type = cubeColumnMetadata.getType();
            AggregationSignature aggregationSignature = cubeMetadata.getAggregationSignature(cubeColumnMetadata.getName()).orElseThrow(() -> new ColumnNotFoundException(new SchemaTableName(starTreeTableName.getSchemaName(), starTreeTableName.getObjectName()), cubeColHandle.getColumnName()));
            String aggFunction = COUNT.getName().equals(aggregationSignature.getFunction()) ? "sum" : aggregationSignature.getFunction();
            SymbolReference argument = toSymbolReference(aggregatorSource.getScanSymbol());
            FunctionHandle functionHandle = metadata.getFunctionAndTypeManager().lookupFunction(aggFunction, TypeSignatureProvider.fromTypeSignatures(type.getTypeSignature()));
            cubeScanToAggOutputMap.put(aggregatorSource.getScanSymbol(), aggregatorSource.getOriginalAggSymbol());
            aggregationsBuilder.put(aggregatorSource.getOriginalAggSymbol(), new AggregationNode.Aggregation(new CallExpression(aggFunction, functionHandle, type, ImmutableList.of(OriginalExpressionUtils.castToRowExpression(argument))), ImmutableList.of(OriginalExpressionUtils.castToRowExpression(argument)), false, Optional.empty(), Optional.empty(), Optional.empty()));
        }
        List<Symbol> groupingKeys = originalAggregationNode.getGroupingKeys().stream().map(Symbol::getName).map(rewrittenMappings::get).collect(Collectors.toList());
        planNode = new AggregationNode(idAllocator.getNextId(), planNode, aggregationsBuilder.build(), singleGroupingSet(groupingKeys), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty(), AggregationNode.AggregationType.HASH, Optional.empty());
        AggregationNode aggNode = (AggregationNode) planNode;
        if (!cubeRewriteResult.getAvgAggregationColumns().isEmpty()) {
            if (!cubeRewriteResult.getComputeAvgDividingSumByCount()) {
                Map<Symbol, Expression> aggregateAssignments = new HashMap<>();
                for (CubeRewriteResult.AggregatorSource aggregatorSource : cubeRewriteResult.getAggregationColumns()) {
                    aggregateAssignments.put(aggregatorSource.getOriginalAggSymbol(), toSymbolReference(aggregatorSource.getScanSymbol()));
                }
                planNode = new ProjectNode(idAllocator.getNextId(), aggNode, new Assignments(aggregateAssignments.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> castToRowExpression(entry.getValue())))));
            } else {
                // If there was an AVG aggregation, map it to AVG = SUM/COUNT
                Map<Symbol, Expression> projections = new HashMap<>();
                aggNode.getGroupingKeys().forEach(symbol -> projections.put(symbol, toSymbolReference(symbol)));
                aggNode.getAggregations().keySet().stream().filter(symbol -> symbolMappings.containsValue(symbol.getName())).forEach(aggSymbol -> projections.put(aggSymbol, toSymbolReference(aggSymbol)));
                // Add AVG = SUM / COUNT
                for (CubeRewriteResult.AverageAggregatorSource avgAggSource : cubeRewriteResult.getAvgAggregationColumns()) {
                    Symbol sumSymbol = cubeScanToAggOutputMap.get(avgAggSource.getSum());
                    Symbol countSymbol = cubeScanToAggOutputMap.get(avgAggSource.getCount());
                    Type avgResultType = typeProvider.get(avgAggSource.getOriginalAggSymbol());
                    ArithmeticBinaryExpression division = new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.DIVIDE, new Cast(toSymbolReference(sumSymbol), avgResultType.getTypeSignature().toString()), new Cast(toSymbolReference(countSymbol), avgResultType.getTypeSignature().toString()));
                    projections.put(avgAggSource.getOriginalAggSymbol(), division);
                }
                planNode = new ProjectNode(idAllocator.getNextId(), aggNode, new Assignments(projections.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> castToRowExpression(entry.getValue())))));
            }
        }
    }
    // Safety check to remove redundant symbols and rename original column names to intermediate names
    if (!planNode.getOutputSymbols().equals(originalAggregationNode.getOutputSymbols())) {
        // Map new symbol names to the old symbols
        Map<Symbol, Expression> assignments = new HashMap<>();
        Set<Symbol> planNodeOutput = new HashSet<>(planNode.getOutputSymbols());
        for (Symbol originalAggOutputSymbol : originalAggregationNode.getOutputSymbols()) {
            if (!planNodeOutput.contains(originalAggOutputSymbol)) {
                // Must be grouping key
                assignments.put(originalAggOutputSymbol, toSymbolReference(rewrittenMappings.get(originalAggOutputSymbol.getName())));
            } else {
                // Should be an expression and must have the same name in the new plan node
                assignments.put(originalAggOutputSymbol, toSymbolReference(originalAggOutputSymbol));
            }
        }
        planNode = new ProjectNode(idAllocator.getNextId(), planNode, new Assignments(assignments.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> castToRowExpression(entry.getValue())))));
    }
    return planNode;
}
Also used : TypeProvider(io.prestosql.sql.planner.TypeProvider) ColumnNotFoundException(io.prestosql.spi.connector.ColumnNotFoundException) AggregationNode(io.prestosql.spi.plan.AggregationNode) CallExpression(io.prestosql.spi.relation.CallExpression) Cast(io.prestosql.sql.tree.Cast) FilterNode(io.prestosql.spi.plan.FilterNode) Map(java.util.Map) OriginalExpressionUtils.castToRowExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToRowExpression) Type(io.prestosql.spi.type.Type) OriginalExpressionUtils.castToExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToExpression) AggregationNode.singleGroupingSet(io.prestosql.spi.plan.AggregationNode.singleGroupingSet) PrestoException(io.prestosql.spi.PrestoException) SymbolsExtractor(io.prestosql.sql.planner.SymbolsExtractor) SUM(io.hetu.core.spi.cube.CubeAggregateFunction.SUM) ImmutableMap(com.google.common.collect.ImmutableMap) CubeAggregateFunction(io.hetu.core.spi.cube.CubeAggregateFunction) TableScanNode(io.prestosql.spi.plan.TableScanNode) Set(java.util.Set) PlanNode(io.prestosql.spi.plan.PlanNode) UUID(java.util.UUID) ProjectNode(io.prestosql.spi.plan.ProjectNode) Collectors(java.util.stream.Collectors) Metadata(io.prestosql.metadata.Metadata) FunctionHandle(io.prestosql.spi.function.FunctionHandle) PlanSymbolAllocator(io.prestosql.sql.planner.PlanSymbolAllocator) SymbolUtils.toSymbolReference(io.prestosql.sql.planner.SymbolUtils.toSymbolReference) ReuseExchangeOperator(io.prestosql.spi.operator.ReuseExchangeOperator) List(java.util.List) LongLiteral(io.prestosql.sql.tree.LongLiteral) SymbolReference(io.prestosql.sql.tree.SymbolReference) AggregationSignature(io.hetu.core.spi.cube.aggregator.AggregationSignature) Optional(java.util.Optional) StandardErrorCode(io.prestosql.spi.StandardErrorCode) OriginalExpressionUtils(io.prestosql.sql.relational.OriginalExpressionUtils) COUNT(io.hetu.core.spi.cube.CubeAggregateFunction.COUNT) TableMetadata(io.prestosql.metadata.TableMetadata) Logger(io.airlift.log.Logger) TypeSignatureProvider(io.prestosql.sql.analyzer.TypeSignatureProvider) HashMap(java.util.HashMap) TableHandle(io.prestosql.spi.metadata.TableHandle) Function(java.util.function.Function) ExpressionTreeRewriter(io.prestosql.sql.tree.ExpressionTreeRewriter) QualifiedObjectName(io.prestosql.spi.connector.QualifiedObjectName) ArrayList(java.util.ArrayList) HashSet(java.util.HashSet) SchemaTableName(io.prestosql.spi.connector.SchemaTableName) ArithmeticBinaryExpression(io.prestosql.sql.tree.ArithmeticBinaryExpression) ImmutableList(com.google.common.collect.ImmutableList) Session(io.prestosql.Session) ExpressionRewriter(io.prestosql.sql.tree.ExpressionRewriter) Symbol(io.prestosql.spi.plan.Symbol) Assignments(io.prestosql.spi.plan.Assignments) ColumnMetadata(io.prestosql.spi.connector.ColumnMetadata) SymbolAllocator(io.prestosql.spi.SymbolAllocator) ColumnHandle(io.prestosql.spi.connector.ColumnHandle) CUBE_ERROR(io.prestosql.spi.StandardErrorCode.CUBE_ERROR) PlanNodeIdAllocator(io.prestosql.spi.plan.PlanNodeIdAllocator) CubeMetadata(io.hetu.core.spi.cube.CubeMetadata) CubeNotFoundException(io.prestosql.spi.connector.CubeNotFoundException) Comparator(java.util.Comparator) Expression(io.prestosql.sql.tree.Expression) ArithmeticBinaryExpression(io.prestosql.sql.tree.ArithmeticBinaryExpression) Cast(io.prestosql.sql.tree.Cast) ColumnMetadata(io.prestosql.spi.connector.ColumnMetadata) HashMap(java.util.HashMap) Symbol(io.prestosql.spi.plan.Symbol) SymbolUtils.toSymbolReference(io.prestosql.sql.planner.SymbolUtils.toSymbolReference) SymbolReference(io.prestosql.sql.tree.SymbolReference) FilterNode(io.prestosql.spi.plan.FilterNode) ArrayList(java.util.ArrayList) Assignments(io.prestosql.spi.plan.Assignments) PlanNode(io.prestosql.spi.plan.PlanNode) CubeNotFoundException(io.prestosql.spi.connector.CubeNotFoundException) FunctionHandle(io.prestosql.spi.function.FunctionHandle) CallExpression(io.prestosql.spi.relation.CallExpression) HashSet(java.util.HashSet) TableMetadata(io.prestosql.metadata.TableMetadata) ColumnHandle(io.prestosql.spi.connector.ColumnHandle) AggregationSignature(io.hetu.core.spi.cube.aggregator.AggregationSignature) AggregationNode(io.prestosql.spi.plan.AggregationNode) SchemaTableName(io.prestosql.spi.connector.SchemaTableName) QualifiedObjectName(io.prestosql.spi.connector.QualifiedObjectName) ImmutableMap(com.google.common.collect.ImmutableMap) Type(io.prestosql.spi.type.Type) ColumnNotFoundException(io.prestosql.spi.connector.ColumnNotFoundException) CallExpression(io.prestosql.spi.relation.CallExpression) OriginalExpressionUtils.castToRowExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToRowExpression) OriginalExpressionUtils.castToExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToExpression) ArithmeticBinaryExpression(io.prestosql.sql.tree.ArithmeticBinaryExpression) Expression(io.prestosql.sql.tree.Expression) TableHandle(io.prestosql.spi.metadata.TableHandle) ProjectNode(io.prestosql.spi.plan.ProjectNode)

Example 3 with AggregationSignature

use of io.hetu.core.spi.cube.aggregator.AggregationSignature in project hetu-core by openlookeng.

the class CreateCubeTask method internalExecute.

@VisibleForTesting
public ListenableFuture<?> internalExecute(CreateCube statement, Metadata metadata, AccessControl accessControl, Session session, QueryStateMachine stateMachine, List<Expression> parameters) {
    Optional<CubeMetaStore> optionalCubeMetaStore = cubeManager.getMetaStore(STAR_TREE);
    if (!optionalCubeMetaStore.isPresent()) {
        throw new RuntimeException("HetuMetaStore is not initialized");
    }
    QualifiedObjectName cubeName = createQualifiedObjectName(session, statement, statement.getCubeName());
    QualifiedObjectName tableName = createQualifiedObjectName(session, statement, statement.getSourceTableName());
    Optional<TableHandle> cubeHandle = metadata.getTableHandle(session, cubeName);
    Optional<TableHandle> tableHandle = metadata.getTableHandle(session, tableName);
    if (optionalCubeMetaStore.get().getMetadataFromCubeName(cubeName.toString()).isPresent()) {
        if (!statement.isNotExists()) {
            throw new SemanticException(CUBE_ALREADY_EXISTS, statement, "Cube '%s' already exists", cubeName);
        }
        return immediateFuture(null);
    }
    if (cubeHandle.isPresent()) {
        if (!statement.isNotExists()) {
            throw new SemanticException(CUBE_OR_TABLE_ALREADY_EXISTS, statement, "Cube or Table '%s' already exists", cubeName);
        }
        return immediateFuture(null);
    }
    CatalogName catalogName = metadata.getCatalogHandle(session, cubeName.getCatalogName()).orElseThrow(() -> new PrestoException(NOT_FOUND, "Catalog not found: " + cubeName.getCatalogName()));
    if (!metadata.isPreAggregationSupported(session, catalogName)) {
        throw new PrestoException(StandardErrorCode.NOT_SUPPORTED, String.format("Cube cannot be created on catalog '%s'", catalogName.toString()));
    }
    if (!tableHandle.isPresent()) {
        throw new SemanticException(MISSING_TABLE, statement, "Table '%s' does not exist", tableName);
    }
    TableMetadata tableMetadata = metadata.getTableMetadata(session, tableHandle.get());
    List<String> groupingSet = statement.getGroupingSet().stream().map(s -> s.getValue().toLowerCase(ENGLISH)).collect(Collectors.toList());
    Map<String, ColumnMetadata> sourceTableColumns = tableMetadata.getColumns().stream().collect(Collectors.toMap(ColumnMetadata::getName, col -> col));
    List<ColumnMetadata> cubeColumns = new ArrayList<>();
    Map<String, AggregationSignature> aggregations = new HashMap<>();
    Analysis analysis = analyzeStatement(statement, session, metadata, accessControl, parameters, stateMachine.getWarningCollector());
    Map<String, Field> fields = analysis.getOutputDescriptor().getAllFields().stream().collect(Collectors.toMap(col -> col.getName().map(String::toLowerCase).get(), col -> col));
    for (FunctionCall aggFunction : statement.getAggregations()) {
        String aggFunctionName = aggFunction.getName().toString().toLowerCase(ENGLISH);
        String argument = aggFunction.getArguments().isEmpty() || aggFunction.getArguments().get(0) instanceof LongLiteral ? null : ((Identifier) aggFunction.getArguments().get(0)).getValue().toLowerCase(ENGLISH);
        boolean distinct = aggFunction.isDistinct();
        String cubeColumnName = aggFunctionName + "_" + (argument == null ? "all" : argument) + (aggFunction.isDistinct() ? "_distinct" : "");
        CubeAggregateFunction cubeAggregateFunction = CubeAggregateFunction.valueOf(aggFunctionName.toUpperCase(ENGLISH));
        switch(cubeAggregateFunction) {
            case SUM:
                aggregations.put(cubeColumnName, AggregationSignature.sum(argument, distinct));
                break;
            case COUNT:
                AggregationSignature aggregationSignature = argument == null ? AggregationSignature.count() : AggregationSignature.count(argument, distinct);
                aggregations.put(cubeColumnName, aggregationSignature);
                break;
            case AVG:
                aggregations.put(cubeColumnName, AggregationSignature.avg(argument, distinct));
                break;
            case MAX:
                aggregations.put(cubeColumnName, AggregationSignature.max(argument, distinct));
                break;
            case MIN:
                aggregations.put(cubeColumnName, AggregationSignature.min(argument, distinct));
                break;
            default:
                throw new PrestoException(NOT_SUPPORTED, format("Unsupported aggregation function : %s", aggFunctionName));
        }
        Field tableField = fields.get(cubeColumnName);
        ColumnMetadata cubeCol = new ColumnMetadata(cubeColumnName, tableField.getType(), true, null, null, false, Collections.emptyMap());
        cubeColumns.add(cubeCol);
    }
    accessControl.checkCanCreateTable(session.getRequiredTransactionId(), session.getIdentity(), tableName);
    Map<String, Expression> sqlProperties = mapFromProperties(statement.getProperties());
    Map<String, Object> properties = metadata.getTablePropertyManager().getProperties(catalogName, cubeName.getCatalogName(), sqlProperties, session, metadata, parameters);
    if (properties.containsKey("partitioned_by")) {
        List<String> partitionCols = new ArrayList<>(((List<String>) properties.get("partitioned_by")));
        // put all partition columns at the end of the list
        groupingSet.removeAll(partitionCols);
        groupingSet.addAll(partitionCols);
    }
    for (String dimension : groupingSet) {
        if (!sourceTableColumns.containsKey(dimension)) {
            throw new SemanticException(MISSING_COLUMN, statement, "Column %s does not exist", dimension);
        }
        ColumnMetadata tableCol = sourceTableColumns.get(dimension);
        ColumnMetadata cubeCol = new ColumnMetadata(dimension, tableCol.getType(), tableCol.isNullable(), null, null, false, tableCol.getProperties());
        cubeColumns.add(cubeCol);
    }
    ConnectorTableMetadata cubeTableMetadata = new ConnectorTableMetadata(cubeName.asSchemaTableName(), ImmutableList.copyOf(cubeColumns), properties);
    try {
        metadata.createTable(session, cubeName.getCatalogName(), cubeTableMetadata, statement.isNotExists());
    } catch (PrestoException e) {
        // connectors are not required to handle the ignoreExisting flag
        if (!e.getErrorCode().equals(ALREADY_EXISTS.toErrorCode()) || !statement.isNotExists()) {
            throw e;
        }
    }
    CubeMetadataBuilder builder = optionalCubeMetaStore.get().getBuilder(cubeName.toString(), tableName.toString());
    groupingSet.forEach(dimension -> builder.addDimensionColumn(dimension, dimension));
    aggregations.forEach((column, aggregationSignature) -> builder.addAggregationColumn(column, aggregationSignature.getFunction(), aggregationSignature.getDimension(), aggregationSignature.isDistinct()));
    builder.addGroup(new HashSet<>(groupingSet));
    // Status and Table modified time will be updated on the first insert into the cube
    builder.setCubeStatus(CubeStatus.INACTIVE);
    builder.setTableLastUpdatedTime(-1L);
    statement.getSourceFilter().ifPresent(sourceTablePredicate -> {
        sourceTablePredicate = Coercer.addCoercions(sourceTablePredicate, analysis);
        builder.withCubeFilter(new CubeFilter(ExpressionFormatter.formatExpression(sourceTablePredicate, Optional.empty())));
    });
    builder.setCubeLastUpdatedTime(System.currentTimeMillis());
    optionalCubeMetaStore.get().persist(builder.build());
    return immediateFuture(null);
}
Also used : SqlParser(io.prestosql.sql.parser.SqlParser) CreateCube(io.prestosql.sql.tree.CreateCube) Statement(io.prestosql.sql.tree.Statement) CUBE_OR_TABLE_ALREADY_EXISTS(io.prestosql.sql.analyzer.SemanticErrorCode.CUBE_OR_TABLE_ALREADY_EXISTS) WarningCollector(io.prestosql.execution.warnings.WarningCollector) ExpressionFormatter(io.prestosql.sql.ExpressionFormatter) Map(java.util.Map) CubeFilter(io.hetu.core.spi.cube.CubeFilter) ENGLISH(java.util.Locale.ENGLISH) Coercer(io.prestosql.sql.planner.Coercer) Identifier(io.prestosql.sql.tree.Identifier) VisibleForTesting(org.assertj.core.util.VisibleForTesting) CubeMetaStore(io.hetu.core.spi.cube.io.CubeMetaStore) HeuristicIndexerManager(io.prestosql.heuristicindex.HeuristicIndexerManager) PrestoException(io.prestosql.spi.PrestoException) AccessControl(io.prestosql.security.AccessControl) CatalogName(io.prestosql.spi.connector.CatalogName) CubeAggregateFunction(io.hetu.core.spi.cube.CubeAggregateFunction) Collectors(java.util.stream.Collectors) Metadata(io.prestosql.metadata.Metadata) String.format(java.lang.String.format) List(java.util.List) LongLiteral(io.prestosql.sql.tree.LongLiteral) AggregationSignature(io.hetu.core.spi.cube.aggregator.AggregationSignature) Optional(java.util.Optional) Analysis(io.prestosql.sql.analyzer.Analysis) NOT_SUPPORTED(io.prestosql.spi.StandardErrorCode.NOT_SUPPORTED) StandardErrorCode(io.prestosql.spi.StandardErrorCode) STAR_TREE(io.prestosql.cube.CubeManager.STAR_TREE) CubeMetadataBuilder(io.hetu.core.spi.cube.CubeMetadataBuilder) Field(io.prestosql.sql.analyzer.Field) MetadataUtil.createQualifiedObjectName(io.prestosql.metadata.MetadataUtil.createQualifiedObjectName) NodeUtils.mapFromProperties(io.prestosql.sql.NodeUtils.mapFromProperties) ALREADY_EXISTS(io.prestosql.spi.StandardErrorCode.ALREADY_EXISTS) CUBE_ALREADY_EXISTS(io.prestosql.sql.analyzer.SemanticErrorCode.CUBE_ALREADY_EXISTS) ListenableFuture(com.google.common.util.concurrent.ListenableFuture) TableMetadata(io.prestosql.metadata.TableMetadata) TransactionManager(io.prestosql.transaction.TransactionManager) HashMap(java.util.HashMap) NOT_FOUND(io.prestosql.spi.StandardErrorCode.NOT_FOUND) TableHandle(io.prestosql.spi.metadata.TableHandle) QualifiedObjectName(io.prestosql.spi.connector.QualifiedObjectName) ArrayList(java.util.ArrayList) Inject(javax.inject.Inject) HashSet(java.util.HashSet) CubeStatus(io.hetu.core.spi.cube.CubeStatus) SemanticException(io.prestosql.sql.analyzer.SemanticException) ImmutableList(com.google.common.collect.ImmutableList) FunctionCall(io.prestosql.sql.tree.FunctionCall) Session(io.prestosql.Session) MISSING_COLUMN(io.prestosql.sql.analyzer.SemanticErrorCode.MISSING_COLUMN) Futures.immediateFuture(com.google.common.util.concurrent.Futures.immediateFuture) ColumnMetadata(io.prestosql.spi.connector.ColumnMetadata) MISSING_TABLE(io.prestosql.sql.analyzer.SemanticErrorCode.MISSING_TABLE) ConnectorTableMetadata(io.prestosql.spi.connector.ConnectorTableMetadata) Analyzer(io.prestosql.sql.analyzer.Analyzer) CubeManager(io.prestosql.cube.CubeManager) Collections(java.util.Collections) Expression(io.prestosql.sql.tree.Expression) ColumnMetadata(io.prestosql.spi.connector.ColumnMetadata) HashMap(java.util.HashMap) CubeAggregateFunction(io.hetu.core.spi.cube.CubeAggregateFunction) ArrayList(java.util.ArrayList) CubeMetaStore(io.hetu.core.spi.cube.io.CubeMetaStore) PrestoException(io.prestosql.spi.PrestoException) Field(io.prestosql.sql.analyzer.Field) Identifier(io.prestosql.sql.tree.Identifier) CubeMetadataBuilder(io.hetu.core.spi.cube.CubeMetadataBuilder) CubeFilter(io.hetu.core.spi.cube.CubeFilter) List(java.util.List) ArrayList(java.util.ArrayList) ImmutableList(com.google.common.collect.ImmutableList) FunctionCall(io.prestosql.sql.tree.FunctionCall) ConnectorTableMetadata(io.prestosql.spi.connector.ConnectorTableMetadata) SemanticException(io.prestosql.sql.analyzer.SemanticException) TableMetadata(io.prestosql.metadata.TableMetadata) ConnectorTableMetadata(io.prestosql.spi.connector.ConnectorTableMetadata) LongLiteral(io.prestosql.sql.tree.LongLiteral) AggregationSignature(io.hetu.core.spi.cube.aggregator.AggregationSignature) MetadataUtil.createQualifiedObjectName(io.prestosql.metadata.MetadataUtil.createQualifiedObjectName) QualifiedObjectName(io.prestosql.spi.connector.QualifiedObjectName) Expression(io.prestosql.sql.tree.Expression) Analysis(io.prestosql.sql.analyzer.Analysis) TableHandle(io.prestosql.spi.metadata.TableHandle) CatalogName(io.prestosql.spi.connector.CatalogName) VisibleForTesting(org.assertj.core.util.VisibleForTesting)

Example 4 with AggregationSignature

use of io.hetu.core.spi.cube.aggregator.AggregationSignature in project hetu-core by openlookeng.

the class CubeOptimizer method createScanNode.

public CubeRewriteResult createScanNode() {
    Set<Symbol> cubeScanSymbols = new HashSet<>();
    Map<Symbol, ColumnHandle> cubeColumnHandles = new HashMap<>();
    Set<CubeRewriteResult.DimensionSource> dimensionSymbols = new HashSet<>();
    Set<CubeRewriteResult.AggregatorSource> aggregationColumns = new HashSet<>();
    Set<CubeRewriteResult.AverageAggregatorSource> averageAggregationColumns = new HashSet<>();
    Map<Symbol, ColumnMetadata> symbolMetadataMap = new HashMap<>();
    SymbolAllocator symbolAllocator = context.getSymbolAllocator();
    dimensions.forEach(dimension -> {
        // assumed that source table dimension is same in cube table as well
        ColumnWithTable sourceColumn = originalPlanMappings.get(dimension);
        ColumnHandle cubeColumn = cubeColumnsMap.get(sourceColumn.getColumnName());
        ColumnMetadata columnMetadata = metadata.getColumnMetadata(context.getSession(), cubeTableHandle, cubeColumn);
        Symbol cubeScanSymbol = symbolAllocator.newSymbol(cubeColumn.getColumnName(), columnMetadata.getType());
        cubeScanSymbols.add(cubeScanSymbol);
        cubeColumnHandles.put(cubeScanSymbol, cubeColumn);
        dimensionSymbols.add(new CubeRewriteResult.DimensionSource(cubeScanSymbol, cubeScanSymbol));
    });
    filterColumns.forEach(filterColumn -> {
        ColumnHandle cubeColumn = cubeColumnsMap.get(filterColumn);
        if (cubeColumn == null) {
            // Predicate rewrite will take care of remove those columns from predicate
            return;
        }
        ColumnMetadata columnMetadata = metadata.getColumnMetadata(context.getSession(), cubeTableHandle, cubeColumn);
        Symbol cubeScanSymbol = symbolAllocator.newSymbol(cubeColumn.getColumnName(), columnMetadata.getType());
        cubeScanSymbols.add(cubeScanSymbol);
        cubeColumnHandles.put(cubeScanSymbol, cubeColumn);
        dimensionSymbols.add(new CubeRewriteResult.DimensionSource(cubeScanSymbol, cubeScanSymbol));
    });
    // process all aggregations except avg.
    originalAggregationsMap.forEach((originalAgg, originalAggOutputSymbol) -> {
        if (originalAgg.getFunction().equalsIgnoreCase(CubeAggregateFunction.AVG.toString())) {
            // skip average aggregation
            return;
        }
        String cubeColumnName = cubeMetadata.getColumn(originalAgg).orElseThrow(() -> new PrestoException(CUBE_ERROR, "Cannot find column associated with aggregation " + originalAgg));
        ColumnHandle cubeColHandle = cubeColumnsMap.get(cubeColumnName);
        ColumnMetadata columnMetadata = metadata.getColumnMetadata(context.getSession(), cubeTableHandle, cubeColHandle);
        Symbol cubeScanSymbol = symbolAllocator.newSymbol(cubeColHandle.getColumnName(), columnMetadata.getType());
        cubeScanSymbols.add(cubeScanSymbol);
        cubeColumnHandles.put(cubeScanSymbol, cubeColHandle);
        symbolMetadataMap.put(originalAggOutputSymbol, columnMetadata);
        aggregationColumns.add(new CubeRewriteResult.AggregatorSource(originalAggOutputSymbol, cubeScanSymbol));
    });
    boolean computeAvgDividingSumByCount = true;
    // process average aggregation
    for (AggregationSignature originalAgg : originalAggregationsMap.keySet()) {
        if (!originalAgg.getFunction().equalsIgnoreCase(CubeAggregateFunction.AVG.toString())) {
            // process only average aggregation. skip others
            continue;
        }
        Symbol originalOutputSymbol = originalAggregationsMap.get(originalAgg);
        String aggSourceColumnName = originalAgg.getDimension();
        AggregationSignature avgAggregationSignature = new AggregationSignature(originalAgg.getFunction(), aggSourceColumnName, originalAgg.isDistinct());
        // exactGroupByMatch is True iff query group by clause by does not contain columns from other table. This is extremely unlikely for Join queries
        boolean exactGroupByMatch = !groupByFromOtherTable && groupBy.size() == cubeMetadata.getGroup().size() && groupBy.stream().map(String::toLowerCase).allMatch(cubeMetadata.getGroup()::contains);
        if (exactGroupByMatch && cubeMetadata.getColumn(avgAggregationSignature).isPresent()) {
            // Use AVG column for exact group match only if Cube has it. The Cubes created before shipping
            // this optimization will not have AVG column
            computeAvgDividingSumByCount = false;
            String avgCubeColumnName = cubeMetadata.getColumn(avgAggregationSignature).orElseThrow(() -> new PrestoException(CUBE_ERROR, "Cannot find column associated with aggregation " + avgAggregationSignature));
            ColumnHandle avgCubeColHandle = cubeColumnsMap.get(avgCubeColumnName);
            if (!cubeColumnHandles.containsValue(avgCubeColHandle)) {
                ColumnMetadata columnMetadata = metadata.getColumnMetadata(context.getSession(), cubeTableHandle, avgCubeColHandle);
                cubeColumnHandles.put(originalOutputSymbol, avgCubeColHandle);
                symbolMetadataMap.put(originalOutputSymbol, columnMetadata);
                cubeScanSymbols.add(originalOutputSymbol);
                aggregationColumns.add(new CubeRewriteResult.AggregatorSource(originalOutputSymbol, originalOutputSymbol));
            }
        } else {
            AggregationSignature sumSignature = new AggregationSignature(SUM.getName(), aggSourceColumnName, originalAgg.isDistinct());
            String sumColumnName = cubeMetadata.getColumn(sumSignature).orElseThrow(() -> new PrestoException(CUBE_ERROR, "Cannot find column associated with aggregation " + sumSignature));
            ColumnHandle sumColumnHandle = cubeColumnsMap.get(sumColumnName);
            ColumnMetadata sumColumnMetadata = metadata.getColumnMetadata(context.getSession(), cubeTableHandle, sumColumnHandle);
            Symbol sumColumnSymbol = cubeColumnHandles.entrySet().stream().filter(entry -> sumColumnHandle.equals(entry.getValue())).findAny().map(Map.Entry::getKey).orElse(symbolAllocator.newSymbol("sum_" + aggSourceColumnName + "_" + originalOutputSymbol.getName(), sumColumnMetadata.getType()));
            if (!cubeScanSymbols.contains(sumColumnSymbol)) {
                cubeScanSymbols.add(sumColumnSymbol);
                cubeColumnHandles.put(sumColumnSymbol, sumColumnHandle);
                symbolMetadataMap.put(sumColumnSymbol, sumColumnMetadata);
                aggregationColumns.add(new CubeRewriteResult.AggregatorSource(sumColumnSymbol, sumColumnSymbol));
            }
            AggregationSignature countSignature = new AggregationSignature(COUNT.getName(), aggSourceColumnName, originalAgg.isDistinct());
            String countColumnName = cubeMetadata.getColumn(countSignature).orElseThrow(() -> new PrestoException(CUBE_ERROR, "Cannot find column associated with aggregation " + countSignature));
            ColumnHandle countColumnHandle = cubeColumnsMap.get(countColumnName);
            ColumnMetadata countColumnMetadata = metadata.getColumnMetadata(context.getSession(), cubeTableHandle, countColumnHandle);
            Symbol countColumnSymbol = cubeColumnHandles.entrySet().stream().filter(entry -> countColumnHandle.equals(entry.getValue())).findAny().map(Map.Entry::getKey).orElse(symbolAllocator.newSymbol("count_" + aggSourceColumnName + "_" + originalOutputSymbol.getName(), countColumnMetadata.getType()));
            if (!cubeScanSymbols.contains(countColumnSymbol)) {
                cubeScanSymbols.add(countColumnSymbol);
                cubeColumnHandles.put(countColumnSymbol, countColumnHandle);
                symbolMetadataMap.put(countColumnSymbol, countColumnMetadata);
                aggregationColumns.add(new CubeRewriteResult.AggregatorSource(countColumnSymbol, countColumnSymbol));
            }
            averageAggregationColumns.add(new CubeRewriteResult.AverageAggregatorSource(originalOutputSymbol, sumColumnSymbol, countColumnSymbol));
        }
    }
    // Scan output order is important for partitioned cubes. Otherwise, incorrect results may be produced.
    // Refer: https://gitee.com/openlookeng/hetu-core/issues/I4LAYC
    List<ColumnMetadata> cubeColumnMetadataList = cubeTableMetadata.getColumns();
    List<Symbol> scanOutput = new ArrayList<>(cubeScanSymbols);
    scanOutput.sort(Comparator.comparingInt(outSymbol -> cubeColumnMetadataList.indexOf(symbolMetadataMap.get(outSymbol))));
    cubeScanNode = TableScanNode.newInstance(context.getIdAllocator().getNextId(), cubeTableHandle, scanOutput, cubeColumnHandles, ReuseExchangeOperator.STRATEGY.REUSE_STRATEGY_DEFAULT, new UUID(0, 0), 0, false);
    return new CubeRewriteResult(cubeScanNode, symbolMetadataMap, dimensionSymbols, aggregationColumns, averageAggregationColumns, computeAvgDividingSumByCount);
}
Also used : SymbolAllocator(io.prestosql.spi.SymbolAllocator) LongSupplier(java.util.function.LongSupplier) ConstantExpression(io.prestosql.spi.relation.ConstantExpression) PrestoWarning(io.prestosql.spi.PrestoWarning) TypeProvider(io.prestosql.sql.planner.TypeProvider) SqlParser(io.prestosql.sql.parser.SqlParser) ColumnNotFoundException(io.prestosql.spi.connector.ColumnNotFoundException) AggregationNode(io.prestosql.spi.plan.AggregationNode) CallExpression(io.prestosql.spi.relation.CallExpression) Cast(io.prestosql.sql.tree.Cast) FilterNode(io.prestosql.spi.plan.FilterNode) Pair(org.apache.commons.lang3.tuple.Pair) Map(java.util.Map) OriginalExpressionUtils.castToRowExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToRowExpression) Type(io.prestosql.spi.type.Type) OriginalExpressionUtils.castToExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToExpression) CubeFilter(io.hetu.core.spi.cube.CubeFilter) ENGLISH(java.util.Locale.ENGLISH) Identifier(io.prestosql.sql.tree.Identifier) AggregationNode.singleGroupingSet(io.prestosql.spi.plan.AggregationNode.singleGroupingSet) CubeMetaStore(io.hetu.core.spi.cube.io.CubeMetaStore) PrestoException(io.prestosql.spi.PrestoException) SymbolsExtractor(io.prestosql.sql.planner.SymbolsExtractor) SUM(io.hetu.core.spi.cube.CubeAggregateFunction.SUM) ImmutableMap(com.google.common.collect.ImmutableMap) CubeAggregateFunction(io.hetu.core.spi.cube.CubeAggregateFunction) TableScanNode(io.prestosql.spi.plan.TableScanNode) Set(java.util.Set) PlanNode(io.prestosql.spi.plan.PlanNode) UUID(java.util.UUID) ProjectNode(io.prestosql.spi.plan.ProjectNode) Collectors(java.util.stream.Collectors) Metadata(io.prestosql.metadata.Metadata) String.format(java.lang.String.format) FunctionHandle(io.prestosql.spi.function.FunctionHandle) Objects(java.util.Objects) SymbolUtils.toSymbolReference(io.prestosql.sql.planner.SymbolUtils.toSymbolReference) ReuseExchangeOperator(io.prestosql.spi.operator.ReuseExchangeOperator) List(java.util.List) ExpressionUtils(io.prestosql.sql.ExpressionUtils) LongLiteral(io.prestosql.sql.tree.LongLiteral) SymbolReference(io.prestosql.sql.tree.SymbolReference) AggregationSignature(io.hetu.core.spi.cube.aggregator.AggregationSignature) Optional(java.util.Optional) NOT_SUPPORTED(io.prestosql.spi.StandardErrorCode.NOT_SUPPORTED) TypeSignature(io.prestosql.spi.type.TypeSignature) OriginalExpressionUtils(io.prestosql.sql.relational.OriginalExpressionUtils) Iterables(com.google.common.collect.Iterables) COUNT(io.hetu.core.spi.cube.CubeAggregateFunction.COUNT) TableMetadata(io.prestosql.metadata.TableMetadata) Logger(io.airlift.log.Logger) TypeSignatureProvider(io.prestosql.sql.analyzer.TypeSignatureProvider) Literal(io.prestosql.sql.tree.Literal) HashMap(java.util.HashMap) TableHandle(io.prestosql.spi.metadata.TableHandle) ExpressionTreeRewriter(io.prestosql.sql.tree.ExpressionTreeRewriter) QualifiedObjectName(io.prestosql.spi.connector.QualifiedObjectName) ArrayList(java.util.ArrayList) HashSet(java.util.HashSet) SchemaTableName(io.prestosql.spi.connector.SchemaTableName) Lists(com.google.common.collect.Lists) ArithmeticBinaryExpression(io.prestosql.sql.tree.ArithmeticBinaryExpression) ImmutableList(com.google.common.collect.ImmutableList) ExpressionDomainTranslator(io.prestosql.sql.planner.ExpressionDomainTranslator) BooleanLiteral(io.prestosql.sql.tree.BooleanLiteral) Objects.requireNonNull(java.util.Objects.requireNonNull) Session(io.prestosql.Session) RowExpressionRewriter(io.prestosql.expressions.RowExpressionRewriter) RowExpressionTreeRewriter(io.prestosql.expressions.RowExpressionTreeRewriter) ExpressionRewriter(io.prestosql.sql.tree.ExpressionRewriter) JoinNode(io.prestosql.spi.plan.JoinNode) ParsingOptions(io.prestosql.sql.parser.ParsingOptions) Symbol(io.prestosql.spi.plan.Symbol) AssignmentUtils(io.prestosql.sql.planner.plan.AssignmentUtils) EXPIRED_CUBE(io.prestosql.spi.connector.StandardWarningCode.EXPIRED_CUBE) Assignments(io.prestosql.spi.plan.Assignments) Rule(io.prestosql.sql.planner.iterative.Rule) ColumnMetadata(io.prestosql.spi.connector.ColumnMetadata) TupleDomain(io.prestosql.spi.predicate.TupleDomain) ComparisonExpression(io.prestosql.sql.tree.ComparisonExpression) VariableReferenceExpression(io.prestosql.spi.relation.VariableReferenceExpression) SymbolAllocator(io.prestosql.spi.SymbolAllocator) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) ColumnHandle(io.prestosql.spi.connector.ColumnHandle) CUBE_ERROR(io.prestosql.spi.StandardErrorCode.CUBE_ERROR) RowExpression(io.prestosql.spi.relation.RowExpression) CubeMetadata(io.hetu.core.spi.cube.CubeMetadata) Comparator(java.util.Comparator) Expression(io.prestosql.sql.tree.Expression) ColumnMetadata(io.prestosql.spi.connector.ColumnMetadata) HashMap(java.util.HashMap) Symbol(io.prestosql.spi.plan.Symbol) ArrayList(java.util.ArrayList) PrestoException(io.prestosql.spi.PrestoException) UUID(java.util.UUID) HashSet(java.util.HashSet) ColumnHandle(io.prestosql.spi.connector.ColumnHandle) AggregationSignature(io.hetu.core.spi.cube.aggregator.AggregationSignature) Map(java.util.Map) ImmutableMap(com.google.common.collect.ImmutableMap) HashMap(java.util.HashMap)

Example 5 with AggregationSignature

use of io.hetu.core.spi.cube.aggregator.AggregationSignature in project hetu-core by openlookeng.

the class CubeOptimizer method parseAggregation.

private void parseAggregation() {
    aggregationNode.getAggregations().forEach((outputSymbol, aggregation) -> {
        AggregationSignature signature = getSignature(aggregation);
        originalAggregationsMap.put(signature, outputSymbol);
    });
    boolean hasCountDistinctAgg = originalAggregationsMap.keySet().stream().anyMatch(aggregationSignature -> aggregationSignature.getFunction().equalsIgnoreCase(COUNT.toString()) && aggregationSignature.isDistinct());
    boolean groupByColumnFromSourceTable = true;
    for (Symbol originalAggSymbol : aggregationNode.getGroupingKeys()) {
        ColumnWithTable columnWithTable = originalPlanMappings.get(originalAggSymbol.getName());
        if (columnWithTable.getFQTableName().equalsIgnoreCase(sourceTableMetadata.getQualifiedName().toString())) {
            // add only columns if its of the source table. Join columns are added as well to group by list
            dimensions.add(columnWithTable.getColumnName());
            groupBy.add(columnWithTable.getColumnName());
        } else {
            // Group by contains column other tables also - Not just from sourceTable
            groupByColumnFromSourceTable = false;
        }
    }
    groupByFromOtherTable = !groupByColumnFromSourceTable;
    matchingMetadataList.removeIf(cubeMetadata -> (hasCountDistinctAgg && (groupByFromOtherTable || !(cubeMetadata.getGroup().equals(groupBy)))) || (!cubeMetadata.getGroup().containsAll(groupBy)));
    if (matchingMetadataList.isEmpty()) {
        throw new UnsupportedOperationException("Matching cubes not found. Unable to rewrite");
    }
}
Also used : Symbol(io.prestosql.spi.plan.Symbol) AggregationSignature(io.hetu.core.spi.cube.aggregator.AggregationSignature)

Aggregations

AggregationSignature (io.hetu.core.spi.cube.aggregator.AggregationSignature)7 Session (io.prestosql.Session)6 Metadata (io.prestosql.metadata.Metadata)6 TableMetadata (io.prestosql.metadata.TableMetadata)6 TableHandle (io.prestosql.spi.metadata.TableHandle)6 Symbol (io.prestosql.spi.plan.Symbol)6 ImmutableList (com.google.common.collect.ImmutableList)5 Logger (io.airlift.log.Logger)5 CubeAggregateFunction (io.hetu.core.spi.cube.CubeAggregateFunction)5 CubeMetadata (io.hetu.core.spi.cube.CubeMetadata)5 PrestoException (io.prestosql.spi.PrestoException)5 ColumnMetadata (io.prestosql.spi.connector.ColumnMetadata)5 QualifiedObjectName (io.prestosql.spi.connector.QualifiedObjectName)5 AggregationNode (io.prestosql.spi.plan.AggregationNode)5 FilterNode (io.prestosql.spi.plan.FilterNode)5 PlanNode (io.prestosql.spi.plan.PlanNode)5 ProjectNode (io.prestosql.spi.plan.ProjectNode)5 TableScanNode (io.prestosql.spi.plan.TableScanNode)5 Expression (io.prestosql.sql.tree.Expression)5 ArrayList (java.util.ArrayList)5