Search in sources :

Example 26 with AggregationNode

use of io.prestosql.spi.plan.AggregationNode in project hetu-core by openlookeng.

the class TransformExistsApplyToLateralNode method rewriteToDefaultAggregation.

private PlanNode rewriteToDefaultAggregation(ApplyNode parent, Context context) {
    Symbol count = context.getSymbolAllocator().newSymbol(COUNT.toString(), BIGINT);
    Symbol exists = getOnlyElement(parent.getSubqueryAssignments().getSymbols());
    return new LateralJoinNode(parent.getId(), parent.getInput(), new ProjectNode(context.getIdAllocator().getNextId(), new AggregationNode(context.getIdAllocator().getNextId(), parent.getSubquery(), ImmutableMap.of(count, new Aggregation(new CallExpression("count", functionResolution.countFunction(), BIGINT, ImmutableList.of(), Optional.empty()), ImmutableList.of(), false, Optional.empty(), Optional.empty(), Optional.empty())), globalAggregation(), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty(), AggregationNode.AggregationType.HASH, Optional.empty()), Assignments.of(exists, castToRowExpression(new ComparisonExpression(GREATER_THAN, toSymbolReference(count), new Cast(new LongLiteral("0"), BIGINT.toString()))))), parent.getCorrelation(), INNER, TRUE_LITERAL, parent.getOriginSubquery());
}
Also used : AggregationNode.globalAggregation(io.prestosql.spi.plan.AggregationNode.globalAggregation) Aggregation(io.prestosql.spi.plan.AggregationNode.Aggregation) Cast(io.prestosql.sql.tree.Cast) ComparisonExpression(io.prestosql.sql.tree.ComparisonExpression) LateralJoinNode(io.prestosql.sql.planner.plan.LateralJoinNode) LongLiteral(io.prestosql.sql.tree.LongLiteral) Symbol(io.prestosql.spi.plan.Symbol) ProjectNode(io.prestosql.spi.plan.ProjectNode) AggregationNode(io.prestosql.spi.plan.AggregationNode) CallExpression(io.prestosql.spi.relation.CallExpression)

Example 27 with AggregationNode

use of io.prestosql.spi.plan.AggregationNode in project hetu-core by openlookeng.

the class TransformFilteringSemiJoinToInnerJoin method apply.

