Search in sources :

Example 1 with ArithmeticBinaryExpression

use of io.prestosql.sql.tree.ArithmeticBinaryExpression in project hetu-core by openlookeng.

the class TestDesugarTryExpressionRewriter method testTryExpressionDesugaringRewriter.

@Test
public void testTryExpressionDesugaringRewriter() {
    // 1 + try(2)
    Expression before = new ArithmeticBinaryExpression(ADD, new DecimalLiteral("1"), new TryExpression(new DecimalLiteral("2")));
    // 1 + try_function(() -> 2)
    Expression after = new ArithmeticBinaryExpression(ADD, new DecimalLiteral("1"), new FunctionCallBuilder(tester().getMetadata()).setName(QualifiedName.of(TryFunction.NAME)).addArgument(new FunctionType(ImmutableList.of(), createDecimalType(1)), new LambdaExpression(ImmutableList.of(), new DecimalLiteral("2"))).build());
    assertEquals(DesugarTryExpressionRewriter.rewrite(before, tester().getMetadata(), tester().getTypeAnalyzer(), tester().getSession(), new PlanSymbolAllocator()), after);
}
Also used : ArithmeticBinaryExpression(io.prestosql.sql.tree.ArithmeticBinaryExpression) LambdaExpression(io.prestosql.sql.tree.LambdaExpression) ArithmeticBinaryExpression(io.prestosql.sql.tree.ArithmeticBinaryExpression) TryExpression(io.prestosql.sql.tree.TryExpression) Expression(io.prestosql.sql.tree.Expression) FunctionType(io.prestosql.spi.type.FunctionType) DecimalLiteral(io.prestosql.sql.tree.DecimalLiteral) TryExpression(io.prestosql.sql.tree.TryExpression) LambdaExpression(io.prestosql.sql.tree.LambdaExpression) BaseRuleTest(io.prestosql.sql.planner.iterative.rule.test.BaseRuleTest) Test(org.testng.annotations.Test)

Example 2 with ArithmeticBinaryExpression

use of io.prestosql.sql.tree.ArithmeticBinaryExpression in project hetu-core by openlookeng.

the class TestSqlParser method testPrecedenceAndAssociativity.

@Test
public void testPrecedenceAndAssociativity() {
    assertExpression("1 AND 2 OR 3", new LogicalBinaryExpression(LogicalBinaryExpression.Operator.OR, new LogicalBinaryExpression(LogicalBinaryExpression.Operator.AND, new LongLiteral("1"), new LongLiteral("2")), new LongLiteral("3")));
    assertExpression("1 OR 2 AND 3", new LogicalBinaryExpression(LogicalBinaryExpression.Operator.OR, new LongLiteral("1"), new LogicalBinaryExpression(LogicalBinaryExpression.Operator.AND, new LongLiteral("2"), new LongLiteral("3"))));
    assertExpression("NOT 1 AND 2", new LogicalBinaryExpression(LogicalBinaryExpression.Operator.AND, new NotExpression(new LongLiteral("1")), new LongLiteral("2")));
    assertExpression("NOT 1 OR 2", new LogicalBinaryExpression(LogicalBinaryExpression.Operator.OR, new NotExpression(new LongLiteral("1")), new LongLiteral("2")));
    assertExpression("-1 + 2", new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.ADD, new LongLiteral("-1"), new LongLiteral("2")));
    assertExpression("1 - 2 - 3", new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.SUBTRACT, new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.SUBTRACT, new LongLiteral("1"), new LongLiteral("2")), new LongLiteral("3")));
    assertExpression("1 / 2 / 3", new ArithmeticBinaryExpression(DIVIDE, new ArithmeticBinaryExpression(DIVIDE, new LongLiteral("1"), new LongLiteral("2")), new LongLiteral("3")));
    assertExpression("1 + 2 * 3", new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.ADD, new LongLiteral("1"), new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.MULTIPLY, new LongLiteral("2"), new LongLiteral("3"))));
}
Also used : LogicalBinaryExpression(io.prestosql.sql.tree.LogicalBinaryExpression) ArithmeticBinaryExpression(io.prestosql.sql.tree.ArithmeticBinaryExpression) LongLiteral(io.prestosql.sql.tree.LongLiteral) NotExpression(io.prestosql.sql.tree.NotExpression) Test(org.testng.annotations.Test)

Example 3 with ArithmeticBinaryExpression

use of io.prestosql.sql.tree.ArithmeticBinaryExpression in project hetu-core by openlookeng.

the class AggregationRewriteWithCube method rewrite.

