Search in sources :

Example 1 with UNKNOWN_FILTER_COEFFICIENT

use of io.prestosql.cost.FilterStatsCalculator.UNKNOWN_FILTER_COEFFICIENT in project hetu-core by openlookeng.

the class TableScanStatsRule method doCalculate.

@Override
protected Optional<PlanNodeStatsEstimate> doCalculate(TableScanNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types) {
    // TODO Construct predicate like AddExchanges's LayoutConstraintEvaluator
    TupleDomain<ColumnHandle> predicate = metadata.getTableProperties(session, node.getTable()).getPredicate();
    Constraint constraint = new Constraint(predicate);
    TableStatistics tableStatistics = metadata.getTableStatistics(session, node.getTable(), constraint, true);
    verify(tableStatistics != null, "tableStatistics is null for %s", node);
    Map<Symbol, SymbolStatsEstimate> outputSymbolStats = new HashMap<>();
    Map<ColumnHandle, Symbol> remainingSymbols = new HashMap<>();
    Map<ColumnHandle, Symbol> assignments = ImmutableBiMap.copyOf(node.getAssignments()).inverse();
    boolean isPredicatesPushDown = false;
    if ((predicate.isAll() || predicate.getDomains().get().equals(node.getEnforcedConstraint().getDomains().get())) && !(node.getEnforcedConstraint().isAll() || node.getEnforcedConstraint().isNone())) {
        predicate = node.getEnforcedConstraint();
        isPredicatesPushDown = true;
        predicate.getDomains().get().entrySet().stream().forEach(e -> {
            remainingSymbols.put(e.getKey(), new Symbol(e.getKey().getColumnName()));
        });
    }
    for (Map.Entry<Symbol, ColumnHandle> entry : node.getAssignments().entrySet()) {
        Symbol symbol = entry.getKey();
        Optional<ColumnStatistics> columnStatistics = Optional.ofNullable(tableStatistics.getColumnStatistics().get(entry.getValue()));
        SymbolStatsEstimate symbolStatistics = columnStatistics.map(statistics -> toSymbolStatistics(tableStatistics, statistics, types.get(symbol))).orElse(SymbolStatsEstimate.unknown());
        outputSymbolStats.put(symbol, symbolStatistics);
        remainingSymbols.remove(entry.getValue());
    }
    PlanNodeStatsEstimate tableEstimates = PlanNodeStatsEstimate.builder().setOutputRowCount(tableStatistics.getRowCount().getValue()).addSymbolStatistics(outputSymbolStats).build();
    if (isPredicatesPushDown) {
        if (remainingSymbols.size() > 0) {
            ImmutableBiMap.Builder<ColumnHandle, Symbol> assignmentBuilder = ImmutableBiMap.builder();
            assignments = assignmentBuilder.putAll(assignments).putAll(remainingSymbols).build();
            for (Map.Entry<ColumnHandle, Symbol> entry : remainingSymbols.entrySet()) {
                Symbol symbol = entry.getValue();
                Optional<ColumnStatistics> columnStatistics = Optional.ofNullable(tableStatistics.getColumnStatistics().get(entry.getKey()));
                SymbolStatsEstimate symbolStatistics = columnStatistics.map(statistics -> toSymbolStatistics(tableStatistics, statistics, types.get(symbol))).orElse(SymbolStatsEstimate.unknown());
                outputSymbolStats.put(symbol, symbolStatistics);
            }
            /* Refresh TableEstimates for remaining columns */
            tableEstimates = PlanNodeStatsEstimate.builder().setOutputRowCount(tableStatistics.getRowCount().getValue()).addSymbolStatistics(outputSymbolStats).build();
        }
        Expression pushDownExpression = domainTranslator.toPredicate(predicate.transform(assignments::get));
        PlanNodeStatsEstimate estimate = filterStatsCalculator.filterStats(tableEstimates, pushDownExpression, session, types);
        if ((isDefaultFilterFactorEnabled(session) || sourceStats.isEnforceDefaultFilterFactor()) && estimate.isOutputRowCountUnknown()) {
            PlanNodeStatsEstimate finalTableEstimates = tableEstimates;
            estimate = tableEstimates.mapOutputRowCount(sourceRowCount -> finalTableEstimates.getOutputRowCount() * UNKNOWN_FILTER_COEFFICIENT);
        }
        return Optional.of(estimate);
    }
    return Optional.of(tableEstimates);
}
Also used : ColumnStatistics(io.prestosql.spi.statistics.ColumnStatistics) TableStatistics(io.prestosql.spi.statistics.TableStatistics) Lookup(io.prestosql.sql.planner.iterative.Lookup) TypeProvider(io.prestosql.sql.planner.TypeProvider) HashMap(java.util.HashMap) Pattern(io.prestosql.matching.Pattern) ImmutableBiMap(com.google.common.collect.ImmutableBiMap) NaN(java.lang.Double.NaN) ExpressionDomainTranslator(io.prestosql.sql.planner.ExpressionDomainTranslator) Verify.verify(com.google.common.base.Verify.verify) FixedWidthType(io.prestosql.spi.type.FixedWidthType) Map(java.util.Map) Objects.requireNonNull(java.util.Objects.requireNonNull) Session(io.prestosql.Session) Type(io.prestosql.spi.type.Type) ColumnStatistics(io.prestosql.spi.statistics.ColumnStatistics) Constraint(io.prestosql.spi.connector.Constraint) Symbol(io.prestosql.spi.plan.Symbol) UNKNOWN_FILTER_COEFFICIENT(io.prestosql.cost.FilterStatsCalculator.UNKNOWN_FILTER_COEFFICIENT) TupleDomain(io.prestosql.spi.predicate.TupleDomain) TableScanNode(io.prestosql.spi.plan.TableScanNode) SystemSessionProperties.isDefaultFilterFactorEnabled(io.prestosql.SystemSessionProperties.isDefaultFilterFactorEnabled) Metadata(io.prestosql.metadata.Metadata) ColumnHandle(io.prestosql.spi.connector.ColumnHandle) LiteralEncoder(io.prestosql.sql.planner.LiteralEncoder) Optional(java.util.Optional) Patterns.tableScan(io.prestosql.sql.planner.plan.Patterns.tableScan) Expression(io.prestosql.sql.tree.Expression) ColumnHandle(io.prestosql.spi.connector.ColumnHandle) Constraint(io.prestosql.spi.connector.Constraint) HashMap(java.util.HashMap) Symbol(io.prestosql.spi.plan.Symbol) ImmutableBiMap(com.google.common.collect.ImmutableBiMap) Expression(io.prestosql.sql.tree.Expression) TableStatistics(io.prestosql.spi.statistics.TableStatistics) HashMap(java.util.HashMap) ImmutableBiMap(com.google.common.collect.ImmutableBiMap) Map(java.util.Map)

