Search in sources :

Example 6 with Aggregation

use of io.trino.sql.planner.plan.AggregationNode.Aggregation in project trino by trinodb.

the class TransformCorrelatedGlobalAggregationWithProjection method apply.

@Override
public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Context context) {
    checkArgument(correlatedJoinNode.getType() == INNER || correlatedJoinNode.getType() == LEFT, "unexpected correlated join type: " + correlatedJoinNode.getType());
    // if there is another aggregation below the AggregationNode, handle both
    PlanNode source = captures.get(SOURCE);
    AggregationNode distinct = null;
    if (isDistinctOperator(source)) {
        distinct = (AggregationNode) source;
        source = distinct.getSource();
    }
    // decorrelate nested plan
    PlanNodeDecorrelator decorrelator = new PlanNodeDecorrelator(plannerContext, context.getSymbolAllocator(), context.getLookup());
    Optional<PlanNodeDecorrelator.DecorrelatedNode> decorrelatedSource = decorrelator.decorrelateFilters(source, correlatedJoinNode.getCorrelation());
    if (decorrelatedSource.isEmpty()) {
        return Result.empty();
    }
    source = decorrelatedSource.get().getNode();
    // append non-null symbol on nested plan. It will be used to restore semantics of null-sensitive aggregations after LEFT join
    Symbol nonNull = context.getSymbolAllocator().newSymbol("non_null", BOOLEAN);
    source = new ProjectNode(context.getIdAllocator().getNextId(), source, Assignments.builder().putIdentities(source.getOutputSymbols()).put(nonNull, TRUE_LITERAL).build());
    // assign unique id on correlated join's input. It will be used to distinguish between original input rows after join
    PlanNode inputWithUniqueId = new AssignUniqueId(context.getIdAllocator().getNextId(), correlatedJoinNode.getInput(), context.getSymbolAllocator().newSymbol("unique", BIGINT));
    JoinNode join = new JoinNode(context.getIdAllocator().getNextId(), JoinNode.Type.LEFT, inputWithUniqueId, source, ImmutableList.of(), inputWithUniqueId.getOutputSymbols(), source.getOutputSymbols(), false, decorrelatedSource.get().getCorrelatedPredicates(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), ImmutableMap.of(), Optional.empty());
    PlanNode root = join;
    // restore distinct aggregation
    if (distinct != null) {
        root = new AggregationNode(distinct.getId(), join, distinct.getAggregations(), singleGroupingSet(ImmutableList.<Symbol>builder().addAll(join.getLeftOutputSymbols()).add(nonNull).addAll(distinct.getGroupingKeys()).build()), ImmutableList.of(), distinct.getStep(), Optional.empty(), Optional.empty());
    }
    // prepare mask symbols for aggregations
    // Every original aggregation agg() will be rewritten to agg() mask(non_null). If the aggregation
    // already has a mask, it will be replaced with conjunction of the existing mask and non_null.
    // This is necessary to restore the original aggregation result in case when:
    // - the nested lateral subquery returned empty result for some input row,
    // - aggregation is null-sensitive, which means that its result over a single null row is different
    // than result for empty input (with global grouping)
    // It applies to the following aggregate functions: count(*), checksum(), array_agg().
    AggregationNode globalAggregation = captures.get(AGGREGATION);
    ImmutableMap.Builder<Symbol, Symbol> masks = ImmutableMap.builder();
    Assignments.Builder assignmentsBuilder = Assignments.builder();
    for (Map.Entry<Symbol, Aggregation> entry : globalAggregation.getAggregations().entrySet()) {
        Aggregation aggregation = entry.getValue();
        if (aggregation.getMask().isPresent()) {
            Symbol newMask = context.getSymbolAllocator().newSymbol("mask", BOOLEAN);
            Expression expression = and(aggregation.getMask().get().toSymbolReference(), nonNull.toSymbolReference());
            assignmentsBuilder.put(newMask, expression);
            masks.put(entry.getKey(), newMask);
        } else {
            masks.put(entry.getKey(), nonNull);
        }
    }
    Assignments maskAssignments = assignmentsBuilder.build();
    if (!maskAssignments.isEmpty()) {
        root = new ProjectNode(context.getIdAllocator().getNextId(), root, Assignments.builder().putIdentities(root.getOutputSymbols()).putAll(maskAssignments).build());
    }
    // restore global aggregation
    globalAggregation = new AggregationNode(globalAggregation.getId(), root, rewriteWithMasks(globalAggregation.getAggregations(), masks.buildOrThrow()), singleGroupingSet(ImmutableList.<Symbol>builder().addAll(join.getLeftOutputSymbols()).addAll(globalAggregation.getGroupingKeys()).build()), ImmutableList.of(), globalAggregation.getStep(), Optional.empty(), Optional.empty());
    // restrict outputs and apply projection
    Set<Symbol> outputSymbols = new HashSet<>(correlatedJoinNode.getOutputSymbols());
    List<Symbol> expectedAggregationOutputs = globalAggregation.getOutputSymbols().stream().filter(outputSymbols::contains).collect(toImmutableList());
    Assignments assignments = Assignments.builder().putIdentities(expectedAggregationOutputs).putAll(captures.get(PROJECTION).getAssignments()).build();
    return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), globalAggregation, assignments));
}
Also used : Symbol(io.trino.sql.planner.Symbol) CorrelatedJoinNode(io.trino.sql.planner.plan.CorrelatedJoinNode) JoinNode(io.trino.sql.planner.plan.JoinNode) Assignments(io.trino.sql.planner.plan.Assignments) AggregationNode(io.trino.sql.planner.plan.AggregationNode) ImmutableMap(com.google.common.collect.ImmutableMap) Aggregation(io.trino.sql.planner.plan.AggregationNode.Aggregation) PlanNodeDecorrelator(io.trino.sql.planner.optimizations.PlanNodeDecorrelator) PlanNode(io.trino.sql.planner.plan.PlanNode) AssignUniqueId(io.trino.sql.planner.plan.AssignUniqueId) Expression(io.trino.sql.tree.Expression) ProjectNode(io.trino.sql.planner.plan.ProjectNode) Map(java.util.Map) ImmutableMap(com.google.common.collect.ImmutableMap) HashSet(java.util.HashSet)

