Search in sources :

Example 21 with AggregationNode

use of io.prestosql.spi.plan.AggregationNode in project hetu-core by openlookeng.

the class PushAggregationThroughOuterJoin method coalesceWithNullAggregation.

// When the aggregation is done after the join, there will be a null value that gets aggregated over
// where rows did not exist in the inner table.  For some aggregate functions, such as count, the result
// of an aggregation over a single null row is one or zero rather than null. In order to ensure correct results,
// we add a coalesce function with the output of the new outer join and the agggregation performed over a single
// null row.
private Optional<PlanNode> coalesceWithNullAggregation(AggregationNode aggregationNode, PlanNode outerJoin, PlanSymbolAllocator planSymbolAllocator, PlanNodeIdAllocator idAllocator, Lookup lookup) {
    // Create an aggregation node over a row of nulls.
    Optional<MappedAggregationInfo> aggregationOverNullInfoResultNode = createAggregationOverNull(aggregationNode, planSymbolAllocator, idAllocator, lookup);
    if (!aggregationOverNullInfoResultNode.isPresent()) {
        return Optional.empty();
    }
    MappedAggregationInfo aggregationOverNullInfo = aggregationOverNullInfoResultNode.get();
    AggregationNode aggregationOverNull = aggregationOverNullInfo.getAggregation();
    Map<Symbol, Symbol> sourceAggregationToOverNullMapping = aggregationOverNullInfo.getSymbolMapping();
    // Do a cross join with the aggregation over null
    JoinNode crossJoin = new JoinNode(idAllocator.getNextId(), JoinNode.Type.INNER, outerJoin, aggregationOverNull, ImmutableList.of(), ImmutableList.<Symbol>builder().addAll(outerJoin.getOutputSymbols()).addAll(aggregationOverNull.getOutputSymbols()).build(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), ImmutableMap.of());
    // Add coalesce expressions for all aggregation functions
    Assignments.Builder assignmentsBuilder = Assignments.builder();
    for (Symbol symbol : outerJoin.getOutputSymbols()) {
        if (aggregationNode.getAggregations().containsKey(symbol)) {
            assignmentsBuilder.put(symbol, new SpecialForm(COALESCE, planSymbolAllocator.getTypes().get(symbol), ImmutableList.of(toVariableReference(symbol, planSymbolAllocator.getTypes()), toVariableReference(sourceAggregationToOverNullMapping.get(symbol), planSymbolAllocator.getTypes()))));
        } else {
            assignmentsBuilder.put(symbol, toVariableReference(symbol, planSymbolAllocator.getTypes()));
        }
    }
    return Optional.of(new ProjectNode(idAllocator.getNextId(), crossJoin, assignmentsBuilder.build()));
}
Also used : Symbol(io.prestosql.spi.plan.Symbol) JoinNode(io.prestosql.spi.plan.JoinNode) Assignments(io.prestosql.spi.plan.Assignments) AggregationNode(io.prestosql.spi.plan.AggregationNode) ProjectNode(io.prestosql.spi.plan.ProjectNode) SpecialForm(io.prestosql.spi.relation.SpecialForm)

Example 22 with AggregationNode

use of io.prestosql.spi.plan.AggregationNode in project hetu-core by openlookeng.

the class PushAggregationThroughOuterJoin method apply.

