use of com.facebook.presto.spi.plan.AggregationNode in project presto by prestodb.
the class TestTranslateExpressions method testTranslateIntermediateAggregationWithLambda.
@Test
public void testTranslateIntermediateAggregationWithLambda() {
PlanNode result = tester().assertThat(new TranslateExpressions(METADATA, new SqlParser()).aggregationRowExpressionRewriteRule()).on(p -> p.aggregation(builder -> builder.globalGrouping().addAggregation(variable("reduce_agg", INTEGER), new AggregationNode.Aggregation(new CallExpression("reduce_agg", REDUCE_AGG, INTEGER, ImmutableList.of(castToRowExpression(expression("input")), castToRowExpression(expression("(x,y) -> x*y")), castToRowExpression(expression("(a,b) -> a*b")))), Optional.of(castToRowExpression(expression("input > 10"))), Optional.empty(), false, Optional.empty())).source(p.values(p.variable("input", INTEGER))))).get();
AggregationNode.Aggregation translated = ((AggregationNode) result).getAggregations().get(variable("reduce_agg", INTEGER));
assertEquals(translated, new AggregationNode.Aggregation(new CallExpression("reduce_agg", REDUCE_AGG, INTEGER, ImmutableList.of(variable("input", INTEGER), new LambdaDefinitionExpression(Optional.empty(), ImmutableList.of(INTEGER, INTEGER), ImmutableList.of("x", "y"), multiply(variable("x", INTEGER), variable("y", INTEGER))), new LambdaDefinitionExpression(Optional.empty(), ImmutableList.of(INTEGER, INTEGER), ImmutableList.of("a", "b"), multiply(variable("a", INTEGER), variable("b", INTEGER))))), Optional.of(greaterThan(variable("input", INTEGER), constant(10L, INTEGER))), Optional.empty(), false, Optional.empty()));
assertFalse(isUntranslated(translated));
}
use of com.facebook.presto.spi.plan.AggregationNode in project presto by prestodb.
the class TestTranslateExpressions method testTranslateAggregationWithLambda.
@Test
public void testTranslateAggregationWithLambda() {
PlanNode result = tester().assertThat(new TranslateExpressions(METADATA, new SqlParser()).aggregationRowExpressionRewriteRule()).on(p -> p.aggregation(builder -> builder.globalGrouping().addAggregation(variable("reduce_agg", INTEGER), new AggregationNode.Aggregation(new CallExpression("reduce_agg", REDUCE_AGG, INTEGER, ImmutableList.of(castToRowExpression(expression("input")), castToRowExpression(expression("0")), castToRowExpression(expression("(x,y) -> x*y")), castToRowExpression(expression("(a,b) -> a*b")))), Optional.of(castToRowExpression(expression("input > 10"))), Optional.empty(), false, Optional.empty())).source(p.values(p.variable("input", INTEGER))))).get();
// TODO migrate this to RowExpressionMatcher
AggregationNode.Aggregation translated = ((AggregationNode) result).getAggregations().get(variable("reduce_agg", INTEGER));
assertEquals(translated, new AggregationNode.Aggregation(new CallExpression("reduce_agg", REDUCE_AGG, INTEGER, ImmutableList.of(variable("input", INTEGER), constant(0L, INTEGER), new LambdaDefinitionExpression(Optional.empty(), ImmutableList.of(INTEGER, INTEGER), ImmutableList.of("x", "y"), multiply(variable("x", INTEGER), variable("y", INTEGER))), new LambdaDefinitionExpression(Optional.empty(), ImmutableList.of(INTEGER, INTEGER), ImmutableList.of("a", "b"), multiply(variable("a", INTEGER), variable("b", INTEGER))))), Optional.of(greaterThan(variable("input", INTEGER), constant(10L, INTEGER))), Optional.empty(), false, Optional.empty()));
assertFalse(isUntranslated(translated));
}
use of com.facebook.presto.spi.plan.AggregationNode in project presto by prestodb.
the class ScalarAggregationToJoinRewriter method createAggregationNode.
private Optional<AggregationNode> createAggregationNode(AggregationNode scalarAggregation, JoinNode leftOuterJoin, VariableReferenceExpression nonNull) {
ImmutableMap.Builder<VariableReferenceExpression, Aggregation> aggregations = ImmutableMap.builder();
for (Map.Entry<VariableReferenceExpression, Aggregation> entry : scalarAggregation.getAggregations().entrySet()) {
VariableReferenceExpression variable = entry.getKey();
if (functionResolution.isCountFunction(entry.getValue().getFunctionHandle())) {
Type scalarAggregationSourceType = nonNull.getType();
aggregations.put(variable, new Aggregation(new CallExpression(variable.getSourceLocation(), "count", functionResolution.countFunction(scalarAggregationSourceType), BIGINT, ImmutableList.of(castToRowExpression(asSymbolReference(nonNull)))), Optional.empty(), Optional.empty(), false, entry.getValue().getMask()));
} else {
aggregations.put(variable, entry.getValue());
}
}
return Optional.of(new AggregationNode(scalarAggregation.getSourceLocation(), idAllocator.getNextId(), leftOuterJoin, aggregations.build(), singleGroupingSet(leftOuterJoin.getLeft().getOutputVariables()), ImmutableList.of(), scalarAggregation.getStep(), scalarAggregation.getHashVariable(), Optional.empty()));
}
use of com.facebook.presto.spi.plan.AggregationNode in project presto by prestodb.
the class TestCostCalculator method testAggregation.
@Test
public void testAggregation() {
TableScanNode tableScan = tableScan("ts", "orderkey");
AggregationNode aggregation = aggregation("agg", tableScan);
Map<String, PlanCostEstimate> costs = ImmutableMap.of("ts", cpuCost(6000));
Map<String, PlanNodeStatsEstimate> stats = ImmutableMap.of("ts", statsEstimate(tableScan, 6000), "agg", statsEstimate(aggregation, 13));
Map<String, Type> types = ImmutableMap.of("orderkey", BIGINT, "count", BIGINT);
assertCost(aggregation, costs, stats).cpu(6000 * IS_NULL_OVERHEAD + 6000).memory(13 * IS_NULL_OVERHEAD).network(0);
assertCostEstimatedExchanges(aggregation, costs, stats).cpu((6000 + 6000 + 6000) * IS_NULL_OVERHEAD + 6000).memory(13 * IS_NULL_OVERHEAD).network(6000 * IS_NULL_OVERHEAD);
assertCostSingleStageFragmentedPlan(aggregation, costs, stats, types).cpu(6000 + 6000 * IS_NULL_OVERHEAD).memory(13 * IS_NULL_OVERHEAD).network(0 * IS_NULL_OVERHEAD);
assertCostHasUnknownComponentsForUnknownStats(aggregation);
}
use of com.facebook.presto.spi.plan.AggregationNode in project presto by prestodb.
the class MultipleDistinctAggregationToMarkDistinct method apply.
@Override
public Result apply(AggregationNode parent, Captures captures, Context context) {
if (!SystemSessionProperties.useMarkDistinct(context.getSession())) {
return Result.empty();
}
// the distinct marker for the given set of input columns
Map<Set<VariableReferenceExpression>, VariableReferenceExpression> markers = new HashMap<>();
Map<VariableReferenceExpression, Aggregation> newAggregations = new HashMap<>();
PlanNode subPlan = parent.getSource();
for (Map.Entry<VariableReferenceExpression, Aggregation> entry : parent.getAggregations().entrySet()) {
Aggregation aggregation = entry.getValue();
if (aggregation.isDistinct() && !aggregation.getFilter().isPresent() && !aggregation.getMask().isPresent()) {
Set<VariableReferenceExpression> inputs = aggregation.getArguments().stream().map(OriginalExpressionUtils::castToExpression).map(context.getVariableAllocator()::toVariableReference).collect(toSet());
VariableReferenceExpression marker = markers.get(inputs);
if (marker == null) {
marker = context.getVariableAllocator().newVariable(Iterables.getLast(inputs).getName(), BOOLEAN, "distinct");
markers.put(inputs, marker);
ImmutableSet.Builder<VariableReferenceExpression> distinctVariables = ImmutableSet.<VariableReferenceExpression>builder().addAll(parent.getGroupingKeys()).addAll(inputs);
parent.getGroupIdVariable().ifPresent(distinctVariables::add);
subPlan = new MarkDistinctNode(subPlan.getSourceLocation(), context.getIdAllocator().getNextId(), subPlan, marker, ImmutableList.copyOf(distinctVariables.build()), Optional.empty());
}
// remove the distinct flag and set the distinct marker
newAggregations.put(entry.getKey(), new Aggregation(aggregation.getCall(), aggregation.getFilter(), aggregation.getOrderBy(), false, Optional.of(marker)));
} else {
newAggregations.put(entry.getKey(), aggregation);
}
}
return Result.ofPlanNode(new AggregationNode(parent.getSourceLocation(), parent.getId(), subPlan, newAggregations, parent.getGroupingSets(), ImmutableList.of(), parent.getStep(), parent.getHashVariable(), parent.getGroupIdVariable()));
}
Aggregations