Example 7 with Aggregation

use of io.trino.sql.planner.plan.AggregationNode.Aggregation in project trino by trinodb.

the class RewriteSpatialPartitioningAggregation method apply.

@Override
public Result apply(AggregationNode node, Captures captures, Context context) {
    ResolvedFunction spatialPartitioningFunction = plannerContext.getMetadata().resolveFunction(context.getSession(), QualifiedName.of(NAME), fromTypeSignatures(GEOMETRY_TYPE_SIGNATURE, INTEGER.getTypeSignature()));
    ResolvedFunction stEnvelopeFunction = plannerContext.getMetadata().resolveFunction(context.getSession(), QualifiedName.of("ST_Envelope"), fromTypeSignatures(GEOMETRY_TYPE_SIGNATURE));
    ImmutableMap.Builder<Symbol, Aggregation> aggregations = ImmutableMap.builder();
    Symbol partitionCountSymbol = context.getSymbolAllocator().newSymbol("partition_count", INTEGER);
    ImmutableMap.Builder<Symbol, Expression> envelopeAssignments = ImmutableMap.builder();
    for (Map.Entry<Symbol, Aggregation> entry : node.getAggregations().entrySet()) {
        Aggregation aggregation = entry.getValue();
        String name = aggregation.getResolvedFunction().getSignature().getName();
        if (name.equals(NAME) && aggregation.getArguments().size() == 1) {
            Expression geometry = getOnlyElement(aggregation.getArguments());
            Symbol envelopeSymbol = context.getSymbolAllocator().newSymbol("envelope", plannerContext.getTypeManager().getType(GEOMETRY_TYPE_SIGNATURE));
            if (isStEnvelopeFunctionCall(geometry, stEnvelopeFunction)) {
                envelopeAssignments.put(envelopeSymbol, geometry);
            } else {
                envelopeAssignments.put(envelopeSymbol, FunctionCallBuilder.resolve(context.getSession(), plannerContext.getMetadata()).setName(QualifiedName.of("ST_Envelope")).addArgument(GEOMETRY_TYPE_SIGNATURE, geometry).build());
            }
            aggregations.put(entry.getKey(), new Aggregation(spatialPartitioningFunction, ImmutableList.of(envelopeSymbol.toSymbolReference(), partitionCountSymbol.toSymbolReference()), false, Optional.empty(), Optional.empty(), aggregation.getMask()));
        } else {
            aggregations.put(entry);
        }
    }
    return Result.ofPlanNode(new AggregationNode(node.getId(), new ProjectNode(context.getIdAllocator().getNextId(), node.getSource(), Assignments.builder().putIdentities(node.getSource().getOutputSymbols()).put(partitionCountSymbol, new LongLiteral(Integer.toString(getHashPartitionCount(context.getSession())))).putAll(envelopeAssignments.buildOrThrow()).build()), aggregations.buildOrThrow(), node.getGroupingSets(), node.getPreGroupedSymbols(), node.getStep(), node.getHashSymbol(), node.getGroupIdSymbol()));
}
Also used : LongLiteral(io.trino.sql.tree.LongLiteral) ResolvedFunction(io.trino.metadata.ResolvedFunction) Symbol(io.trino.sql.planner.Symbol) AggregationNode(io.trino.sql.planner.plan.AggregationNode) ImmutableMap(com.google.common.collect.ImmutableMap) Aggregation(io.trino.sql.planner.plan.AggregationNode.Aggregation) Expression(io.trino.sql.tree.Expression) ProjectNode(io.trino.sql.planner.plan.ProjectNode) Map(java.util.Map) ImmutableMap(com.google.common.collect.ImmutableMap)

Example 8 with Aggregation

use of io.trino.sql.planner.plan.AggregationNode.Aggregation in project trino by trinodb.

