Search in sources :

Example 21 with TypeProvider

use of io.prestosql.sql.planner.TypeProvider in project hetu-core by openlookeng.

the class CubeOptimizer method rewriteAggregationNode.

private PlanNode rewriteAggregationNode(CubeRewriteResult cubeRewriteResult, PlanNode inputPlanNode) {
    TypeProvider typeProvider = context.getSymbolAllocator().getTypes();
    // Add group by
    List<Symbol> groupings = aggregationNode.getGroupingKeys().stream().map(Symbol::getName).map(columnRewritesMap::get).map(optimizedPlanMappings::get).collect(Collectors.toList());
    Map<Symbol, Symbol> cubeScanToAggOutputMap = new HashMap<>();
    // Rewrite AggregationNode using Cube table
    ImmutableMap.Builder<Symbol, AggregationNode.Aggregation> aggregationsBuilder = ImmutableMap.builder();
    for (CubeRewriteResult.AggregatorSource aggregatorSource : cubeRewriteResult.getAggregationColumns()) {
        Type type = cubeRewriteResult.getSymbolMetadataMap().get(aggregatorSource.getOriginalAggSymbol()).getType();
        TypeSignature typeSignature = type.getTypeSignature();
        ColumnHandle cubeColHandle = cubeRewriteResult.getTableScanNode().getAssignments().get(aggregatorSource.getScanSymbol());
        ColumnMetadata cubeColumnMetadata = metadata.getColumnMetadata(context.getSession(), cubeTableHandle, cubeColHandle);
        AggregationSignature aggregationSignature = cubeMetadata.getAggregationSignature(cubeColumnMetadata.getName()).orElseThrow(() -> new ColumnNotFoundException(new SchemaTableName("", ""), cubeColHandle.getColumnName()));
        String aggFunction = COUNT.getName().equals(aggregationSignature.getFunction()) ? SUM.getName() : aggregationSignature.getFunction();
        SymbolReference argument = toSymbolReference(aggregatorSource.getScanSymbol());
        FunctionHandle functionHandle = metadata.getFunctionAndTypeManager().lookupFunction(aggFunction, TypeSignatureProvider.fromTypeSignatures(typeSignature));
        cubeScanToAggOutputMap.put(aggregatorSource.getScanSymbol(), aggregatorSource.getOriginalAggSymbol());
        aggregationsBuilder.put(aggregatorSource.getOriginalAggSymbol(), new AggregationNode.Aggregation(new CallExpression(aggFunction, functionHandle, type, ImmutableList.of(castToRowExpression(argument))), ImmutableList.of(castToRowExpression(argument)), false, Optional.empty(), Optional.empty(), Optional.empty()));
    }
    PlanNode planNode = inputPlanNode;
    AggregationNode aggNode = new AggregationNode(context.getIdAllocator().getNextId(), planNode, aggregationsBuilder.build(), singleGroupingSet(groupings), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty(), AggregationNode.AggregationType.HASH, Optional.empty());
    if (cubeRewriteResult.getAvgAggregationColumns().isEmpty()) {
        return aggNode;
    }
    if (!cubeRewriteResult.getComputeAvgDividingSumByCount()) {
        Map<Symbol, Expression> aggregateAssignments = new HashMap<>();
        for (CubeRewriteResult.AggregatorSource aggregatorSource : cubeRewriteResult.getAggregationColumns()) {
            aggregateAssignments.put(aggregatorSource.getOriginalAggSymbol(), toSymbolReference(aggregatorSource.getScanSymbol()));
        }
        planNode = new ProjectNode(context.getIdAllocator().getNextId(), aggNode, new Assignments(aggregateAssignments.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> castToRowExpression(entry.getValue())))));
    } else {
        // If there was an AVG aggregation, map it to AVG = SUM/COUNT
        Map<Symbol, Expression> projections = new HashMap<>();
        aggNode.getGroupingKeys().forEach(symbol -> projections.put(symbol, toSymbolReference(symbol)));
        aggNode.getAggregations().keySet().stream().filter(originalAggregationsMap::containsValue).forEach(aggSymbol -> projections.put(aggSymbol, toSymbolReference(aggSymbol)));
        // Add AVG = SUM / COUNT
        for (CubeRewriteResult.AverageAggregatorSource avgAggSource : cubeRewriteResult.getAvgAggregationColumns()) {
            Symbol sumSymbol = cubeScanToAggOutputMap.get(avgAggSource.getSum());
            Symbol countSymbol = cubeScanToAggOutputMap.get(avgAggSource.getCount());
            Type avgResultType = typeProvider.get(avgAggSource.getOriginalAggSymbol());
            ArithmeticBinaryExpression division = new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.DIVIDE, new Cast(toSymbolReference(sumSymbol), avgResultType.getTypeSignature().toString()), new Cast(toSymbolReference(countSymbol), avgResultType.getTypeSignature().toString()));
            projections.put(avgAggSource.getOriginalAggSymbol(), division);
        }
        return new ProjectNode(context.getIdAllocator().getNextId(), aggNode, new Assignments(projections.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> castToRowExpression(entry.getValue())))));
    }
    return planNode;
}
Also used : ArithmeticBinaryExpression(io.prestosql.sql.tree.ArithmeticBinaryExpression) Cast(io.prestosql.sql.tree.Cast) LongSupplier(java.util.function.LongSupplier) ConstantExpression(io.prestosql.spi.relation.ConstantExpression) PrestoWarning(io.prestosql.spi.PrestoWarning) TypeProvider(io.prestosql.sql.planner.TypeProvider) SqlParser(io.prestosql.sql.parser.SqlParser) ColumnNotFoundException(io.prestosql.spi.connector.ColumnNotFoundException) AggregationNode(io.prestosql.spi.plan.AggregationNode) CallExpression(io.prestosql.spi.relation.CallExpression) Cast(io.prestosql.sql.tree.Cast) FilterNode(io.prestosql.spi.plan.FilterNode) Pair(org.apache.commons.lang3.tuple.Pair) Map(java.util.Map) OriginalExpressionUtils.castToRowExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToRowExpression) Type(io.prestosql.spi.type.Type) OriginalExpressionUtils.castToExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToExpression) CubeFilter(io.hetu.core.spi.cube.CubeFilter) ENGLISH(java.util.Locale.ENGLISH) Identifier(io.prestosql.sql.tree.Identifier) AggregationNode.singleGroupingSet(io.prestosql.spi.plan.AggregationNode.singleGroupingSet) CubeMetaStore(io.hetu.core.spi.cube.io.CubeMetaStore) PrestoException(io.prestosql.spi.PrestoException) SymbolsExtractor(io.prestosql.sql.planner.SymbolsExtractor) SUM(io.hetu.core.spi.cube.CubeAggregateFunction.SUM) ImmutableMap(com.google.common.collect.ImmutableMap) CubeAggregateFunction(io.hetu.core.spi.cube.CubeAggregateFunction) TableScanNode(io.prestosql.spi.plan.TableScanNode) Set(java.util.Set) PlanNode(io.prestosql.spi.plan.PlanNode) UUID(java.util.UUID) ProjectNode(io.prestosql.spi.plan.ProjectNode) Collectors(java.util.stream.Collectors) Metadata(io.prestosql.metadata.Metadata) String.format(java.lang.String.format) FunctionHandle(io.prestosql.spi.function.FunctionHandle) Objects(java.util.Objects) SymbolUtils.toSymbolReference(io.prestosql.sql.planner.SymbolUtils.toSymbolReference) ReuseExchangeOperator(io.prestosql.spi.operator.ReuseExchangeOperator) List(java.util.List) ExpressionUtils(io.prestosql.sql.ExpressionUtils) LongLiteral(io.prestosql.sql.tree.LongLiteral) SymbolReference(io.prestosql.sql.tree.SymbolReference) AggregationSignature(io.hetu.core.spi.cube.aggregator.AggregationSignature) Optional(java.util.Optional) NOT_SUPPORTED(io.prestosql.spi.StandardErrorCode.NOT_SUPPORTED) TypeSignature(io.prestosql.spi.type.TypeSignature) OriginalExpressionUtils(io.prestosql.sql.relational.OriginalExpressionUtils) Iterables(com.google.common.collect.Iterables) COUNT(io.hetu.core.spi.cube.CubeAggregateFunction.COUNT) TableMetadata(io.prestosql.metadata.TableMetadata) Logger(io.airlift.log.Logger) TypeSignatureProvider(io.prestosql.sql.analyzer.TypeSignatureProvider) Literal(io.prestosql.sql.tree.Literal) HashMap(java.util.HashMap) TableHandle(io.prestosql.spi.metadata.TableHandle) ExpressionTreeRewriter(io.prestosql.sql.tree.ExpressionTreeRewriter) QualifiedObjectName(io.prestosql.spi.connector.QualifiedObjectName) ArrayList(java.util.ArrayList) HashSet(java.util.HashSet) SchemaTableName(io.prestosql.spi.connector.SchemaTableName) Lists(com.google.common.collect.Lists) ArithmeticBinaryExpression(io.prestosql.sql.tree.ArithmeticBinaryExpression) ImmutableList(com.google.common.collect.ImmutableList) ExpressionDomainTranslator(io.prestosql.sql.planner.ExpressionDomainTranslator) BooleanLiteral(io.prestosql.sql.tree.BooleanLiteral) Objects.requireNonNull(java.util.Objects.requireNonNull) Session(io.prestosql.Session) RowExpressionRewriter(io.prestosql.expressions.RowExpressionRewriter) RowExpressionTreeRewriter(io.prestosql.expressions.RowExpressionTreeRewriter) ExpressionRewriter(io.prestosql.sql.tree.ExpressionRewriter) JoinNode(io.prestosql.spi.plan.JoinNode) ParsingOptions(io.prestosql.sql.parser.ParsingOptions) Symbol(io.prestosql.spi.plan.Symbol) AssignmentUtils(io.prestosql.sql.planner.plan.AssignmentUtils) EXPIRED_CUBE(io.prestosql.spi.connector.StandardWarningCode.EXPIRED_CUBE) Assignments(io.prestosql.spi.plan.Assignments) Rule(io.prestosql.sql.planner.iterative.Rule) ColumnMetadata(io.prestosql.spi.connector.ColumnMetadata) TupleDomain(io.prestosql.spi.predicate.TupleDomain) ComparisonExpression(io.prestosql.sql.tree.ComparisonExpression) VariableReferenceExpression(io.prestosql.spi.relation.VariableReferenceExpression) SymbolAllocator(io.prestosql.spi.SymbolAllocator) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) ColumnHandle(io.prestosql.spi.connector.ColumnHandle) CUBE_ERROR(io.prestosql.spi.StandardErrorCode.CUBE_ERROR) RowExpression(io.prestosql.spi.relation.RowExpression) CubeMetadata(io.hetu.core.spi.cube.CubeMetadata) Comparator(java.util.Comparator) Expression(io.prestosql.sql.tree.Expression) ColumnMetadata(io.prestosql.spi.connector.ColumnMetadata) HashMap(java.util.HashMap) Symbol(io.prestosql.spi.plan.Symbol) SymbolUtils.toSymbolReference(io.prestosql.sql.planner.SymbolUtils.toSymbolReference) SymbolReference(io.prestosql.sql.tree.SymbolReference) Assignments(io.prestosql.spi.plan.Assignments) TypeSignature(io.prestosql.spi.type.TypeSignature) PlanNode(io.prestosql.spi.plan.PlanNode) FunctionHandle(io.prestosql.spi.function.FunctionHandle) CallExpression(io.prestosql.spi.relation.CallExpression) ColumnHandle(io.prestosql.spi.connector.ColumnHandle) TypeProvider(io.prestosql.sql.planner.TypeProvider) AggregationSignature(io.hetu.core.spi.cube.aggregator.AggregationSignature) AggregationNode(io.prestosql.spi.plan.AggregationNode) SchemaTableName(io.prestosql.spi.connector.SchemaTableName) ImmutableMap(com.google.common.collect.ImmutableMap) Type(io.prestosql.spi.type.Type) ColumnNotFoundException(io.prestosql.spi.connector.ColumnNotFoundException) ConstantExpression(io.prestosql.spi.relation.ConstantExpression) CallExpression(io.prestosql.spi.relation.CallExpression) OriginalExpressionUtils.castToRowExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToRowExpression) OriginalExpressionUtils.castToExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToExpression) ArithmeticBinaryExpression(io.prestosql.sql.tree.ArithmeticBinaryExpression) ComparisonExpression(io.prestosql.sql.tree.ComparisonExpression) VariableReferenceExpression(io.prestosql.spi.relation.VariableReferenceExpression) RowExpression(io.prestosql.spi.relation.RowExpression) Expression(io.prestosql.sql.tree.Expression) ProjectNode(io.prestosql.spi.plan.ProjectNode) Map(java.util.Map) ImmutableMap(com.google.common.collect.ImmutableMap) HashMap(java.util.HashMap)