Example 2 with UNKNOWN_FILTER_COEFFICIENT

use of io.prestosql.cost.FilterStatsCalculator.UNKNOWN_FILTER_COEFFICIENT in project hetu-core by openlookeng.

the class FilterStatsRule method doCalculate.

@Override
public Optional<PlanNodeStatsEstimate> doCalculate(FilterNode node, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types) {
    PlanNodeStatsEstimate sourceStats = statsProvider.getStats(node.getSource());
    PlanNodeStatsEstimate estimate;
    if (isExpression(node.getPredicate())) {
        estimate = filterStatsCalculator.filterStats(sourceStats, castToExpression(node.getPredicate()), session, types);
    } else {
        Map<Integer, Symbol> layout = new HashMap<>();
        int channel = 0;
        for (Symbol symbol : node.getSource().getOutputSymbols()) {
            layout.put(channel++, symbol);
        }
        estimate = filterStatsCalculator.filterStats(sourceStats, node.getPredicate(), session, types, layout);
    }
    if ((isDefaultFilterFactorEnabled(session) || statsProvider.isEnforceDefaultFilterFactor()) && estimate.isOutputRowCountUnknown()) {
        estimate = sourceStats.mapOutputRowCount(sourceRowCount -> sourceStats.getOutputRowCount() * UNKNOWN_FILTER_COEFFICIENT);
    }
    return Optional.of(estimate);
}
Also used : Symbol(io.prestosql.spi.plan.Symbol) UNKNOWN_FILTER_COEFFICIENT(io.prestosql.cost.FilterStatsCalculator.UNKNOWN_FILTER_COEFFICIENT) Lookup(io.prestosql.sql.planner.iterative.Lookup) Patterns.filter(io.prestosql.sql.planner.plan.Patterns.filter) TypeProvider(io.prestosql.sql.planner.TypeProvider) HashMap(java.util.HashMap) Pattern(io.prestosql.matching.Pattern) SystemSessionProperties.isDefaultFilterFactorEnabled(io.prestosql.SystemSessionProperties.isDefaultFilterFactorEnabled) OriginalExpressionUtils.isExpression(io.prestosql.sql.relational.OriginalExpressionUtils.isExpression) FilterNode(io.prestosql.spi.plan.FilterNode) Map(java.util.Map) Session(io.prestosql.Session) Optional(java.util.Optional) OriginalExpressionUtils.castToExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToExpression) HashMap(java.util.HashMap) Symbol(io.prestosql.spi.plan.Symbol)