public PlanNode rewrite(AggregationNode originalAggregationNode, PlanNode filterNode) {
    QualifiedObjectName starTreeTableName = QualifiedObjectName.valueOf(cubeMetadata.getCubeName());
    TableHandle cubeTableHandle = metadata.getTableHandle(session, starTreeTableName).orElseThrow(() -> new CubeNotFoundException(starTreeTableName.toString()));
    Map<String, ColumnHandle> cubeColumnsMap = metadata.getColumnHandles(session, cubeTableHandle);
    TableMetadata cubeTableMetadata = metadata.getTableMetadata(session, cubeTableHandle);
    List<ColumnMetadata> cubeColumnMetadataList = cubeTableMetadata.getColumns();
    // Add group by
    List<Symbol> groupings = new ArrayList<>(originalAggregationNode.getGroupingKeys().size());
    for (Symbol symbol : originalAggregationNode.getGroupingKeys()) {
        Object column = symbolMappings.get(symbol.getName());
        if (column instanceof ColumnHandle) {
            groupings.add(new Symbol(((ColumnHandle) column).getColumnName()));
        }
    }
    Set<String> cubeGroups = cubeMetadata.getGroup();
    boolean exactGroupsMatch = false;
    if (groupings.size() == cubeGroups.size()) {
        exactGroupsMatch = groupings.stream().map(Symbol::getName).map(String::toLowerCase).allMatch(cubeGroups::contains);
    }
    CubeRewriteResult cubeRewriteResult = createScanNode(originalAggregationNode, filterNode, cubeTableHandle, cubeColumnsMap, cubeColumnMetadataList, exactGroupsMatch);
    PlanNode planNode = cubeRewriteResult.getTableScanNode();
    // Add filter node
    if (filterNode != null) {
        Expression expression = castToExpression(((FilterNode) filterNode).getPredicate());
        expression = rewriteExpression(expression, rewrittenMappings);
        planNode = new FilterNode(idAllocator.getNextId(), planNode, castToRowExpression(expression));
    }
    if (!exactGroupsMatch) {
        Map<Symbol, Symbol> cubeScanToAggOutputMap = new HashMap<>();
        // Rewrite AggregationNode using Cube table
        ImmutableMap.Builder<Symbol, AggregationNode.Aggregation> aggregationsBuilder = ImmutableMap.builder();
        for (CubeRewriteResult.AggregatorSource aggregatorSource : cubeRewriteResult.getAggregationColumns()) {
            ColumnHandle cubeColHandle = cubeRewriteResult.getTableScanNode().getAssignments().get(aggregatorSource.getScanSymbol());
            ColumnMetadata cubeColumnMetadata = cubeRewriteResult.getSymbolMetadataMap().get(aggregatorSource.getScanSymbol());
            Type type = cubeColumnMetadata.getType();
            AggregationSignature aggregationSignature = cubeMetadata.getAggregationSignature(cubeColumnMetadata.getName()).orElseThrow(() -> new ColumnNotFoundException(new SchemaTableName(starTreeTableName.getSchemaName(), starTreeTableName.getObjectName()), cubeColHandle.getColumnName()));
            String aggFunction = COUNT.getName().equals(aggregationSignature.getFunction()) ? "sum" : aggregationSignature.getFunction();
            SymbolReference argument = toSymbolReference(aggregatorSource.getScanSymbol());
            FunctionHandle functionHandle = metadata.getFunctionAndTypeManager().lookupFunction(aggFunction, TypeSignatureProvider.fromTypeSignatures(type.getTypeSignature()));
            cubeScanToAggOutputMap.put(aggregatorSource.getScanSymbol(), aggregatorSource.getOriginalAggSymbol());
            aggregationsBuilder.put(aggregatorSource.getOriginalAggSymbol(), new AggregationNode.Aggregation(new CallExpression(aggFunction, functionHandle, type, ImmutableList.of(OriginalExpressionUtils.castToRowExpression(argument))), ImmutableList.of(OriginalExpressionUtils.castToRowExpression(argument)), false, Optional.empty(), Optional.empty(), Optional.empty()));
        }
        List<Symbol> groupingKeys = originalAggregationNode.getGroupingKeys().stream().map(Symbol::getName).map(rewrittenMappings::get).collect(Collectors.toList());
        planNode = new AggregationNode(idAllocator.getNextId(), planNode, aggregationsBuilder.build(), singleGroupingSet(groupingKeys), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty(), AggregationNode.AggregationType.HASH, Optional.empty());
        AggregationNode aggNode = (AggregationNode) planNode;
        if (!cubeRewriteResult.getAvgAggregationColumns().isEmpty()) {
            if (!cubeRewriteResult.getComputeAvgDividingSumByCount()) {
                Map<Symbol, Expression> aggregateAssignments = new HashMap<>();
                for (CubeRewriteResult.AggregatorSource aggregatorSource : cubeRewriteResult.getAggregationColumns()) {
                    aggregateAssignments.put(aggregatorSource.getOriginalAggSymbol(), toSymbolReference(aggregatorSource.getScanSymbol()));
                }
                planNode = new ProjectNode(idAllocator.getNextId(), aggNode, new Assignments(aggregateAssignments.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> castToRowExpression(entry.getValue())))));
            } else {
                // If there was an AVG aggregation, map it to AVG = SUM/COUNT
                Map<Symbol, Expression> projections = new HashMap<>();
                aggNode.getGroupingKeys().forEach(symbol -> projections.put(symbol, toSymbolReference(symbol)));
                aggNode.getAggregations().keySet().stream().filter(symbol -> symbolMappings.containsValue(symbol.getName())).forEach(aggSymbol -> projections.put(aggSymbol, toSymbolReference(aggSymbol)));
                // Add AVG = SUM / COUNT
                for (CubeRewriteResult.AverageAggregatorSource avgAggSource : cubeRewriteResult.getAvgAggregationColumns()) {
                    Symbol sumSymbol = cubeScanToAggOutputMap.get(avgAggSource.getSum());
                    Symbol countSymbol = cubeScanToAggOutputMap.get(avgAggSource.getCount());
                    Type avgResultType = typeProvider.get(avgAggSource.getOriginalAggSymbol());
                    ArithmeticBinaryExpression division = new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.DIVIDE, new Cast(toSymbolReference(sumSymbol), avgResultType.getTypeSignature().toString()), new Cast(toSymbolReference(countSymbol), avgResultType.getTypeSignature().toString()));
                    projections.put(avgAggSource.getOriginalAggSymbol(), division);
                }
                planNode = new ProjectNode(idAllocator.getNextId(), aggNode, new Assignments(projections.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> castToRowExpression(entry.getValue())))));
            }
        }
    }
    // Safety check to remove redundant symbols and rename original column names to intermediate names
    if (!planNode.getOutputSymbols().equals(originalAggregationNode.getOutputSymbols())) {
        // Map new symbol names to the old symbols
        Map<Symbol, Expression> assignments = new HashMap<>();
        Set<Symbol> planNodeOutput = new HashSet<>(planNode.getOutputSymbols());
        for (Symbol originalAggOutputSymbol : originalAggregationNode.getOutputSymbols()) {
            if (!planNodeOutput.contains(originalAggOutputSymbol)) {
                // Must be grouping key
                assignments.put(originalAggOutputSymbol, toSymbolReference(rewrittenMappings.get(originalAggOutputSymbol.getName())));
            } else {
                // Should be an expression and must have the same name in the new plan node
                assignments.put(originalAggOutputSymbol, toSymbolReference(originalAggOutputSymbol));
            }
        }
        planNode = new ProjectNode(idAllocator.getNextId(), planNode, new Assignments(assignments.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> castToRowExpression(entry.getValue())))));
    }
    return planNode;
}
Also used : TypeProvider(io.prestosql.sql.planner.TypeProvider) ColumnNotFoundException(io.prestosql.spi.connector.ColumnNotFoundException) AggregationNode(io.prestosql.spi.plan.AggregationNode) CallExpression(io.prestosql.spi.relation.CallExpression) Cast(io.prestosql.sql.tree.Cast) FilterNode(io.prestosql.spi.plan.FilterNode) Map(java.util.Map) OriginalExpressionUtils.castToRowExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToRowExpression) Type(io.prestosql.spi.type.Type) OriginalExpressionUtils.castToExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToExpression) AggregationNode.singleGroupingSet(io.prestosql.spi.plan.AggregationNode.singleGroupingSet) PrestoException(io.prestosql.spi.PrestoException) SymbolsExtractor(io.prestosql.sql.planner.SymbolsExtractor) SUM(io.hetu.core.spi.cube.CubeAggregateFunction.SUM) ImmutableMap(com.google.common.collect.ImmutableMap) CubeAggregateFunction(io.hetu.core.spi.cube.CubeAggregateFunction) TableScanNode(io.prestosql.spi.plan.TableScanNode) Set(java.util.Set) PlanNode(io.prestosql.spi.plan.PlanNode) UUID(java.util.UUID) ProjectNode(io.prestosql.spi.plan.ProjectNode) Collectors(java.util.stream.Collectors) Metadata(io.prestosql.metadata.Metadata) FunctionHandle(io.prestosql.spi.function.FunctionHandle) PlanSymbolAllocator(io.prestosql.sql.planner.PlanSymbolAllocator) SymbolUtils.toSymbolReference(io.prestosql.sql.planner.SymbolUtils.toSymbolReference) ReuseExchangeOperator(io.prestosql.spi.operator.ReuseExchangeOperator) List(java.util.List) LongLiteral(io.prestosql.sql.tree.LongLiteral) SymbolReference(io.prestosql.sql.tree.SymbolReference) AggregationSignature(io.hetu.core.spi.cube.aggregator.AggregationSignature) Optional(java.util.Optional) StandardErrorCode(io.prestosql.spi.StandardErrorCode) OriginalExpressionUtils(io.prestosql.sql.relational.OriginalExpressionUtils) COUNT(io.hetu.core.spi.cube.CubeAggregateFunction.COUNT) TableMetadata(io.prestosql.metadata.TableMetadata) Logger(io.airlift.log.Logger) TypeSignatureProvider(io.prestosql.sql.analyzer.TypeSignatureProvider) HashMap(java.util.HashMap) TableHandle(io.prestosql.spi.metadata.TableHandle) Function(java.util.function.Function) ExpressionTreeRewriter(io.prestosql.sql.tree.ExpressionTreeRewriter) QualifiedObjectName(io.prestosql.spi.connector.QualifiedObjectName) ArrayList(java.util.ArrayList) HashSet(java.util.HashSet) SchemaTableName(io.prestosql.spi.connector.SchemaTableName) ArithmeticBinaryExpression(io.prestosql.sql.tree.ArithmeticBinaryExpression) ImmutableList(com.google.common.collect.ImmutableList) Session(io.prestosql.Session) ExpressionRewriter(io.prestosql.sql.tree.ExpressionRewriter) Symbol(io.prestosql.spi.plan.Symbol) Assignments(io.prestosql.spi.plan.Assignments) ColumnMetadata(io.prestosql.spi.connector.ColumnMetadata) SymbolAllocator(io.prestosql.spi.SymbolAllocator) ColumnHandle(io.prestosql.spi.connector.ColumnHandle) CUBE_ERROR(io.prestosql.spi.StandardErrorCode.CUBE_ERROR) PlanNodeIdAllocator(io.prestosql.spi.plan.PlanNodeIdAllocator) CubeMetadata(io.hetu.core.spi.cube.CubeMetadata) CubeNotFoundException(io.prestosql.spi.connector.CubeNotFoundException) Comparator(java.util.Comparator) Expression(io.prestosql.sql.tree.Expression) ArithmeticBinaryExpression(io.prestosql.sql.tree.ArithmeticBinaryExpression) Cast(io.prestosql.sql.tree.Cast) ColumnMetadata(io.prestosql.spi.connector.ColumnMetadata) HashMap(java.util.HashMap) Symbol(io.prestosql.spi.plan.Symbol) SymbolUtils.toSymbolReference(io.prestosql.sql.planner.SymbolUtils.toSymbolReference) SymbolReference(io.prestosql.sql.tree.SymbolReference) FilterNode(io.prestosql.spi.plan.FilterNode) ArrayList(java.util.ArrayList) Assignments(io.prestosql.spi.plan.Assignments) PlanNode(io.prestosql.spi.plan.PlanNode) CubeNotFoundException(io.prestosql.spi.connector.CubeNotFoundException) FunctionHandle(io.prestosql.spi.function.FunctionHandle) CallExpression(io.prestosql.spi.relation.CallExpression) HashSet(java.util.HashSet) TableMetadata(io.prestosql.metadata.TableMetadata) ColumnHandle(io.prestosql.spi.connector.ColumnHandle) AggregationSignature(io.hetu.core.spi.cube.aggregator.AggregationSignature) AggregationNode(io.prestosql.spi.plan.AggregationNode) SchemaTableName(io.prestosql.spi.connector.SchemaTableName) QualifiedObjectName(io.prestosql.spi.connector.QualifiedObjectName) ImmutableMap(com.google.common.collect.ImmutableMap) Type(io.prestosql.spi.type.Type) ColumnNotFoundException(io.prestosql.spi.connector.ColumnNotFoundException) CallExpression(io.prestosql.spi.relation.CallExpression) OriginalExpressionUtils.castToRowExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToRowExpression) OriginalExpressionUtils.castToExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToExpression) ArithmeticBinaryExpression(io.prestosql.sql.tree.ArithmeticBinaryExpression) Expression(io.prestosql.sql.tree.Expression) TableHandle(io.prestosql.spi.metadata.TableHandle) ProjectNode(io.prestosql.spi.plan.ProjectNode)