Example 22 with TypeProvider

use of io.prestosql.sql.planner.TypeProvider in project hetu-core by openlookeng.

the class TestExpressionEquivalence method assertEquivalent.

private static void assertEquivalent(@Language("SQL") String left, @Language("SQL") String right) {
    ParsingOptions parsingOptions = new ParsingOptions(AS_DOUBLE);
    Expression leftExpression = rewriteIdentifiersToSymbolReferences(SQL_PARSER.createExpression(left, parsingOptions));
    Expression rightExpression = rewriteIdentifiersToSymbolReferences(SQL_PARSER.createExpression(right, parsingOptions));
    Set<Symbol> symbols = extractUnique(ImmutableList.of(leftExpression, rightExpression));
    TypeProvider types = TypeProvider.copyOf(symbols.stream().collect(toMap(identity(), TestExpressionEquivalence::generateType)));
    assertTrue(EQUIVALENCE.areExpressionsEquivalent(TEST_SESSION, leftExpression, rightExpression, types), format("Expected (%s) and (%s) to be equivalent", left, right));
    assertTrue(EQUIVALENCE.areExpressionsEquivalent(TEST_SESSION, rightExpression, leftExpression, types), format("Expected (%s) and (%s) to be equivalent", right, left));
}
Also used : ParsingOptions(io.prestosql.sql.parser.ParsingOptions) Expression(io.prestosql.sql.tree.Expression) Symbol(io.prestosql.spi.plan.Symbol) TypeProvider(io.prestosql.sql.planner.TypeProvider)

