Search in sources :

Example 1 with OriginalExpressionUtils

use of com.facebook.presto.sql.relational.OriginalExpressionUtils in project presto by prestodb.

the class SingleDistinctAggregationToGroupBy method apply.

@Override
public Result apply(AggregationNode aggregation, Captures captures, Context context) {
    List<Set<RowExpression>> argumentSets = extractArgumentSets(aggregation).collect(Collectors.toList());
    Set<VariableReferenceExpression> variables = Iterables.getOnlyElement(argumentSets).stream().map(OriginalExpressionUtils::castToExpression).map(context.getVariableAllocator()::toVariableReference).collect(Collectors.toSet());
    return Result.ofPlanNode(new AggregationNode(aggregation.getSourceLocation(), aggregation.getId(), new AggregationNode(aggregation.getSourceLocation(), context.getIdAllocator().getNextId(), aggregation.getSource(), ImmutableMap.of(), singleGroupingSet(ImmutableList.<VariableReferenceExpression>builder().addAll(aggregation.getGroupingKeys()).addAll(variables).build()), ImmutableList.of(), SINGLE, Optional.empty(), Optional.empty()), // remove DISTINCT flag from function calls
    aggregation.getAggregations().entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> removeDistinct(e.getValue()))), aggregation.getGroupingSets(), emptyList(), aggregation.getStep(), aggregation.getHashVariable(), aggregation.getGroupIdVariable()));
}
Also used : RowExpression(com.facebook.presto.spi.relation.RowExpression) Iterables(com.google.common.collect.Iterables) AggregationNode(com.facebook.presto.spi.plan.AggregationNode) Patterns.aggregation(com.facebook.presto.sql.planner.plan.Patterns.aggregation) ImmutableMap(com.google.common.collect.ImmutableMap) OriginalExpressionUtils(com.facebook.presto.sql.relational.OriginalExpressionUtils) Collections.emptyList(java.util.Collections.emptyList) Rule(com.facebook.presto.sql.planner.iterative.Rule) Captures(com.facebook.presto.matching.Captures) Set(java.util.Set) VariableReferenceExpression(com.facebook.presto.spi.relation.VariableReferenceExpression) SINGLE(com.facebook.presto.spi.plan.AggregationNode.Step.SINGLE) Collectors(java.util.stream.Collectors) Pattern(com.facebook.presto.matching.Pattern) HashSet(java.util.HashSet) List(java.util.List) Preconditions.checkArgument(com.google.common.base.Preconditions.checkArgument) Stream(java.util.stream.Stream) ImmutableList(com.google.common.collect.ImmutableList) Map(java.util.Map) Aggregation(com.facebook.presto.spi.plan.AggregationNode.Aggregation) Optional(java.util.Optional) AggregationNode.singleGroupingSet(com.facebook.presto.spi.plan.AggregationNode.singleGroupingSet) Set(java.util.Set) HashSet(java.util.HashSet) AggregationNode.singleGroupingSet(com.facebook.presto.spi.plan.AggregationNode.singleGroupingSet) VariableReferenceExpression(com.facebook.presto.spi.relation.VariableReferenceExpression) AggregationNode(com.facebook.presto.spi.plan.AggregationNode) OriginalExpressionUtils(com.facebook.presto.sql.relational.OriginalExpressionUtils) ImmutableMap(com.google.common.collect.ImmutableMap) Map(java.util.Map)

Example 2 with OriginalExpressionUtils

use of com.facebook.presto.sql.relational.OriginalExpressionUtils in project presto by prestodb.

the class QueryPlanner method aggregate.