Example 3 with UNKNOWN_FILTER_COEFFICIENT

use of io.prestosql.cost.FilterStatsCalculator.UNKNOWN_FILTER_COEFFICIENT in project hetu-core by openlookeng.

the class GroupIdStatsRule method groupBy.

public static PlanNodeStatsEstimate groupBy(PlanNodeStatsEstimate sourceStats, List<List<Symbol>> groupingSets, Map<Symbol, Symbol> groupingColumns) {
    PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.builder();
    double rowsCount;
    double finalRowCount = 0;
    boolean knowsStatsFound = false;
    for (List<Symbol> groupingSet : groupingSets) {
        if (groupingSet.size() == 0) {
            finalRowCount = 1;
            continue;
        }
        rowsCount = 1;
        for (Symbol groupBySymbol : groupingSet) {
            Symbol currentCol = groupingColumns.get(groupBySymbol);
            SymbolStatsEstimate symbolStatistics = sourceStats.getSymbolStatistics(currentCol);
            result.addSymbolStatistics(currentCol, symbolStatistics.mapNullsFraction(nullsFraction -> {
                if (nullsFraction == 0.0) {
                    return 0.0;
                }
                return 1.0 / (symbolStatistics.getDistinctValuesCount() + 1);
            }));
        }
        for (Symbol groupBySymbol : groupingSet) {
            SymbolStatsEstimate symbolStatistics = sourceStats.getSymbolStatistics(groupingColumns.get(groupBySymbol));
            if (symbolStatistics.isUnknown()) {
                // so for now we dont consider the same for overall stats calculation.
                continue;
            }
            knowsStatsFound = true;
            int nullRow = (symbolStatistics.getNullsFraction() == 0.0) ? 0 : 1;
            rowsCount *= symbolStatistics.getDistinctValuesCount() + nullRow;
        }
        rowsCount = min(rowsCount, sourceStats.getOutputRowCount());
        // corresponding to each multi col group one NULL value (i.e. corresponding to all) is added, so adjust for the same.
        finalRowCount += rowsCount + (groupingSet.size() > 1 ? (sourceStats.getOutputRowCount() / rowsCount) : 0);
    }
    if (knowsStatsFound) {
        result.setOutputRowCount(finalRowCount + 1);
    } else {
        result.setOutputRowCount(sourceStats.getOutputRowCount() * UNKNOWN_FILTER_COEFFICIENT);
    }
    return result.build();
}
Also used : Symbol(io.prestosql.spi.plan.Symbol) GroupIdNode(io.prestosql.spi.plan.GroupIdNode) UNKNOWN_FILTER_COEFFICIENT(io.prestosql.cost.FilterStatsCalculator.UNKNOWN_FILTER_COEFFICIENT) List(java.util.List) Patterns.groupId(io.prestosql.sql.planner.plan.Patterns.groupId) Lookup(io.prestosql.sql.planner.iterative.Lookup) Map(java.util.Map) Session(io.prestosql.Session) TypeProvider(io.prestosql.sql.planner.TypeProvider) Optional(java.util.Optional) Pattern(io.prestosql.matching.Pattern) Math.min(java.lang.Math.min) Symbol(io.prestosql.spi.plan.Symbol)

