Search in sources :

Example 1 with CUBE_ERROR

use of io.prestosql.spi.StandardErrorCode.CUBE_ERROR in project hetu-core by openlookeng.

the class AggregationRewriteWithCube method createScanNode.

public CubeRewriteResult createScanNode(AggregationNode originalAggregationNode, PlanNode filterNode, TableHandle cubeTableHandle, Map<String, ColumnHandle> cubeColumnsMap, List<ColumnMetadata> cubeColumnsMetadata, boolean exactGroupsMatch) {
    Set<Symbol> cubeScanSymbols = new HashSet<>();
    Map<Symbol, ColumnHandle> symbolAssignments = new HashMap<>();
    Set<CubeRewriteResult.DimensionSource> dimensionSymbols = new HashSet<>();
    Set<CubeRewriteResult.AggregatorSource> aggregationColumns = new HashSet<>();
    Set<CubeRewriteResult.AverageAggregatorSource> averageAggregationColumns = new HashSet<>();
    Map<Symbol, ColumnMetadata> symbolMetadataMap = new HashMap<>();
    Map<String, ColumnMetadata> columnMetadataMap = cubeColumnsMetadata.stream().collect(Collectors.toMap(ColumnMetadata::getName, Function.identity()));
    boolean computeAvgDividingSumByCount = true;
    Set<Symbol> filterSymbols = new HashSet<>();
    if (filterNode != null) {
        filterSymbols.addAll(SymbolsExtractor.extractUnique(((FilterNode) filterNode).getPredicate()));
    }
    for (Symbol filterSymbol : filterSymbols) {
        if (symbolMappings.containsKey(filterSymbol.getName()) && symbolMappings.get(filterSymbol.getName()) instanceof ColumnHandle) {
            // output symbol references of the columns in original table
            ColumnHandle originalColumn = (ColumnHandle) symbolMappings.get(filterSymbol.getName());
            ColumnHandle cubeScanColumn = cubeColumnsMap.get(originalColumn.getColumnName());
            ColumnMetadata columnMetadata = columnMetadataMap.get(cubeScanColumn.getColumnName());
            Symbol cubeScanSymbol = symbolAllocator.newSymbol(cubeScanColumn.getColumnName(), columnMetadata.getType());
            cubeScanSymbols.add(cubeScanSymbol);
            symbolAssignments.put(cubeScanSymbol, cubeScanColumn);
            symbolMetadataMap.put(cubeScanSymbol, columnMetadata);
            rewrittenMappings.put(filterSymbol.getName(), cubeScanSymbol);
            dimensionSymbols.add(new CubeRewriteResult.DimensionSource(filterSymbol, cubeScanSymbol));
        }
    }
    for (Symbol originalAggOutputSymbol : originalAggregationNode.getOutputSymbols()) {
        if (symbolMappings.containsKey(originalAggOutputSymbol.getName()) && symbolMappings.get(originalAggOutputSymbol.getName()) instanceof ColumnHandle) {
            // output symbol references of the columns in original table - column part of group by clause
            ColumnHandle originalColumn = (ColumnHandle) symbolMappings.get(originalAggOutputSymbol.getName());
            ColumnHandle cubeScanColumn = cubeColumnsMap.get(originalColumn.getColumnName());
            ColumnMetadata columnMetadata = columnMetadataMap.get(cubeScanColumn.getColumnName());
            if (!symbolAssignments.containsValue(cubeScanColumn)) {
                Symbol cubeScanSymbol = symbolAllocator.newSymbol(cubeScanColumn.getColumnName(), columnMetadata.getType());
                cubeScanSymbols.add(cubeScanSymbol);
                symbolAssignments.put(cubeScanSymbol, cubeScanColumn);
                symbolMetadataMap.put(cubeScanSymbol, columnMetadata);
                rewrittenMappings.put(originalAggOutputSymbol.getName(), cubeScanSymbol);
                dimensionSymbols.add(new CubeRewriteResult.DimensionSource(originalAggOutputSymbol, cubeScanSymbol));
            } else {
                Symbol cubeScanSymbol = symbolAssignments.keySet().stream().filter(key -> cubeScanColumn.equals(symbolAssignments.get(key))).findFirst().get();
                rewrittenMappings.put(originalAggOutputSymbol.getName(), cubeScanSymbol);
                dimensionSymbols.add(new CubeRewriteResult.DimensionSource(originalAggOutputSymbol, cubeScanSymbol));
            }
        } else if (originalAggregationNode.getAggregations().containsKey(originalAggOutputSymbol)) {
            // output symbol is mapped to an aggregation
            AggregationNode.Aggregation aggregation = originalAggregationNode.getAggregations().get(originalAggOutputSymbol);
            String aggFunction = aggregation.getFunctionCall().getDisplayName();
            List<Expression> arguments = aggregation.getArguments() == null ? null : aggregation.getArguments().stream().map(OriginalExpressionUtils::castToExpression).collect(Collectors.toList());
            if (arguments != null && !arguments.isEmpty() && (!(arguments.get(0) instanceof SymbolReference))) {
                log.info("Not a symbol reference in aggregation function. Agg Function = %s, Arguments = %s", aggFunction, arguments);
                continue;
            }
            Object mappedValue = arguments == null || arguments.isEmpty() ? null : symbolMappings.get(((SymbolReference) arguments.get(0)).getName());
            if (mappedValue == null || (mappedValue instanceof LongLiteral && ((LongLiteral) mappedValue).getValue() == 1)) {
                // COUNT aggregation
                if (CubeAggregateFunction.COUNT.getName().equals(aggFunction) && !aggregation.isDistinct()) {
                    // COUNT 1
                    AggregationSignature aggregationSignature = AggregationSignature.count();
                    String cubeColumnName = cubeMetadata.getColumn(aggregationSignature).orElseThrow(() -> new PrestoException(CUBE_ERROR, "Cannot find column associated with aggregation " + aggregationSignature));
                    ColumnHandle cubeColHandle = cubeColumnsMap.get(cubeColumnName);
                    if (!symbolAssignments.containsValue(cubeColHandle)) {
                        ColumnMetadata columnMetadata = columnMetadataMap.get(cubeColHandle.getColumnName());
                        Symbol cubeScanSymbol = symbolAllocator.newSymbol(cubeColHandle.getColumnName(), columnMetadata.getType());
                        cubeScanSymbols.add(cubeScanSymbol);
                        symbolMetadataMap.put(cubeScanSymbol, columnMetadata);
                        symbolAssignments.put(cubeScanSymbol, cubeColHandle);
                        rewrittenMappings.put(originalAggOutputSymbol.getName(), cubeScanSymbol);
                        aggregationColumns.add(new CubeRewriteResult.AggregatorSource(originalAggOutputSymbol, cubeScanSymbol));
                    }
                }
            } else if (mappedValue instanceof ColumnHandle) {
                String originalColumnName = ((ColumnHandle) mappedValue).getColumnName();
                boolean distinct = originalAggregationNode.getAggregations().get(originalAggOutputSymbol).isDistinct();
                switch(aggFunction) {
                    case "min":
                    case "max":
                    case "sum":
                    case "count":
                        AggregationSignature aggregationSignature = new AggregationSignature(aggFunction, originalColumnName, distinct);
                        String cubeColumnName = cubeMetadata.getColumn(aggregationSignature).orElseThrow(() -> new PrestoException(CUBE_ERROR, "Cannot find column associated with aggregation " + aggregationSignature));
                        ColumnHandle cubeColHandle = cubeColumnsMap.get(cubeColumnName);
                        if (!symbolAssignments.containsValue(cubeColHandle)) {
                            ColumnMetadata columnMetadata = columnMetadataMap.get(cubeColHandle.getColumnName());
                            Symbol cubeScanSymbol = symbolAllocator.newSymbol(cubeColHandle.getColumnName(), columnMetadata.getType());
                            cubeScanSymbols.add(cubeScanSymbol);
                            symbolAssignments.put(cubeScanSymbol, cubeColHandle);
                            symbolMetadataMap.put(cubeScanSymbol, columnMetadata);
                            rewrittenMappings.put(originalAggOutputSymbol.getName(), cubeScanSymbol);
                            aggregationColumns.add(new CubeRewriteResult.AggregatorSource(originalAggOutputSymbol, cubeScanSymbol));
                        } else {
                            ColumnMetadata columnMetadata = columnMetadataMap.get(cubeColHandle.getColumnName());
                            Symbol cubeScanSymbol = symbolAssignments.keySet().stream().filter(key -> cubeColHandle.equals(symbolAssignments.get(key))).findFirst().get();
                            cubeScanSymbols.add(cubeScanSymbol);
                            symbolAssignments.put(cubeScanSymbol, cubeColHandle);
                            symbolMetadataMap.put(cubeScanSymbol, columnMetadata);
                            rewrittenMappings.put(originalAggOutputSymbol.getName(), cubeScanSymbol);
                            aggregationColumns.add(new CubeRewriteResult.AggregatorSource(originalAggOutputSymbol, cubeScanSymbol));
                        }
                        break;
                    case "avg":
                        AggregationSignature avgAggregationSignature = new AggregationSignature(aggFunction, originalColumnName, distinct);
                        if (exactGroupsMatch && cubeMetadata.getColumn(avgAggregationSignature).isPresent()) {
                            computeAvgDividingSumByCount = false;
                            String avgCubeColumnName = cubeMetadata.getColumn(avgAggregationSignature).orElseThrow(() -> new PrestoException(CUBE_ERROR, "Cannot find column associated with aggregation " + avgAggregationSignature));
                            ColumnHandle avgCubeColHandle = cubeColumnsMap.get(avgCubeColumnName);
                            if (!symbolAssignments.containsValue(avgCubeColHandle)) {
                                ColumnMetadata columnMetadata = columnMetadataMap.get(avgCubeColHandle.getColumnName());
                                Symbol cubeScanSymbol = symbolAllocator.newSymbol(avgCubeColHandle.getColumnName(), columnMetadata.getType());
                                cubeScanSymbols.add(cubeScanSymbol);
                                symbolAssignments.put(cubeScanSymbol, avgCubeColHandle);
                                symbolMetadataMap.put(cubeScanSymbol, columnMetadata);
                                rewrittenMappings.put(originalAggOutputSymbol.getName(), cubeScanSymbol);
                                aggregationColumns.add(new CubeRewriteResult.AggregatorSource(originalAggOutputSymbol, cubeScanSymbol));
                            }
                        } else {
                            AggregationSignature sumSignature = new AggregationSignature(SUM.getName(), originalColumnName, distinct);
                            String sumColumnName = cubeMetadata.getColumn(sumSignature).orElseThrow(() -> new PrestoException(CUBE_ERROR, "Cannot find column associated with aggregation " + sumSignature));
                            ColumnHandle sumColumnHandle = cubeColumnsMap.get(sumColumnName);
                            Symbol sumSymbol = null;
                            if (!symbolAssignments.containsValue(sumColumnHandle)) {
                                ColumnMetadata columnMetadata = columnMetadataMap.get(sumColumnHandle.getColumnName());
                                sumSymbol = symbolAllocator.newSymbol("sum_" + originalColumnName + "_" + originalAggOutputSymbol.getName(), columnMetadata.getType());
                                cubeScanSymbols.add(sumSymbol);
                                symbolAssignments.put(sumSymbol, sumColumnHandle);
                                symbolMetadataMap.put(sumSymbol, columnMetadata);
                                rewrittenMappings.put(sumSymbol.getName(), sumSymbol);
                                aggregationColumns.add(new CubeRewriteResult.AggregatorSource(sumSymbol, sumSymbol));
                            } else {
                                for (Map.Entry<Symbol, ColumnHandle> assignment : symbolAssignments.entrySet()) {
                                    if (assignment.getValue().equals(sumColumnHandle)) {
                                        sumSymbol = assignment.getKey();
                                        break;
                                    }
                                }
                            }
                            AggregationSignature countSignature = new AggregationSignature(COUNT.getName(), originalColumnName, distinct);
                            String countColumnName = cubeMetadata.getColumn(countSignature).orElseThrow(() -> new PrestoException(CUBE_ERROR, "Cannot find column associated with aggregation " + countSignature));
                            ColumnHandle countColumnHandle = cubeColumnsMap.get(countColumnName);
                            Symbol countSymbol = null;
                            if (!symbolAssignments.containsValue(countColumnHandle)) {
                                ColumnMetadata columnMetadata = columnMetadataMap.get(countColumnHandle.getColumnName());
                                countSymbol = symbolAllocator.newSymbol("count_" + originalColumnName + "_" + originalAggOutputSymbol.getName(), columnMetadata.getType());
                                cubeScanSymbols.add(countSymbol);
                                symbolAssignments.put(countSymbol, countColumnHandle);
                                symbolMetadataMap.put(countSymbol, columnMetadata);
                                rewrittenMappings.put(countSymbol.getName(), countSymbol);
                                aggregationColumns.add(new CubeRewriteResult.AggregatorSource(countSymbol, countSymbol));
                            } else {
                                for (Map.Entry<Symbol, ColumnHandle> assignment : symbolAssignments.entrySet()) {
                                    if (assignment.getValue().equals(countColumnHandle)) {
                                        countSymbol = assignment.getKey();
                                        break;
                                    }
                                }
                            }
                            averageAggregationColumns.add(new CubeRewriteResult.AverageAggregatorSource(originalAggOutputSymbol, sumSymbol, countSymbol));
                        }
                        break;
                    default:
                        throw new PrestoException(StandardErrorCode.GENERIC_INTERNAL_ERROR, "Unsupported aggregation function " + aggFunction);
                }
            } else {
                log.info("Aggregation function argument is not a Column Handle. Agg Function = %s, Arguments = %s", aggFunction, arguments);
            }
        }
    }
    // Scan output order is important for partitioned cubes. Otherwise, incorrect results may be produced.
    // Refer: https://gitee.com/openlookeng/hetu-core/issues/I4LAYC
    List<Symbol> scanOutput = new ArrayList<>(cubeScanSymbols);
    scanOutput.sort(Comparator.comparingInt(outSymbol -> cubeColumnsMetadata.indexOf(symbolMetadataMap.get(outSymbol))));
    TableScanNode tableScanNode = TableScanNode.newInstance(idAllocator.getNextId(), cubeTableHandle, scanOutput, symbolAssignments, ReuseExchangeOperator.STRATEGY.REUSE_STRATEGY_DEFAULT, new UUID(0, 0), 0, false);
    return new CubeRewriteResult(tableScanNode, symbolMetadataMap, dimensionSymbols, aggregationColumns, averageAggregationColumns, computeAvgDividingSumByCount);
}
Also used : TypeProvider(io.prestosql.sql.planner.TypeProvider) ColumnNotFoundException(io.prestosql.spi.connector.ColumnNotFoundException) AggregationNode(io.prestosql.spi.plan.AggregationNode) CallExpression(io.prestosql.spi.relation.CallExpression) Cast(io.prestosql.sql.tree.Cast) FilterNode(io.prestosql.spi.plan.FilterNode) Map(java.util.Map) OriginalExpressionUtils.castToRowExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToRowExpression) Type(io.prestosql.spi.type.Type) OriginalExpressionUtils.castToExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToExpression) AggregationNode.singleGroupingSet(io.prestosql.spi.plan.AggregationNode.singleGroupingSet) PrestoException(io.prestosql.spi.PrestoException) SymbolsExtractor(io.prestosql.sql.planner.SymbolsExtractor) SUM(io.hetu.core.spi.cube.CubeAggregateFunction.SUM) ImmutableMap(com.google.common.collect.ImmutableMap) CubeAggregateFunction(io.hetu.core.spi.cube.CubeAggregateFunction) TableScanNode(io.prestosql.spi.plan.TableScanNode) Set(java.util.Set) PlanNode(io.prestosql.spi.plan.PlanNode) UUID(java.util.UUID) ProjectNode(io.prestosql.spi.plan.ProjectNode) Collectors(java.util.stream.Collectors) Metadata(io.prestosql.metadata.Metadata) FunctionHandle(io.prestosql.spi.function.FunctionHandle) PlanSymbolAllocator(io.prestosql.sql.planner.PlanSymbolAllocator) SymbolUtils.toSymbolReference(io.prestosql.sql.planner.SymbolUtils.toSymbolReference) ReuseExchangeOperator(io.prestosql.spi.operator.ReuseExchangeOperator) List(java.util.List) LongLiteral(io.prestosql.sql.tree.LongLiteral) SymbolReference(io.prestosql.sql.tree.SymbolReference) AggregationSignature(io.hetu.core.spi.cube.aggregator.AggregationSignature) Optional(java.util.Optional) StandardErrorCode(io.prestosql.spi.StandardErrorCode) OriginalExpressionUtils(io.prestosql.sql.relational.OriginalExpressionUtils) COUNT(io.hetu.core.spi.cube.CubeAggregateFunction.COUNT) TableMetadata(io.prestosql.metadata.TableMetadata) Logger(io.airlift.log.Logger) TypeSignatureProvider(io.prestosql.sql.analyzer.TypeSignatureProvider) HashMap(java.util.HashMap) TableHandle(io.prestosql.spi.metadata.TableHandle) Function(java.util.function.Function) ExpressionTreeRewriter(io.prestosql.sql.tree.ExpressionTreeRewriter) QualifiedObjectName(io.prestosql.spi.connector.QualifiedObjectName) ArrayList(java.util.ArrayList) HashSet(java.util.HashSet) SchemaTableName(io.prestosql.spi.connector.SchemaTableName) ArithmeticBinaryExpression(io.prestosql.sql.tree.ArithmeticBinaryExpression) ImmutableList(com.google.common.collect.ImmutableList) Session(io.prestosql.Session) ExpressionRewriter(io.prestosql.sql.tree.ExpressionRewriter) Symbol(io.prestosql.spi.plan.Symbol) Assignments(io.prestosql.spi.plan.Assignments) ColumnMetadata(io.prestosql.spi.connector.ColumnMetadata) SymbolAllocator(io.prestosql.spi.SymbolAllocator) ColumnHandle(io.prestosql.spi.connector.ColumnHandle) CUBE_ERROR(io.prestosql.spi.StandardErrorCode.CUBE_ERROR) PlanNodeIdAllocator(io.prestosql.spi.plan.PlanNodeIdAllocator) CubeMetadata(io.hetu.core.spi.cube.CubeMetadata) CubeNotFoundException(io.prestosql.spi.connector.CubeNotFoundException) Comparator(java.util.Comparator) Expression(io.prestosql.sql.tree.Expression) ColumnMetadata(io.prestosql.spi.connector.ColumnMetadata) HashMap(java.util.HashMap) Symbol(io.prestosql.spi.plan.Symbol) SymbolUtils.toSymbolReference(io.prestosql.sql.planner.SymbolUtils.toSymbolReference) SymbolReference(io.prestosql.sql.tree.SymbolReference) FilterNode(io.prestosql.spi.plan.FilterNode) ArrayList(java.util.ArrayList) PrestoException(io.prestosql.spi.PrestoException) List(java.util.List) ArrayList(java.util.ArrayList) ImmutableList(com.google.common.collect.ImmutableList) UUID(java.util.UUID) HashSet(java.util.HashSet) ColumnHandle(io.prestosql.spi.connector.ColumnHandle) LongLiteral(io.prestosql.sql.tree.LongLiteral) AggregationSignature(io.hetu.core.spi.cube.aggregator.AggregationSignature) TableScanNode(io.prestosql.spi.plan.TableScanNode) Map(java.util.Map) ImmutableMap(com.google.common.collect.ImmutableMap) HashMap(java.util.HashMap)