Example 4 with ArithmeticBinaryExpression

use of io.prestosql.sql.tree.ArithmeticBinaryExpression in project hetu-core by openlookeng.

the class GroupingOperationRewriter method rewriteGroupingOperation.

public static Expression rewriteGroupingOperation(GroupingOperation expression, List<Set<Integer>> groupingSets, Map<NodeRef<Expression>, FieldId> columnReferenceFields, Optional<Symbol> groupIdSymbol) {
    requireNonNull(groupIdSymbol, "groupIdSymbol is null");
    // See SQL:2011:4.16.2 and SQL:2011:6.9.10.
    if (groupingSets.size() == 1) {
        return new LongLiteral("0");
    } else {
        checkState(groupIdSymbol.isPresent(), "groupId symbol is missing");
        RelationId relationId = columnReferenceFields.get(NodeRef.of(expression.getGroupingColumns().get(0))).getRelationId();
        List<Integer> columns = expression.getGroupingColumns().stream().map(NodeRef::of).peek(groupingColumn -> checkState(columnReferenceFields.containsKey(groupingColumn), "the grouping column is not in the columnReferencesField map")).map(columnReferenceFields::get).map(fieldId -> translateFieldToInteger(fieldId, relationId)).collect(toImmutableList());
        List<Expression> groupingResults = groupingSets.stream().map(groupingSet -> String.valueOf(calculateGrouping(groupingSet, columns))).map(LongLiteral::new).collect(toImmutableList());
        // It is necessary to add a 1 to the groupId because the underlying array is indexed starting at 1
        return new SubscriptExpression(new ArrayConstructor(groupingResults), new ArithmeticBinaryExpression(ADD, toSymbolReference(groupIdSymbol.get()), new GenericLiteral("BIGINT", "1")));
    }
}
Also used : Symbol(io.prestosql.spi.plan.Symbol) GroupingOperation(io.prestosql.sql.tree.GroupingOperation) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) Set(java.util.Set) RelationId(io.prestosql.sql.analyzer.RelationId) NodeRef(io.prestosql.sql.tree.NodeRef) Preconditions.checkState(com.google.common.base.Preconditions.checkState) SymbolUtils.toSymbolReference(io.prestosql.sql.planner.SymbolUtils.toSymbolReference) List(java.util.List) ArithmeticBinaryExpression(io.prestosql.sql.tree.ArithmeticBinaryExpression) GenericLiteral(io.prestosql.sql.tree.GenericLiteral) LongLiteral(io.prestosql.sql.tree.LongLiteral) ADD(io.prestosql.sql.tree.ArithmeticBinaryExpression.Operator.ADD) Map(java.util.Map) Objects.requireNonNull(java.util.Objects.requireNonNull) ArrayConstructor(io.prestosql.sql.tree.ArrayConstructor) Optional(java.util.Optional) SubscriptExpression(io.prestosql.sql.tree.SubscriptExpression) FieldId(io.prestosql.sql.analyzer.FieldId) Expression(io.prestosql.sql.tree.Expression) ArithmeticBinaryExpression(io.prestosql.sql.tree.ArithmeticBinaryExpression) LongLiteral(io.prestosql.sql.tree.LongLiteral) GenericLiteral(io.prestosql.sql.tree.GenericLiteral) NodeRef(io.prestosql.sql.tree.NodeRef) ArithmeticBinaryExpression(io.prestosql.sql.tree.ArithmeticBinaryExpression) SubscriptExpression(io.prestosql.sql.tree.SubscriptExpression) Expression(io.prestosql.sql.tree.Expression) RelationId(io.prestosql.sql.analyzer.RelationId) SubscriptExpression(io.prestosql.sql.tree.SubscriptExpression) ArrayConstructor(io.prestosql.sql.tree.ArrayConstructor)