Example 4 with UNKNOWN_FILTER_COEFFICIENT

use of io.prestosql.cost.FilterStatsCalculator.UNKNOWN_FILTER_COEFFICIENT in project hetu-core by openlookeng.

the class AggregationStatsRule method groupBy.

public static PlanNodeStatsEstimate groupBy(PlanNodeStatsEstimate sourceStats, Collection<Symbol> groupBySymbols, Map<Symbol, Aggregation> aggregations) {
    PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.builder();
    for (Symbol groupBySymbol : groupBySymbols) {
        SymbolStatsEstimate symbolStatistics = sourceStats.getSymbolStatistics(groupBySymbol);
        result.addSymbolStatistics(groupBySymbol, symbolStatistics.mapNullsFraction(nullsFraction -> {
            if (nullsFraction == 0.0) {
                return 0.0;
            }
            return 1.0 / (symbolStatistics.getDistinctValuesCount() + 1);
        }));
    }
    double rowsCount = 1;
    boolean knowsStatsFound = false;
    for (Symbol groupBySymbol : groupBySymbols) {
        SymbolStatsEstimate symbolStatistics = sourceStats.getSymbolStatistics(groupBySymbol);
        if (symbolStatistics.isUnknown()) {
            // so for now we dont consider the same for overall stats calculation.
            continue;
        }
        knowsStatsFound = true;
        int nullRow = (symbolStatistics.getNullsFraction() == 0.0) ? 0 : 1;
        rowsCount *= symbolStatistics.getDistinctValuesCount() + nullRow;
    }
    if (knowsStatsFound) {
        result.setOutputRowCount(min(rowsCount, sourceStats.getOutputRowCount()));
    } else {
        // If there was no proper stats in any of the grouping symbols, we just take fixed % of source.
        result.setOutputRowCount(sourceStats.getOutputRowCount() * UNKNOWN_FILTER_COEFFICIENT);
    }
    for (Map.Entry<Symbol, Aggregation> aggregationEntry : aggregations.entrySet()) {
        result.addSymbolStatistics(aggregationEntry.getKey(), estimateAggregationStats(aggregationEntry.getValue(), sourceStats));
    }
    return result.build();
}
Also used : Symbol(io.prestosql.spi.plan.Symbol) UNKNOWN_FILTER_COEFFICIENT(io.prestosql.cost.FilterStatsCalculator.UNKNOWN_FILTER_COEFFICIENT) Lookup(io.prestosql.sql.planner.iterative.Lookup) Patterns.aggregation(io.prestosql.sql.planner.plan.Patterns.aggregation) Collection(java.util.Collection) TypeProvider(io.prestosql.sql.planner.TypeProvider) Pattern(io.prestosql.matching.Pattern) FINAL(io.prestosql.spi.plan.AggregationNode.Step.FINAL) Math.min(java.lang.Math.min) PARTIAL(io.prestosql.spi.plan.AggregationNode.Step.PARTIAL) AggregationNode(io.prestosql.spi.plan.AggregationNode) Aggregation(io.prestosql.spi.plan.AggregationNode.Aggregation) SINGLE(io.prestosql.spi.plan.AggregationNode.Step.SINGLE) Map(java.util.Map) Objects.requireNonNull(java.util.Objects.requireNonNull) Session(io.prestosql.Session) Optional(java.util.Optional) Aggregation(io.prestosql.spi.plan.AggregationNode.Aggregation) Symbol(io.prestosql.spi.plan.Symbol) Map(java.util.Map)