@Override
public Result apply(AggregationNode aggregation, Captures captures, Context context) {
    JoinNode join = captures.get(JOIN);
    if (join.getFilter().isPresent() || !(join.getType() == JoinNode.Type.LEFT || join.getType() == JoinNode.Type.RIGHT) || !groupsOnAllColumns(aggregation, getOuterTable(join).getOutputSymbols()) || !isDistinct(context.getLookup().resolve(getOuterTable(join)), context.getLookup()::resolve)) {
        return Result.empty();
    }
    List<Symbol> groupingKeys = join.getCriteria().stream().map(join.getType() == JoinNode.Type.RIGHT ? JoinNode.EquiJoinClause::getLeft : JoinNode.EquiJoinClause::getRight).collect(toImmutableList());
    AggregationNode rewrittenAggregation = new AggregationNode(aggregation.getId(), getInnerTable(join), aggregation.getAggregations(), singleGroupingSet(groupingKeys), ImmutableList.of(), aggregation.getStep(), aggregation.getHashSymbol(), aggregation.getGroupIdSymbol(), aggregation.getAggregationType(), aggregation.getFinalizeSymbol());
    JoinNode rewrittenJoin;
    if (join.getType() == JoinNode.Type.LEFT) {
        rewrittenJoin = new JoinNode(join.getId(), join.getType(), join.getLeft(), rewrittenAggregation, join.getCriteria(), ImmutableList.<Symbol>builder().addAll(join.getLeft().getOutputSymbols()).addAll(rewrittenAggregation.getAggregations().keySet()).build(), join.getFilter(), join.getLeftHashSymbol(), join.getRightHashSymbol(), join.getDistributionType(), join.isSpillable(), join.getDynamicFilters());
    } else {
        rewrittenJoin = new JoinNode(join.getId(), join.getType(), rewrittenAggregation, join.getRight(), join.getCriteria(), ImmutableList.<Symbol>builder().addAll(rewrittenAggregation.getAggregations().keySet()).addAll(join.getRight().getOutputSymbols()).build(), join.getFilter(), join.getLeftHashSymbol(), join.getRightHashSymbol(), join.getDistributionType(), join.isSpillable(), join.getDynamicFilters());
    }
    Optional<PlanNode> resultNode = coalesceWithNullAggregation(rewrittenAggregation, rewrittenJoin, context.getSymbolAllocator(), context.getIdAllocator(), context.getLookup());
    if (!resultNode.isPresent()) {
        return Result.empty();
    }
    return Result.ofPlanNode(resultNode.get());
}
Also used : PlanNode(io.prestosql.spi.plan.PlanNode) JoinNode(io.prestosql.spi.plan.JoinNode) Symbol(io.prestosql.spi.plan.Symbol) AggregationNode(io.prestosql.spi.plan.AggregationNode)

Example 23 with AggregationNode

use of io.prestosql.spi.plan.AggregationNode in project hetu-core by openlookeng.

the class SetOperationNodeTranslator method computeCounts.

private AggregationNode computeCounts(UnionNode sourceNode, List<Symbol> originalColumns, List<Symbol> markers, List<Symbol> aggregationOutputs) {
    ImmutableMap.Builder<Symbol, AggregationNode.Aggregation> aggregations = ImmutableMap.builder();
    for (int i = 0; i < markers.size(); i++) {
        Symbol output = aggregationOutputs.get(i);
        aggregations.put(output, new AggregationNode.Aggregation(new CallExpression(COUNT_AGGREGATION_NAME.getObjectName(), metadata.getFunctionAndTypeManager().lookupFunction(COUNT_AGGREGATION_NAME.getObjectName(), fromTypes(BOOLEAN)), BIGINT, ImmutableList.of(castToRowExpression(toSymbolReference(markers.get(i)))), Optional.empty()), ImmutableList.of(castToRowExpression(toSymbolReference(markers.get(i)))), false, Optional.empty(), Optional.empty(), Optional.empty()));
    }
    return new AggregationNode(idAllocator.getNextId(), sourceNode, aggregations.build(), singleGroupingSet(originalColumns), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty(), AggregationNode.AggregationType.HASH, Optional.empty());
}
Also used : Symbol(io.prestosql.spi.plan.Symbol) AggregationNode(io.prestosql.spi.plan.AggregationNode) CallExpression(io.prestosql.spi.relation.CallExpression) ImmutableMap(com.google.common.collect.ImmutableMap)

Example 24 with AggregationNode

