use of io.prestosql.sql.relational.OriginalExpressionUtils in project hetu-core by openlookeng.
the class QueryPlanner method aggregate.
private PlanBuilder aggregate(PlanBuilder inputSubPlan, QuerySpecification node) {
PlanBuilder subPlan = inputSubPlan;
if (!analysis.isAggregation(node)) {
return subPlan;
}
// 1. Pre-project all scalar inputs (arguments and non-trivial group by expressions)
Set<Expression> groupByExpressions = ImmutableSet.copyOf(analysis.getGroupByExpressions(node));
ImmutableList.Builder<Expression> arguments = ImmutableList.builder();
analysis.getAggregates(node).stream().map(FunctionCall::getArguments).flatMap(List::stream).filter(// lambda expression is generated at execution time
exp -> !(exp instanceof LambdaExpression)).forEach(arguments::add);
analysis.getAggregates(node).stream().map(FunctionCall::getOrderBy).filter(Optional::isPresent).map(Optional::get).map(OrderBy::getSortItems).flatMap(List::stream).map(SortItem::getSortKey).forEach(arguments::add);
// filter expressions need to be projected first
analysis.getAggregates(node).stream().map(FunctionCall::getFilter).filter(Optional::isPresent).map(Optional::get).forEach(arguments::add);
Iterable<Expression> inputs = Iterables.concat(groupByExpressions, arguments.build());
subPlan = handleSubqueries(subPlan, node, inputs);
if (!Iterables.isEmpty(inputs)) {
// avoid an empty projection if the only aggregation is COUNT (which has no arguments)
subPlan = project(subPlan, inputs);
}
// 2. Aggregate
// 2.a. Rewrite aggregate arguments
TranslationMap argumentTranslations = new TranslationMap(subPlan.getRelationPlan(), analysis, lambdaDeclarationToSymbolMap);
ImmutableList.Builder<Symbol> aggregationArgumentsBuilder = ImmutableList.builder();
for (Expression argument : arguments.build()) {
Symbol symbol = subPlan.translate(argument);
argumentTranslations.put(argument, symbol);
aggregationArgumentsBuilder.add(symbol);
}
List<Symbol> aggregationArguments = aggregationArgumentsBuilder.build();
// 2.b. Rewrite grouping columns
TranslationMap groupingTranslations = new TranslationMap(subPlan.getRelationPlan(), analysis, lambdaDeclarationToSymbolMap);
Map<Symbol, Symbol> groupingSetMappings = new LinkedHashMap<>();
for (Expression expression : groupByExpressions) {
Symbol input = subPlan.translate(expression);
Symbol output = planSymbolAllocator.newSymbol(expression, analysis.getTypeWithCoercions(expression), "gid");
groupingTranslations.put(expression, output);
groupingSetMappings.put(output, input);
}
// This tracks the grouping sets before complex expressions are considered (see comments below)
// It's also used to compute the descriptors needed to implement grouping()
List<Set<FieldId>> columnOnlyGroupingSets = ImmutableList.of(ImmutableSet.of());
List<List<Symbol>> groupingSets = ImmutableList.of(ImmutableList.of());
if (node.getGroupBy().isPresent()) {
// For the purpose of "distinct", we need to canonicalize column references that may have varying
// syntactic forms (e.g., "t.a" vs "a"). Thus we need to enumerate grouping sets based on the underlying
// fieldId associated with each column reference expression.
// The catch is that simple group-by expressions can be arbitrary expressions (this is a departure from the SQL specification).
// But, they don't affect the number of grouping sets or the behavior of "distinct" . We can compute all the candidate
// grouping sets in terms of fieldId, dedup as appropriate and then cross-join them with the complex expressions.
Analysis.GroupingSetAnalysis groupingSetAnalysis = analysis.getGroupingSets(node);
columnOnlyGroupingSets = enumerateGroupingSets(groupingSetAnalysis);
if (node.getGroupBy().get().isDistinct()) {
columnOnlyGroupingSets = columnOnlyGroupingSets.stream().distinct().collect(toImmutableList());
}
// add in the complex expressions an turn materialize the grouping sets in terms of plan columns
ImmutableList.Builder<List<Symbol>> groupingSetBuilder = ImmutableList.builder();
for (Set<FieldId> groupingSet : columnOnlyGroupingSets) {
ImmutableList.Builder<Symbol> columns = ImmutableList.builder();
groupingSetAnalysis.getComplexExpressions().stream().map(groupingTranslations::get).forEach(columns::add);
groupingSet.stream().map(field -> groupingTranslations.get(new FieldReference(field.getFieldIndex()))).forEach(columns::add);
groupingSetBuilder.add(columns.build());
}
groupingSets = groupingSetBuilder.build();
}
// 2.c. Generate GroupIdNode (multiple grouping sets) or ProjectNode (single grouping set)
Optional<Symbol> groupIdSymbol = Optional.empty();
if (groupingSets.size() > 1) {
groupIdSymbol = Optional.of(planSymbolAllocator.newSymbol("groupId", BIGINT));
GroupIdNode groupId = new GroupIdNode(idAllocator.getNextId(), subPlan.getRoot(), groupingSets, groupingSetMappings, aggregationArguments, groupIdSymbol.get());
subPlan = new PlanBuilder(groupingTranslations, groupId);
} else {
Assignments.Builder assignments = Assignments.builder();
aggregationArguments.forEach(symbol -> assignments.put(symbol, castToRowExpression(toSymbolReference(symbol))));
groupingSetMappings.forEach((key, value) -> assignments.put(key, castToRowExpression(toSymbolReference(value))));
ProjectNode project = new ProjectNode(idAllocator.getNextId(), subPlan.getRoot(), assignments.build());
subPlan = new PlanBuilder(groupingTranslations, project);
}
TranslationMap aggregationTranslations = new TranslationMap(subPlan.getRelationPlan(), analysis, lambdaDeclarationToSymbolMap);
aggregationTranslations.copyMappingsFrom(groupingTranslations);
// 2.d. Rewrite aggregates
ImmutableMap.Builder<Symbol, Aggregation> aggregationsBuilder = ImmutableMap.builder();
boolean needPostProjectionCoercion = false;
for (FunctionCall aggregate : analysis.getAggregates(node)) {
Expression rewritten = argumentTranslations.rewrite(aggregate);
Symbol newSymbol = planSymbolAllocator.newSymbol(rewritten, analysis.getType(aggregate));
// Therefore we can end up with this implicit cast, and have to move it into a post-projection
if (rewritten instanceof Cast) {
rewritten = ((Cast) rewritten).getExpression();
needPostProjectionCoercion = true;
}
aggregationTranslations.put(aggregate, newSymbol);
FunctionCall functionCall = (FunctionCall) rewritten;
aggregationsBuilder.put(newSymbol, new Aggregation(call(aggregate.getName().getSuffix(), analysis.getFunctionHandle(aggregate), analysis.getType(aggregate), functionCall.getArguments().stream().map(OriginalExpressionUtils::castToRowExpression).collect(toImmutableList())), functionCall.getArguments().stream().map(OriginalExpressionUtils::castToRowExpression).collect(toImmutableList()), functionCall.isDistinct(), functionCall.getFilter().map(SymbolUtils::from), functionCall.getOrderBy().map(OrderingSchemeUtils::fromOrderBy), Optional.empty()));
}
Map<Symbol, Aggregation> aggregations = aggregationsBuilder.build();
ImmutableSet.Builder<Integer> globalGroupingSets = ImmutableSet.builder();
for (int i = 0; i < groupingSets.size(); i++) {
if (groupingSets.get(i).isEmpty()) {
globalGroupingSets.add(i);
}
}
ImmutableList.Builder<Symbol> groupingKeys = ImmutableList.builder();
groupingSets.stream().flatMap(List::stream).distinct().forEach(groupingKeys::add);
groupIdSymbol.ifPresent(groupingKeys::add);
AggregationNode aggregationNode = new AggregationNode(idAllocator.getNextId(), subPlan.getRoot(), aggregations, groupingSets(groupingKeys.build(), groupingSets.size(), globalGroupingSets.build()), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), groupIdSymbol, AggregationNode.AggregationType.HASH, Optional.empty());
subPlan = new PlanBuilder(aggregationTranslations, aggregationNode);
// TODO: this is a hack, we should change type coercions to coerce the inputs to functions/operators instead of coercing the output
if (needPostProjectionCoercion) {
ImmutableList.Builder<Expression> alreadyCoerced = ImmutableList.builder();
alreadyCoerced.addAll(groupByExpressions);
groupIdSymbol.map(SymbolUtils::toSymbolReference).ifPresent(alreadyCoerced::add);
subPlan = explicitCoercionFields(subPlan, alreadyCoerced.build(), analysis.getAggregates(node));
}
// 4. Project and re-write all grouping functions
return handleGroupingOperations(subPlan, node, groupIdSymbol, columnOnlyGroupingSets);
}
use of io.prestosql.sql.relational.OriginalExpressionUtils in project hetu-core by openlookeng.
the class InlineProjections method apply.
@Override
public Result apply(ProjectNode parent, Captures captures, Context context) {
ProjectNode child = captures.get(CHILD);
Sets.SetView<Symbol> targets = extractInliningTargets(parent, child);
if (targets.isEmpty()) {
return Result.empty();
}
// inline the expressions
Assignments assignments = child.getAssignments().filter(targets::contains);
Map<Symbol, RowExpression> parentAssignments = parent.getAssignments().entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> (inlineReferences(entry.getValue(), assignments))));
// Synthesize identity assignments for the inputs of expressions that were inlined
// to place in the child projection.
// If all assignments end up becoming identity assignments, they'll get pruned by
// other rules
Set<Symbol> inputs = child.getAssignments().entrySet().stream().filter(entry -> targets.contains(entry.getKey())).map(Map.Entry::getValue).flatMap(entry -> extractInputs(entry).stream()).collect(toSet());
Assignments.Builder childAssignments = Assignments.builder();
for (Map.Entry<Symbol, RowExpression> assignment : child.getAssignments().entrySet()) {
if (!targets.contains(assignment.getKey())) {
childAssignments.put(assignment);
}
}
boolean allTranslated = child.getAssignments().entrySet().stream().map(Map.Entry::getValue).noneMatch(OriginalExpressionUtils::isExpression);
for (Symbol input : inputs) {
if (allTranslated) {
Type inputType = context.getSymbolAllocator().getSymbols().get(input);
childAssignments.put(input, new VariableReferenceExpression(input.getName(), inputType));
} else {
childAssignments.put(input, castToRowExpression(toSymbolReference(input)));
}
}
return Result.ofPlanNode(new ProjectNode(parent.getId(), new ProjectNode(child.getId(), child.getSource(), childAssignments.build()), Assignments.copyOf(parentAssignments)));
}
use of io.prestosql.sql.relational.OriginalExpressionUtils in project hetu-core by openlookeng.
the class TransformUnCorrelatedInPredicateSubQuerySelfJoinToAggregate method transformProjectNode.
private Optional<ProjectNode> transformProjectNode(Context context, ProjectNode projectNode) {
// IN predicate requires only one projection
if (projectNode.getOutputSymbols().size() > 1) {
return Optional.empty();
}
PlanNode source = context.getLookup().resolve(projectNode.getSource());
if (source instanceof CTEScanNode) {
source = getChildFilterNode(context, context.getLookup().resolve(((CTEScanNode) source).getSource()));
}
if (!(source instanceof FilterNode && context.getLookup().resolve(((FilterNode) source).getSource()) instanceof JoinNode)) {
return Optional.empty();
}
FilterNode filter = (FilterNode) source;
Expression predicate = OriginalExpressionUtils.castToExpression(filter.getPredicate());
List<SymbolReference> allPredicateSymbols = new ArrayList<>();
getAllSymbols(predicate, allPredicateSymbols);
JoinNode joinNode = (JoinNode) context.getLookup().resolve(((FilterNode) source).getSource());
if (!isSelfJoin(projectNode, predicate, joinNode, context.getLookup())) {
// Check next level for Self Join
PlanNode left = context.getLookup().resolve(joinNode.getLeft());
boolean changed = false;
if (left instanceof ProjectNode) {
Optional<ProjectNode> transformResult = transformProjectNode(context, (ProjectNode) left);
if (transformResult.isPresent()) {
joinNode = new JoinNode(joinNode.getId(), joinNode.getType(), transformResult.get(), joinNode.getRight(), joinNode.getCriteria(), joinNode.getOutputSymbols(), joinNode.getFilter(), joinNode.getLeftHashSymbol(), joinNode.getRightHashSymbol(), joinNode.getDistributionType(), joinNode.isSpillable(), joinNode.getDynamicFilters());
changed = true;
}
}
PlanNode right = context.getLookup().resolve(joinNode.getRight());
if (right instanceof ProjectNode) {
Optional<ProjectNode> transformResult = transformProjectNode(context, (ProjectNode) right);
if (transformResult.isPresent()) {
joinNode = new JoinNode(joinNode.getId(), joinNode.getType(), joinNode.getLeft(), transformResult.get(), joinNode.getCriteria(), joinNode.getOutputSymbols(), joinNode.getFilter(), joinNode.getLeftHashSymbol(), joinNode.getRightHashSymbol(), joinNode.getDistributionType(), joinNode.isSpillable(), joinNode.getDynamicFilters());
changed = true;
}
}
if (changed) {
FilterNode transformedFilter = new FilterNode(filter.getId(), joinNode, filter.getPredicate());
ProjectNode transformedProject = new ProjectNode(projectNode.getId(), transformedFilter, projectNode.getAssignments());
return Optional.of(transformedProject);
}
return Optional.empty();
}
// Choose the table to use based on projected output.
TableScanNode leftTable = (TableScanNode) context.getLookup().resolve(joinNode.getLeft());
TableScanNode rightTable = (TableScanNode) context.getLookup().resolve(joinNode.getRight());
TableScanNode tableToUse;
List<RowExpression> aggregationSymbols;
AggregationNode.GroupingSetDescriptor groupingSetDescriptor;
Assignments projectionsForCTE = null;
Assignments projectionsFromFilter = null;
// Use non-projected column for aggregation
if (context.getLookup().resolve(projectNode.getSource()) instanceof CTEScanNode) {
CTEScanNode cteScanNode = (CTEScanNode) context.getLookup().resolve(projectNode.getSource());
ProjectNode childProjectOfCte = (ProjectNode) context.getLookup().resolve(cteScanNode.getSource());
List<Symbol> completeOutputSymbols = new ArrayList<>();
rightTable.getOutputSymbols().forEach((s) -> completeOutputSymbols.add(s));
leftTable.getOutputSymbols().forEach((s) -> completeOutputSymbols.add(s));
List<Symbol> outputSymbols = new ArrayList<>();
for (int i = 0; i < completeOutputSymbols.size(); i++) {
Symbol outputSymbol = completeOutputSymbols.get(i);
for (Symbol symbol : projectNode.getOutputSymbols()) {
if (childProjectOfCte.getAssignments().getMap().containsKey(symbol)) {
if (((SymbolReference) OriginalExpressionUtils.castToExpression(childProjectOfCte.getAssignments().getMap().get(symbol))).getName().equals(outputSymbol.getName())) {
outputSymbols.add(outputSymbol);
}
}
}
}
Map<Symbol, RowExpression> projectionsForCTEMap = new HashMap<>();
Map<Symbol, RowExpression> projectionsFromFilterMap = new HashMap<>();
for (Map.Entry entry : childProjectOfCte.getAssignments().getMap().entrySet()) {
if (entry.getKey().equals(getOnlyElement(projectNode.getOutputSymbols()))) {
projectionsForCTEMap.put((Symbol) entry.getKey(), (RowExpression) entry.getValue());
}
if (entry.getKey().equals(getOnlyElement(projectNode.getOutputSymbols()))) {
projectionsFromFilterMap.put(getOnlyElement(outputSymbols), (RowExpression) entry.getValue());
}
}
projectionsForCTE = new Assignments(projectionsForCTEMap);
projectionsFromFilter = new Assignments(projectionsFromFilterMap);
tableToUse = leftTable.getOutputSymbols().contains(getOnlyElement(outputSymbols)) ? leftTable : rightTable;
aggregationSymbols = allPredicateSymbols.stream().filter(s -> tableToUse.getOutputSymbols().contains(SymbolUtils.from(s))).filter(s -> !outputSymbols.contains(SymbolUtils.from(s))).map(OriginalExpressionUtils::castToRowExpression).collect(Collectors.toList());
// Create aggregation
groupingSetDescriptor = new AggregationNode.GroupingSetDescriptor(ImmutableList.copyOf(outputSymbols), 1, ImmutableSet.of());
} else {
tableToUse = leftTable.getOutputSymbols().contains(getOnlyElement(projectNode.getOutputSymbols())) ? leftTable : rightTable;
aggregationSymbols = allPredicateSymbols.stream().filter(s -> tableToUse.getOutputSymbols().contains(SymbolUtils.from(s))).filter(s -> !projectNode.getOutputSymbols().contains(SymbolUtils.from(s))).map(OriginalExpressionUtils::castToRowExpression).collect(Collectors.toList());
// Create aggregation
groupingSetDescriptor = new AggregationNode.GroupingSetDescriptor(ImmutableList.copyOf(projectNode.getOutputSymbols()), 1, ImmutableSet.of());
}
AggregationNode.Aggregation aggregation = new AggregationNode.Aggregation(new CallExpression("count", functionResolution.countFunction(), BIGINT, aggregationSymbols), aggregationSymbols, // mark DISTINCT since NOT_EQUALS predicate
true, Optional.empty(), Optional.empty(), Optional.empty());
ImmutableMap.Builder<Symbol, AggregationNode.Aggregation> aggregationsBuilder = ImmutableMap.builder();
Symbol countSymbol = context.getSymbolAllocator().newSymbol(aggregation.getFunctionCall().getDisplayName(), BIGINT);
aggregationsBuilder.put(countSymbol, aggregation);
AggregationNode aggregationNode = new AggregationNode(context.getIdAllocator().getNextId(), tableToUse, aggregationsBuilder.build(), groupingSetDescriptor, ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty(), AggregationNode.AggregationType.HASH, Optional.empty());
// Filter rows with count < 1 from aggregation results to match the NOT_EQUALS clause in original query.
FilterNode filterNode = new FilterNode(context.getIdAllocator().getNextId(), aggregationNode, OriginalExpressionUtils.castToRowExpression(new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN, SymbolUtils.toSymbolReference(countSymbol), new GenericLiteral("BIGINT", "1"))));
// Project the aggregated+filtered rows.
ProjectNode transformedSubquery = new ProjectNode(projectNode.getId(), filterNode, projectNode.getAssignments());
if (context.getLookup().resolve(projectNode.getSource()) instanceof CTEScanNode) {
CTEScanNode cteScanNode = (CTEScanNode) context.getLookup().resolve(projectNode.getSource());
PlanNode projectNodeForCTE = context.getLookup().resolve(cteScanNode.getSource());
PlanNode projectNodeFromFilter = context.getLookup().resolve(projectNodeForCTE);
projectNodeFromFilter = new ProjectNode(projectNodeFromFilter.getId(), filterNode, projectionsFromFilter);
projectNodeForCTE = new ProjectNode(projectNodeForCTE.getId(), projectNodeFromFilter, projectionsForCTE);
cteScanNode = (CTEScanNode) cteScanNode.replaceChildren(ImmutableList.of(projectNodeForCTE));
cteScanNode.setOutputSymbols(projectNode.getOutputSymbols());
transformedSubquery = new ProjectNode(projectNode.getId(), cteScanNode, projectNode.getAssignments());
}
return Optional.of(transformedSubquery);
}
use of io.prestosql.sql.relational.OriginalExpressionUtils in project hetu-core by openlookeng.
the class AggregationRewriteWithCube method createScanNode.
public CubeRewriteResult createScanNode(AggregationNode originalAggregationNode, PlanNode filterNode, TableHandle cubeTableHandle, Map<String, ColumnHandle> cubeColumnsMap, List<ColumnMetadata> cubeColumnsMetadata, boolean exactGroupsMatch) {
Set<Symbol> cubeScanSymbols = new HashSet<>();
Map<Symbol, ColumnHandle> symbolAssignments = new HashMap<>();
Set<CubeRewriteResult.DimensionSource> dimensionSymbols = new HashSet<>();
Set<CubeRewriteResult.AggregatorSource> aggregationColumns = new HashSet<>();
Set<CubeRewriteResult.AverageAggregatorSource> averageAggregationColumns = new HashSet<>();
Map<Symbol, ColumnMetadata> symbolMetadataMap = new HashMap<>();
Map<String, ColumnMetadata> columnMetadataMap = cubeColumnsMetadata.stream().collect(Collectors.toMap(ColumnMetadata::getName, Function.identity()));
boolean computeAvgDividingSumByCount = true;
Set<Symbol> filterSymbols = new HashSet<>();
if (filterNode != null) {
filterSymbols.addAll(SymbolsExtractor.extractUnique(((FilterNode) filterNode).getPredicate()));
}
for (Symbol filterSymbol : filterSymbols) {
if (symbolMappings.containsKey(filterSymbol.getName()) && symbolMappings.get(filterSymbol.getName()) instanceof ColumnHandle) {
// output symbol references of the columns in original table
ColumnHandle originalColumn = (ColumnHandle) symbolMappings.get(filterSymbol.getName());
ColumnHandle cubeScanColumn = cubeColumnsMap.get(originalColumn.getColumnName());
ColumnMetadata columnMetadata = columnMetadataMap.get(cubeScanColumn.getColumnName());
Symbol cubeScanSymbol = symbolAllocator.newSymbol(cubeScanColumn.getColumnName(), columnMetadata.getType());
cubeScanSymbols.add(cubeScanSymbol);
symbolAssignments.put(cubeScanSymbol, cubeScanColumn);
symbolMetadataMap.put(cubeScanSymbol, columnMetadata);
rewrittenMappings.put(filterSymbol.getName(), cubeScanSymbol);
dimensionSymbols.add(new CubeRewriteResult.DimensionSource(filterSymbol, cubeScanSymbol));
}
}
for (Symbol originalAggOutputSymbol : originalAggregationNode.getOutputSymbols()) {
if (symbolMappings.containsKey(originalAggOutputSymbol.getName()) && symbolMappings.get(originalAggOutputSymbol.getName()) instanceof ColumnHandle) {
// output symbol references of the columns in original table - column part of group by clause
ColumnHandle originalColumn = (ColumnHandle) symbolMappings.get(originalAggOutputSymbol.getName());
ColumnHandle cubeScanColumn = cubeColumnsMap.get(originalColumn.getColumnName());
ColumnMetadata columnMetadata = columnMetadataMap.get(cubeScanColumn.getColumnName());
if (!symbolAssignments.containsValue(cubeScanColumn)) {
Symbol cubeScanSymbol = symbolAllocator.newSymbol(cubeScanColumn.getColumnName(), columnMetadata.getType());
cubeScanSymbols.add(cubeScanSymbol);
symbolAssignments.put(cubeScanSymbol, cubeScanColumn);
symbolMetadataMap.put(cubeScanSymbol, columnMetadata);
rewrittenMappings.put(originalAggOutputSymbol.getName(), cubeScanSymbol);
dimensionSymbols.add(new CubeRewriteResult.DimensionSource(originalAggOutputSymbol, cubeScanSymbol));
} else {
Symbol cubeScanSymbol = symbolAssignments.keySet().stream().filter(key -> cubeScanColumn.equals(symbolAssignments.get(key))).findFirst().get();
rewrittenMappings.put(originalAggOutputSymbol.getName(), cubeScanSymbol);
dimensionSymbols.add(new CubeRewriteResult.DimensionSource(originalAggOutputSymbol, cubeScanSymbol));
}
} else if (originalAggregationNode.getAggregations().containsKey(originalAggOutputSymbol)) {
// output symbol is mapped to an aggregation
AggregationNode.Aggregation aggregation = originalAggregationNode.getAggregations().get(originalAggOutputSymbol);
String aggFunction = aggregation.getFunctionCall().getDisplayName();
List<Expression> arguments = aggregation.getArguments() == null ? null : aggregation.getArguments().stream().map(OriginalExpressionUtils::castToExpression).collect(Collectors.toList());
if (arguments != null && !arguments.isEmpty() && (!(arguments.get(0) instanceof SymbolReference))) {
log.info("Not a symbol reference in aggregation function. Agg Function = %s, Arguments = %s", aggFunction, arguments);
continue;
}
Object mappedValue = arguments == null || arguments.isEmpty() ? null : symbolMappings.get(((SymbolReference) arguments.get(0)).getName());
if (mappedValue == null || (mappedValue instanceof LongLiteral && ((LongLiteral) mappedValue).getValue() == 1)) {
// COUNT aggregation
if (CubeAggregateFunction.COUNT.getName().equals(aggFunction) && !aggregation.isDistinct()) {
// COUNT 1
AggregationSignature aggregationSignature = AggregationSignature.count();
String cubeColumnName = cubeMetadata.getColumn(aggregationSignature).orElseThrow(() -> new PrestoException(CUBE_ERROR, "Cannot find column associated with aggregation " + aggregationSignature));
ColumnHandle cubeColHandle = cubeColumnsMap.get(cubeColumnName);
if (!symbolAssignments.containsValue(cubeColHandle)) {
ColumnMetadata columnMetadata = columnMetadataMap.get(cubeColHandle.getColumnName());
Symbol cubeScanSymbol = symbolAllocator.newSymbol(cubeColHandle.getColumnName(), columnMetadata.getType());
cubeScanSymbols.add(cubeScanSymbol);
symbolMetadataMap.put(cubeScanSymbol, columnMetadata);
symbolAssignments.put(cubeScanSymbol, cubeColHandle);
rewrittenMappings.put(originalAggOutputSymbol.getName(), cubeScanSymbol);
aggregationColumns.add(new CubeRewriteResult.AggregatorSource(originalAggOutputSymbol, cubeScanSymbol));
}
}
} else if (mappedValue instanceof ColumnHandle) {
String originalColumnName = ((ColumnHandle) mappedValue).getColumnName();
boolean distinct = originalAggregationNode.getAggregations().get(originalAggOutputSymbol).isDistinct();
switch(aggFunction) {
case "min":
case "max":
case "sum":
case "count":
AggregationSignature aggregationSignature = new AggregationSignature(aggFunction, originalColumnName, distinct);
String cubeColumnName = cubeMetadata.getColumn(aggregationSignature).orElseThrow(() -> new PrestoException(CUBE_ERROR, "Cannot find column associated with aggregation " + aggregationSignature));
ColumnHandle cubeColHandle = cubeColumnsMap.get(cubeColumnName);
if (!symbolAssignments.containsValue(cubeColHandle)) {
ColumnMetadata columnMetadata = columnMetadataMap.get(cubeColHandle.getColumnName());
Symbol cubeScanSymbol = symbolAllocator.newSymbol(cubeColHandle.getColumnName(), columnMetadata.getType());
cubeScanSymbols.add(cubeScanSymbol);
symbolAssignments.put(cubeScanSymbol, cubeColHandle);
symbolMetadataMap.put(cubeScanSymbol, columnMetadata);
rewrittenMappings.put(originalAggOutputSymbol.getName(), cubeScanSymbol);
aggregationColumns.add(new CubeRewriteResult.AggregatorSource(originalAggOutputSymbol, cubeScanSymbol));
} else {
ColumnMetadata columnMetadata = columnMetadataMap.get(cubeColHandle.getColumnName());
Symbol cubeScanSymbol = symbolAssignments.keySet().stream().filter(key -> cubeColHandle.equals(symbolAssignments.get(key))).findFirst().get();
cubeScanSymbols.add(cubeScanSymbol);
symbolAssignments.put(cubeScanSymbol, cubeColHandle);
symbolMetadataMap.put(cubeScanSymbol, columnMetadata);
rewrittenMappings.put(originalAggOutputSymbol.getName(), cubeScanSymbol);
aggregationColumns.add(new CubeRewriteResult.AggregatorSource(originalAggOutputSymbol, cubeScanSymbol));
}
break;
case "avg":
AggregationSignature avgAggregationSignature = new AggregationSignature(aggFunction, originalColumnName, distinct);
if (exactGroupsMatch && cubeMetadata.getColumn(avgAggregationSignature).isPresent()) {
computeAvgDividingSumByCount = false;
String avgCubeColumnName = cubeMetadata.getColumn(avgAggregationSignature).orElseThrow(() -> new PrestoException(CUBE_ERROR, "Cannot find column associated with aggregation " + avgAggregationSignature));
ColumnHandle avgCubeColHandle = cubeColumnsMap.get(avgCubeColumnName);
if (!symbolAssignments.containsValue(avgCubeColHandle)) {
ColumnMetadata columnMetadata = columnMetadataMap.get(avgCubeColHandle.getColumnName());
Symbol cubeScanSymbol = symbolAllocator.newSymbol(avgCubeColHandle.getColumnName(), columnMetadata.getType());
cubeScanSymbols.add(cubeScanSymbol);
symbolAssignments.put(cubeScanSymbol, avgCubeColHandle);
symbolMetadataMap.put(cubeScanSymbol, columnMetadata);
rewrittenMappings.put(originalAggOutputSymbol.getName(), cubeScanSymbol);
aggregationColumns.add(new CubeRewriteResult.AggregatorSource(originalAggOutputSymbol, cubeScanSymbol));
}
} else {
AggregationSignature sumSignature = new AggregationSignature(SUM.getName(), originalColumnName, distinct);
String sumColumnName = cubeMetadata.getColumn(sumSignature).orElseThrow(() -> new PrestoException(CUBE_ERROR, "Cannot find column associated with aggregation " + sumSignature));
ColumnHandle sumColumnHandle = cubeColumnsMap.get(sumColumnName);
Symbol sumSymbol = null;
if (!symbolAssignments.containsValue(sumColumnHandle)) {
ColumnMetadata columnMetadata = columnMetadataMap.get(sumColumnHandle.getColumnName());
sumSymbol = symbolAllocator.newSymbol("sum_" + originalColumnName + "_" + originalAggOutputSymbol.getName(), columnMetadata.getType());
cubeScanSymbols.add(sumSymbol);
symbolAssignments.put(sumSymbol, sumColumnHandle);
symbolMetadataMap.put(sumSymbol, columnMetadata);
rewrittenMappings.put(sumSymbol.getName(), sumSymbol);
aggregationColumns.add(new CubeRewriteResult.AggregatorSource(sumSymbol, sumSymbol));
} else {
for (Map.Entry<Symbol, ColumnHandle> assignment : symbolAssignments.entrySet()) {
if (assignment.getValue().equals(sumColumnHandle)) {
sumSymbol = assignment.getKey();
break;
}
}
}
AggregationSignature countSignature = new AggregationSignature(COUNT.getName(), originalColumnName, distinct);
String countColumnName = cubeMetadata.getColumn(countSignature).orElseThrow(() -> new PrestoException(CUBE_ERROR, "Cannot find column associated with aggregation " + countSignature));
ColumnHandle countColumnHandle = cubeColumnsMap.get(countColumnName);
Symbol countSymbol = null;
if (!symbolAssignments.containsValue(countColumnHandle)) {
ColumnMetadata columnMetadata = columnMetadataMap.get(countColumnHandle.getColumnName());
countSymbol = symbolAllocator.newSymbol("count_" + originalColumnName + "_" + originalAggOutputSymbol.getName(), columnMetadata.getType());
cubeScanSymbols.add(countSymbol);
symbolAssignments.put(countSymbol, countColumnHandle);
symbolMetadataMap.put(countSymbol, columnMetadata);
rewrittenMappings.put(countSymbol.getName(), countSymbol);
aggregationColumns.add(new CubeRewriteResult.AggregatorSource(countSymbol, countSymbol));
} else {
for (Map.Entry<Symbol, ColumnHandle> assignment : symbolAssignments.entrySet()) {
if (assignment.getValue().equals(countColumnHandle)) {
countSymbol = assignment.getKey();
break;
}
}
}
averageAggregationColumns.add(new CubeRewriteResult.AverageAggregatorSource(originalAggOutputSymbol, sumSymbol, countSymbol));
}
break;
default:
throw new PrestoException(StandardErrorCode.GENERIC_INTERNAL_ERROR, "Unsupported aggregation function " + aggFunction);
}
} else {
log.info("Aggregation function argument is not a Column Handle. Agg Function = %s, Arguments = %s", aggFunction, arguments);
}
}
}
// Scan output order is important for partitioned cubes. Otherwise, incorrect results may be produced.
// Refer: https://gitee.com/openlookeng/hetu-core/issues/I4LAYC
List<Symbol> scanOutput = new ArrayList<>(cubeScanSymbols);
scanOutput.sort(Comparator.comparingInt(outSymbol -> cubeColumnsMetadata.indexOf(symbolMetadataMap.get(outSymbol))));
TableScanNode tableScanNode = TableScanNode.newInstance(idAllocator.getNextId(), cubeTableHandle, scanOutput, symbolAssignments, ReuseExchangeOperator.STRATEGY.REUSE_STRATEGY_DEFAULT, new UUID(0, 0), 0, false);
return new CubeRewriteResult(tableScanNode, symbolMetadataMap, dimensionSymbols, aggregationColumns, averageAggregationColumns, computeAvgDividingSumByCount);
}
use of io.prestosql.sql.relational.OriginalExpressionUtils in project hetu-core by openlookeng.
the class TranslateExpressions method createRewriter.
private static PlanRowExpressionRewriter createRewriter(Metadata metadata, SqlParser sqlParser) {
return new PlanRowExpressionRewriter() {
@Override
public RowExpression rewrite(RowExpression expression, Rule.Context context) {
// special treatment of the CallExpression in Aggregation
if (expression instanceof CallExpression && ((CallExpression) expression).getArguments().stream().anyMatch(OriginalExpressionUtils::isExpression)) {
return removeOriginalExpressionArguments((CallExpression) expression, context.getSession(), context.getSymbolAllocator(), context);
}
return removeOriginalExpression(expression, context, new HashMap<>());
}
private RowExpression removeOriginalExpressionArguments(CallExpression callExpression, Session session, PlanSymbolAllocator planSymbolAllocator, Rule.Context context) {
Map<NodeRef<Expression>, Type> types = analyzeCallExpressionTypes(callExpression, session, planSymbolAllocator.getTypes());
return new CallExpression(callExpression.getDisplayName(), callExpression.getFunctionHandle(), callExpression.getType(), callExpression.getArguments().stream().map(expression -> removeOriginalExpression(expression, session, types, context)).collect(toImmutableList()), Optional.empty());
}
private Map<NodeRef<Expression>, Type> analyzeCallExpressionTypes(CallExpression callExpression, Session session, TypeProvider typeProvider) {
List<LambdaExpression> lambdaExpressions = callExpression.getArguments().stream().filter(OriginalExpressionUtils::isExpression).map(OriginalExpressionUtils::castToExpression).filter(LambdaExpression.class::isInstance).map(LambdaExpression.class::cast).collect(toImmutableList());
ImmutableMap.Builder<NodeRef<Expression>, Type> builder = ImmutableMap.<NodeRef<Expression>, Type>builder();
TypeAnalyzer typeAnalyzer = new TypeAnalyzer(sqlParser, metadata);
if (!lambdaExpressions.isEmpty()) {
List<FunctionType> functionTypes = metadata.getFunctionAndTypeManager().getFunctionMetadata(callExpression.getFunctionHandle()).getArgumentTypes().stream().filter(typeSignature -> typeSignature.getBase().equals(FunctionType.NAME)).map(metadata::getType).map(FunctionType.class::cast).collect(toImmutableList());
InternalAggregationFunction internalAggregationFunction = metadata.getFunctionAndTypeManager().getAggregateFunctionImplementation(callExpression.getFunctionHandle());
List<Class<?>> lambdaInterfaces = internalAggregationFunction.getLambdaInterfaces();
verify(lambdaExpressions.size() == functionTypes.size());
verify(lambdaExpressions.size() == lambdaInterfaces.size());
for (int i = 0; i < lambdaExpressions.size(); i++) {
LambdaExpression lambdaExpression = lambdaExpressions.get(i);
FunctionType functionType = functionTypes.get(i);
// To compile lambda, LambdaDefinitionExpression needs to be generated from LambdaExpression,
// which requires the types of all sub-expressions.
//
// In project and filter expression compilation, ExpressionAnalyzer.getExpressionTypesFromInput
// is used to generate the types of all sub-expressions. (see visitScanFilterAndProject and visitFilter)
//
// This does not work here since the function call representation in final aggregation node
// is currently a hack: it takes intermediate type as input, and may not be a valid
// function call in Presto.
//
// TODO: Once the final aggregation function call representation is fixed,
// the same mechanism in project and filter expression should be used here.
verify(lambdaExpression.getArguments().size() == functionType.getArgumentTypes().size());
Map<NodeRef<Expression>, Type> lambdaArgumentExpressionTypes = new HashMap<>();
Map<Symbol, Type> lambdaArgumentSymbolTypes = new HashMap<>();
for (int j = 0; j < lambdaExpression.getArguments().size(); j++) {
LambdaArgumentDeclaration argument = lambdaExpression.getArguments().get(j);
Type type = functionType.getArgumentTypes().get(j);
lambdaArgumentExpressionTypes.put(NodeRef.of(argument), type);
lambdaArgumentSymbolTypes.put(new Symbol(argument.getName().getValue()), type);
}
// the lambda expression itself
builder.put(NodeRef.of(lambdaExpression), functionType).putAll(lambdaArgumentExpressionTypes).putAll(typeAnalyzer.getTypes(session, TypeProvider.copyOf(lambdaArgumentSymbolTypes), lambdaExpression.getBody()));
}
}
for (RowExpression argument : callExpression.getArguments()) {
if (!isExpression(argument) || castToExpression(argument) instanceof LambdaExpression) {
continue;
}
builder.putAll(typeAnalyzer.getTypes(session, typeProvider, castToExpression(argument)));
}
return builder.build();
}
private RowExpression toRowExpression(Expression expression, Map<NodeRef<Expression>, Type> types, Map<Symbol, Integer> layout, Session session) {
RowExpression rowExpression = SqlToRowExpressionTranslator.translate(expression, FunctionKind.SCALAR, types, layout, metadata.getFunctionAndTypeManager(), session, false);
return new RowExpressionOptimizer(metadata).optimize(rowExpression, RowExpressionInterpreter.Level.SERIALIZABLE, session.toConnectorSession());
}
private RowExpression removeOriginalExpression(RowExpression expression, Rule.Context context, Map<Symbol, Integer> layout) {
if (isExpression(expression)) {
TypeAnalyzer typeAnalyzer = new TypeAnalyzer(sqlParser, metadata);
return toRowExpression(castToExpression(expression), typeAnalyzer.getTypes(context.getSession(), context.getSymbolAllocator().getTypes(), castToExpression(expression)), layout, context.getSession());
}
return expression;
}
private RowExpression removeOriginalExpression(RowExpression rowExpression, Session session, Map<NodeRef<Expression>, Type> types, Rule.Context context) {
if (isExpression(rowExpression)) {
Expression expression = castToExpression(rowExpression);
return toRowExpression(expression, types, new HashMap<>(), session);
}
return rowExpression;
}
};
}
Aggregations