use of com.facebook.presto.sql.planner.plan.AggregationNode in project presto by prestodb.
the class SimplifyCountOverConstant method apply.
@Override
public Optional<PlanNode> apply(PlanNode node, Lookup lookup, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator) {
if (!(node instanceof AggregationNode)) {
return Optional.empty();
}
AggregationNode parent = (AggregationNode) node;
PlanNode input = lookup.resolve(parent.getSource());
if (!(input instanceof ProjectNode)) {
return Optional.empty();
}
ProjectNode child = (ProjectNode) input;
boolean changed = false;
Map<Symbol, AggregationNode.Aggregation> assignments = new LinkedHashMap<>(parent.getAssignments());
for (Entry<Symbol, AggregationNode.Aggregation> entry : parent.getAssignments().entrySet()) {
Symbol symbol = entry.getKey();
AggregationNode.Aggregation aggregation = entry.getValue();
if (isCountOverConstant(aggregation, child.getAssignments())) {
changed = true;
assignments.put(symbol, new AggregationNode.Aggregation(new FunctionCall(QualifiedName.of("count"), ImmutableList.of()), new Signature("count", AGGREGATE, parseTypeSignature(StandardTypes.BIGINT))));
}
}
if (!changed) {
return Optional.empty();
}
return Optional.of(new AggregationNode(node.getId(), child, assignments, parent.getGroupingSets(), parent.getStep(), parent.getHashSymbol(), parent.getGroupIdSymbol()));
}
use of com.facebook.presto.sql.planner.plan.AggregationNode in project presto by prestodb.
the class AggregationMatcher method detailMatches.
@Override
public MatchResult detailMatches(PlanNode node, Session session, Metadata metadata, SymbolAliases symbolAliases) {
checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName());
AggregationNode aggregationNode = (AggregationNode) node;
if (groupId.isPresent() != aggregationNode.getGroupIdSymbol().isPresent()) {
return NO_MATCH;
}
if (groupingSets.size() != aggregationNode.getGroupingSets().size()) {
return NO_MATCH;
}
List<Symbol> aggregationsWithMask = aggregationNode.getAggregations().entrySet().stream().filter(entry -> entry.getValue().isDistinct()).map(entry -> entry.getKey()).collect(Collectors.toList());
if (aggregationsWithMask.size() != masks.keySet().size()) {
return NO_MATCH;
}
for (Symbol symbol : aggregationsWithMask) {
if (!masks.keySet().contains(symbol)) {
return NO_MATCH;
}
}
for (int i = 0; i < groupingSets.size(); i++) {
if (!matches(groupingSets.get(i), aggregationNode.getGroupingSets().get(i), symbolAliases)) {
return NO_MATCH;
}
}
return match();
}
use of com.facebook.presto.sql.planner.plan.AggregationNode in project presto by prestodb.
the class TestEffectivePredicateExtractor method testGroupByEmpty.
@Test
public void testGroupByEmpty() throws Exception {
PlanNode node = new AggregationNode(newId(), filter(baseTableScan, FALSE_LITERAL), ImmutableMap.of(), ImmutableMap.of(), ImmutableMap.of(), ImmutableList.of(ImmutableList.of()), AggregationNode.Step.FINAL, Optional.empty(), Optional.empty());
Expression effectivePredicate = EffectivePredicateExtractor.extract(node, TYPES);
assertEquals(effectivePredicate, TRUE_LITERAL);
}
use of com.facebook.presto.sql.planner.plan.AggregationNode in project presto by prestodb.
the class TestTypeValidator method testInvalidAggregationFunctionSignature.
@Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "type of symbol 'sum(_[0-9]+)?' is expected to be double, but the actual type is bigint")
public void testInvalidAggregationFunctionSignature() throws Exception {
Symbol aggregationSymbol = symbolAllocator.newSymbol("sum", DOUBLE);
Map<Symbol, Signature> functions = ImmutableMap.of(aggregationSymbol, new Signature("sum", FunctionKind.AGGREGATE, ImmutableList.of(), ImmutableList.of(), // should be DOUBLE
BIGINT.getTypeSignature(), ImmutableList.of(DOUBLE.getTypeSignature()), false));
Map<Symbol, FunctionCall> aggregations = ImmutableMap.of(aggregationSymbol, new FunctionCall(QualifiedName.of("sum"), ImmutableList.of(columnC.toSymbolReference())));
PlanNode node = new AggregationNode(newId(), baseTableScan, aggregations, functions, ImmutableMap.of(), ImmutableList.of(ImmutableList.of(columnA, columnB)), SINGLE, Optional.empty(), Optional.empty());
assertTypesValid(node);
}
Aggregations