private PlanBuilder aggregate(PlanBuilder subPlan, QuerySpecification node) {
    if (!analysis.isAggregation(node)) {
        return subPlan;
    }
    // 1. Pre-project all scalar inputs (arguments and non-trivial group by expressions)
    Set<Expression> groupByExpressions = ImmutableSet.copyOf(analysis.getGroupByExpressions(node));
    ImmutableList.Builder<Expression> arguments = ImmutableList.builder();
    analysis.getAggregates(node).stream().map(FunctionCall::getArguments).flatMap(List::stream).filter(// lambda expression is generated at execution time
    exp -> !(exp instanceof LambdaExpression)).forEach(arguments::add);
    analysis.getAggregates(node).stream().map(FunctionCall::getOrderBy).filter(Optional::isPresent).map(Optional::get).map(OrderBy::getSortItems).flatMap(List::stream).map(SortItem::getSortKey).forEach(arguments::add);
    // filter expressions need to be projected first
    analysis.getAggregates(node).stream().map(FunctionCall::getFilter).filter(Optional::isPresent).map(Optional::get).forEach(arguments::add);
    Iterable<Expression> inputs = Iterables.concat(groupByExpressions, arguments.build());
    subPlan = handleSubqueries(subPlan, node, inputs);
    if (!Iterables.isEmpty(inputs)) {
        // avoid an empty projection if the only aggregation is COUNT (which has no arguments)
        subPlan = project(subPlan, inputs);
    }
    // 2. Aggregate
    // 2.a. Rewrite aggregate arguments
    TranslationMap argumentTranslations = new TranslationMap(subPlan.getRelationPlan(), analysis, lambdaDeclarationToVariableMap);
    ImmutableList.Builder<VariableReferenceExpression> aggregationArgumentsBuilder = ImmutableList.builder();
    for (Expression argument : arguments.build()) {
        VariableReferenceExpression variable = subPlan.translate(argument);
        argumentTranslations.put(argument, variable);
        aggregationArgumentsBuilder.add(variable);
    }
    List<VariableReferenceExpression> aggregationArguments = aggregationArgumentsBuilder.build();
    // 2.b. Rewrite grouping columns
    TranslationMap groupingTranslations = new TranslationMap(subPlan.getRelationPlan(), analysis, lambdaDeclarationToVariableMap);
    Map<VariableReferenceExpression, VariableReferenceExpression> groupingSetMappings = new LinkedHashMap<>();
    for (Expression expression : groupByExpressions) {
        VariableReferenceExpression input = subPlan.translate(expression);
        VariableReferenceExpression output = variableAllocator.newVariable(expression, analysis.getTypeWithCoercions(expression), "gid");
        groupingTranslations.put(expression, output);
        groupingSetMappings.put(output, input);
    }
    // This tracks the grouping sets before complex expressions are considered (see comments below)
    // It's also used to compute the descriptors needed to implement grouping()
    List<Set<FieldId>> columnOnlyGroupingSets = ImmutableList.of(ImmutableSet.of());
    List<List<VariableReferenceExpression>> groupingSets = ImmutableList.of(ImmutableList.of());
    if (node.getGroupBy().isPresent()) {
        // For the purpose of "distinct", we need to canonicalize column references that may have varying
        // syntactic forms (e.g., "t.a" vs "a"). Thus we need to enumerate grouping sets based on the underlying
        // fieldId associated with each column reference expression.
        // The catch is that simple group-by expressions can be arbitrary expressions (this is a departure from the SQL specification).
        // But, they don't affect the number of grouping sets or the behavior of "distinct" . We can compute all the candidate
        // grouping sets in terms of fieldId, dedup as appropriate and then cross-join them with the complex expressions.
        Analysis.GroupingSetAnalysis groupingSetAnalysis = analysis.getGroupingSets(node);
        columnOnlyGroupingSets = enumerateGroupingSets(groupingSetAnalysis);
        if (node.getGroupBy().get().isDistinct()) {
            columnOnlyGroupingSets = columnOnlyGroupingSets.stream().distinct().collect(toImmutableList());
        }
        // add in the complex expressions an turn materialize the grouping sets in terms of plan columns
        ImmutableList.Builder<List<VariableReferenceExpression>> groupingSetBuilder = ImmutableList.builder();
        for (Set<FieldId> groupingSet : columnOnlyGroupingSets) {
            ImmutableList.Builder<VariableReferenceExpression> columns = ImmutableList.builder();
            groupingSetAnalysis.getComplexExpressions().stream().map(groupingTranslations::get).forEach(columns::add);
            groupingSet.stream().map(field -> groupingTranslations.get(new FieldReference(field.getFieldIndex()))).forEach(columns::add);
            groupingSetBuilder.add(columns.build());
        }
        groupingSets = groupingSetBuilder.build();
    }
    // 2.c. Generate GroupIdNode (multiple grouping sets) or ProjectNode (single grouping set)
    Optional<VariableReferenceExpression> groupIdVariable = Optional.empty();
    if (groupingSets.size() > 1) {
        groupIdVariable = Optional.of(variableAllocator.newVariable("groupId", BIGINT));
        GroupIdNode groupId = new GroupIdNode(subPlan.getRoot().getSourceLocation(), idAllocator.getNextId(), subPlan.getRoot(), groupingSets, groupingSetMappings, aggregationArguments, groupIdVariable.get());
        subPlan = new PlanBuilder(groupingTranslations, groupId);
    } else {
        Assignments.Builder assignments = Assignments.builder();
        aggregationArguments.stream().map(AssignmentUtils::identityAsSymbolReference).forEach(assignments::put);
        groupingSetMappings.forEach((key, value) -> assignments.put(key, castToRowExpression(asSymbolReference(value))));
        ProjectNode project = new ProjectNode(subPlan.getRoot().getSourceLocation(), idAllocator.getNextId(), subPlan.getRoot(), assignments.build(), LOCAL);
        subPlan = new PlanBuilder(groupingTranslations, project);
    }
    TranslationMap aggregationTranslations = new TranslationMap(subPlan.getRelationPlan(), analysis, lambdaDeclarationToVariableMap);
    aggregationTranslations.copyMappingsFrom(groupingTranslations);
    // 2.d. Rewrite aggregates
    ImmutableMap.Builder<VariableReferenceExpression, Aggregation> aggregationsBuilder = ImmutableMap.builder();
    boolean needPostProjectionCoercion = false;
    for (FunctionCall aggregate : analysis.getAggregates(node)) {
        Expression rewritten = argumentTranslations.rewrite(aggregate);
        VariableReferenceExpression newVariable = variableAllocator.newVariable(rewritten, analysis.getType(aggregate));
        // Therefore we can end up with this implicit cast, and have to move it into a post-projection
        if (rewritten instanceof Cast) {
            rewritten = ((Cast) rewritten).getExpression();
            needPostProjectionCoercion = true;
        }
        aggregationTranslations.put(aggregate, newVariable);
        FunctionCall rewrittenFunction = (FunctionCall) rewritten;
        aggregationsBuilder.put(newVariable, new Aggregation(new CallExpression(getSourceLocation(rewrittenFunction), aggregate.getName().getSuffix(), analysis.getFunctionHandle(aggregate), analysis.getType(aggregate), rewrittenFunction.getArguments().stream().map(OriginalExpressionUtils::castToRowExpression).collect(toImmutableList())), rewrittenFunction.getFilter().map(OriginalExpressionUtils::castToRowExpression), rewrittenFunction.getOrderBy().map(orderBy -> toOrderingScheme(orderBy, variableAllocator.getTypes())), rewrittenFunction.isDistinct(), Optional.empty()));
    }
    Map<VariableReferenceExpression, Aggregation> aggregations = aggregationsBuilder.build();
    ImmutableSet.Builder<Integer> globalGroupingSets = ImmutableSet.builder();
    for (int i = 0; i < groupingSets.size(); i++) {
        if (groupingSets.get(i).isEmpty()) {
            globalGroupingSets.add(i);
        }
    }
    ImmutableList.Builder<VariableReferenceExpression> groupingKeys = ImmutableList.builder();
    groupingSets.stream().flatMap(List::stream).distinct().forEach(groupingKeys::add);
    groupIdVariable.ifPresent(groupingKeys::add);
    AggregationNode aggregationNode = new AggregationNode(subPlan.getRoot().getSourceLocation(), idAllocator.getNextId(), subPlan.getRoot(), aggregations, groupingSets(groupingKeys.build(), groupingSets.size(), globalGroupingSets.build()), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), groupIdVariable);
    subPlan = new PlanBuilder(aggregationTranslations, aggregationNode);
    // TODO: this is a hack, we should change type coercions to coerce the inputs to functions/operators instead of coercing the output
    if (needPostProjectionCoercion) {
        ImmutableList.Builder<Expression> alreadyCoerced = ImmutableList.builder();
        alreadyCoerced.addAll(groupByExpressions);
        groupIdVariable.map(ExpressionTreeUtils::createSymbolReference).ifPresent(alreadyCoerced::add);
        subPlan = explicitCoercionFields(subPlan, alreadyCoerced.build(), analysis.getAggregates(node));
    }
    // 4. Project and re-write all grouping functions
    return handleGroupingOperations(subPlan, node, groupIdVariable, columnOnlyGroupingSets);
}
Also used : FINAL(com.facebook.presto.spi.plan.LimitNode.Step.FINAL) AggregationNode(com.facebook.presto.spi.plan.AggregationNode) SortNode(com.facebook.presto.sql.planner.plan.SortNode) OriginalExpressionUtils(com.facebook.presto.sql.relational.OriginalExpressionUtils) FrameBound(com.facebook.presto.sql.tree.FrameBound) VariableReferenceExpression(com.facebook.presto.spi.relation.VariableReferenceExpression) Field(com.facebook.presto.sql.analyzer.Field) WindowNodeUtil.toBoundType(com.facebook.presto.sql.planner.optimizations.WindowNodeUtil.toBoundType) ValuesNode(com.facebook.presto.spi.plan.ValuesNode) Delete(com.facebook.presto.sql.tree.Delete) Map(java.util.Map) LOCAL(com.facebook.presto.spi.plan.ProjectNode.Locality.LOCAL) AggregationNode.singleGroupingSet(com.facebook.presto.spi.plan.AggregationNode.singleGroupingSet) CallExpression(com.facebook.presto.spi.relation.CallExpression) OrderingScheme(com.facebook.presto.spi.plan.OrderingScheme) FunctionCall(com.facebook.presto.sql.tree.FunctionCall) OffsetNode(com.facebook.presto.sql.planner.plan.OffsetNode) SymbolReference(com.facebook.presto.sql.tree.SymbolReference) AssignmentUtils.identitiesAsSymbolReferences(com.facebook.presto.sql.planner.plan.AssignmentUtils.identitiesAsSymbolReferences) RelationId(com.facebook.presto.sql.analyzer.RelationId) ImmutableSet(com.google.common.collect.ImmutableSet) Query(com.facebook.presto.sql.tree.Query) WindowNodeUtil.toWindowType(com.facebook.presto.sql.planner.optimizations.WindowNodeUtil.toWindowType) SortOrder(com.facebook.presto.common.block.SortOrder) QuerySpecification(com.facebook.presto.sql.tree.QuerySpecification) ImmutableMap(com.google.common.collect.ImmutableMap) LambdaExpression(com.facebook.presto.sql.tree.LambdaExpression) Ordering(com.facebook.presto.spi.plan.Ordering) ExpressionTreeUtils(com.facebook.presto.sql.analyzer.ExpressionTreeUtils) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) Node(com.facebook.presto.sql.tree.Node) Set(java.util.Set) SortItem(com.facebook.presto.sql.tree.SortItem) Sets(com.google.common.collect.Sets) LimitNode(com.facebook.presto.spi.plan.LimitNode) SystemSessionProperties.isSkipRedundantSort(com.facebook.presto.SystemSessionProperties.isSkipRedundantSort) List(java.util.List) Window(com.facebook.presto.sql.tree.Window) ProjectNode(com.facebook.presto.spi.plan.ProjectNode) ExpressionTreeUtils.getSourceLocation(com.facebook.presto.sql.analyzer.ExpressionTreeUtils.getSourceLocation) FieldId(com.facebook.presto.sql.analyzer.FieldId) Analysis(com.facebook.presto.sql.analyzer.Analysis) Optional(java.util.Optional) MoreObjects.firstNonNull(com.google.common.base.MoreObjects.firstNonNull) PlannerUtils.toOrderingScheme(com.facebook.presto.sql.planner.PlannerUtils.toOrderingScheme) IntStream(java.util.stream.IntStream) Iterables(com.google.common.collect.Iterables) LambdaArgumentDeclaration(com.facebook.presto.sql.tree.LambdaArgumentDeclaration) PlannerUtils.toSortOrder(com.facebook.presto.sql.planner.PlannerUtils.toSortOrder) GroupIdNode(com.facebook.presto.sql.planner.plan.GroupIdNode) Assignments(com.facebook.presto.spi.plan.Assignments) Expressions.call(com.facebook.presto.sql.relational.Expressions.call) ArrayList(java.util.ArrayList) LinkedHashMap(java.util.LinkedHashMap) WindowFrame(com.facebook.presto.sql.tree.WindowFrame) FilterNode(com.facebook.presto.spi.plan.FilterNode) AssignmentUtils(com.facebook.presto.sql.planner.plan.AssignmentUtils) ImmutableList(com.google.common.collect.ImmutableList) Objects.requireNonNull(java.util.Objects.requireNonNull) ImmutableSet.toImmutableSet(com.google.common.collect.ImmutableSet.toImmutableSet) TableHandle(com.facebook.presto.spi.TableHandle) Cast(com.facebook.presto.sql.tree.Cast) Type(com.facebook.presto.common.type.Type) RowExpression(com.facebook.presto.spi.relation.RowExpression) BIGINT(com.facebook.presto.common.type.BigintType.BIGINT) GroupingOperation(com.facebook.presto.sql.tree.GroupingOperation) OrderBy(com.facebook.presto.sql.tree.OrderBy) PlanNodeIdAllocator(com.facebook.presto.spi.plan.PlanNodeIdAllocator) WindowNode(com.facebook.presto.sql.planner.plan.WindowNode) Session(com.facebook.presto.Session) NodeLocation(com.facebook.presto.sql.tree.NodeLocation) NodeUtils.getSortItemsFromOrderBy(com.facebook.presto.sql.NodeUtils.getSortItemsFromOrderBy) RelationType(com.facebook.presto.sql.analyzer.RelationType) Offset(com.facebook.presto.sql.tree.Offset) VARBINARY(com.facebook.presto.common.type.VarbinaryType.VARBINARY) TupleDomain(com.facebook.presto.common.predicate.TupleDomain) DeleteNode(com.facebook.presto.sql.planner.plan.DeleteNode) NodeRef(com.facebook.presto.sql.tree.NodeRef) Streams.stream(com.google.common.collect.Streams.stream) Scope(com.facebook.presto.sql.analyzer.Scope) PlanNode(com.facebook.presto.spi.plan.PlanNode) AggregationNode.groupingSets(com.facebook.presto.spi.plan.AggregationNode.groupingSets) Expression(com.facebook.presto.sql.tree.Expression) ColumnHandle(com.facebook.presto.spi.ColumnHandle) TableScanNode(com.facebook.presto.spi.plan.TableScanNode) FieldReference(com.facebook.presto.sql.tree.FieldReference) Aggregation(com.facebook.presto.spi.plan.AggregationNode.Aggregation) OriginalExpressionUtils.asSymbolReference(com.facebook.presto.sql.relational.OriginalExpressionUtils.asSymbolReference) Metadata(com.facebook.presto.metadata.Metadata) OriginalExpressionUtils.castToRowExpression(com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression) Assignments(com.facebook.presto.spi.plan.Assignments) LinkedHashMap(java.util.LinkedHashMap) ImmutableSet(com.google.common.collect.ImmutableSet) ImmutableSet.toImmutableSet(com.google.common.collect.ImmutableSet.toImmutableSet) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) List(java.util.List) ArrayList(java.util.ArrayList) ImmutableList(com.google.common.collect.ImmutableList) Optional(java.util.Optional) ImmutableMap(com.google.common.collect.ImmutableMap) VariableReferenceExpression(com.facebook.presto.spi.relation.VariableReferenceExpression) Analysis(com.facebook.presto.sql.analyzer.Analysis) OriginalExpressionUtils(com.facebook.presto.sql.relational.OriginalExpressionUtils) LambdaExpression(com.facebook.presto.sql.tree.LambdaExpression) Cast(com.facebook.presto.sql.tree.Cast) AggregationNode.singleGroupingSet(com.facebook.presto.spi.plan.AggregationNode.singleGroupingSet) ImmutableSet(com.google.common.collect.ImmutableSet) Set(java.util.Set) ImmutableSet.toImmutableSet(com.google.common.collect.ImmutableSet.toImmutableSet) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) ImmutableList(com.google.common.collect.ImmutableList) Aggregation(com.facebook.presto.spi.plan.AggregationNode.Aggregation) GroupIdNode(com.facebook.presto.sql.planner.plan.GroupIdNode) FunctionCall(com.facebook.presto.sql.tree.FunctionCall) CallExpression(com.facebook.presto.spi.relation.CallExpression) FieldReference(com.facebook.presto.sql.tree.FieldReference) AggregationNode(com.facebook.presto.spi.plan.AggregationNode) VariableReferenceExpression(com.facebook.presto.spi.relation.VariableReferenceExpression) CallExpression(com.facebook.presto.spi.relation.CallExpression) LambdaExpression(com.facebook.presto.sql.tree.LambdaExpression) RowExpression(com.facebook.presto.spi.relation.RowExpression) Expression(com.facebook.presto.sql.tree.Expression) OriginalExpressionUtils.castToRowExpression(com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression) FieldId(com.facebook.presto.sql.analyzer.FieldId) ProjectNode(com.facebook.presto.spi.plan.ProjectNode)