Example 5 with UNKNOWN_FILTER_COEFFICIENT

use of io.prestosql.cost.FilterStatsCalculator.UNKNOWN_FILTER_COEFFICIENT in project hetu-core by openlookeng.

the class TestJoinStatsRule method testStatsForInnerJoinWithTwoEquiClausesAndNonEqualityFunction.

@Test
public void testStatsForInnerJoinWithTwoEquiClausesAndNonEqualityFunction() {
    double innerJoinRowCount = // driver join clause
    LEFT_ROWS_COUNT * RIGHT_ROWS_COUNT / LEFT_JOIN_COLUMN_2_NDV * LEFT_JOIN_COLUMN_2_NON_NULLS * RIGHT_JOIN_COLUMN_2_NON_NULLS * // auxiliary join clause
    UNKNOWN_FILTER_COEFFICIENT * // LEFT_JOIN_COLUMN < 10 non equality filter
    0.3333333333;
    PlanNodeStatsEstimate innerJoinStats = planNodeStats(innerJoinRowCount, symbolStatistics(LEFT_JOIN_COLUMN, 5.0, 10.0, 0.0, RIGHT_JOIN_COLUMN_NDV * 0.3333333333), symbolStatistics(RIGHT_JOIN_COLUMN, 5.0, 20.0, 0.0, RIGHT_JOIN_COLUMN_NDV), symbolStatistics(LEFT_JOIN_COLUMN_2, 100.0, 200.0, 0.0, RIGHT_JOIN_COLUMN_2_NDV), symbolStatistics(RIGHT_JOIN_COLUMN_2, 100.0, 200.0, 0.0, RIGHT_JOIN_COLUMN_2_NDV));
    tester().assertStatsFor(pb -> {
        Symbol leftJoinColumnSymbol = pb.symbol(LEFT_JOIN_COLUMN, BIGINT);
        Symbol rightJoinColumnSymbol = pb.symbol(RIGHT_JOIN_COLUMN, DOUBLE);
        Symbol leftJoinColumnSymbol2 = pb.symbol(LEFT_JOIN_COLUMN_2, BIGINT);
        Symbol rightJoinColumnSymbol2 = pb.symbol(RIGHT_JOIN_COLUMN_2, DOUBLE);
        ComparisonExpression leftJoinColumnLessThanTen = new ComparisonExpression(ComparisonExpression.Operator.LESS_THAN, toSymbolReference(leftJoinColumnSymbol), new LongLiteral("10"));
        return pb.join(INNER, pb.values(leftJoinColumnSymbol, leftJoinColumnSymbol2), pb.values(rightJoinColumnSymbol, rightJoinColumnSymbol2), ImmutableList.of(new EquiJoinClause(leftJoinColumnSymbol2, rightJoinColumnSymbol2), new EquiJoinClause(leftJoinColumnSymbol, rightJoinColumnSymbol)), ImmutableList.of(leftJoinColumnSymbol, leftJoinColumnSymbol2, rightJoinColumnSymbol, rightJoinColumnSymbol2), Optional.of(castToRowExpression(leftJoinColumnLessThanTen)));
    }).withSourceStats(0, planNodeStats(LEFT_ROWS_COUNT, LEFT_JOIN_COLUMN_STATS, LEFT_JOIN_COLUMN_2_STATS)).withSourceStats(1, planNodeStats(RIGHT_ROWS_COUNT, RIGHT_JOIN_COLUMN_STATS, RIGHT_JOIN_COLUMN_2_STATS)).check(stats -> stats.equalTo(innerJoinStats));
}
Also used : EquiJoinClause(io.prestosql.spi.plan.JoinNode.EquiJoinClause) TypeProvider(io.prestosql.sql.planner.TypeProvider) Assert.assertEquals(org.testng.Assert.assertEquals) Test(org.testng.annotations.Test) PlanNodeStatsAssertion.assertThat(io.prestosql.cost.PlanNodeStatsAssertion.assertThat) ImmutableList(com.google.common.collect.ImmutableList) NaN(java.lang.Double.NaN) OriginalExpressionUtils.castToRowExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToRowExpression) DOUBLE(io.prestosql.spi.type.DoubleType.DOUBLE) Type(io.prestosql.spi.type.Type) BIGINT(io.prestosql.spi.type.BigintType.BIGINT) JoinNode(io.prestosql.spi.plan.JoinNode) Symbol(io.prestosql.spi.plan.Symbol) UNKNOWN_FILTER_COEFFICIENT(io.prestosql.cost.FilterStatsCalculator.UNKNOWN_FILTER_COEFFICIENT) ImmutableMap(com.google.common.collect.ImmutableMap) FULL(io.prestosql.spi.plan.JoinNode.Type.FULL) MetadataManager.createTestMetadataManager(io.prestosql.metadata.MetadataManager.createTestMetadataManager) ComparisonExpression(io.prestosql.sql.tree.ComparisonExpression) Metadata(io.prestosql.metadata.Metadata) SymbolUtils.toSymbolReference(io.prestosql.sql.planner.SymbolUtils.toSymbolReference) RIGHT(io.prestosql.spi.plan.JoinNode.Type.RIGHT) LongLiteral(io.prestosql.sql.tree.LongLiteral) INNER(io.prestosql.spi.plan.JoinNode.Type.INNER) Optional(java.util.Optional) LEFT(io.prestosql.spi.plan.JoinNode.Type.LEFT) ComparisonExpression(io.prestosql.sql.tree.ComparisonExpression) LongLiteral(io.prestosql.sql.tree.LongLiteral) Symbol(io.prestosql.spi.plan.Symbol) EquiJoinClause(io.prestosql.spi.plan.JoinNode.EquiJoinClause) Test(org.testng.annotations.Test)

