use of io.prestosql.spi.SymbolAllocator in project hetu-core by openlookeng.
the class CubeOptimizer method createScanNode.
public CubeRewriteResult createScanNode() {
Set<Symbol> cubeScanSymbols = new HashSet<>();
Map<Symbol, ColumnHandle> cubeColumnHandles = new HashMap<>();
Set<CubeRewriteResult.DimensionSource> dimensionSymbols = new HashSet<>();
Set<CubeRewriteResult.AggregatorSource> aggregationColumns = new HashSet<>();
Set<CubeRewriteResult.AverageAggregatorSource> averageAggregationColumns = new HashSet<>();
Map<Symbol, ColumnMetadata> symbolMetadataMap = new HashMap<>();
SymbolAllocator symbolAllocator = context.getSymbolAllocator();
dimensions.forEach(dimension -> {
// assumed that source table dimension is same in cube table as well
ColumnWithTable sourceColumn = originalPlanMappings.get(dimension);
ColumnHandle cubeColumn = cubeColumnsMap.get(sourceColumn.getColumnName());
ColumnMetadata columnMetadata = metadata.getColumnMetadata(context.getSession(), cubeTableHandle, cubeColumn);
Symbol cubeScanSymbol = symbolAllocator.newSymbol(cubeColumn.getColumnName(), columnMetadata.getType());
cubeScanSymbols.add(cubeScanSymbol);
cubeColumnHandles.put(cubeScanSymbol, cubeColumn);
dimensionSymbols.add(new CubeRewriteResult.DimensionSource(cubeScanSymbol, cubeScanSymbol));
});
filterColumns.forEach(filterColumn -> {
ColumnHandle cubeColumn = cubeColumnsMap.get(filterColumn);
if (cubeColumn == null) {
// Predicate rewrite will take care of remove those columns from predicate
return;
}
ColumnMetadata columnMetadata = metadata.getColumnMetadata(context.getSession(), cubeTableHandle, cubeColumn);
Symbol cubeScanSymbol = symbolAllocator.newSymbol(cubeColumn.getColumnName(), columnMetadata.getType());
cubeScanSymbols.add(cubeScanSymbol);
cubeColumnHandles.put(cubeScanSymbol, cubeColumn);
dimensionSymbols.add(new CubeRewriteResult.DimensionSource(cubeScanSymbol, cubeScanSymbol));
});
// process all aggregations except avg.
originalAggregationsMap.forEach((originalAgg, originalAggOutputSymbol) -> {
if (originalAgg.getFunction().equalsIgnoreCase(CubeAggregateFunction.AVG.toString())) {
// skip average aggregation
return;
}
String cubeColumnName = cubeMetadata.getColumn(originalAgg).orElseThrow(() -> new PrestoException(CUBE_ERROR, "Cannot find column associated with aggregation " + originalAgg));
ColumnHandle cubeColHandle = cubeColumnsMap.get(cubeColumnName);
ColumnMetadata columnMetadata = metadata.getColumnMetadata(context.getSession(), cubeTableHandle, cubeColHandle);
Symbol cubeScanSymbol = symbolAllocator.newSymbol(cubeColHandle.getColumnName(), columnMetadata.getType());
cubeScanSymbols.add(cubeScanSymbol);
cubeColumnHandles.put(cubeScanSymbol, cubeColHandle);
symbolMetadataMap.put(originalAggOutputSymbol, columnMetadata);
aggregationColumns.add(new CubeRewriteResult.AggregatorSource(originalAggOutputSymbol, cubeScanSymbol));
});
boolean computeAvgDividingSumByCount = true;
// process average aggregation
for (AggregationSignature originalAgg : originalAggregationsMap.keySet()) {
if (!originalAgg.getFunction().equalsIgnoreCase(CubeAggregateFunction.AVG.toString())) {
// process only average aggregation. skip others
continue;
}
Symbol originalOutputSymbol = originalAggregationsMap.get(originalAgg);
String aggSourceColumnName = originalAgg.getDimension();
AggregationSignature avgAggregationSignature = new AggregationSignature(originalAgg.getFunction(), aggSourceColumnName, originalAgg.isDistinct());
// exactGroupByMatch is True iff query group by clause by does not contain columns from other table. This is extremely unlikely for Join queries
boolean exactGroupByMatch = !groupByFromOtherTable && groupBy.size() == cubeMetadata.getGroup().size() && groupBy.stream().map(String::toLowerCase).allMatch(cubeMetadata.getGroup()::contains);
if (exactGroupByMatch && cubeMetadata.getColumn(avgAggregationSignature).isPresent()) {
// Use AVG column for exact group match only if Cube has it. The Cubes created before shipping
// this optimization will not have AVG column
computeAvgDividingSumByCount = false;
String avgCubeColumnName = cubeMetadata.getColumn(avgAggregationSignature).orElseThrow(() -> new PrestoException(CUBE_ERROR, "Cannot find column associated with aggregation " + avgAggregationSignature));
ColumnHandle avgCubeColHandle = cubeColumnsMap.get(avgCubeColumnName);
if (!cubeColumnHandles.containsValue(avgCubeColHandle)) {
ColumnMetadata columnMetadata = metadata.getColumnMetadata(context.getSession(), cubeTableHandle, avgCubeColHandle);
cubeColumnHandles.put(originalOutputSymbol, avgCubeColHandle);
symbolMetadataMap.put(originalOutputSymbol, columnMetadata);
cubeScanSymbols.add(originalOutputSymbol);
aggregationColumns.add(new CubeRewriteResult.AggregatorSource(originalOutputSymbol, originalOutputSymbol));
}
} else {
AggregationSignature sumSignature = new AggregationSignature(SUM.getName(), aggSourceColumnName, originalAgg.isDistinct());
String sumColumnName = cubeMetadata.getColumn(sumSignature).orElseThrow(() -> new PrestoException(CUBE_ERROR, "Cannot find column associated with aggregation " + sumSignature));
ColumnHandle sumColumnHandle = cubeColumnsMap.get(sumColumnName);
ColumnMetadata sumColumnMetadata = metadata.getColumnMetadata(context.getSession(), cubeTableHandle, sumColumnHandle);
Symbol sumColumnSymbol = cubeColumnHandles.entrySet().stream().filter(entry -> sumColumnHandle.equals(entry.getValue())).findAny().map(Map.Entry::getKey).orElse(symbolAllocator.newSymbol("sum_" + aggSourceColumnName + "_" + originalOutputSymbol.getName(), sumColumnMetadata.getType()));
if (!cubeScanSymbols.contains(sumColumnSymbol)) {
cubeScanSymbols.add(sumColumnSymbol);
cubeColumnHandles.put(sumColumnSymbol, sumColumnHandle);
symbolMetadataMap.put(sumColumnSymbol, sumColumnMetadata);
aggregationColumns.add(new CubeRewriteResult.AggregatorSource(sumColumnSymbol, sumColumnSymbol));
}
AggregationSignature countSignature = new AggregationSignature(COUNT.getName(), aggSourceColumnName, originalAgg.isDistinct());
String countColumnName = cubeMetadata.getColumn(countSignature).orElseThrow(() -> new PrestoException(CUBE_ERROR, "Cannot find column associated with aggregation " + countSignature));
ColumnHandle countColumnHandle = cubeColumnsMap.get(countColumnName);
ColumnMetadata countColumnMetadata = metadata.getColumnMetadata(context.getSession(), cubeTableHandle, countColumnHandle);
Symbol countColumnSymbol = cubeColumnHandles.entrySet().stream().filter(entry -> countColumnHandle.equals(entry.getValue())).findAny().map(Map.Entry::getKey).orElse(symbolAllocator.newSymbol("count_" + aggSourceColumnName + "_" + originalOutputSymbol.getName(), countColumnMetadata.getType()));
if (!cubeScanSymbols.contains(countColumnSymbol)) {
cubeScanSymbols.add(countColumnSymbol);
cubeColumnHandles.put(countColumnSymbol, countColumnHandle);
symbolMetadataMap.put(countColumnSymbol, countColumnMetadata);
aggregationColumns.add(new CubeRewriteResult.AggregatorSource(countColumnSymbol, countColumnSymbol));
}
averageAggregationColumns.add(new CubeRewriteResult.AverageAggregatorSource(originalOutputSymbol, sumColumnSymbol, countColumnSymbol));
}
}
// Scan output order is important for partitioned cubes. Otherwise, incorrect results may be produced.
// Refer: https://gitee.com/openlookeng/hetu-core/issues/I4LAYC
List<ColumnMetadata> cubeColumnMetadataList = cubeTableMetadata.getColumns();
List<Symbol> scanOutput = new ArrayList<>(cubeScanSymbols);
scanOutput.sort(Comparator.comparingInt(outSymbol -> cubeColumnMetadataList.indexOf(symbolMetadataMap.get(outSymbol))));
cubeScanNode = TableScanNode.newInstance(context.getIdAllocator().getNextId(), cubeTableHandle, scanOutput, cubeColumnHandles, ReuseExchangeOperator.STRATEGY.REUSE_STRATEGY_DEFAULT, new UUID(0, 0), 0, false);
return new CubeRewriteResult(cubeScanNode, symbolMetadataMap, dimensionSymbols, aggregationColumns, averageAggregationColumns, computeAvgDividingSumByCount);
}
Aggregations