Example 3 with OriginalExpressionUtils

use of com.facebook.presto.sql.relational.OriginalExpressionUtils in project presto by prestodb.

the class MultipleDistinctAggregationToMarkDistinct method apply.

@Override
public Result apply(AggregationNode parent, Captures captures, Context context) {
    if (!SystemSessionProperties.useMarkDistinct(context.getSession())) {
        return Result.empty();
    }
    // the distinct marker for the given set of input columns
    Map<Set<VariableReferenceExpression>, VariableReferenceExpression> markers = new HashMap<>();
    Map<VariableReferenceExpression, Aggregation> newAggregations = new HashMap<>();
    PlanNode subPlan = parent.getSource();
    for (Map.Entry<VariableReferenceExpression, Aggregation> entry : parent.getAggregations().entrySet()) {
        Aggregation aggregation = entry.getValue();
        if (aggregation.isDistinct() && !aggregation.getFilter().isPresent() && !aggregation.getMask().isPresent()) {
            Set<VariableReferenceExpression> inputs = aggregation.getArguments().stream().map(OriginalExpressionUtils::castToExpression).map(context.getVariableAllocator()::toVariableReference).collect(toSet());
            VariableReferenceExpression marker = markers.get(inputs);
            if (marker == null) {
                marker = context.getVariableAllocator().newVariable(Iterables.getLast(inputs).getName(), BOOLEAN, "distinct");
                markers.put(inputs, marker);
                ImmutableSet.Builder<VariableReferenceExpression> distinctVariables = ImmutableSet.<VariableReferenceExpression>builder().addAll(parent.getGroupingKeys()).addAll(inputs);
                parent.getGroupIdVariable().ifPresent(distinctVariables::add);
                subPlan = new MarkDistinctNode(subPlan.getSourceLocation(), context.getIdAllocator().getNextId(), subPlan, marker, ImmutableList.copyOf(distinctVariables.build()), Optional.empty());
            }
            // remove the distinct flag and set the distinct marker
            newAggregations.put(entry.getKey(), new Aggregation(aggregation.getCall(), aggregation.getFilter(), aggregation.getOrderBy(), false, Optional.of(marker)));
        } else {
            newAggregations.put(entry.getKey(), aggregation);
        }
    }
    return Result.ofPlanNode(new AggregationNode(parent.getSourceLocation(), parent.getId(), subPlan, newAggregations, parent.getGroupingSets(), ImmutableList.of(), parent.getStep(), parent.getHashVariable(), parent.getGroupIdVariable()));
}
Also used : ImmutableSet(com.google.common.collect.ImmutableSet) Set(java.util.Set) HashSet(java.util.HashSet) Collectors.toSet(java.util.stream.Collectors.toSet) MarkDistinctNode(com.facebook.presto.spi.plan.MarkDistinctNode) HashMap(java.util.HashMap) AggregationNode(com.facebook.presto.spi.plan.AggregationNode) Aggregation(com.facebook.presto.spi.plan.AggregationNode.Aggregation) PlanNode(com.facebook.presto.spi.plan.PlanNode) ImmutableSet(com.google.common.collect.ImmutableSet) VariableReferenceExpression(com.facebook.presto.spi.relation.VariableReferenceExpression) OriginalExpressionUtils(com.facebook.presto.sql.relational.OriginalExpressionUtils) HashMap(java.util.HashMap) Map(java.util.Map)

