Search in sources :

Example 46 with AggregationNode

use of io.prestosql.spi.plan.AggregationNode in project hetu-core by openlookeng.

the class TestTypeValidator method testValidAggregation.

@Test
public void testValidAggregation() {
    Symbol aggregationSymbol = planSymbolAllocator.newSymbol("sum", DOUBLE);
    PlanNode node = new AggregationNode(newId(), baseTableScan, ImmutableMap.of(aggregationSymbol, new Aggregation(new CallExpression("sum", SUM, DOUBLE, ImmutableList.of(castToRowExpression(toSymbolReference(columnC)))), ImmutableList.of(castToRowExpression(toSymbolReference(columnC))), false, Optional.empty(), Optional.empty(), Optional.empty())), singleGroupingSet(ImmutableList.of(columnA, columnB)), ImmutableList.of(), SINGLE, Optional.empty(), Optional.empty(), AggregationNode.AggregationType.HASH, Optional.empty());
    assertTypesValid(node);
}
Also used : Aggregation(io.prestosql.spi.plan.AggregationNode.Aggregation) PlanNode(io.prestosql.spi.plan.PlanNode) Symbol(io.prestosql.spi.plan.Symbol) AggregationNode(io.prestosql.spi.plan.AggregationNode) CallExpression(io.prestosql.spi.relation.CallExpression) Test(org.testng.annotations.Test)

Example 47 with AggregationNode

use of io.prestosql.spi.plan.AggregationNode in project hetu-core by openlookeng.

the class TestTypeValidator method testInvalidAggregationFunctionCall.

@Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "type of symbol 'sum(_[0-9]+)?' is expected to be double, but the actual type is bigint")
public void testInvalidAggregationFunctionCall() {
    Symbol aggregationSymbol = planSymbolAllocator.newSymbol("sum", DOUBLE);
    PlanNode node = new AggregationNode(newId(), baseTableScan, ImmutableMap.of(aggregationSymbol, new Aggregation(new CallExpression("sum", SUM, DOUBLE, ImmutableList.of(castToRowExpression(toSymbolReference(columnA)))), ImmutableList.of(castToRowExpression(toSymbolReference(columnA))), false, Optional.empty(), Optional.empty(), Optional.empty())), singleGroupingSet(ImmutableList.of(columnA, columnB)), ImmutableList.of(), SINGLE, Optional.empty(), Optional.empty(), AggregationNode.AggregationType.HASH, Optional.empty());
    assertTypesValid(node);
}
Also used : Aggregation(io.prestosql.spi.plan.AggregationNode.Aggregation) PlanNode(io.prestosql.spi.plan.PlanNode) Symbol(io.prestosql.spi.plan.Symbol) AggregationNode(io.prestosql.spi.plan.AggregationNode) CallExpression(io.prestosql.spi.relation.CallExpression) Test(org.testng.annotations.Test)

Example 48 with AggregationNode

use of io.prestosql.spi.plan.AggregationNode in project hetu-core by openlookeng.

the class AggregationFunctionMatcher method getAssignedSymbol.

@Override
public Optional<Symbol> getAssignedSymbol(PlanNode node, Session session, Metadata metadata, SymbolAliases symbolAliases) {
    Optional<Symbol> result = Optional.empty();
    if (!(node instanceof AggregationNode)) {
        return result;
    }
    AggregationNode aggregationNode = (AggregationNode) node;
    FunctionCall expectedCall = callMaker.getExpectedValue(symbolAliases);
    for (Map.Entry<Symbol, Aggregation> assignment : aggregationNode.getAggregations().entrySet()) {
        Aggregation aggregation = assignment.getValue();
        if (aggregationMatches(aggregation, expectedCall, metadata)) {
            checkState(!result.isPresent(), "Ambiguous function calls in %s", aggregationNode);
            result = Optional.of(assignment.getKey());
        }
    }
    return result;
}
Also used : Aggregation(io.prestosql.spi.plan.AggregationNode.Aggregation) Symbol(io.prestosql.spi.plan.Symbol) AggregationNode(io.prestosql.spi.plan.AggregationNode) FunctionCall(io.prestosql.sql.tree.FunctionCall) Map(java.util.Map)

Example 49 with AggregationNode

use of io.prestosql.spi.plan.AggregationNode in project hetu-core by openlookeng.

the class AggregationMatcher method detailMatches.

