Search in sources :

Example 31 with Assignments

use of io.prestosql.spi.plan.Assignments in project hetu-core by openlookeng.

the class CubeOptimizer method rewriteAggregationNode.

private PlanNode rewriteAggregationNode(CubeRewriteResult cubeRewriteResult, PlanNode inputPlanNode) {
    TypeProvider typeProvider = context.getSymbolAllocator().getTypes();
    // Add group by
    List<Symbol> groupings = aggregationNode.getGroupingKeys().stream().map(Symbol::getName).map(columnRewritesMap::get).map(optimizedPlanMappings::get).collect(Collectors.toList());
    Map<Symbol, Symbol> cubeScanToAggOutputMap = new HashMap<>();
    // Rewrite AggregationNode using Cube table
    ImmutableMap.Builder<Symbol, AggregationNode.Aggregation> aggregationsBuilder = ImmutableMap.builder();
    for (CubeRewriteResult.AggregatorSource aggregatorSource : cubeRewriteResult.getAggregationColumns()) {
        Type type = cubeRewriteResult.getSymbolMetadataMap().get(aggregatorSource.getOriginalAggSymbol()).getType();
        TypeSignature typeSignature = type.getTypeSignature();
        ColumnHandle cubeColHandle = cubeRewriteResult.getTableScanNode().getAssignments().get(aggregatorSource.getScanSymbol());
        ColumnMetadata cubeColumnMetadata = metadata.getColumnMetadata(context.getSession(), cubeTableHandle, cubeColHandle);
        AggregationSignature aggregationSignature = cubeMetadata.getAggregationSignature(cubeColumnMetadata.getName()).orElseThrow(() -> new ColumnNotFoundException(new SchemaTableName("", ""), cubeColHandle.getColumnName()));
        String aggFunction = COUNT.getName().equals(aggregationSignature.getFunction()) ? SUM.getName() : aggregationSignature.getFunction();
        SymbolReference argument = toSymbolReference(aggregatorSource.getScanSymbol());
        FunctionHandle functionHandle = metadata.getFunctionAndTypeManager().lookupFunction(aggFunction, TypeSignatureProvider.fromTypeSignatures(typeSignature));
        cubeScanToAggOutputMap.put(aggregatorSource.getScanSymbol(), aggregatorSource.getOriginalAggSymbol());
        aggregationsBuilder.put(aggregatorSource.getOriginalAggSymbol(), new AggregationNode.Aggregation(new CallExpression(aggFunction, functionHandle, type, ImmutableList.of(castToRowExpression(argument))), ImmutableList.of(castToRowExpression(argument)), false, Optional.empty(), Optional.empty(), Optional.empty()));
    }
    PlanNode planNode = inputPlanNode;
    AggregationNode aggNode = new AggregationNode(context.getIdAllocator().getNextId(), planNode, aggregationsBuilder.build(), singleGroupingSet(groupings), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty(), AggregationNode.AggregationType.HASH, Optional.empty());
    if (cubeRewriteResult.getAvgAggregationColumns().isEmpty()) {
        return aggNode;
    }
    if (!cubeRewriteResult.getComputeAvgDividingSumByCount()) {
        Map<Symbol, Expression> aggregateAssignments = new HashMap<>();
        for (CubeRewriteResult.AggregatorSource aggregatorSource : cubeRewriteResult.getAggregationColumns()) {
            aggregateAssignments.put(aggregatorSource.getOriginalAggSymbol(), toSymbolReference(aggregatorSource.getScanSymbol()));
        }
        planNode = new ProjectNode(context.getIdAllocator().getNextId(), aggNode, new Assignments(aggregateAssignments.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> castToRowExpression(entry.getValue())))));
    } else {
        // If there was an AVG aggregation, map it to AVG = SUM/COUNT
        Map<Symbol, Expression> projections = new HashMap<>();
        aggNode.getGroupingKeys().forEach(symbol -> projections.put(symbol, toSymbolReference(symbol)));
        aggNode.getAggregations().keySet().stream().filter(originalAggregationsMap::containsValue).forEach(aggSymbol -> projections.put(aggSymbol, toSymbolReference(aggSymbol)));
        // Add AVG = SUM / COUNT
        for (CubeRewriteResult.AverageAggregatorSource avgAggSource : cubeRewriteResult.getAvgAggregationColumns()) {
            Symbol sumSymbol = cubeScanToAggOutputMap.get(avgAggSource.getSum());
            Symbol countSymbol = cubeScanToAggOutputMap.get(avgAggSource.getCount());
            Type avgResultType = typeProvider.get(avgAggSource.getOriginalAggSymbol());
            ArithmeticBinaryExpression division = new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.DIVIDE, new Cast(toSymbolReference(sumSymbol), avgResultType.getTypeSignature().toString()), new Cast(toSymbolReference(countSymbol), avgResultType.getTypeSignature().toString()));
            projections.put(avgAggSource.getOriginalAggSymbol(), division);
        }
        return new ProjectNode(context.getIdAllocator().getNextId(), aggNode, new Assignments(projections.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> castToRowExpression(entry.getValue())))));
    }
    return planNode;
}
Also used : ArithmeticBinaryExpression(io.prestosql.sql.tree.ArithmeticBinaryExpression) Cast(io.prestosql.sql.tree.Cast) LongSupplier(java.util.function.LongSupplier) ConstantExpression(io.prestosql.spi.relation.ConstantExpression) PrestoWarning(io.prestosql.spi.PrestoWarning) TypeProvider(io.prestosql.sql.planner.TypeProvider) SqlParser(io.prestosql.sql.parser.SqlParser) ColumnNotFoundException(io.prestosql.spi.connector.ColumnNotFoundException) AggregationNode(io.prestosql.spi.plan.AggregationNode) CallExpression(io.prestosql.spi.relation.CallExpression) Cast(io.prestosql.sql.tree.Cast) FilterNode(io.prestosql.spi.plan.FilterNode) Pair(org.apache.commons.lang3.tuple.Pair) Map(java.util.Map) OriginalExpressionUtils.castToRowExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToRowExpression) Type(io.prestosql.spi.type.Type) OriginalExpressionUtils.castToExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToExpression) CubeFilter(io.hetu.core.spi.cube.CubeFilter) ENGLISH(java.util.Locale.ENGLISH) Identifier(io.prestosql.sql.tree.Identifier) AggregationNode.singleGroupingSet(io.prestosql.spi.plan.AggregationNode.singleGroupingSet) CubeMetaStore(io.hetu.core.spi.cube.io.CubeMetaStore) PrestoException(io.prestosql.spi.PrestoException) SymbolsExtractor(io.prestosql.sql.planner.SymbolsExtractor) SUM(io.hetu.core.spi.cube.CubeAggregateFunction.SUM) ImmutableMap(com.google.common.collect.ImmutableMap) CubeAggregateFunction(io.hetu.core.spi.cube.CubeAggregateFunction) TableScanNode(io.prestosql.spi.plan.TableScanNode) Set(java.util.Set) PlanNode(io.prestosql.spi.plan.PlanNode) UUID(java.util.UUID) ProjectNode(io.prestosql.spi.plan.ProjectNode) Collectors(java.util.stream.Collectors) Metadata(io.prestosql.metadata.Metadata) String.format(java.lang.String.format) FunctionHandle(io.prestosql.spi.function.FunctionHandle) Objects(java.util.Objects) SymbolUtils.toSymbolReference(io.prestosql.sql.planner.SymbolUtils.toSymbolReference) ReuseExchangeOperator(io.prestosql.spi.operator.ReuseExchangeOperator) List(java.util.List) ExpressionUtils(io.prestosql.sql.ExpressionUtils) LongLiteral(io.prestosql.sql.tree.LongLiteral) SymbolReference(io.prestosql.sql.tree.SymbolReference) AggregationSignature(io.hetu.core.spi.cube.aggregator.AggregationSignature) Optional(java.util.Optional) NOT_SUPPORTED(io.prestosql.spi.StandardErrorCode.NOT_SUPPORTED) TypeSignature(io.prestosql.spi.type.TypeSignature) OriginalExpressionUtils(io.prestosql.sql.relational.OriginalExpressionUtils) Iterables(com.google.common.collect.Iterables) COUNT(io.hetu.core.spi.cube.CubeAggregateFunction.COUNT) TableMetadata(io.prestosql.metadata.TableMetadata) Logger(io.airlift.log.Logger) TypeSignatureProvider(io.prestosql.sql.analyzer.TypeSignatureProvider) Literal(io.prestosql.sql.tree.Literal) HashMap(java.util.HashMap) TableHandle(io.prestosql.spi.metadata.TableHandle) ExpressionTreeRewriter(io.prestosql.sql.tree.ExpressionTreeRewriter) QualifiedObjectName(io.prestosql.spi.connector.QualifiedObjectName) ArrayList(java.util.ArrayList) HashSet(java.util.HashSet) SchemaTableName(io.prestosql.spi.connector.SchemaTableName) Lists(com.google.common.collect.Lists) ArithmeticBinaryExpression(io.prestosql.sql.tree.ArithmeticBinaryExpression) ImmutableList(com.google.common.collect.ImmutableList) ExpressionDomainTranslator(io.prestosql.sql.planner.ExpressionDomainTranslator) BooleanLiteral(io.prestosql.sql.tree.BooleanLiteral) Objects.requireNonNull(java.util.Objects.requireNonNull) Session(io.prestosql.Session) RowExpressionRewriter(io.prestosql.expressions.RowExpressionRewriter) RowExpressionTreeRewriter(io.prestosql.expressions.RowExpressionTreeRewriter) ExpressionRewriter(io.prestosql.sql.tree.ExpressionRewriter) JoinNode(io.prestosql.spi.plan.JoinNode) ParsingOptions(io.prestosql.sql.parser.ParsingOptions) Symbol(io.prestosql.spi.plan.Symbol) AssignmentUtils(io.prestosql.sql.planner.plan.AssignmentUtils) EXPIRED_CUBE(io.prestosql.spi.connector.StandardWarningCode.EXPIRED_CUBE) Assignments(io.prestosql.spi.plan.Assignments) Rule(io.prestosql.sql.planner.iterative.Rule) ColumnMetadata(io.prestosql.spi.connector.ColumnMetadata) TupleDomain(io.prestosql.spi.predicate.TupleDomain) ComparisonExpression(io.prestosql.sql.tree.ComparisonExpression) VariableReferenceExpression(io.prestosql.spi.relation.VariableReferenceExpression) SymbolAllocator(io.prestosql.spi.SymbolAllocator) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) ColumnHandle(io.prestosql.spi.connector.ColumnHandle) CUBE_ERROR(io.prestosql.spi.StandardErrorCode.CUBE_ERROR) RowExpression(io.prestosql.spi.relation.RowExpression) CubeMetadata(io.hetu.core.spi.cube.CubeMetadata) Comparator(java.util.Comparator) Expression(io.prestosql.sql.tree.Expression) ColumnMetadata(io.prestosql.spi.connector.ColumnMetadata) HashMap(java.util.HashMap) Symbol(io.prestosql.spi.plan.Symbol) SymbolUtils.toSymbolReference(io.prestosql.sql.planner.SymbolUtils.toSymbolReference) SymbolReference(io.prestosql.sql.tree.SymbolReference) Assignments(io.prestosql.spi.plan.Assignments) TypeSignature(io.prestosql.spi.type.TypeSignature) PlanNode(io.prestosql.spi.plan.PlanNode) FunctionHandle(io.prestosql.spi.function.FunctionHandle) CallExpression(io.prestosql.spi.relation.CallExpression) ColumnHandle(io.prestosql.spi.connector.ColumnHandle) TypeProvider(io.prestosql.sql.planner.TypeProvider) AggregationSignature(io.hetu.core.spi.cube.aggregator.AggregationSignature) AggregationNode(io.prestosql.spi.plan.AggregationNode) SchemaTableName(io.prestosql.spi.connector.SchemaTableName) ImmutableMap(com.google.common.collect.ImmutableMap) Type(io.prestosql.spi.type.Type) ColumnNotFoundException(io.prestosql.spi.connector.ColumnNotFoundException) ConstantExpression(io.prestosql.spi.relation.ConstantExpression) CallExpression(io.prestosql.spi.relation.CallExpression) OriginalExpressionUtils.castToRowExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToRowExpression) OriginalExpressionUtils.castToExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToExpression) ArithmeticBinaryExpression(io.prestosql.sql.tree.ArithmeticBinaryExpression) ComparisonExpression(io.prestosql.sql.tree.ComparisonExpression) VariableReferenceExpression(io.prestosql.spi.relation.VariableReferenceExpression) RowExpression(io.prestosql.spi.relation.RowExpression) Expression(io.prestosql.sql.tree.Expression) ProjectNode(io.prestosql.spi.plan.ProjectNode) Map(java.util.Map) ImmutableMap(com.google.common.collect.ImmutableMap) HashMap(java.util.HashMap)