Example 23 with TypeProvider

use of io.prestosql.sql.planner.TypeProvider in project hetu-core by openlookeng.

the class PlanPrinter method jsonFragmentPlan.

public static String jsonFragmentPlan(PlanNode root, Map<Symbol, Type> symbols, Metadata metadata, Session session) {
    TypeProvider typeProvider = TypeProvider.copyOf(symbols.entrySet().stream().distinct().collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue)));
    TableInfoSupplier supplier = new TableInfoSupplier(metadata, session);
    ValuePrinter printer = new ValuePrinter(metadata, session);
    return new PlanPrinter(root, typeProvider, Optional.empty(), supplier, printer, StatsAndCosts.empty(), Optional.empty(), metadata).toJson();
}
Also used : Entry(java.util.Map.Entry) TypeProvider(io.prestosql.sql.planner.TypeProvider)

Example 24 with TypeProvider

use of io.prestosql.sql.planner.TypeProvider in project hetu-core by openlookeng.

the class StatsNormalizer method normalize.

private PlanNodeStatsEstimate normalize(PlanNodeStatsEstimate stats, Optional<Collection<Symbol>> outputSymbols, TypeProvider types) {
    if (stats.isOutputRowCountUnknown()) {
        return PlanNodeStatsEstimate.unknown();
    }
    PlanNodeStatsEstimate.Builder normalized = PlanNodeStatsEstimate.buildFrom(stats);
    Predicate<Symbol> symbolFilter = outputSymbols.map(ImmutableSet::copyOf).map(set -> (Predicate<Symbol>) set::contains).orElse(symbol -> true);
    for (Symbol symbol : stats.getSymbolsWithKnownStatistics()) {
        if (!symbolFilter.test(symbol)) {
            normalized.removeSymbolStatistics(symbol);
            continue;
        }
        SymbolStatsEstimate symbolStats = stats.getSymbolStatistics(symbol);
        SymbolStatsEstimate normalizedSymbolStats = stats.getOutputRowCount() == 0 ? SymbolStatsEstimate.zero() : normalizeSymbolStats(symbol, symbolStats, stats, types);
        if (normalizedSymbolStats.isUnknown()) {
            normalized.removeSymbolStatistics(symbol);
            continue;
        }
        if (!Objects.equals(normalizedSymbolStats, symbolStats)) {
            normalized.addSymbolStatistics(symbol, normalizedSymbolStats);
        }
    }
    return normalized.build();
}
Also used : Symbol(io.prestosql.spi.plan.Symbol) BigintType(io.prestosql.spi.type.BigintType) ImmutableSet(com.google.common.collect.ImmutableSet) IntegerType(io.prestosql.spi.type.IntegerType) Predicate(java.util.function.Predicate) Collection(java.util.Collection) TypeProvider(io.prestosql.sql.planner.TypeProvider) DecimalType(io.prestosql.spi.type.DecimalType) TinyintType(io.prestosql.spi.type.TinyintType) Math.pow(java.lang.Math.pow) Objects(java.util.Objects) Preconditions.checkArgument(com.google.common.base.Preconditions.checkArgument) NaN(java.lang.Double.NaN) Double.isNaN(java.lang.Double.isNaN) DateType(io.prestosql.spi.type.DateType) Objects.requireNonNull(java.util.Objects.requireNonNull) Optional(java.util.Optional) Math.floor(java.lang.Math.floor) Type(io.prestosql.spi.type.Type) BooleanType(io.prestosql.spi.type.BooleanType) SmallintType(io.prestosql.spi.type.SmallintType) ImmutableSet(com.google.common.collect.ImmutableSet) Symbol(io.prestosql.spi.plan.Symbol) Predicate(java.util.function.Predicate)