Example 2 with CUBE_ERROR

use of io.prestosql.spi.StandardErrorCode.CUBE_ERROR in project hetu-core by openlookeng.

the class CubeOptimizer method createScanNode.

public CubeRewriteResult createScanNode() {
    Set<Symbol> cubeScanSymbols = new HashSet<>();
    Map<Symbol, ColumnHandle> cubeColumnHandles = new HashMap<>();
    Set<CubeRewriteResult.DimensionSource> dimensionSymbols = new HashSet<>();
    Set<CubeRewriteResult.AggregatorSource> aggregationColumns = new HashSet<>();
    Set<CubeRewriteResult.AverageAggregatorSource> averageAggregationColumns = new HashSet<>();
    Map<Symbol, ColumnMetadata> symbolMetadataMap = new HashMap<>();
    SymbolAllocator symbolAllocator = context.getSymbolAllocator();
    dimensions.forEach(dimension -> {
        // assumed that source table dimension is same in cube table as well
        ColumnWithTable sourceColumn = originalPlanMappings.get(dimension);
        ColumnHandle cubeColumn = cubeColumnsMap.get(sourceColumn.getColumnName());
        ColumnMetadata columnMetadata = metadata.getColumnMetadata(context.getSession(), cubeTableHandle, cubeColumn);
        Symbol cubeScanSymbol = symbolAllocator.newSymbol(cubeColumn.getColumnName(), columnMetadata.getType());
        cubeScanSymbols.add(cubeScanSymbol);
        cubeColumnHandles.put(cubeScanSymbol, cubeColumn);
        dimensionSymbols.add(new CubeRewriteResult.DimensionSource(cubeScanSymbol, cubeScanSymbol));
    });
    filterColumns.forEach(filterColumn -> {
        ColumnHandle cubeColumn = cubeColumnsMap.get(filterColumn);
        if (cubeColumn == null) {
            // Predicate rewrite will take care of remove those columns from predicate
            return;
        }
        ColumnMetadata columnMetadata = metadata.getColumnMetadata(context.getSession(), cubeTableHandle, cubeColumn);
        Symbol cubeScanSymbol = symbolAllocator.newSymbol(cubeColumn.getColumnName(), columnMetadata.getType());
        cubeScanSymbols.add(cubeScanSymbol);
        cubeColumnHandles.put(cubeScanSymbol, cubeColumn);
        dimensionSymbols.add(new CubeRewriteResult.DimensionSource(cubeScanSymbol, cubeScanSymbol));
    });
    // process all aggregations except avg.
    originalAggregationsMap.forEach((originalAgg, originalAggOutputSymbol) -> {
        if (originalAgg.getFunction().equalsIgnoreCase(CubeAggregateFunction.AVG.toString())) {
            // skip average aggregation
            return;
        }
        String cubeColumnName = cubeMetadata.getColumn(originalAgg).orElseThrow(() -> new PrestoException(CUBE_ERROR, "Cannot find column associated with aggregation " + originalAgg));
        ColumnHandle cubeColHandle = cubeColumnsMap.get(cubeColumnName);
        ColumnMetadata columnMetadata = metadata.getColumnMetadata(context.getSession(), cubeTableHandle, cubeColHandle);
        Symbol cubeScanSymbol = symbolAllocator.newSymbol(cubeColHandle.getColumnName(), columnMetadata.getType());
        cubeScanSymbols.add(cubeScanSymbol);
        cubeColumnHandles.put(cubeScanSymbol, cubeColHandle);
        symbolMetadataMap.put(originalAggOutputSymbol, columnMetadata);
        aggregationColumns.add(new CubeRewriteResult.AggregatorSource(originalAggOutputSymbol, cubeScanSymbol));
    });
    boolean computeAvgDividingSumByCount = true;
    // process average aggregation
    for (AggregationSignature originalAgg : originalAggregationsMap.keySet()) {
        if (!originalAgg.getFunction().equalsIgnoreCase(CubeAggregateFunction.AVG.toString())) {
            // process only average aggregation. skip others
            continue;
        }
        Symbol originalOutputSymbol = originalAggregationsMap.get(originalAgg);
        String aggSourceColumnName = originalAgg.getDimension();
        AggregationSignature avgAggregationSignature = new AggregationSignature(originalAgg.getFunction(), aggSourceColumnName, originalAgg.isDistinct());
        // exactGroupByMatch is True iff query group by clause by does not contain columns from other table. This is extremely unlikely for Join queries
        boolean exactGroupByMatch = !groupByFromOtherTable && groupBy.size() == cubeMetadata.getGroup().size() && groupBy.stream().map(String::toLowerCase).allMatch(cubeMetadata.getGroup()::contains);
        if (exactGroupByMatch && cubeMetadata.getColumn(avgAggregationSignature).isPresent()) {
            // Use AVG column for exact group match only if Cube has it. The Cubes created before shipping
            // this optimization will not have AVG column
            computeAvgDividingSumByCount = false;
            String avgCubeColumnName = cubeMetadata.getColumn(avgAggregationSignature).orElseThrow(() -> new PrestoException(CUBE_ERROR, "Cannot find column associated with aggregation " + avgAggregationSignature));
            ColumnHandle avgCubeColHandle = cubeColumnsMap.get(avgCubeColumnName);
            if (!cubeColumnHandles.containsValue(avgCubeColHandle)) {
                ColumnMetadata columnMetadata = metadata.getColumnMetadata(context.getSession(), cubeTableHandle, avgCubeColHandle);
                cubeColumnHandles.put(originalOutputSymbol, avgCubeColHandle);
                symbolMetadataMap.put(originalOutputSymbol, columnMetadata);
                cubeScanSymbols.add(originalOutputSymbol);
                aggregationColumns.add(new CubeRewriteResult.AggregatorSource(originalOutputSymbol, originalOutputSymbol));
            }
        } else {
            AggregationSignature sumSignature = new AggregationSignature(SUM.getName(), aggSourceColumnName, originalAgg.isDistinct());
            String sumColumnName = cubeMetadata.getColumn(sumSignature).orElseThrow(() -> new PrestoException(CUBE_ERROR, "Cannot find column associated with aggregation " + sumSignature));
            ColumnHandle sumColumnHandle = cubeColumnsMap.get(sumColumnName);
            ColumnMetadata sumColumnMetadata = metadata.getColumnMetadata(context.getSession(), cubeTableHandle, sumColumnHandle);
            Symbol sumColumnSymbol = cubeColumnHandles.entrySet().stream().filter(entry -> sumColumnHandle.equals(entry.getValue())).findAny().map(Map.Entry::getKey).orElse(symbolAllocator.newSymbol("sum_" + aggSourceColumnName + "_" + originalOutputSymbol.getName(), sumColumnMetadata.getType()));
            if (!cubeScanSymbols.contains(sumColumnSymbol)) {
                cubeScanSymbols.add(sumColumnSymbol);
                cubeColumnHandles.put(sumColumnSymbol, sumColumnHandle);
                symbolMetadataMap.put(sumColumnSymbol, sumColumnMetadata);
                aggregationColumns.add(new CubeRewriteResult.AggregatorSource(sumColumnSymbol, sumColumnSymbol));
            }
            AggregationSignature countSignature = new AggregationSignature(COUNT.getName(), aggSourceColumnName, originalAgg.isDistinct());
            String countColumnName = cubeMetadata.getColumn(countSignature).orElseThrow(() -> new PrestoException(CUBE_ERROR, "Cannot find column associated with aggregation " + countSignature));
            ColumnHandle countColumnHandle = cubeColumnsMap.get(countColumnName);
            ColumnMetadata countColumnMetadata = metadata.getColumnMetadata(context.getSession(), cubeTableHandle, countColumnHandle);
            Symbol countColumnSymbol = cubeColumnHandles.entrySet().stream().filter(entry -> countColumnHandle.equals(entry.getValue())).findAny().map(Map.Entry::getKey).orElse(symbolAllocator.newSymbol("count_" + aggSourceColumnName + "_" + originalOutputSymbol.getName(), countColumnMetadata.getType()));
            if (!cubeScanSymbols.contains(countColumnSymbol)) {
                cubeScanSymbols.add(countColumnSymbol);
                cubeColumnHandles.put(countColumnSymbol, countColumnHandle);
                symbolMetadataMap.put(countColumnSymbol, countColumnMetadata);
                aggregationColumns.add(new CubeRewriteResult.AggregatorSource(countColumnSymbol, countColumnSymbol));
            }
            averageAggregationColumns.add(new CubeRewriteResult.AverageAggregatorSource(originalOutputSymbol, sumColumnSymbol, countColumnSymbol));
        }
    }
    // Scan output order is important for partitioned cubes. Otherwise, incorrect results may be produced.
    // Refer: https://gitee.com/openlookeng/hetu-core/issues/I4LAYC
    List<ColumnMetadata> cubeColumnMetadataList = cubeTableMetadata.getColumns();
    List<Symbol> scanOutput = new ArrayList<>(cubeScanSymbols);
    scanOutput.sort(Comparator.comparingInt(outSymbol -> cubeColumnMetadataList.indexOf(symbolMetadataMap.get(outSymbol))));
    cubeScanNode = TableScanNode.newInstance(context.getIdAllocator().getNextId(), cubeTableHandle, scanOutput, cubeColumnHandles, ReuseExchangeOperator.STRATEGY.REUSE_STRATEGY_DEFAULT, new UUID(0, 0), 0, false);
    return new CubeRewriteResult(cubeScanNode, symbolMetadataMap, dimensionSymbols, aggregationColumns, averageAggregationColumns, computeAvgDividingSumByCount);
}
Also used : SymbolAllocator(io.prestosql.spi.SymbolAllocator) LongSupplier(java.util.function.LongSupplier) ConstantExpression(io.prestosql.spi.relation.ConstantExpression) PrestoWarning(io.prestosql.spi.PrestoWarning) TypeProvider(io.prestosql.sql.planner.TypeProvider) SqlParser(io.prestosql.sql.parser.SqlParser) ColumnNotFoundException(io.prestosql.spi.connector.ColumnNotFoundException) AggregationNode(io.prestosql.spi.plan.AggregationNode) CallExpression(io.prestosql.spi.relation.CallExpression) Cast(io.prestosql.sql.tree.Cast) FilterNode(io.prestosql.spi.plan.FilterNode) Pair(org.apache.commons.lang3.tuple.Pair) Map(java.util.Map) OriginalExpressionUtils.castToRowExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToRowExpression) Type(io.prestosql.spi.type.Type) OriginalExpressionUtils.castToExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToExpression) CubeFilter(io.hetu.core.spi.cube.CubeFilter) ENGLISH(java.util.Locale.ENGLISH) Identifier(io.prestosql.sql.tree.Identifier) AggregationNode.singleGroupingSet(io.prestosql.spi.plan.AggregationNode.singleGroupingSet) CubeMetaStore(io.hetu.core.spi.cube.io.CubeMetaStore) PrestoException(io.prestosql.spi.PrestoException) SymbolsExtractor(io.prestosql.sql.planner.SymbolsExtractor) SUM(io.hetu.core.spi.cube.CubeAggregateFunction.SUM) ImmutableMap(com.google.common.collect.ImmutableMap) CubeAggregateFunction(io.hetu.core.spi.cube.CubeAggregateFunction) TableScanNode(io.prestosql.spi.plan.TableScanNode) Set(java.util.Set) PlanNode(io.prestosql.spi.plan.PlanNode) UUID(java.util.UUID) ProjectNode(io.prestosql.spi.plan.ProjectNode) Collectors(java.util.stream.Collectors) Metadata(io.prestosql.metadata.Metadata) String.format(java.lang.String.format) FunctionHandle(io.prestosql.spi.function.FunctionHandle) Objects(java.util.Objects) SymbolUtils.toSymbolReference(io.prestosql.sql.planner.SymbolUtils.toSymbolReference) ReuseExchangeOperator(io.prestosql.spi.operator.ReuseExchangeOperator) List(java.util.List) ExpressionUtils(io.prestosql.sql.ExpressionUtils) LongLiteral(io.prestosql.sql.tree.LongLiteral) SymbolReference(io.prestosql.sql.tree.SymbolReference) AggregationSignature(io.hetu.core.spi.cube.aggregator.AggregationSignature) Optional(java.util.Optional) NOT_SUPPORTED(io.prestosql.spi.StandardErrorCode.NOT_SUPPORTED) TypeSignature(io.prestosql.spi.type.TypeSignature) OriginalExpressionUtils(io.prestosql.sql.relational.OriginalExpressionUtils) Iterables(com.google.common.collect.Iterables) COUNT(io.hetu.core.spi.cube.CubeAggregateFunction.COUNT) TableMetadata(io.prestosql.metadata.TableMetadata) Logger(io.airlift.log.Logger) TypeSignatureProvider(io.prestosql.sql.analyzer.TypeSignatureProvider) Literal(io.prestosql.sql.tree.Literal) HashMap(java.util.HashMap) TableHandle(io.prestosql.spi.metadata.TableHandle) ExpressionTreeRewriter(io.prestosql.sql.tree.ExpressionTreeRewriter) QualifiedObjectName(io.prestosql.spi.connector.QualifiedObjectName) ArrayList(java.util.ArrayList) HashSet(java.util.HashSet) SchemaTableName(io.prestosql.spi.connector.SchemaTableName) Lists(com.google.common.collect.Lists) ArithmeticBinaryExpression(io.prestosql.sql.tree.ArithmeticBinaryExpression) ImmutableList(com.google.common.collect.ImmutableList) ExpressionDomainTranslator(io.prestosql.sql.planner.ExpressionDomainTranslator) BooleanLiteral(io.prestosql.sql.tree.BooleanLiteral) Objects.requireNonNull(java.util.Objects.requireNonNull) Session(io.prestosql.Session) RowExpressionRewriter(io.prestosql.expressions.RowExpressionRewriter) RowExpressionTreeRewriter(io.prestosql.expressions.RowExpressionTreeRewriter) ExpressionRewriter(io.prestosql.sql.tree.ExpressionRewriter) JoinNode(io.prestosql.spi.plan.JoinNode) ParsingOptions(io.prestosql.sql.parser.ParsingOptions) Symbol(io.prestosql.spi.plan.Symbol) AssignmentUtils(io.prestosql.sql.planner.plan.AssignmentUtils) EXPIRED_CUBE(io.prestosql.spi.connector.StandardWarningCode.EXPIRED_CUBE) Assignments(io.prestosql.spi.plan.Assignments) Rule(io.prestosql.sql.planner.iterative.Rule) ColumnMetadata(io.prestosql.spi.connector.ColumnMetadata) TupleDomain(io.prestosql.spi.predicate.TupleDomain) ComparisonExpression(io.prestosql.sql.tree.ComparisonExpression) VariableReferenceExpression(io.prestosql.spi.relation.VariableReferenceExpression) SymbolAllocator(io.prestosql.spi.SymbolAllocator) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) ColumnHandle(io.prestosql.spi.connector.ColumnHandle) CUBE_ERROR(io.prestosql.spi.StandardErrorCode.CUBE_ERROR) RowExpression(io.prestosql.spi.relation.RowExpression) CubeMetadata(io.hetu.core.spi.cube.CubeMetadata) Comparator(java.util.Comparator) Expression(io.prestosql.sql.tree.Expression) ColumnMetadata(io.prestosql.spi.connector.ColumnMetadata) HashMap(java.util.HashMap) Symbol(io.prestosql.spi.plan.Symbol) ArrayList(java.util.ArrayList) PrestoException(io.prestosql.spi.PrestoException) UUID(java.util.UUID) HashSet(java.util.HashSet) ColumnHandle(io.prestosql.spi.connector.ColumnHandle) AggregationSignature(io.hetu.core.spi.cube.aggregator.AggregationSignature) Map(java.util.Map) ImmutableMap(com.google.common.collect.ImmutableMap) HashMap(java.util.HashMap)