Example 5 with ArithmeticBinaryExpression

use of io.prestosql.sql.tree.ArithmeticBinaryExpression in project hetu-core by openlookeng.

the class CubeOptimizer method rewriteAggregationNode.

private PlanNode rewriteAggregationNode(CubeRewriteResult cubeRewriteResult, PlanNode inputPlanNode) {
    TypeProvider typeProvider = context.getSymbolAllocator().getTypes();
    // Add group by
    List<Symbol> groupings = aggregationNode.getGroupingKeys().stream().map(Symbol::getName).map(columnRewritesMap::get).map(optimizedPlanMappings::get).collect(Collectors.toList());
    Map<Symbol, Symbol> cubeScanToAggOutputMap = new HashMap<>();
    // Rewrite AggregationNode using Cube table
    ImmutableMap.Builder<Symbol, AggregationNode.Aggregation> aggregationsBuilder = ImmutableMap.builder();
    for (CubeRewriteResult.AggregatorSource aggregatorSource : cubeRewriteResult.getAggregationColumns()) {
        Type type = cubeRewriteResult.getSymbolMetadataMap().get(aggregatorSource.getOriginalAggSymbol()).getType();
        TypeSignature typeSignature = type.getTypeSignature();
        ColumnHandle cubeColHandle = cubeRewriteResult.getTableScanNode().getAssignments().get(aggregatorSource.getScanSymbol());
        ColumnMetadata cubeColumnMetadata = metadata.getColumnMetadata(context.getSession(), cubeTableHandle, cubeColHandle);
        AggregationSignature aggregationSignature = cubeMetadata.getAggregationSignature(cubeColumnMetadata.getName()).orElseThrow(() -> new ColumnNotFoundException(new SchemaTableName("", ""), cubeColHandle.getColumnName()));
        String aggFunction = COUNT.getName().equals(aggregationSignature.getFunction()) ? SUM.getName() : aggregationSignature.getFunction();
        SymbolReference argument = toSymbolReference(aggregatorSource.getScanSymbol());
        FunctionHandle functionHandle = metadata.getFunctionAndTypeManager().lookupFunction(aggFunction, TypeSignatureProvider.fromTypeSignatures(typeSignature));
        cubeScanToAggOutputMap.put(aggregatorSource.getScanSymbol(), aggregatorSource.getOriginalAggSymbol());
        aggregationsBuilder.put(aggregatorSource.getOriginalAggSymbol(), new AggregationNode.Aggregation(new CallExpression(aggFunction, functionHandle, type, ImmutableList.of(castToRowExpression(argument))), ImmutableList.of(castToRowExpression(argument)), false, Optional.empty(), Optional.empty(), Optional.empty()));
    }
    PlanNode planNode = inputPlanNode;
    AggregationNode aggNode = new AggregationNode(context.getIdAllocator().getNextId(), planNode, aggregationsBuilder.build(), singleGroupingSet(groupings), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty(), AggregationNode.AggregationType.HASH, Optional.empty());
    if (cubeRewriteResult.getAvgAggregationColumns().isEmpty()) {
        return aggNode;
    }
    if (!cubeRewriteResult.getComputeAvgDividingSumByCount()) {
        Map<Symbol, Expression> aggregateAssignments = new HashMap<>();
        for (CubeRewriteResult.AggregatorSource aggregatorSource : cubeRewriteResult.getAggregationColumns()) {
            aggregateAssignments.put(aggregatorSource.getOriginalAggSymbol(), toSymbolReference(aggregatorSource.getScanSymbol()));
        }
        planNode = new ProjectNode(context.getIdAllocator().getNextId(), aggNode, new Assignments(aggregateAssignments.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> castToRowExpression(entry.getValue())))));
    } else {
        // If there was an AVG aggregation, map it to AVG = SUM/COUNT
        Map<Symbol, Expression> projections = new HashMap<>();
        aggNode.getGroupingKeys().forEach(symbol -> projections.put(symbol, toSymbolReference(symbol)));
        aggNode.getAggregations().keySet().stream().filter(originalAggregationsMap::containsValue).forEach(aggSymbol -> projections.put(aggSymbol, toSymbolReference(aggSymbol)));
        // Add AVG = SUM / COUNT
        for (CubeRewriteResult.AverageAggregatorSource avgAggSource : cubeRewriteResult.getAvgAggregationColumns()) {
            Symbol sumSymbol = cubeScanToAggOutputMap.get(avgAggSource.getSum());
            Symbol countSymbol = cubeScanToAggOutputMap.get(avgAggSource.getCount());
            Type avgResultType = typeProvider.get(avgAggSource.getOriginalAggSymbol());
            ArithmeticBinaryExpression division = new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.DIVIDE, new Cast(toSymbolReference(sumSymbol), avgResultType.getTypeSignature().toString()), new Cast(toSymbolReference(countSymbol), avgResultType.getTypeSignature().toString()));
            projections.put(avgAggSource.getOriginalAggSymbol(), division);
        }
        return new ProjectNode(context.getIdAllocator().getNextId(), aggNode, new Assignments(projections.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> castToRowExpression(entry.getValue())))));
    }
    return planNode;
}
Also used : ArithmeticBinaryExpression(io.prestosql.sql.tree.ArithmeticBinaryExpression) Cast(io.prestosql.sql.tree.Cast) LongSupplier(java.util.function.LongSupplier) ConstantExpression(io.prestosql.spi.relation.ConstantExpression) PrestoWarning(io.prestosql.spi.PrestoWarning) TypeProvider(io.prestosql.sql.planner.TypeProvider) SqlParser(io.prestosql.sql.parser.SqlParser) ColumnNotFoundException(io.prestosql.spi.connector.ColumnNotFoundException) AggregationNode(io.prestosql.spi.plan.AggregationNode) CallExpression(io.prestosql.spi.relation.CallExpression) Cast(io.prestosql.sql.tree.Cast) FilterNode(io.prestosql.spi.plan.FilterNode) Pair(org.apache.commons.lang3.tuple.Pair) Map(java.util.Map) OriginalExpressionUtils.castToRowExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToRowExpression) Type(io.prestosql.spi.type.Type) OriginalExpressionUtils.castToExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToExpression) CubeFilter(io.hetu.core.spi.cube.CubeFilter) ENGLISH(java.util.Locale.ENGLISH) Identifier(io.prestosql.sql.tree.Identifier) AggregationNode.singleGroupingSet(io.prestosql.spi.plan.AggregationNode.singleGroupingSet) CubeMetaStore(io.hetu.core.spi.cube.io.CubeMetaStore) PrestoException(io.prestosql.spi.PrestoException) SymbolsExtractor(io.prestosql.sql.planner.SymbolsExtractor) SUM(io.hetu.core.spi.cube.CubeAggregateFunction.SUM) ImmutableMap(com.google.common.collect.ImmutableMap) CubeAggregateFunction(io.hetu.core.spi.cube.CubeAggregateFunction) TableScanNode(io.prestosql.spi.plan.TableScanNode) Set(java.util.Set) PlanNode(io.prestosql.spi.plan.PlanNode) UUID(java.util.UUID) ProjectNode(io.prestosql.spi.plan.ProjectNode) Collectors(java.util.stream.Collectors) Metadata(io.prestosql.metadata.Metadata) String.format(java.lang.String.format) FunctionHandle(io.prestosql.spi.function.FunctionHandle) Objects(java.util.Objects) SymbolUtils.toSymbolReference(io.prestosql.sql.planner.SymbolUtils.toSymbolReference) ReuseExchangeOperator(io.prestosql.spi.operator.ReuseExchangeOperator) List(java.util.List) ExpressionUtils(io.prestosql.sql.ExpressionUtils) LongLiteral(io.prestosql.sql.tree.LongLiteral) SymbolReference(io.prestosql.sql.tree.SymbolReference) AggregationSignature(io.hetu.core.spi.cube.aggregator.AggregationSignature) Optional(java.util.Optional) NOT_SUPPORTED(io.prestosql.spi.StandardErrorCode.NOT_SUPPORTED) TypeSignature(io.prestosql.spi.type.TypeSignature) OriginalExpressionUtils(io.prestosql.sql.relational.OriginalExpressionUtils) Iterables(com.google.common.collect.Iterables) COUNT(io.hetu.core.spi.cube.CubeAggregateFunction.COUNT) TableMetadata(io.prestosql.metadata.TableMetadata) Logger(io.airlift.log.Logger) TypeSignatureProvider(io.prestosql.sql.analyzer.TypeSignatureProvider) Literal(io.prestosql.sql.tree.Literal) HashMap(java.util.HashMap) TableHandle(io.prestosql.spi.metadata.TableHandle) ExpressionTreeRewriter(io.prestosql.sql.tree.ExpressionTreeRewriter) QualifiedObjectName(io.prestosql.spi.connector.QualifiedObjectName) ArrayList(java.util.ArrayList) HashSet(java.util.HashSet) SchemaTableName(io.prestosql.spi.connector.SchemaTableName) Lists(com.google.common.collect.Lists) ArithmeticBinaryExpression(io.prestosql.sql.tree.ArithmeticBinaryExpression) ImmutableList(com.google.common.collect.ImmutableList) ExpressionDomainTranslator(io.prestosql.sql.planner.ExpressionDomainTranslator) BooleanLiteral(io.prestosql.sql.tree.BooleanLiteral) Objects.requireNonNull(java.util.Objects.requireNonNull) Session(io.prestosql.Session) RowExpressionRewriter(io.prestosql.expressions.RowExpressionRewriter) RowExpressionTreeRewriter(io.prestosql.expressions.RowExpressionTreeRewriter) ExpressionRewriter(io.prestosql.sql.tree.ExpressionRewriter) JoinNode(io.prestosql.spi.plan.JoinNode) ParsingOptions(io.prestosql.sql.parser.ParsingOptions) Symbol(io.prestosql.spi.plan.Symbol) AssignmentUtils(io.prestosql.sql.planner.plan.AssignmentUtils) EXPIRED_CUBE(io.prestosql.spi.connector.StandardWarningCode.EXPIRED_CUBE) Assignments(io.prestosql.spi.plan.Assignments) Rule(io.prestosql.sql.planner.iterative.Rule) ColumnMetadata(io.prestosql.spi.connector.ColumnMetadata) TupleDomain(io.prestosql.spi.predicate.TupleDomain) ComparisonExpression(io.prestosql.sql.tree.ComparisonExpression) VariableReferenceExpression(io.prestosql.spi.relation.VariableReferenceExpression) SymbolAllocator(io.prestosql.spi.SymbolAllocator) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) ColumnHandle(io.prestosql.spi.connector.ColumnHandle) CUBE_ERROR(io.prestosql.spi.StandardErrorCode.CUBE_ERROR) RowExpression(io.prestosql.spi.relation.RowExpression) CubeMetadata(io.hetu.core.spi.cube.CubeMetadata) Comparator(java.util.Comparator) Expression(io.prestosql.sql.tree.Expression) ColumnMetadata(io.prestosql.spi.connector.ColumnMetadata) HashMap(java.util.HashMap) Symbol(io.prestosql.spi.plan.Symbol) SymbolUtils.toSymbolReference(io.prestosql.sql.planner.SymbolUtils.toSymbolReference) SymbolReference(io.prestosql.sql.tree.SymbolReference) Assignments(io.prestosql.spi.plan.Assignments) TypeSignature(io.prestosql.spi.type.TypeSignature) PlanNode(io.prestosql.spi.plan.PlanNode) FunctionHandle(io.prestosql.spi.function.FunctionHandle) CallExpression(io.prestosql.spi.relation.CallExpression) ColumnHandle(io.prestosql.spi.connector.ColumnHandle) TypeProvider(io.prestosql.sql.planner.TypeProvider) AggregationSignature(io.hetu.core.spi.cube.aggregator.AggregationSignature) AggregationNode(io.prestosql.spi.plan.AggregationNode) SchemaTableName(io.prestosql.spi.connector.SchemaTableName) ImmutableMap(com.google.common.collect.ImmutableMap) Type(io.prestosql.spi.type.Type) ColumnNotFoundException(io.prestosql.spi.connector.ColumnNotFoundException) ConstantExpression(io.prestosql.spi.relation.ConstantExpression) CallExpression(io.prestosql.spi.relation.CallExpression) OriginalExpressionUtils.castToRowExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToRowExpression) OriginalExpressionUtils.castToExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToExpression) ArithmeticBinaryExpression(io.prestosql.sql.tree.ArithmeticBinaryExpression) ComparisonExpression(io.prestosql.sql.tree.ComparisonExpression) VariableReferenceExpression(io.prestosql.spi.relation.VariableReferenceExpression) RowExpression(io.prestosql.spi.relation.RowExpression) Expression(io.prestosql.sql.tree.Expression) ProjectNode(io.prestosql.spi.plan.ProjectNode) Map(java.util.Map) ImmutableMap(com.google.common.collect.ImmutableMap) HashMap(java.util.HashMap)