Example 4 with OriginalExpressionUtils

use of com.facebook.presto.sql.relational.OriginalExpressionUtils in project presto by prestodb.

the class InlineProjections method apply.

@Override
public Result apply(ProjectNode parent, Captures captures, Context context) {
    ProjectNode child = captures.get(CHILD);
    // Do not inline remote projections, or if parent and child has different locality
    if (parent.getLocality().equals(REMOTE) || child.getLocality().equals(REMOTE) || !parent.getLocality().equals(child.getLocality())) {
        return Result.empty();
    }
    Sets.SetView<VariableReferenceExpression> targets = extractInliningTargets(parent, child, context);
    if (targets.isEmpty()) {
        return Result.empty();
    }
    // inline the expressions
    Assignments assignments = child.getAssignments().filter(targets::contains);
    Map<VariableReferenceExpression, RowExpression> parentAssignments = parent.getAssignments().entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> inlineReferences(entry.getValue(), assignments, context.getVariableAllocator().getTypes())));
    // Synthesize identity assignments for the inputs of expressions that were inlined
    // to place in the child projection.
    // If all assignments end up becoming identity assignments, they'll get pruned by
    // other rules
    Set<VariableReferenceExpression> inputs = child.getAssignments().entrySet().stream().filter(entry -> targets.contains(entry.getKey())).map(Map.Entry::getValue).flatMap(expression -> extractInputs(expression, context.getVariableAllocator().getTypes()).stream()).collect(toSet());
    Builder childAssignments = Assignments.builder();
    for (Map.Entry<VariableReferenceExpression, RowExpression> assignment : child.getAssignments().entrySet()) {
        if (!targets.contains(assignment.getKey())) {
            childAssignments.put(assignment);
        }
    }
    boolean allTranslated = child.getAssignments().entrySet().stream().map(Map.Entry::getValue).noneMatch(OriginalExpressionUtils::isExpression);
    for (VariableReferenceExpression input : inputs) {
        if (allTranslated) {
            childAssignments.put(input, input);
        } else {
            childAssignments.put(identityAsSymbolReference(input));
        }
    }
    return Result.ofPlanNode(new ProjectNode(parent.getSourceLocation(), parent.getId(), new ProjectNode(parent.getSourceLocation(), child.getId(), child.getSource(), childAssignments.build(), child.getLocality()), Assignments.copyOf(parentAssignments), parent.getLocality()));
}
Also used : FunctionAndTypeManager(com.facebook.presto.metadata.FunctionAndTypeManager) Builder(com.facebook.presto.spi.plan.Assignments.Builder) OriginalExpressionUtils(com.facebook.presto.sql.relational.OriginalExpressionUtils) OriginalExpressionUtils.isExpression(com.facebook.presto.sql.relational.OriginalExpressionUtils.isExpression) Captures(com.facebook.presto.matching.Captures) Assignments(com.facebook.presto.spi.plan.Assignments) VariableReferenceExpression(com.facebook.presto.spi.relation.VariableReferenceExpression) ConstantExpression(com.facebook.presto.spi.relation.ConstantExpression) AstUtils(com.facebook.presto.sql.util.AstUtils) Function(java.util.function.Function) Pattern(com.facebook.presto.matching.Pattern) OriginalExpressionUtils.castToExpression(com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression) Capture(com.facebook.presto.matching.Capture) Literal(com.facebook.presto.sql.tree.Literal) TypeProvider(com.facebook.presto.sql.planner.TypeProvider) TryExpression(com.facebook.presto.sql.tree.TryExpression) Map(java.util.Map) AssignmentUtils.isIdentity(com.facebook.presto.sql.planner.plan.AssignmentUtils.isIdentity) FunctionResolution(com.facebook.presto.sql.relational.FunctionResolution) CallExpression(com.facebook.presto.spi.relation.CallExpression) Collectors.toSet(java.util.stream.Collectors.toSet) VariablesExtractor(com.facebook.presto.sql.planner.VariablesExtractor) RowExpression(com.facebook.presto.spi.relation.RowExpression) ImmutableSet(com.google.common.collect.ImmutableSet) RowExpressionVariableInliner(com.facebook.presto.sql.planner.RowExpressionVariableInliner) Rule(com.facebook.presto.sql.planner.iterative.Rule) Patterns.project(com.facebook.presto.sql.planner.plan.Patterns.project) Set(java.util.Set) Collectors(java.util.stream.Collectors) ExpressionVariableInliner(com.facebook.presto.sql.planner.ExpressionVariableInliner) Sets(com.google.common.collect.Sets) Patterns.source(com.facebook.presto.sql.planner.plan.Patterns.source) List(java.util.List) REMOTE(com.facebook.presto.spi.plan.ProjectNode.Locality.REMOTE) Expression(com.facebook.presto.sql.tree.Expression) DefaultRowExpressionTraversalVisitor(com.facebook.presto.expressions.DefaultRowExpressionTraversalVisitor) ProjectNode(com.facebook.presto.spi.plan.ProjectNode) Capture.newCapture(com.facebook.presto.matching.Capture.newCapture) AssignmentUtils.identityAsSymbolReference(com.facebook.presto.sql.planner.plan.AssignmentUtils.identityAsSymbolReference) ExpressionTreeUtils.createSymbolReference(com.facebook.presto.sql.analyzer.ExpressionTreeUtils.createSymbolReference) OriginalExpressionUtils.castToRowExpression(com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression) Builder(com.facebook.presto.spi.plan.Assignments.Builder) Assignments(com.facebook.presto.spi.plan.Assignments) RowExpression(com.facebook.presto.spi.relation.RowExpression) OriginalExpressionUtils.castToRowExpression(com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression) Sets(com.google.common.collect.Sets) VariableReferenceExpression(com.facebook.presto.spi.relation.VariableReferenceExpression) ProjectNode(com.facebook.presto.spi.plan.ProjectNode) OriginalExpressionUtils(com.facebook.presto.sql.relational.OriginalExpressionUtils) Map(java.util.Map)