use of io.prestosql.spi.plan.AggregationNode in project hetu-core by openlookeng.

the class SetOperationNodeTranslator method makeSetContainmentPlan.

public TranslationResult makeSetContainmentPlan(SetOperationNode node) {
    checkArgument(!(node instanceof UnionNode), "Cannot simplify a UnionNode");
    List<Symbol> markers = allocateSymbols(node.getSources().size(), MARKER, BOOLEAN);
    // identity projection for all the fields in each of the sources plus marker columns
    List<PlanNode> withMarkers = appendMarkers(markers, node.getSources(), node);
    // add a union over all the rewritten sources. The outputs of the union have the same name as the
    // original intersect node
    List<Symbol> outputs = node.getOutputSymbols();
    UnionNode union = union(withMarkers, ImmutableList.copyOf(concat(outputs, markers)));
    // add count aggregations and filter rows where any of the counts is >= 1
    List<Symbol> aggregationOutputs = allocateSymbols(markers.size(), "count", BIGINT);
    AggregationNode aggregation = computeCounts(union, outputs, markers, aggregationOutputs);
    List<Expression> presentExpression = aggregationOutputs.stream().map(symbol -> new ComparisonExpression(GREATER_THAN_OR_EQUAL, toSymbolReference(symbol), GENERIC_LITERAL)).collect(toImmutableList());
    return new TranslationResult(aggregation, presentExpression);
}
Also used : GREATER_THAN_OR_EQUAL(io.prestosql.sql.tree.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL) StandardTypes(io.prestosql.spi.type.StandardTypes) Literal(io.prestosql.sql.tree.Literal) DEFAULT_NAMESPACE(io.prestosql.spi.connector.CatalogSchemaName.DEFAULT_NAMESPACE) TRUE_LITERAL(io.prestosql.sql.tree.BooleanLiteral.TRUE_LITERAL) QualifiedObjectName(io.prestosql.spi.connector.QualifiedObjectName) AggregationNode(io.prestosql.spi.plan.AggregationNode) CallExpression(io.prestosql.spi.relation.CallExpression) Cast(io.prestosql.sql.tree.Cast) Preconditions.checkArgument(com.google.common.base.Preconditions.checkArgument) ImmutableList(com.google.common.collect.ImmutableList) Iterables.concat(com.google.common.collect.Iterables.concat) Map(java.util.Map) OriginalExpressionUtils.castToRowExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToRowExpression) Objects.requireNonNull(java.util.Objects.requireNonNull) BOOLEAN(io.prestosql.spi.type.BooleanType.BOOLEAN) Type(io.prestosql.spi.type.Type) BIGINT(io.prestosql.spi.type.BigintType.BIGINT) AggregationNode.singleGroupingSet(io.prestosql.spi.plan.AggregationNode.singleGroupingSet) Symbol(io.prestosql.spi.plan.Symbol) SetOperationNodeUtils.sourceSymbolMap(io.prestosql.sql.planner.optimizations.SetOperationNodeUtils.sourceSymbolMap) TypeSignatureProvider.fromTypes(io.prestosql.sql.analyzer.TypeSignatureProvider.fromTypes) ImmutableMap(com.google.common.collect.ImmutableMap) Assignments(io.prestosql.spi.plan.Assignments) SetOperationNode(io.prestosql.spi.plan.SetOperationNode) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) ComparisonExpression(io.prestosql.sql.tree.ComparisonExpression) NullLiteral(io.prestosql.sql.tree.NullLiteral) PlanNode(io.prestosql.spi.plan.PlanNode) ProjectNode(io.prestosql.spi.plan.ProjectNode) Metadata(io.prestosql.metadata.Metadata) PlanSymbolAllocator(io.prestosql.sql.planner.PlanSymbolAllocator) SymbolUtils.toSymbolReference(io.prestosql.sql.planner.SymbolUtils.toSymbolReference) List(java.util.List) GenericLiteral(io.prestosql.sql.tree.GenericLiteral) SymbolReference(io.prestosql.sql.tree.SymbolReference) ImmutableListMultimap(com.google.common.collect.ImmutableListMultimap) PlanNodeIdAllocator(io.prestosql.spi.plan.PlanNodeIdAllocator) UnionNode(io.prestosql.spi.plan.UnionNode) Optional(java.util.Optional) Expression(io.prestosql.sql.tree.Expression) ComparisonExpression(io.prestosql.sql.tree.ComparisonExpression) PlanNode(io.prestosql.spi.plan.PlanNode) UnionNode(io.prestosql.spi.plan.UnionNode) CallExpression(io.prestosql.spi.relation.CallExpression) OriginalExpressionUtils.castToRowExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToRowExpression) ComparisonExpression(io.prestosql.sql.tree.ComparisonExpression) Expression(io.prestosql.sql.tree.Expression) Symbol(io.prestosql.spi.plan.Symbol) AggregationNode(io.prestosql.spi.plan.AggregationNode)