Aggregations

ArithmeticBinaryExpression (io.prestosql.sql.tree.ArithmeticBinaryExpression)6 Expression (io.prestosql.sql.tree.Expression)5 LongLiteral (io.prestosql.sql.tree.LongLiteral)5 Symbol (io.prestosql.spi.plan.Symbol)3 List (java.util.List)3 ImmutableList (com.google.common.collect.ImmutableList)2 ImmutableMap (com.google.common.collect.ImmutableMap)2 Logger (io.airlift.log.Logger)2 CubeAggregateFunction (io.hetu.core.spi.cube.CubeAggregateFunction)2 COUNT (io.hetu.core.spi.cube.CubeAggregateFunction.COUNT)2 SUM (io.hetu.core.spi.cube.CubeAggregateFunction.SUM)2 CubeMetadata (io.hetu.core.spi.cube.CubeMetadata)2 AggregationSignature (io.hetu.core.spi.cube.aggregator.AggregationSignature)2 Session (io.prestosql.Session)2 Metadata (io.prestosql.metadata.Metadata)2 TableMetadata (io.prestosql.metadata.TableMetadata)2 PrestoException (io.prestosql.spi.PrestoException)2 CUBE_ERROR (io.prestosql.spi.StandardErrorCode.CUBE_ERROR)2 SymbolAllocator (io.prestosql.spi.SymbolAllocator)2 ColumnHandle (io.prestosql.spi.connector.ColumnHandle)2