the class TestTypeValidator method testInvalidAggregationFunctionCall.

@Test
public void testInvalidAggregationFunctionCall() {
    Symbol aggregationSymbol = symbolAllocator.newSymbol("sum", DOUBLE);
    PlanNode node = new AggregationNode(newId(), baseTableScan, ImmutableMap.of(aggregationSymbol, new Aggregation(functionResolution.resolveFunction(QualifiedName.of("sum"), fromTypes(DOUBLE)), ImmutableList.of(columnA.toSymbolReference()), false, Optional.empty(), Optional.empty(), Optional.empty())), singleGroupingSet(ImmutableList.of(columnA, columnB)), ImmutableList.of(), SINGLE, Optional.empty(), Optional.empty());
    assertThatThrownBy(() -> assertTypesValid(node)).isInstanceOf(IllegalArgumentException.class).hasMessageMatching("type of symbol 'sum(_[0-9]+)?' is expected to be double, but the actual type is bigint");
}
Also used : Aggregation(io.trino.sql.planner.plan.AggregationNode.Aggregation) PlanNode(io.trino.sql.planner.plan.PlanNode) AggregationNode(io.trino.sql.planner.plan.AggregationNode) Test(org.testng.annotations.Test)

Example 9 with Aggregation

use of io.trino.sql.planner.plan.AggregationNode.Aggregation in project trino by trinodb.

the class TestTypeValidator method testValidAggregation.

@Test
public void testValidAggregation() {
    Symbol aggregationSymbol = symbolAllocator.newSymbol("sum", DOUBLE);
    PlanNode node = new AggregationNode(newId(), baseTableScan, ImmutableMap.of(aggregationSymbol, new Aggregation(functionResolution.resolveFunction(QualifiedName.of("sum"), fromTypes(DOUBLE)), ImmutableList.of(columnC.toSymbolReference()), false, Optional.empty(), Optional.empty(), Optional.empty())), singleGroupingSet(ImmutableList.of(columnA, columnB)), ImmutableList.of(), SINGLE, Optional.empty(), Optional.empty());
    assertTypesValid(node);
}
Also used : Aggregation(io.trino.sql.planner.plan.AggregationNode.Aggregation) PlanNode(io.trino.sql.planner.plan.PlanNode) AggregationNode(io.trino.sql.planner.plan.AggregationNode) Test(org.testng.annotations.Test)

Example 10 with Aggregation

use of io.trino.sql.planner.plan.AggregationNode.Aggregation in project trino by trinodb.

the class TestTypeValidator method testInvalidAggregationFunctionSignature.

@Test
public void testInvalidAggregationFunctionSignature() {
    Symbol aggregationSymbol = symbolAllocator.newSymbol("sum", BIGINT);
    PlanNode node = new AggregationNode(newId(), baseTableScan, ImmutableMap.of(aggregationSymbol, new Aggregation(functionResolution.resolveFunction(QualifiedName.of("sum"), fromTypes(DOUBLE)), ImmutableList.of(columnC.toSymbolReference()), false, Optional.empty(), Optional.empty(), Optional.empty())), singleGroupingSet(ImmutableList.of(columnA, columnB)), ImmutableList.of(), SINGLE, Optional.empty(), Optional.empty());
    assertThatThrownBy(() -> assertTypesValid(node)).isInstanceOf(IllegalArgumentException.class).hasMessageMatching("type of symbol 'sum(_[0-9]+)?' is expected to be bigint, but the actual type is double");
}
Also used : Aggregation(io.trino.sql.planner.plan.AggregationNode.Aggregation) PlanNode(io.trino.sql.planner.plan.PlanNode) AggregationNode(io.trino.sql.planner.plan.AggregationNode) Test(org.testng.annotations.Test)

Aggregations

Aggregation (io.trino.sql.planner.plan.AggregationNode.Aggregation)20 AggregationNode (io.trino.sql.planner.plan.AggregationNode)17 Symbol (io.trino.sql.planner.Symbol)14 Map (java.util.Map)14 ImmutableMap (com.google.common.collect.ImmutableMap)11 PlanNode (io.trino.sql.planner.plan.PlanNode)9 Expression (io.trino.sql.tree.Expression)9 ProjectNode (io.trino.sql.planner.plan.ProjectNode)7 ImmutableList (com.google.common.collect.ImmutableList)5 ResolvedFunction (io.trino.metadata.ResolvedFunction)5 Assignments (io.trino.sql.planner.plan.Assignments)5 Type (io.trino.spi.type.Type)4 CorrelatedJoinNode (io.trino.sql.planner.plan.CorrelatedJoinNode)4 ComparisonExpression (io.trino.sql.tree.ComparisonExpression)4 Test (org.testng.annotations.Test)4 Preconditions.checkArgument (com.google.common.base.Preconditions.checkArgument)3 ImmutableList.toImmutableList (com.google.common.collect.ImmutableList.toImmutableList)3 ImmutableSet (com.google.common.collect.ImmutableSet)3 Session (io.trino.Session)3 Cast (io.trino.sql.tree.Cast)3