@Override
public Result apply(FilterNode filterNode, Captures captures, Context context) {
    SemiJoinNode semiJoin = captures.get(SEMI_JOIN);
    // Do no transform semi-join in context of DELETE
    if (PlanNodeSearcher.searchFrom(semiJoin.getSource(), context.getLookup()).where(node -> node instanceof TableScanNode && ((TableScanNode) node).isForDelete()).matches()) {
        return Result.empty();
    }
    Symbol semiJoinSymbol = semiJoin.getSemiJoinOutput();
    TypeProvider types = context.getSymbolAllocator().getTypes();
    Predicate<RowExpression> isSemiJoinSymbol = expression -> expression.equals(toVariableReference(semiJoinSymbol, types));
    LogicalRowExpressions logicalRowExpressions = new LogicalRowExpressions(new RowExpressionDeterminismEvaluator(metadata), new FunctionResolution(metadata.getFunctionAndTypeManager()), metadata.getFunctionAndTypeManager());
    List<RowExpression> conjuncts = LogicalRowExpressions.extractConjuncts(filterNode.getPredicate());
    if (conjuncts.stream().noneMatch(isSemiJoinSymbol)) {
        return Result.empty();
    }
    RowExpression filteredPredicate = LogicalRowExpressions.and(conjuncts.stream().filter(expression -> !expression.equals(toVariableReference(semiJoinSymbol, types))).collect(toImmutableList()));
    RowExpression simplifiedPredicate = inlineVariables(variable -> {
        if (variable.equals(toVariableReference(semiJoinSymbol, types))) {
            return TRUE_CONSTANT;
        }
        return variable;
    }, filteredPredicate);
    Optional<RowExpression> joinFilter = simplifiedPredicate.equals(TRUE_CONSTANT) ? Optional.empty() : Optional.of(simplifiedPredicate);
    PlanNode filteringSourceDistinct = new AggregationNode(context.getIdAllocator().getNextId(), semiJoin.getFilteringSource(), ImmutableMap.of(), singleGroupingSet(ImmutableList.of(semiJoin.getFilteringSourceJoinSymbol())), ImmutableList.of(), SINGLE, Optional.empty(), Optional.empty(), AggregationNode.AggregationType.HASH, Optional.empty());
    JoinNode innerJoin = new JoinNode(semiJoin.getId(), INNER, semiJoin.getSource(), filteringSourceDistinct, ImmutableList.of(new EquiJoinClause(semiJoin.getSourceJoinSymbol(), semiJoin.getFilteringSourceJoinSymbol())), semiJoin.getSource().getOutputSymbols(), joinFilter.isPresent() ? Optional.of(joinFilter.get()) : Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), // TODO: dynamic filter from SemiJoinNode
    ImmutableMap.of());
    ProjectNode project = new ProjectNode(context.getIdAllocator().getNextId(), innerJoin, Assignments.builder().putAll(AssignmentUtils.identityAsSymbolReferences(innerJoin.getOutputSymbols())).put(semiJoinSymbol, TRUE_CONSTANT).build());
    return Result.ofPlanNode(project);
}
Also used : TRUE_CONSTANT(io.prestosql.expressions.LogicalRowExpressions.TRUE_CONSTANT) Patterns.source(io.prestosql.sql.planner.plan.Patterns.source) EquiJoinClause(io.prestosql.spi.plan.JoinNode.EquiJoinClause) LogicalRowExpressions(io.prestosql.expressions.LogicalRowExpressions) TypeProvider(io.prestosql.sql.planner.TypeProvider) Pattern(io.prestosql.matching.Pattern) SemiJoinNode(io.prestosql.sql.planner.plan.SemiJoinNode) AggregationNode(io.prestosql.spi.plan.AggregationNode) SystemSessionProperties.isRewriteFilteringSemiJoinToInnerJoin(io.prestosql.SystemSessionProperties.isRewriteFilteringSemiJoinToInnerJoin) Capture.newCapture(io.prestosql.matching.Capture.newCapture) ImmutableList(com.google.common.collect.ImmutableList) FilterNode(io.prestosql.spi.plan.FilterNode) SINGLE(io.prestosql.spi.plan.AggregationNode.Step.SINGLE) RowExpressionVariableInliner.inlineVariables(io.prestosql.sql.planner.RowExpressionVariableInliner.inlineVariables) Session(io.prestosql.Session) JoinNode(io.prestosql.spi.plan.JoinNode) AggregationNode.singleGroupingSet(io.prestosql.spi.plan.AggregationNode.singleGroupingSet) Symbol(io.prestosql.spi.plan.Symbol) AssignmentUtils(io.prestosql.sql.planner.plan.AssignmentUtils) RowExpressionDeterminismEvaluator(io.prestosql.sql.relational.RowExpressionDeterminismEvaluator) ImmutableMap(com.google.common.collect.ImmutableMap) Assignments(io.prestosql.spi.plan.Assignments) Rule(io.prestosql.sql.planner.iterative.Rule) Predicate(java.util.function.Predicate) Patterns.filter(io.prestosql.sql.planner.plan.Patterns.filter) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) TableScanNode(io.prestosql.spi.plan.TableScanNode) Patterns.semiJoin(io.prestosql.sql.planner.plan.Patterns.semiJoin) PlanNode(io.prestosql.spi.plan.PlanNode) ProjectNode(io.prestosql.spi.plan.ProjectNode) VariableReferenceSymbolConverter.toVariableReference(io.prestosql.sql.planner.VariableReferenceSymbolConverter.toVariableReference) Metadata(io.prestosql.metadata.Metadata) Captures(io.prestosql.matching.Captures) List(java.util.List) FunctionResolution(io.prestosql.sql.relational.FunctionResolution) PlanNodeSearcher(io.prestosql.sql.planner.optimizations.PlanNodeSearcher) Capture(io.prestosql.matching.Capture) RowExpression(io.prestosql.spi.relation.RowExpression) INNER(io.prestosql.spi.plan.JoinNode.Type.INNER) Optional(java.util.Optional) RowExpressionDeterminismEvaluator(io.prestosql.sql.relational.RowExpressionDeterminismEvaluator) Symbol(io.prestosql.spi.plan.Symbol) SemiJoinNode(io.prestosql.sql.planner.plan.SemiJoinNode) JoinNode(io.prestosql.spi.plan.JoinNode) EquiJoinClause(io.prestosql.spi.plan.JoinNode.EquiJoinClause) LogicalRowExpressions(io.prestosql.expressions.LogicalRowExpressions) TypeProvider(io.prestosql.sql.planner.TypeProvider) RowExpression(io.prestosql.spi.relation.RowExpression) AggregationNode(io.prestosql.spi.plan.AggregationNode) FunctionResolution(io.prestosql.sql.relational.FunctionResolution) SemiJoinNode(io.prestosql.sql.planner.plan.SemiJoinNode) PlanNode(io.prestosql.spi.plan.PlanNode) TableScanNode(io.prestosql.spi.plan.TableScanNode) ProjectNode(io.prestosql.spi.plan.ProjectNode)