@Override
public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session session, Metadata metadata, SymbolAliases symbolAliases) {
    checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName());
    AggregationNode aggregationNode = (AggregationNode) node;
    if (groupId.isPresent() != aggregationNode.getGroupIdSymbol().isPresent()) {
        return NO_MATCH;
    }
    if (!matches(groupingSets.getGroupingKeys(), aggregationNode.getGroupingKeys(), symbolAliases)) {
        return NO_MATCH;
    }
    if (groupingSets.getGroupingSetCount() != aggregationNode.getGroupingSetCount()) {
        return NO_MATCH;
    }
    if (!groupingSets.getGlobalGroupingSets().equals(aggregationNode.getGlobalGroupingSets())) {
        return NO_MATCH;
    }
    List<Symbol> aggregationsWithMask = aggregationNode.getAggregations().entrySet().stream().filter(entry -> entry.getValue().isDistinct()).map(entry -> entry.getKey()).collect(Collectors.toList());
    if (aggregationsWithMask.size() != masks.keySet().size()) {
        return NO_MATCH;
    }
    for (Symbol symbol : aggregationsWithMask) {
        if (!masks.keySet().contains(symbol)) {
            return NO_MATCH;
        }
    }
    if (step != aggregationNode.getStep()) {
        return NO_MATCH;
    }
    if (!matches(preGroupedSymbols, aggregationNode.getPreGroupedSymbols(), symbolAliases)) {
        return NO_MATCH;
    }
    return match();
}
Also used : Symbol(io.prestosql.spi.plan.Symbol) Step(io.prestosql.spi.plan.AggregationNode.Step) StatsProvider(io.prestosql.cost.StatsProvider) Collection(java.util.Collection) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) MatchResult.match(io.prestosql.sql.planner.assertions.MatchResult.match) PlanNode(io.prestosql.spi.plan.PlanNode) Collectors(java.util.stream.Collectors) Metadata(io.prestosql.metadata.Metadata) Preconditions.checkState(com.google.common.base.Preconditions.checkState) AggregationNode(io.prestosql.spi.plan.AggregationNode) List(java.util.List) Map(java.util.Map) Session(io.prestosql.Session) Optional(java.util.Optional) NO_MATCH(io.prestosql.sql.planner.assertions.MatchResult.NO_MATCH) MoreObjects.toStringHelper(com.google.common.base.MoreObjects.toStringHelper) Symbol(io.prestosql.spi.plan.Symbol) AggregationNode(io.prestosql.spi.plan.AggregationNode)

Example 50 with AggregationNode

use of io.prestosql.spi.plan.AggregationNode in project hetu-core by openlookeng.

the class AggregationStepMatcher method detailMatches.

@Override
public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session session, Metadata metadata, SymbolAliases symbolAliases) {
    checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName());
    AggregationNode aggregationNode = (AggregationNode) node;
    if (aggregationNode.getStep() != step) {
        return NO_MATCH;
    }
    return match();
}
Also used : AggregationNode(io.prestosql.spi.plan.AggregationNode)

Aggregations

AggregationNode (io.prestosql.spi.plan.AggregationNode)52 Symbol (io.prestosql.spi.plan.Symbol)40 PlanNode (io.prestosql.spi.plan.PlanNode)30 ProjectNode (io.prestosql.spi.plan.ProjectNode)24 Map (java.util.Map)23 CallExpression (io.prestosql.spi.relation.CallExpression)20 ImmutableMap (com.google.common.collect.ImmutableMap)17 Assignments (io.prestosql.spi.plan.Assignments)16 Aggregation (io.prestosql.spi.plan.AggregationNode.Aggregation)14 RowExpression (io.prestosql.spi.relation.RowExpression)14 Expression (io.prestosql.sql.tree.Expression)14 List (java.util.List)14 ImmutableList (com.google.common.collect.ImmutableList)13 TableScanNode (io.prestosql.spi.plan.TableScanNode)13 Optional (java.util.Optional)13 OriginalExpressionUtils.castToRowExpression (io.prestosql.sql.relational.OriginalExpressionUtils.castToRowExpression)12 Test (org.testng.annotations.Test)12 Session (io.prestosql.Session)11 FilterNode (io.prestosql.spi.plan.FilterNode)11 Metadata (io.prestosql.metadata.Metadata)9