Example 32 with Assignments

use of io.prestosql.spi.plan.Assignments in project hetu-core by openlookeng.

the class ScalarAggregationToJoinRewriter method rewriteScalarAggregation.

private PlanNode rewriteScalarAggregation(LateralJoinNode lateralJoinNode, AggregationNode scalarAggregation, PlanNode scalarAggregationSource, Optional<Expression> joinExpression, Symbol nonNull) {
    AssignUniqueId inputWithUniqueColumns = new AssignUniqueId(idAllocator.getNextId(), lateralJoinNode.getInput(), planSymbolAllocator.newSymbol("unique", BIGINT));
    JoinNode leftOuterJoin = new JoinNode(idAllocator.getNextId(), JoinNode.Type.LEFT, inputWithUniqueColumns, scalarAggregationSource, ImmutableList.of(), ImmutableList.<Symbol>builder().addAll(inputWithUniqueColumns.getOutputSymbols()).addAll(scalarAggregationSource.getOutputSymbols()).build(), joinExpression.map(OriginalExpressionUtils::castToRowExpression), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), ImmutableMap.of());
    Optional<AggregationNode> aggregationNode = createAggregationNode(scalarAggregation, leftOuterJoin, nonNull);
    if (!aggregationNode.isPresent()) {
        return lateralJoinNode;
    }
    Optional<ProjectNode> subqueryProjection = searchFrom(lateralJoinNode.getSubquery(), lookup).where(ProjectNode.class::isInstance).recurseOnlyWhen(EnforceSingleRowNode.class::isInstance).findFirst();
    List<Symbol> aggregationOutputSymbols = getTruncatedAggregationSymbols(lateralJoinNode, aggregationNode.get());
    if (subqueryProjection.isPresent()) {
        Assignments assignments = Assignments.builder().putAll(AssignmentUtils.identityAsSymbolReferences(aggregationOutputSymbols)).putAll(subqueryProjection.get().getAssignments()).build();
        return new ProjectNode(idAllocator.getNextId(), aggregationNode.get(), assignments);
    } else {
        return new ProjectNode(idAllocator.getNextId(), aggregationNode.get(), AssignmentUtils.identityAsSymbolReferences(aggregationOutputSymbols));
    }
}
Also used : AssignUniqueId(io.prestosql.sql.planner.plan.AssignUniqueId) LateralJoinNode(io.prestosql.sql.planner.plan.LateralJoinNode) JoinNode(io.prestosql.spi.plan.JoinNode) Symbol(io.prestosql.spi.plan.Symbol) Assignments(io.prestosql.spi.plan.Assignments) AggregationNode(io.prestosql.spi.plan.AggregationNode) ProjectNode(io.prestosql.spi.plan.ProjectNode)

