use of io.prestosql.spi.StandardErrorCode.CUBE_ERROR in project hetu-core by openlookeng.
the class AggregationRewriteWithCube method createScanNode.
public CubeRewriteResult createScanNode(AggregationNode originalAggregationNode, PlanNode filterNode, TableHandle cubeTableHandle, Map<String, ColumnHandle> cubeColumnsMap, List<ColumnMetadata> cubeColumnsMetadata, boolean exactGroupsMatch) {
Set<Symbol> cubeScanSymbols = new HashSet<>();
Map<Symbol, ColumnHandle> symbolAssignments = new HashMap<>();
Set<CubeRewriteResult.DimensionSource> dimensionSymbols = new HashSet<>();
Set<CubeRewriteResult.AggregatorSource> aggregationColumns = new HashSet<>();
Set<CubeRewriteResult.AverageAggregatorSource> averageAggregationColumns = new HashSet<>();
Map<Symbol, ColumnMetadata> symbolMetadataMap = new HashMap<>();
Map<String, ColumnMetadata> columnMetadataMap = cubeColumnsMetadata.stream().collect(Collectors.toMap(ColumnMetadata::getName, Function.identity()));
boolean computeAvgDividingSumByCount = true;
Set<Symbol> filterSymbols = new HashSet<>();
if (filterNode != null) {
filterSymbols.addAll(SymbolsExtractor.extractUnique(((FilterNode) filterNode).getPredicate()));
}
for (Symbol filterSymbol : filterSymbols) {
if (symbolMappings.containsKey(filterSymbol.getName()) && symbolMappings.get(filterSymbol.getName()) instanceof ColumnHandle) {
// output symbol references of the columns in original table
ColumnHandle originalColumn = (ColumnHandle) symbolMappings.get(filterSymbol.getName());
ColumnHandle cubeScanColumn = cubeColumnsMap.get(originalColumn.getColumnName());
ColumnMetadata columnMetadata = columnMetadataMap.get(cubeScanColumn.getColumnName());
Symbol cubeScanSymbol = symbolAllocator.newSymbol(cubeScanColumn.getColumnName(), columnMetadata.getType());
cubeScanSymbols.add(cubeScanSymbol);
symbolAssignments.put(cubeScanSymbol, cubeScanColumn);
symbolMetadataMap.put(cubeScanSymbol, columnMetadata);
rewrittenMappings.put(filterSymbol.getName(), cubeScanSymbol);
dimensionSymbols.add(new CubeRewriteResult.DimensionSource(filterSymbol, cubeScanSymbol));
}
}
for (Symbol originalAggOutputSymbol : originalAggregationNode.getOutputSymbols()) {
if (symbolMappings.containsKey(originalAggOutputSymbol.getName()) && symbolMappings.get(originalAggOutputSymbol.getName()) instanceof ColumnHandle) {
// output symbol references of the columns in original table - column part of group by clause
ColumnHandle originalColumn = (ColumnHandle) symbolMappings.get(originalAggOutputSymbol.getName());
ColumnHandle cubeScanColumn = cubeColumnsMap.get(originalColumn.getColumnName());
ColumnMetadata columnMetadata = columnMetadataMap.get(cubeScanColumn.getColumnName());
if (!symbolAssignments.containsValue(cubeScanColumn)) {
Symbol cubeScanSymbol = symbolAllocator.newSymbol(cubeScanColumn.getColumnName(), columnMetadata.getType());
cubeScanSymbols.add(cubeScanSymbol);
symbolAssignments.put(cubeScanSymbol, cubeScanColumn);
symbolMetadataMap.put(cubeScanSymbol, columnMetadata);
rewrittenMappings.put(originalAggOutputSymbol.getName(), cubeScanSymbol);
dimensionSymbols.add(new CubeRewriteResult.DimensionSource(originalAggOutputSymbol, cubeScanSymbol));
} else {
Symbol cubeScanSymbol = symbolAssignments.keySet().stream().filter(key -> cubeScanColumn.equals(symbolAssignments.get(key))).findFirst().get();
rewrittenMappings.put(originalAggOutputSymbol.getName(), cubeScanSymbol);
dimensionSymbols.add(new CubeRewriteResult.DimensionSource(originalAggOutputSymbol, cubeScanSymbol));
}
} else if (originalAggregationNode.getAggregations().containsKey(originalAggOutputSymbol)) {
// output symbol is mapped to an aggregation
AggregationNode.Aggregation aggregation = originalAggregationNode.getAggregations().get(originalAggOutputSymbol);
String aggFunction = aggregation.getFunctionCall().getDisplayName();
List<Expression> arguments = aggregation.getArguments() == null ? null : aggregation.getArguments().stream().map(OriginalExpressionUtils::castToExpression).collect(Collectors.toList());
if (arguments != null && !arguments.isEmpty() && (!(arguments.get(0) instanceof SymbolReference))) {
log.info("Not a symbol reference in aggregation function. Agg Function = %s, Arguments = %s", aggFunction, arguments);
continue;
}
Object mappedValue = arguments == null || arguments.isEmpty() ? null : symbolMappings.get(((SymbolReference) arguments.get(0)).getName());
if (mappedValue == null || (mappedValue instanceof LongLiteral && ((LongLiteral) mappedValue).getValue() == 1)) {
// COUNT aggregation
if (CubeAggregateFunction.COUNT.getName().equals(aggFunction) && !aggregation.isDistinct()) {
// COUNT 1
AggregationSignature aggregationSignature = AggregationSignature.count();
String cubeColumnName = cubeMetadata.getColumn(aggregationSignature).orElseThrow(() -> new PrestoException(CUBE_ERROR, "Cannot find column associated with aggregation " + aggregationSignature));
ColumnHandle cubeColHandle = cubeColumnsMap.get(cubeColumnName);
if (!symbolAssignments.containsValue(cubeColHandle)) {
ColumnMetadata columnMetadata = columnMetadataMap.get(cubeColHandle.getColumnName());
Symbol cubeScanSymbol = symbolAllocator.newSymbol(cubeColHandle.getColumnName(), columnMetadata.getType());
cubeScanSymbols.add(cubeScanSymbol);
symbolMetadataMap.put(cubeScanSymbol, columnMetadata);
symbolAssignments.put(cubeScanSymbol, cubeColHandle);
rewrittenMappings.put(originalAggOutputSymbol.getName(), cubeScanSymbol);
aggregationColumns.add(new CubeRewriteResult.AggregatorSource(originalAggOutputSymbol, cubeScanSymbol));
}
}
} else if (mappedValue instanceof ColumnHandle) {
String originalColumnName = ((ColumnHandle) mappedValue).getColumnName();
boolean distinct = originalAggregationNode.getAggregations().get(originalAggOutputSymbol).isDistinct();
switch(aggFunction) {
case "min":
case "max":
case "sum":
case "count":
AggregationSignature aggregationSignature = new AggregationSignature(aggFunction, originalColumnName, distinct);
String cubeColumnName = cubeMetadata.getColumn(aggregationSignature).orElseThrow(() -> new PrestoException(CUBE_ERROR, "Cannot find column associated with aggregation " + aggregationSignature));
ColumnHandle cubeColHandle = cubeColumnsMap.get(cubeColumnName);
if (!symbolAssignments.containsValue(cubeColHandle)) {
ColumnMetadata columnMetadata = columnMetadataMap.get(cubeColHandle.getColumnName());
Symbol cubeScanSymbol = symbolAllocator.newSymbol(cubeColHandle.getColumnName(), columnMetadata.getType());
cubeScanSymbols.add(cubeScanSymbol);
symbolAssignments.put(cubeScanSymbol, cubeColHandle);
symbolMetadataMap.put(cubeScanSymbol, columnMetadata);
rewrittenMappings.put(originalAggOutputSymbol.getName(), cubeScanSymbol);
aggregationColumns.add(new CubeRewriteResult.AggregatorSource(originalAggOutputSymbol, cubeScanSymbol));
} else {
ColumnMetadata columnMetadata = columnMetadataMap.get(cubeColHandle.getColumnName());
Symbol cubeScanSymbol = symbolAssignments.keySet().stream().filter(key -> cubeColHandle.equals(symbolAssignments.get(key))).findFirst().get();
cubeScanSymbols.add(cubeScanSymbol);
symbolAssignments.put(cubeScanSymbol, cubeColHandle);
symbolMetadataMap.put(cubeScanSymbol, columnMetadata);
rewrittenMappings.put(originalAggOutputSymbol.getName(), cubeScanSymbol);
aggregationColumns.add(new CubeRewriteResult.AggregatorSource(originalAggOutputSymbol, cubeScanSymbol));
}
break;
case "avg":
AggregationSignature avgAggregationSignature = new AggregationSignature(aggFunction, originalColumnName, distinct);
if (exactGroupsMatch && cubeMetadata.getColumn(avgAggregationSignature).isPresent()) {
computeAvgDividingSumByCount = false;
String avgCubeColumnName = cubeMetadata.getColumn(avgAggregationSignature).orElseThrow(() -> new PrestoException(CUBE_ERROR, "Cannot find column associated with aggregation " + avgAggregationSignature));
ColumnHandle avgCubeColHandle = cubeColumnsMap.get(avgCubeColumnName);
if (!symbolAssignments.containsValue(avgCubeColHandle)) {
ColumnMetadata columnMetadata = columnMetadataMap.get(avgCubeColHandle.getColumnName());
Symbol cubeScanSymbol = symbolAllocator.newSymbol(avgCubeColHandle.getColumnName(), columnMetadata.getType());
cubeScanSymbols.add(cubeScanSymbol);
symbolAssignments.put(cubeScanSymbol, avgCubeColHandle);
symbolMetadataMap.put(cubeScanSymbol, columnMetadata);
rewrittenMappings.put(originalAggOutputSymbol.getName(), cubeScanSymbol);
aggregationColumns.add(new CubeRewriteResult.AggregatorSource(originalAggOutputSymbol, cubeScanSymbol));
}
} else {
AggregationSignature sumSignature = new AggregationSignature(SUM.getName(), originalColumnName, distinct);
String sumColumnName = cubeMetadata.getColumn(sumSignature).orElseThrow(() -> new PrestoException(CUBE_ERROR, "Cannot find column associated with aggregation " + sumSignature));
ColumnHandle sumColumnHandle = cubeColumnsMap.get(sumColumnName);
Symbol sumSymbol = null;
if (!symbolAssignments.containsValue(sumColumnHandle)) {
ColumnMetadata columnMetadata = columnMetadataMap.get(sumColumnHandle.getColumnName());
sumSymbol = symbolAllocator.newSymbol("sum_" + originalColumnName + "_" + originalAggOutputSymbol.getName(), columnMetadata.getType());
cubeScanSymbols.add(sumSymbol);
symbolAssignments.put(sumSymbol, sumColumnHandle);
symbolMetadataMap.put(sumSymbol, columnMetadata);
rewrittenMappings.put(sumSymbol.getName(), sumSymbol);
aggregationColumns.add(new CubeRewriteResult.AggregatorSource(sumSymbol, sumSymbol));
} else {
for (Map.Entry<Symbol, ColumnHandle> assignment : symbolAssignments.entrySet()) {
if (assignment.getValue().equals(sumColumnHandle)) {
sumSymbol = assignment.getKey();
break;
}
}
}
AggregationSignature countSignature = new AggregationSignature(COUNT.getName(), originalColumnName, distinct);
String countColumnName = cubeMetadata.getColumn(countSignature).orElseThrow(() -> new PrestoException(CUBE_ERROR, "Cannot find column associated with aggregation " + countSignature));
ColumnHandle countColumnHandle = cubeColumnsMap.get(countColumnName);
Symbol countSymbol = null;
if (!symbolAssignments.containsValue(countColumnHandle)) {
ColumnMetadata columnMetadata = columnMetadataMap.get(countColumnHandle.getColumnName());
countSymbol = symbolAllocator.newSymbol("count_" + originalColumnName + "_" + originalAggOutputSymbol.getName(), columnMetadata.getType());
cubeScanSymbols.add(countSymbol);
symbolAssignments.put(countSymbol, countColumnHandle);
symbolMetadataMap.put(countSymbol, columnMetadata);
rewrittenMappings.put(countSymbol.getName(), countSymbol);
aggregationColumns.add(new CubeRewriteResult.AggregatorSource(countSymbol, countSymbol));
} else {
for (Map.Entry<Symbol, ColumnHandle> assignment : symbolAssignments.entrySet()) {
if (assignment.getValue().equals(countColumnHandle)) {
countSymbol = assignment.getKey();
break;
}
}
}
averageAggregationColumns.add(new CubeRewriteResult.AverageAggregatorSource(originalAggOutputSymbol, sumSymbol, countSymbol));
}
break;
default:
throw new PrestoException(StandardErrorCode.GENERIC_INTERNAL_ERROR, "Unsupported aggregation function " + aggFunction);
}
} else {
log.info("Aggregation function argument is not a Column Handle. Agg Function = %s, Arguments = %s", aggFunction, arguments);
}
}
}
// Scan output order is important for partitioned cubes. Otherwise, incorrect results may be produced.
// Refer: https://gitee.com/openlookeng/hetu-core/issues/I4LAYC
List<Symbol> scanOutput = new ArrayList<>(cubeScanSymbols);
scanOutput.sort(Comparator.comparingInt(outSymbol -> cubeColumnsMetadata.indexOf(symbolMetadataMap.get(outSymbol))));
TableScanNode tableScanNode = TableScanNode.newInstance(idAllocator.getNextId(), cubeTableHandle, scanOutput, symbolAssignments, ReuseExchangeOperator.STRATEGY.REUSE_STRATEGY_DEFAULT, new UUID(0, 0), 0, false);
return new CubeRewriteResult(tableScanNode, symbolMetadataMap, dimensionSymbols, aggregationColumns, averageAggregationColumns, computeAvgDividingSumByCount);
}
use of io.prestosql.spi.StandardErrorCode.CUBE_ERROR in project hetu-core by openlookeng.
the class CubeOptimizer method createScanNode.
public CubeRewriteResult createScanNode() {
Set<Symbol> cubeScanSymbols = new HashSet<>();
Map<Symbol, ColumnHandle> cubeColumnHandles = new HashMap<>();
Set<CubeRewriteResult.DimensionSource> dimensionSymbols = new HashSet<>();
Set<CubeRewriteResult.AggregatorSource> aggregationColumns = new HashSet<>();
Set<CubeRewriteResult.AverageAggregatorSource> averageAggregationColumns = new HashSet<>();
Map<Symbol, ColumnMetadata> symbolMetadataMap = new HashMap<>();
SymbolAllocator symbolAllocator = context.getSymbolAllocator();
dimensions.forEach(dimension -> {
// assumed that source table dimension is same in cube table as well
ColumnWithTable sourceColumn = originalPlanMappings.get(dimension);
ColumnHandle cubeColumn = cubeColumnsMap.get(sourceColumn.getColumnName());
ColumnMetadata columnMetadata = metadata.getColumnMetadata(context.getSession(), cubeTableHandle, cubeColumn);
Symbol cubeScanSymbol = symbolAllocator.newSymbol(cubeColumn.getColumnName(), columnMetadata.getType());
cubeScanSymbols.add(cubeScanSymbol);
cubeColumnHandles.put(cubeScanSymbol, cubeColumn);
dimensionSymbols.add(new CubeRewriteResult.DimensionSource(cubeScanSymbol, cubeScanSymbol));
});
filterColumns.forEach(filterColumn -> {
ColumnHandle cubeColumn = cubeColumnsMap.get(filterColumn);
if (cubeColumn == null) {
// Predicate rewrite will take care of remove those columns from predicate
return;
}
ColumnMetadata columnMetadata = metadata.getColumnMetadata(context.getSession(), cubeTableHandle, cubeColumn);
Symbol cubeScanSymbol = symbolAllocator.newSymbol(cubeColumn.getColumnName(), columnMetadata.getType());
cubeScanSymbols.add(cubeScanSymbol);
cubeColumnHandles.put(cubeScanSymbol, cubeColumn);
dimensionSymbols.add(new CubeRewriteResult.DimensionSource(cubeScanSymbol, cubeScanSymbol));
});
// process all aggregations except avg.
originalAggregationsMap.forEach((originalAgg, originalAggOutputSymbol) -> {
if (originalAgg.getFunction().equalsIgnoreCase(CubeAggregateFunction.AVG.toString())) {
// skip average aggregation
return;
}
String cubeColumnName = cubeMetadata.getColumn(originalAgg).orElseThrow(() -> new PrestoException(CUBE_ERROR, "Cannot find column associated with aggregation " + originalAgg));
ColumnHandle cubeColHandle = cubeColumnsMap.get(cubeColumnName);
ColumnMetadata columnMetadata = metadata.getColumnMetadata(context.getSession(), cubeTableHandle, cubeColHandle);
Symbol cubeScanSymbol = symbolAllocator.newSymbol(cubeColHandle.getColumnName(), columnMetadata.getType());
cubeScanSymbols.add(cubeScanSymbol);
cubeColumnHandles.put(cubeScanSymbol, cubeColHandle);
symbolMetadataMap.put(originalAggOutputSymbol, columnMetadata);
aggregationColumns.add(new CubeRewriteResult.AggregatorSource(originalAggOutputSymbol, cubeScanSymbol));
});
boolean computeAvgDividingSumByCount = true;
// process average aggregation
for (AggregationSignature originalAgg : originalAggregationsMap.keySet()) {
if (!originalAgg.getFunction().equalsIgnoreCase(CubeAggregateFunction.AVG.toString())) {
// process only average aggregation. skip others
continue;
}
Symbol originalOutputSymbol = originalAggregationsMap.get(originalAgg);
String aggSourceColumnName = originalAgg.getDimension();
AggregationSignature avgAggregationSignature = new AggregationSignature(originalAgg.getFunction(), aggSourceColumnName, originalAgg.isDistinct());
// exactGroupByMatch is True iff query group by clause by does not contain columns from other table. This is extremely unlikely for Join queries
boolean exactGroupByMatch = !groupByFromOtherTable && groupBy.size() == cubeMetadata.getGroup().size() && groupBy.stream().map(String::toLowerCase).allMatch(cubeMetadata.getGroup()::contains);
if (exactGroupByMatch && cubeMetadata.getColumn(avgAggregationSignature).isPresent()) {
// Use AVG column for exact group match only if Cube has it. The Cubes created before shipping
// this optimization will not have AVG column
computeAvgDividingSumByCount = false;
String avgCubeColumnName = cubeMetadata.getColumn(avgAggregationSignature).orElseThrow(() -> new PrestoException(CUBE_ERROR, "Cannot find column associated with aggregation " + avgAggregationSignature));
ColumnHandle avgCubeColHandle = cubeColumnsMap.get(avgCubeColumnName);
if (!cubeColumnHandles.containsValue(avgCubeColHandle)) {
ColumnMetadata columnMetadata = metadata.getColumnMetadata(context.getSession(), cubeTableHandle, avgCubeColHandle);
cubeColumnHandles.put(originalOutputSymbol, avgCubeColHandle);
symbolMetadataMap.put(originalOutputSymbol, columnMetadata);
cubeScanSymbols.add(originalOutputSymbol);
aggregationColumns.add(new CubeRewriteResult.AggregatorSource(originalOutputSymbol, originalOutputSymbol));
}
} else {
AggregationSignature sumSignature = new AggregationSignature(SUM.getName(), aggSourceColumnName, originalAgg.isDistinct());
String sumColumnName = cubeMetadata.getColumn(sumSignature).orElseThrow(() -> new PrestoException(CUBE_ERROR, "Cannot find column associated with aggregation " + sumSignature));
ColumnHandle sumColumnHandle = cubeColumnsMap.get(sumColumnName);
ColumnMetadata sumColumnMetadata = metadata.getColumnMetadata(context.getSession(), cubeTableHandle, sumColumnHandle);
Symbol sumColumnSymbol = cubeColumnHandles.entrySet().stream().filter(entry -> sumColumnHandle.equals(entry.getValue())).findAny().map(Map.Entry::getKey).orElse(symbolAllocator.newSymbol("sum_" + aggSourceColumnName + "_" + originalOutputSymbol.getName(), sumColumnMetadata.getType()));
if (!cubeScanSymbols.contains(sumColumnSymbol)) {
cubeScanSymbols.add(sumColumnSymbol);
cubeColumnHandles.put(sumColumnSymbol, sumColumnHandle);
symbolMetadataMap.put(sumColumnSymbol, sumColumnMetadata);
aggregationColumns.add(new CubeRewriteResult.AggregatorSource(sumColumnSymbol, sumColumnSymbol));
}
AggregationSignature countSignature = new AggregationSignature(COUNT.getName(), aggSourceColumnName, originalAgg.isDistinct());
String countColumnName = cubeMetadata.getColumn(countSignature).orElseThrow(() -> new PrestoException(CUBE_ERROR, "Cannot find column associated with aggregation " + countSignature));
ColumnHandle countColumnHandle = cubeColumnsMap.get(countColumnName);
ColumnMetadata countColumnMetadata = metadata.getColumnMetadata(context.getSession(), cubeTableHandle, countColumnHandle);
Symbol countColumnSymbol = cubeColumnHandles.entrySet().stream().filter(entry -> countColumnHandle.equals(entry.getValue())).findAny().map(Map.Entry::getKey).orElse(symbolAllocator.newSymbol("count_" + aggSourceColumnName + "_" + originalOutputSymbol.getName(), countColumnMetadata.getType()));
if (!cubeScanSymbols.contains(countColumnSymbol)) {
cubeScanSymbols.add(countColumnSymbol);
cubeColumnHandles.put(countColumnSymbol, countColumnHandle);
symbolMetadataMap.put(countColumnSymbol, countColumnMetadata);
aggregationColumns.add(new CubeRewriteResult.AggregatorSource(countColumnSymbol, countColumnSymbol));
}
averageAggregationColumns.add(new CubeRewriteResult.AverageAggregatorSource(originalOutputSymbol, sumColumnSymbol, countColumnSymbol));
}
}
// Scan output order is important for partitioned cubes. Otherwise, incorrect results may be produced.
// Refer: https://gitee.com/openlookeng/hetu-core/issues/I4LAYC
List<ColumnMetadata> cubeColumnMetadataList = cubeTableMetadata.getColumns();
List<Symbol> scanOutput = new ArrayList<>(cubeScanSymbols);
scanOutput.sort(Comparator.comparingInt(outSymbol -> cubeColumnMetadataList.indexOf(symbolMetadataMap.get(outSymbol))));
cubeScanNode = TableScanNode.newInstance(context.getIdAllocator().getNextId(), cubeTableHandle, scanOutput, cubeColumnHandles, ReuseExchangeOperator.STRATEGY.REUSE_STRATEGY_DEFAULT, new UUID(0, 0), 0, false);
return new CubeRewriteResult(cubeScanNode, symbolMetadataMap, dimensionSymbols, aggregationColumns, averageAggregationColumns, computeAvgDividingSumByCount);
}
Aggregations