Example 5 with OriginalExpressionUtils

use of com.facebook.presto.sql.relational.OriginalExpressionUtils in project presto by prestodb.

the class TranslateExpressions method createRewriter.

private static PlanRowExpressionRewriter createRewriter(Metadata metadata, SqlParser sqlParser) {
    return new PlanRowExpressionRewriter() {

        @Override
        public RowExpression rewrite(RowExpression expression, Rule.Context context) {
            // special treatment of the CallExpression in Aggregation
            if (expression instanceof CallExpression && ((CallExpression) expression).getArguments().stream().anyMatch(OriginalExpressionUtils::isExpression)) {
                return removeOriginalExpressionArguments((CallExpression) expression, context.getSession(), context.getVariableAllocator());
            }
            return removeOriginalExpression(expression, context);
        }

        private RowExpression removeOriginalExpressionArguments(CallExpression callExpression, Session session, PlanVariableAllocator variableAllocator) {
            Map<NodeRef<Expression>, Type> types = analyzeCallExpressionTypes(callExpression, session, variableAllocator.getTypes());
            return new CallExpression(callExpression.getSourceLocation(), callExpression.getDisplayName(), callExpression.getFunctionHandle(), callExpression.getType(), callExpression.getArguments().stream().map(expression -> removeOriginalExpression(expression, session, types)).collect(toImmutableList()));
        }

        private Map<NodeRef<Expression>, Type> analyzeCallExpressionTypes(CallExpression callExpression, Session session, TypeProvider typeProvider) {
            List<LambdaExpression> lambdaExpressions = callExpression.getArguments().stream().filter(OriginalExpressionUtils::isExpression).map(OriginalExpressionUtils::castToExpression).filter(LambdaExpression.class::isInstance).map(LambdaExpression.class::cast).collect(toImmutableList());
            ImmutableMap.Builder<NodeRef<Expression>, Type> builder = ImmutableMap.<NodeRef<Expression>, Type>builder();
            if (!lambdaExpressions.isEmpty()) {
                List<FunctionType> functionTypes = metadata.getFunctionAndTypeManager().getFunctionMetadata(callExpression.getFunctionHandle()).getArgumentTypes().stream().filter(typeSignature -> typeSignature.getBase().equals(FunctionType.NAME)).map(typeSignature -> (FunctionType) (metadata.getFunctionAndTypeManager().getType(typeSignature))).collect(toImmutableList());
                InternalAggregationFunction internalAggregationFunction = metadata.getFunctionAndTypeManager().getAggregateFunctionImplementation(callExpression.getFunctionHandle());
                List<Class> lambdaInterfaces = internalAggregationFunction.getLambdaInterfaces();
                verify(lambdaExpressions.size() == functionTypes.size());
                verify(lambdaExpressions.size() == lambdaInterfaces.size());
                for (int i = 0; i < lambdaExpressions.size(); i++) {
                    LambdaExpression lambdaExpression = lambdaExpressions.get(i);
                    FunctionType functionType = functionTypes.get(i);
                    // To compile lambda, LambdaDefinitionExpression needs to be generated from LambdaExpression,
                    // which requires the types of all sub-expressions.
                    // 
                    // In project and filter expression compilation, ExpressionAnalyzer.getExpressionTypesFromInput
                    // is used to generate the types of all sub-expressions. (see visitScanFilterAndProject and visitFilter)
                    // 
                    // This does not work here since the function call representation in final aggregation node
                    // is currently a hack: it takes intermediate type as input, and may not be a valid
                    // function call in Presto.
                    // 
                    // TODO: Once the final aggregation function call representation is fixed,
                    // the same mechanism in project and filter expression should be used here.
                    verify(lambdaExpression.getArguments().size() == functionType.getArgumentTypes().size());
                    Map<NodeRef<Expression>, Type> lambdaArgumentExpressionTypes = new HashMap<>();
                    Map<String, Type> lambdaArgumentSymbolTypes = new HashMap<>();
                    for (int j = 0; j < lambdaExpression.getArguments().size(); j++) {
                        LambdaArgumentDeclaration argument = lambdaExpression.getArguments().get(j);
                        Type type = functionType.getArgumentTypes().get(j);
                        lambdaArgumentExpressionTypes.put(NodeRef.of(argument), type);
                        lambdaArgumentSymbolTypes.put(argument.getName().getValue(), type);
                    }
                    // the lambda expression itself
                    builder.put(NodeRef.of(lambdaExpression), functionType).putAll(lambdaArgumentExpressionTypes).putAll(getExpressionTypes(session, metadata, sqlParser, TypeProvider.copyOf(lambdaArgumentSymbolTypes), lambdaExpression.getBody(), emptyList(), NOOP));
                }
            }
            for (RowExpression argument : callExpression.getArguments()) {
                if (!isExpression(argument) || castToExpression(argument) instanceof LambdaExpression) {
                    continue;
                }
                builder.putAll(analyze(castToExpression(argument), session, typeProvider));
            }
            return builder.build();
        }

        private Map<NodeRef<Expression>, Type> analyze(Expression expression, Session session, TypeProvider typeProvider) {
            return getExpressionTypes(session, metadata, sqlParser, typeProvider, expression, emptyList(), NOOP);
        }

        private RowExpression toRowExpression(Expression expression, Session session, Map<NodeRef<Expression>, Type> types) {
            return SqlToRowExpressionTranslator.translate(expression, types, ImmutableMap.of(), metadata.getFunctionAndTypeManager(), session);
        }

        private RowExpression removeOriginalExpression(RowExpression expression, Rule.Context context) {
            if (isExpression(expression)) {
                return toRowExpression(castToExpression(expression), context.getSession(), analyze(castToExpression(expression), context.getSession(), context.getVariableAllocator().getTypes()));
            }
            return expression;
        }

        private RowExpression removeOriginalExpression(RowExpression rowExpression, Session session, Map<NodeRef<Expression>, Type> types) {
            if (isExpression(rowExpression)) {
                Expression expression = castToExpression(rowExpression);
                return toRowExpression(expression, session, types);
            }
            return rowExpression;
        }
    };
}
Also used : LambdaArgumentDeclaration(com.facebook.presto.sql.tree.LambdaArgumentDeclaration) OriginalExpressionUtils(com.facebook.presto.sql.relational.OriginalExpressionUtils) SqlToRowExpressionTranslator(com.facebook.presto.sql.relational.SqlToRowExpressionTranslator) OriginalExpressionUtils.isExpression(com.facebook.presto.sql.relational.OriginalExpressionUtils.isExpression) HashMap(java.util.HashMap) OriginalExpressionUtils.castToExpression(com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression) Verify.verify(com.google.common.base.Verify.verify) TypeProvider(com.facebook.presto.sql.planner.TypeProvider) Map(java.util.Map) CallExpression(com.facebook.presto.spi.relation.CallExpression) Type(com.facebook.presto.common.type.Type) RowExpression(com.facebook.presto.spi.relation.RowExpression) ImmutableMap(com.google.common.collect.ImmutableMap) LambdaExpression(com.facebook.presto.sql.tree.LambdaExpression) Collections.emptyList(java.util.Collections.emptyList) Session(com.facebook.presto.Session) Rule(com.facebook.presto.sql.planner.iterative.Rule) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) InternalAggregationFunction(com.facebook.presto.operator.aggregation.InternalAggregationFunction) FunctionType(com.facebook.presto.common.type.FunctionType) SqlParser(com.facebook.presto.sql.parser.SqlParser) NodeRef(com.facebook.presto.sql.tree.NodeRef) List(java.util.List) Expression(com.facebook.presto.sql.tree.Expression) PlanVariableAllocator(com.facebook.presto.sql.planner.PlanVariableAllocator) NOOP(com.facebook.presto.spi.WarningCollector.NOOP) ExpressionAnalyzer.getExpressionTypes(com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypes) Metadata(com.facebook.presto.metadata.Metadata) HashMap(java.util.HashMap) InternalAggregationFunction(com.facebook.presto.operator.aggregation.InternalAggregationFunction) NodeRef(com.facebook.presto.sql.tree.NodeRef) LambdaArgumentDeclaration(com.facebook.presto.sql.tree.LambdaArgumentDeclaration) PlanVariableAllocator(com.facebook.presto.sql.planner.PlanVariableAllocator) CallExpression(com.facebook.presto.spi.relation.CallExpression) FunctionType(com.facebook.presto.common.type.FunctionType) RowExpression(com.facebook.presto.spi.relation.RowExpression) TypeProvider(com.facebook.presto.sql.planner.TypeProvider) ImmutableMap(com.google.common.collect.ImmutableMap) Type(com.facebook.presto.common.type.Type) FunctionType(com.facebook.presto.common.type.FunctionType) OriginalExpressionUtils.isExpression(com.facebook.presto.sql.relational.OriginalExpressionUtils.isExpression) OriginalExpressionUtils.castToExpression(com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression) CallExpression(com.facebook.presto.spi.relation.CallExpression) RowExpression(com.facebook.presto.spi.relation.RowExpression) LambdaExpression(com.facebook.presto.sql.tree.LambdaExpression) Expression(com.facebook.presto.sql.tree.Expression) OriginalExpressionUtils(com.facebook.presto.sql.relational.OriginalExpressionUtils) LambdaExpression(com.facebook.presto.sql.tree.LambdaExpression) HashMap(java.util.HashMap) Map(java.util.Map) ImmutableMap(com.google.common.collect.ImmutableMap) Session(com.facebook.presto.Session)