Example 25 with AggregationNode

use of io.prestosql.spi.plan.AggregationNode in project hetu-core by openlookeng.

the class TransformCorrelatedInPredicateToJoin method buildInPredicateEquivalent.

private PlanNode buildInPredicateEquivalent(ApplyNode apply, InPredicate inPredicate, Symbol inPredicateOutputSymbol, Decorrelated decorrelated, PlanNodeIdAllocator idAllocator, PlanSymbolAllocator planSymbolAllocator) {
    Expression correlationCondition = and(decorrelated.getCorrelatedPredicates());
    PlanNode decorrelatedBuildSource = decorrelated.getDecorrelatedNode();
    AssignUniqueId probeSide = new AssignUniqueId(idAllocator.getNextId(), apply.getInput(), planSymbolAllocator.newSymbol("unique", BIGINT));
    Symbol buildSideKnownNonNull = planSymbolAllocator.newSymbol("buildSideKnownNonNull", BIGINT);
    ProjectNode buildSide = new ProjectNode(idAllocator.getNextId(), decorrelatedBuildSource, Assignments.builder().putAll(identityAsSymbolReferences(decorrelatedBuildSource.getOutputSymbols())).put(buildSideKnownNonNull, castToRowExpression(bigint(0))).build());
    Symbol probeSideSymbol = SymbolUtils.from(inPredicate.getValue());
    Symbol buildSideSymbol = SymbolUtils.from(inPredicate.getValueList());
    Expression joinExpression = and(or(new IsNullPredicate(toSymbolReference(probeSideSymbol)), new ComparisonExpression(ComparisonExpression.Operator.EQUAL, toSymbolReference(probeSideSymbol), toSymbolReference(buildSideSymbol)), new IsNullPredicate(toSymbolReference(buildSideSymbol))), correlationCondition);
    JoinNode leftOuterJoin = leftOuterJoin(idAllocator, probeSide, buildSide, joinExpression);
    Symbol matchConditionSymbol = planSymbolAllocator.newSymbol("matchConditionSymbol", BOOLEAN);
    Expression matchCondition = and(isNotNull(probeSideSymbol), isNotNull(buildSideSymbol));
    Symbol nullMatchConditionSymbol = planSymbolAllocator.newSymbol("nullMatchConditionSymbol", BOOLEAN);
    Expression nullMatchCondition = and(isNotNull(buildSideKnownNonNull), not(matchCondition));
    ProjectNode preProjection = new ProjectNode(idAllocator.getNextId(), leftOuterJoin, Assignments.builder().putAll(AssignmentUtils.identityAsSymbolReferences(leftOuterJoin.getOutputSymbols())).put(matchConditionSymbol, castToRowExpression(matchCondition)).put(nullMatchConditionSymbol, castToRowExpression(nullMatchCondition)).build());
    Symbol countMatchesSymbol = planSymbolAllocator.newSymbol("countMatches", BIGINT);
    Symbol countNullMatchesSymbol = planSymbolAllocator.newSymbol("countNullMatches", BIGINT);
    AggregationNode aggregation = new AggregationNode(idAllocator.getNextId(), preProjection, ImmutableMap.<Symbol, AggregationNode.Aggregation>builder().put(countMatchesSymbol, countWithFilter(matchConditionSymbol)).put(countNullMatchesSymbol, countWithFilter(nullMatchConditionSymbol)).build(), singleGroupingSet(probeSide.getOutputSymbols()), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty(), AggregationNode.AggregationType.HASH, Optional.empty());
    // TODO since we care only about "some count > 0", we could have specialized node instead of leftOuterJoin that does the job without materializing join results
    SearchedCaseExpression inPredicateEquivalent = new SearchedCaseExpression(ImmutableList.of(new WhenClause(isGreaterThan(countMatchesSymbol, 0), booleanConstant(true)), new WhenClause(isGreaterThan(countNullMatchesSymbol, 0), booleanConstant(null))), Optional.of(booleanConstant(false)));
    return new ProjectNode(idAllocator.getNextId(), aggregation, Assignments.builder().putAll(identityAsSymbolReferences(apply.getInput().getOutputSymbols())).put(inPredicateOutputSymbol, castToRowExpression(inPredicateEquivalent)).build());
}
Also used : ComparisonExpression(io.prestosql.sql.tree.ComparisonExpression) WhenClause(io.prestosql.sql.tree.WhenClause) PlanNode(io.prestosql.spi.plan.PlanNode) AssignUniqueId(io.prestosql.sql.planner.plan.AssignUniqueId) SearchedCaseExpression(io.prestosql.sql.tree.SearchedCaseExpression) CallExpression(io.prestosql.spi.relation.CallExpression) OriginalExpressionUtils.castToRowExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToRowExpression) OriginalExpressionUtils.castToExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToExpression) SearchedCaseExpression(io.prestosql.sql.tree.SearchedCaseExpression) NotExpression(io.prestosql.sql.tree.NotExpression) ComparisonExpression(io.prestosql.sql.tree.ComparisonExpression) Expression(io.prestosql.sql.tree.Expression) Symbol(io.prestosql.spi.plan.Symbol) JoinNode(io.prestosql.spi.plan.JoinNode) IsNullPredicate(io.prestosql.sql.tree.IsNullPredicate) ProjectNode(io.prestosql.spi.plan.ProjectNode) AggregationNode(io.prestosql.spi.plan.AggregationNode)

Aggregations

AggregationNode (io.prestosql.spi.plan.AggregationNode)52 Symbol (io.prestosql.spi.plan.Symbol)40 PlanNode (io.prestosql.spi.plan.PlanNode)30 ProjectNode (io.prestosql.spi.plan.ProjectNode)24 Map (java.util.Map)23 CallExpression (io.prestosql.spi.relation.CallExpression)20 ImmutableMap (com.google.common.collect.ImmutableMap)17 Assignments (io.prestosql.spi.plan.Assignments)16 Aggregation (io.prestosql.spi.plan.AggregationNode.Aggregation)14 RowExpression (io.prestosql.spi.relation.RowExpression)14 Expression (io.prestosql.sql.tree.Expression)14 List (java.util.List)14 ImmutableList (com.google.common.collect.ImmutableList)13 TableScanNode (io.prestosql.spi.plan.TableScanNode)13 Optional (java.util.Optional)13 OriginalExpressionUtils.castToRowExpression (io.prestosql.sql.relational.OriginalExpressionUtils.castToRowExpression)12 Test (org.testng.annotations.Test)12 Session (io.prestosql.Session)11 FilterNode (io.prestosql.spi.plan.FilterNode)11 Metadata (io.prestosql.metadata.Metadata)9