Example 28 with AggregationNode

use of io.prestosql.spi.plan.AggregationNode in project hetu-core by openlookeng.

the class PushPartialAggregationThroughExchange method pushPartial.

private PlanNode pushPartial(AggregationNode aggregation, ExchangeNode exchange, Context context) {
    List<PlanNode> partials = new ArrayList<>();
    for (int i = 0; i < exchange.getSources().size(); i++) {
        PlanNode source = exchange.getSources().get(i);
        SymbolMapper.Builder mappingsBuilder = SymbolMapper.builder();
        for (int outputIndex = 0; outputIndex < exchange.getOutputSymbols().size(); outputIndex++) {
            Symbol output = exchange.getOutputSymbols().get(outputIndex);
            Symbol input = exchange.getInputs().get(i).get(outputIndex);
            if (!output.equals(input)) {
                mappingsBuilder.put(output, input);
            }
        }
        SymbolMapper symbolMapper = mappingsBuilder.build();
        if (symbolMapper.getTypes() == null) {
            symbolMapper.setTypes(context.getSymbolAllocator().getTypes());
        }
        AggregationNode mappedPartial = symbolMapper.map(aggregation, source, context.getIdAllocator());
        Assignments.Builder assignments = Assignments.builder();
        for (Symbol output : aggregation.getOutputSymbols()) {
            Symbol input = symbolMapper.map(output);
            assignments.put(output, VariableReferenceSymbolConverter.toVariableReference(input, context.getSymbolAllocator().getTypes()));
        }
        partials.add(new ProjectNode(context.getIdAllocator().getNextId(), mappedPartial, assignments.build()));
    }
    for (PlanNode node : partials) {
        verify(aggregation.getOutputSymbols().equals(node.getOutputSymbols()));
    }
    // Since this exchange source is now guaranteed to have the same symbols as the inputs to the the partial
    // aggregation, we don't need to rewrite symbols in the partitioning function
    PartitioningScheme partitioning = new PartitioningScheme(exchange.getPartitioningScheme().getPartitioning(), aggregation.getOutputSymbols(), exchange.getPartitioningScheme().getHashColumn(), exchange.getPartitioningScheme().isReplicateNullsAndAny(), exchange.getPartitioningScheme().getBucketToPartition());
    return new ExchangeNode(context.getIdAllocator().getNextId(), exchange.getType(), exchange.getScope(), partitioning, partials, ImmutableList.copyOf(Collections.nCopies(partials.size(), aggregation.getOutputSymbols())), Optional.empty(), aggregation.getAggregationType());
}
Also used : SymbolMapper(io.prestosql.sql.planner.optimizations.SymbolMapper) ExchangeNode(io.prestosql.sql.planner.plan.ExchangeNode) Symbol(io.prestosql.spi.plan.Symbol) PartitioningScheme(io.prestosql.sql.planner.PartitioningScheme) ArrayList(java.util.ArrayList) Assignments(io.prestosql.spi.plan.Assignments) AggregationNode(io.prestosql.spi.plan.AggregationNode) PlanNode(io.prestosql.spi.plan.PlanNode) ProjectNode(io.prestosql.spi.plan.ProjectNode)