Example 33 with Assignments

use of io.prestosql.spi.plan.Assignments in project hetu-core by openlookeng.

the class TestTypeValidator method testValidTypeOnlyCoercion.

@Test
public void testValidTypeOnlyCoercion() {
    Expression expression = new Cast(toSymbolReference(columnB), StandardTypes.BIGINT);
    Assignments assignments = Assignments.builder().put(planSymbolAllocator.newSymbol(expression, BIGINT), castToRowExpression(expression)).put(planSymbolAllocator.newSymbol(toSymbolReference(columnE), VARCHAR), // implicit coercion from varchar(3) to varchar
    castToRowExpression(toSymbolReference(columnE))).build();
    PlanNode node = new ProjectNode(newId(), baseTableScan, assignments);
    assertTypesValid(node);
}
Also used : Cast(io.prestosql.sql.tree.Cast) PlanNode(io.prestosql.spi.plan.PlanNode) CallExpression(io.prestosql.spi.relation.CallExpression) OriginalExpressionUtils.castToRowExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToRowExpression) Expression(io.prestosql.sql.tree.Expression) Assignments(io.prestosql.spi.plan.Assignments) ProjectNode(io.prestosql.spi.plan.ProjectNode) Test(org.testng.annotations.Test)

Example 34 with Assignments