Aggregations

OriginalExpressionUtils (com.facebook.presto.sql.relational.OriginalExpressionUtils)6 Map (java.util.Map)5 RowExpression (com.facebook.presto.spi.relation.RowExpression)4 VariableReferenceExpression (com.facebook.presto.spi.relation.VariableReferenceExpression)4 ImmutableMap (com.google.common.collect.ImmutableMap)4 List (java.util.List)4 Captures (com.facebook.presto.matching.Captures)3 Pattern (com.facebook.presto.matching.Pattern)3 AggregationNode (com.facebook.presto.spi.plan.AggregationNode)3 Aggregation (com.facebook.presto.spi.plan.AggregationNode.Aggregation)3 PlanNode (com.facebook.presto.spi.plan.PlanNode)3 Rule (com.facebook.presto.sql.planner.iterative.Rule)3 Set (java.util.Set)3 Session (com.facebook.presto.Session)2 Type (com.facebook.presto.common.type.Type)2 Metadata (com.facebook.presto.metadata.Metadata)2 AggregationNode.singleGroupingSet (com.facebook.presto.spi.plan.AggregationNode.singleGroupingSet)2 Assignments (com.facebook.presto.spi.plan.Assignments)2 ProjectNode (com.facebook.presto.spi.plan.ProjectNode)2 CallExpression (com.facebook.presto.spi.relation.CallExpression)2