Example 29 with AggregationNode

use of io.prestosql.spi.plan.AggregationNode in project hetu-core by openlookeng.

the class PushPartialAggregationThroughJoin method pushPartialToRightChild.

private PlanNode pushPartialToRightChild(AggregationNode node, JoinNode child, Context context) {
    Set<Symbol> joinRightChildSymbols = ImmutableSet.copyOf(child.getRight().getOutputSymbols());
    List<Symbol> groupingSet = getPushedDownGroupingSet(node, joinRightChildSymbols, intersection(getJoinRequiredSymbols(child), joinRightChildSymbols));
    AggregationNode pushedAggregation = replaceAggregationSource(node, child.getRight(), groupingSet);
    return pushPartialToJoin(node, child, child.getLeft(), pushedAggregation, context);
}
Also used : Symbol(io.prestosql.spi.plan.Symbol) AggregationNode(io.prestosql.spi.plan.AggregationNode)

Example 30 with AggregationNode

use of io.prestosql.spi.plan.AggregationNode in project hetu-core by openlookeng.

the class PushPartialAggregationThroughJoin method pushPartialToLeftChild.

private PlanNode pushPartialToLeftChild(AggregationNode node, JoinNode child, Context context) {
    Set<Symbol> joinLeftChildSymbols = ImmutableSet.copyOf(child.getLeft().getOutputSymbols());
    List<Symbol> groupingSet = getPushedDownGroupingSet(node, joinLeftChildSymbols, intersection(getJoinRequiredSymbols(child), joinLeftChildSymbols));
    AggregationNode pushedAggregation = replaceAggregationSource(node, child.getLeft(), groupingSet);
    return pushPartialToJoin(node, child, pushedAggregation, child.getRight(), context);
}
Also used : Symbol(io.prestosql.spi.plan.Symbol) AggregationNode(io.prestosql.spi.plan.AggregationNode)

Aggregations

AggregationNode (io.prestosql.spi.plan.AggregationNode)52 Symbol (io.prestosql.spi.plan.Symbol)40 PlanNode (io.prestosql.spi.plan.PlanNode)30 ProjectNode (io.prestosql.spi.plan.ProjectNode)24 Map (java.util.Map)23 CallExpression (io.prestosql.spi.relation.CallExpression)20 ImmutableMap (com.google.common.collect.ImmutableMap)17 Assignments (io.prestosql.spi.plan.Assignments)16 Aggregation (io.prestosql.spi.plan.AggregationNode.Aggregation)14 RowExpression (io.prestosql.spi.relation.RowExpression)14 Expression (io.prestosql.sql.tree.Expression)14 List (java.util.List)14 ImmutableList (com.google.common.collect.ImmutableList)13 TableScanNode (io.prestosql.spi.plan.TableScanNode)13 Optional (java.util.Optional)13 OriginalExpressionUtils.castToRowExpression (io.prestosql.sql.relational.OriginalExpressionUtils.castToRowExpression)12 Test (org.testng.annotations.Test)12 Session (io.prestosql.Session)11 FilterNode (io.prestosql.spi.plan.FilterNode)11 Metadata (io.prestosql.metadata.Metadata)9