use of io.prestosql.spi.plan.Assignments in project hetu-core by openlookeng.

the class TestTypeValidator method testValidProject.

@Test
public void testValidProject() {
    Expression expression1 = new Cast(toSymbolReference(columnB), StandardTypes.BIGINT);
    Expression expression2 = new Cast(toSymbolReference(columnC), StandardTypes.BIGINT);
    Assignments assignments = Assignments.builder().put(planSymbolAllocator.newSymbol(expression1, BIGINT), castToRowExpression(expression1)).put(planSymbolAllocator.newSymbol(expression2, BIGINT), castToRowExpression(expression2)).build();
    PlanNode node = new ProjectNode(newId(), baseTableScan, assignments);
    assertTypesValid(node);
}
Also used : Cast(io.prestosql.sql.tree.Cast) PlanNode(io.prestosql.spi.plan.PlanNode) CallExpression(io.prestosql.spi.relation.CallExpression) OriginalExpressionUtils.castToRowExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToRowExpression) Expression(io.prestosql.sql.tree.Expression) Assignments(io.prestosql.spi.plan.Assignments) ProjectNode(io.prestosql.spi.plan.ProjectNode) Test(org.testng.annotations.Test)

Aggregations

Assignments (io.prestosql.spi.plan.Assignments)34 ProjectNode (io.prestosql.spi.plan.ProjectNode)31 Symbol (io.prestosql.spi.plan.Symbol)25 PlanNode (io.prestosql.spi.plan.PlanNode)20 Expression (io.prestosql.sql.tree.Expression)18 OriginalExpressionUtils.castToRowExpression (io.prestosql.sql.relational.OriginalExpressionUtils.castToRowExpression)16 Cast (io.prestosql.sql.tree.Cast)15 RowExpression (io.prestosql.spi.relation.RowExpression)14 ImmutableList (com.google.common.collect.ImmutableList)13 Type (io.prestosql.spi.type.Type)13 HashMap (java.util.HashMap)13 Map (java.util.Map)13 AggregationNode (io.prestosql.spi.plan.AggregationNode)12 ImmutableMap (com.google.common.collect.ImmutableMap)11 TableScanNode (io.prestosql.spi.plan.TableScanNode)11 List (java.util.List)11 ImmutableList.toImmutableList (com.google.common.collect.ImmutableList.toImmutableList)10 Metadata (io.prestosql.metadata.Metadata)10 CallExpression (io.prestosql.spi.relation.CallExpression)10 ArrayList (java.util.ArrayList)10