Example 25 with TypeProvider

use of io.prestosql.sql.planner.TypeProvider in project hetu-core by openlookeng.

the class JoinStatsRule method computeInnerJoinStats.

private PlanNodeStatsEstimate computeInnerJoinStats(JoinNode node, PlanNodeStatsEstimate crossJoinStats, Session session, TypeProvider types) {
    List<EquiJoinClause> equiJoinCriteria = node.getCriteria();
    Map<Integer, Symbol> layout = new HashMap<>();
    int channel = 0;
    for (Symbol symbol : node.getOutputSymbols()) {
        layout.put(channel++, symbol);
    }
    if (equiJoinCriteria.isEmpty()) {
        if (!node.getFilter().isPresent()) {
            return crossJoinStats;
        }
        // TODO: this might explode stats
        if (isExpression(node.getFilter().get())) {
            return filterStatsCalculator.filterStats(crossJoinStats, castToExpression(node.getFilter().get()), session, types);
        } else {
            return filterStatsCalculator.filterStats(crossJoinStats, node.getFilter().get(), session, types, layout);
        }
    }
    PlanNodeStatsEstimate equiJoinEstimate = filterByEquiJoinClauses(crossJoinStats, node.getCriteria(), session, types);
    if (equiJoinEstimate.isOutputRowCountUnknown()) {
        return PlanNodeStatsEstimate.unknown();
    }
    if (!node.getFilter().isPresent()) {
        return equiJoinEstimate;
    }
    PlanNodeStatsEstimate filteredEquiJoinEstimate;
    if (isExpression(node.getFilter().get())) {
        filteredEquiJoinEstimate = filterStatsCalculator.filterStats(equiJoinEstimate, castToExpression(node.getFilter().get()), session, types);
    } else {
        filteredEquiJoinEstimate = filterStatsCalculator.filterStats(equiJoinEstimate, node.getFilter().get(), session, types, layout);
    }
    if (filteredEquiJoinEstimate.isOutputRowCountUnknown()) {
        return normalizer.normalize(equiJoinEstimate.mapOutputRowCount(rowCount -> rowCount * UNKNOWN_FILTER_COEFFICIENT), types);
    }
    return filteredEquiJoinEstimate;
}
Also used : EquiJoinClause(io.prestosql.spi.plan.JoinNode.EquiJoinClause) Lookup(io.prestosql.sql.planner.iterative.Lookup) Patterns.join(io.prestosql.sql.planner.plan.Patterns.join) LogicalRowExpressions(io.prestosql.expressions.LogicalRowExpressions) TypeProvider(io.prestosql.sql.planner.TypeProvider) HashMap(java.util.HashMap) Pattern(io.prestosql.matching.Pattern) SymbolStatsEstimate.buildFrom(io.prestosql.cost.SymbolStatsEstimate.buildFrom) OriginalExpressionUtils.isExpression(io.prestosql.sql.relational.OriginalExpressionUtils.isExpression) Preconditions.checkArgument(com.google.common.base.Preconditions.checkArgument) Sets.difference(com.google.common.collect.Sets.difference) NaN(java.lang.Double.NaN) Map(java.util.Map) Objects.requireNonNull(java.util.Objects.requireNonNull) Session(io.prestosql.Session) EQUAL(io.prestosql.sql.tree.ComparisonExpression.Operator.EQUAL) OriginalExpressionUtils.castToExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToExpression) LinkedList(java.util.LinkedList) JoinNode(io.prestosql.spi.plan.JoinNode) Symbol(io.prestosql.spi.plan.Symbol) UNKNOWN_FILTER_COEFFICIENT(io.prestosql.cost.FilterStatsCalculator.UNKNOWN_FILTER_COEFFICIENT) Collection(java.util.Collection) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) ComparisonExpression(io.prestosql.sql.tree.ComparisonExpression) Math.min(java.lang.Math.min) Comparator.comparingDouble(java.util.Comparator.comparingDouble) MoreMath(io.prestosql.util.MoreMath) SymbolUtils.toSymbolReference(io.prestosql.sql.planner.SymbolUtils.toSymbolReference) List(java.util.List) Double.isNaN(java.lang.Double.isNaN) RowExpression(io.prestosql.spi.relation.RowExpression) Optional(java.util.Optional) VisibleForTesting(com.google.common.annotations.VisibleForTesting) Queue(java.util.Queue) ExpressionUtils.extractConjuncts(io.prestosql.sql.ExpressionUtils.extractConjuncts) HashMap(java.util.HashMap) EquiJoinClause(io.prestosql.spi.plan.JoinNode.EquiJoinClause) Symbol(io.prestosql.spi.plan.Symbol)