Aggregations

UNKNOWN_FILTER_COEFFICIENT (io.prestosql.cost.FilterStatsCalculator.UNKNOWN_FILTER_COEFFICIENT)8 Symbol (io.prestosql.spi.plan.Symbol)8 TypeProvider (io.prestosql.sql.planner.TypeProvider)8 Optional (java.util.Optional)8 Session (io.prestosql.Session)7 Pattern (io.prestosql.matching.Pattern)7 Lookup (io.prestosql.sql.planner.iterative.Lookup)7 Map (java.util.Map)6 Math.min (java.lang.Math.min)4 Objects.requireNonNull (java.util.Objects.requireNonNull)4 Metadata (io.prestosql.metadata.Metadata)3 ImmutableList.toImmutableList (com.google.common.collect.ImmutableList.toImmutableList)2 SystemSessionProperties.isDefaultFilterFactorEnabled (io.prestosql.SystemSessionProperties.isDefaultFilterFactorEnabled)2 LogicalRowExpressions (io.prestosql.expressions.LogicalRowExpressions)2 FilterNode (io.prestosql.spi.plan.FilterNode)2 JoinNode (io.prestosql.spi.plan.JoinNode)2 EquiJoinClause (io.prestosql.spi.plan.JoinNode.EquiJoinClause)2 RowExpression (io.prestosql.spi.relation.RowExpression)2 Type (io.prestosql.spi.type.Type)2 ExpressionUtils.extractConjuncts (io.prestosql.sql.ExpressionUtils.extractConjuncts)2