Aggregations

ImmutableList (com.google.common.collect.ImmutableList)2 ImmutableMap (com.google.common.collect.ImmutableMap)2 Logger (io.airlift.log.Logger)2 CubeAggregateFunction (io.hetu.core.spi.cube.CubeAggregateFunction)2 COUNT (io.hetu.core.spi.cube.CubeAggregateFunction.COUNT)2 SUM (io.hetu.core.spi.cube.CubeAggregateFunction.SUM)2 CubeMetadata (io.hetu.core.spi.cube.CubeMetadata)2 AggregationSignature (io.hetu.core.spi.cube.aggregator.AggregationSignature)2 Session (io.prestosql.Session)2 Metadata (io.prestosql.metadata.Metadata)2 TableMetadata (io.prestosql.metadata.TableMetadata)2 PrestoException (io.prestosql.spi.PrestoException)2 CUBE_ERROR (io.prestosql.spi.StandardErrorCode.CUBE_ERROR)2 SymbolAllocator (io.prestosql.spi.SymbolAllocator)2 ColumnHandle (io.prestosql.spi.connector.ColumnHandle)2 ColumnMetadata (io.prestosql.spi.connector.ColumnMetadata)2 ColumnNotFoundException (io.prestosql.spi.connector.ColumnNotFoundException)2 QualifiedObjectName (io.prestosql.spi.connector.QualifiedObjectName)2 SchemaTableName (io.prestosql.spi.connector.SchemaTableName)2 FunctionHandle (io.prestosql.spi.function.FunctionHandle)2