Aggregations

TypeProvider (io.prestosql.sql.planner.TypeProvider)26 Symbol (io.prestosql.spi.plan.Symbol)19 Optional (java.util.Optional)12 Session (io.prestosql.Session)11 PlanNode (io.prestosql.spi.plan.PlanNode)9 RowExpression (io.prestosql.spi.relation.RowExpression)9 Expression (io.prestosql.sql.tree.Expression)9 Map (java.util.Map)9 Metadata (io.prestosql.metadata.Metadata)8 List (java.util.List)8 Objects.requireNonNull (java.util.Objects.requireNonNull)8 ImmutableList.toImmutableList (com.google.common.collect.ImmutableList.toImmutableList)7 OriginalExpressionUtils.castToExpression (io.prestosql.sql.relational.OriginalExpressionUtils.castToExpression)7 HashMap (java.util.HashMap)7 ImmutableList (com.google.common.collect.ImmutableList)6 Pattern (io.prestosql.matching.Pattern)6 FilterNode (io.prestosql.spi.plan.FilterNode)6 OriginalExpressionUtils.isExpression (io.prestosql.sql.relational.OriginalExpressionUtils.isExpression)6 Preconditions.checkArgument (com.google.common.base.Preconditions.checkArgument)5 ImmutableMap (com.google.common.collect.ImmutableMap)5