Search in sources :

Example 1 with DisjointSet

use of io.trino.util.DisjointSet in project trino by trinodb.

the class FilterStatsCalculator method extractCorrelatedGroups.

private static List<List<Expression>> extractCorrelatedGroups(List<Expression> terms, double filterConjunctionIndependenceFactor) {
    if (filterConjunctionIndependenceFactor == 1) {
        // Allows the filters to be estimated as if there is no correlation between any of the terms
        return ImmutableList.of(terms);
    }
    ListMultimap<Expression, Symbol> expressionUniqueSymbols = ArrayListMultimap.create();
    terms.forEach(expression -> expressionUniqueSymbols.putAll(expression, extractUnique(expression)));
    // Partition symbols into disjoint sets such that the symbols belonging to different disjoint sets
    // do not appear together in any expression.
    DisjointSet<Symbol> symbolsPartitioner = new DisjointSet<>();
    for (Expression term : terms) {
        List<Symbol> expressionSymbols = expressionUniqueSymbols.get(term);
        if (expressionSymbols.isEmpty()) {
            continue;
        }
        // Ensure that symbol is added to DisjointSet when there is only one symbol in the list
        symbolsPartitioner.find(expressionSymbols.get(0));
        for (int i = 1; i < expressionSymbols.size(); i++) {
            symbolsPartitioner.findAndUnion(expressionSymbols.get(0), expressionSymbols.get(i));
        }
    }
    // Use disjoint sets of symbols to partition the given list of expressions
    List<Set<Symbol>> symbolPartitions = ImmutableList.copyOf(symbolsPartitioner.getEquivalentClasses());
    checkState(symbolPartitions.size() <= terms.size(), "symbolPartitions size exceeds number of expressions");
    ListMultimap<Integer, Expression> expressionPartitions = ArrayListMultimap.create();
    for (Expression term : terms) {
        List<Symbol> expressionSymbols = expressionUniqueSymbols.get(term);
        int expressionPartitionId;
        if (expressionSymbols.isEmpty()) {
            // For expressions with no symbols
            expressionPartitionId = symbolPartitions.size();
        } else {
            // Lookup any symbol to find the partition id
            Symbol symbol = expressionSymbols.get(0);
            expressionPartitionId = IntStream.range(0, symbolPartitions.size()).filter(partition -> symbolPartitions.get(partition).contains(symbol)).findFirst().orElseThrow();
        }
        expressionPartitions.put(expressionPartitionId, term);
    }
    return expressionPartitions.keySet().stream().map(expressionPartitions::get).collect(toImmutableList());
}
Also used : ArrayListMultimap(com.google.common.collect.ArrayListMultimap) IsNullPredicate(io.trino.sql.tree.IsNullPredicate) ListMultimap(com.google.common.collect.ListMultimap) SymbolsExtractor.extractUnique(io.trino.sql.planner.SymbolsExtractor.extractUnique) SystemSessionProperties.getFilterConjunctionIndependenceFactor(io.trino.SystemSessionProperties.getFilterConjunctionIndependenceFactor) Scope(io.trino.sql.analyzer.Scope) Double.min(java.lang.Double.min) StatsUtil.toStatsRepresentation(io.trino.spi.statistics.StatsUtil.toStatsRepresentation) Node(io.trino.sql.tree.Node) Preconditions.checkArgument(com.google.common.base.Preconditions.checkArgument) BooleanLiteral(io.trino.sql.tree.BooleanLiteral) NaN(java.lang.Double.NaN) Map(java.util.Map) DisjointSet(io.trino.util.DisjointSet) FunctionCall(io.trino.sql.tree.FunctionCall) ExpressionUtils.getExpressionTypes(io.trino.sql.ExpressionUtils.getExpressionTypes) ComparisonStatsCalculator.estimateExpressionToExpressionComparison(io.trino.cost.ComparisonStatsCalculator.estimateExpressionToExpressionComparison) BetweenPredicate(io.trino.sql.tree.BetweenPredicate) Double.isInfinite(java.lang.Double.isInfinite) PlanNodeStatsEstimateMath.capStats(io.trino.cost.PlanNodeStatsEstimateMath.capStats) ImmutableMap(com.google.common.collect.ImmutableMap) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) ExpressionInterpreter(io.trino.sql.planner.ExpressionInterpreter) Set(java.util.Set) ComparisonStatsCalculator.estimateExpressionToLiteralComparison(io.trino.cost.ComparisonStatsCalculator.estimateExpressionToLiteralComparison) ComparisonExpression(io.trino.sql.tree.ComparisonExpression) String.format(java.lang.String.format) InPredicate(io.trino.sql.tree.InPredicate) Preconditions.checkState(com.google.common.base.Preconditions.checkState) LESS_THAN_OR_EQUAL(io.trino.sql.tree.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL) EQUAL(io.trino.sql.tree.ComparisonExpression.Operator.EQUAL) List(java.util.List) SymbolReference(io.trino.sql.tree.SymbolReference) Optional(java.util.Optional) Expression(io.trino.sql.tree.Expression) InListExpression(io.trino.sql.tree.InListExpression) PlanNodeStatsEstimateMath.subtractSubsetStats(io.trino.cost.PlanNodeStatsEstimateMath.subtractSubsetStats) Session(io.trino.Session) PlannerContext(io.trino.sql.PlannerContext) IntStream(java.util.stream.IntStream) PlanNodeStatsEstimateMath.addStatsAndSumDistinctValues(io.trino.cost.PlanNodeStatsEstimateMath.addStatsAndSumDistinctValues) AllowAllAccessControl(io.trino.security.AllowAllAccessControl) LiteralEncoder(io.trino.sql.planner.LiteralEncoder) ExpressionInterpreter.evaluateConstantExpression(io.trino.sql.planner.ExpressionInterpreter.evaluateConstantExpression) Type(io.trino.spi.type.Type) NoOpSymbolResolver(io.trino.sql.planner.NoOpSymbolResolver) OptionalDouble(java.util.OptionalDouble) BOOLEAN(io.trino.spi.type.BooleanType.BOOLEAN) Inject(javax.inject.Inject) ImmutableList(com.google.common.collect.ImmutableList) Verify.verify(com.google.common.base.Verify.verify) PlanNodeStatsEstimateMath.estimateCorrelatedConjunctionRowCount(io.trino.cost.PlanNodeStatsEstimateMath.estimateCorrelatedConjunctionRowCount) IsNotNullPredicate(io.trino.sql.tree.IsNotNullPredicate) ExpressionAnalyzer(io.trino.sql.analyzer.ExpressionAnalyzer) NodeRef(io.trino.sql.tree.NodeRef) NotExpression(io.trino.sql.tree.NotExpression) Objects.requireNonNull(java.util.Objects.requireNonNull) GREATER_THAN_OR_EQUAL(io.trino.sql.tree.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL) Nullable(javax.annotation.Nullable) AstVisitor(io.trino.sql.tree.AstVisitor) VerifyException(com.google.common.base.VerifyException) ExpressionUtils(io.trino.sql.ExpressionUtils) Symbol(io.trino.sql.planner.Symbol) PlanNodeStatsEstimateMath.intersectCorrelatedStats(io.trino.cost.PlanNodeStatsEstimateMath.intersectCorrelatedStats) ExpressionUtils.and(io.trino.sql.ExpressionUtils.and) DynamicFilters.isDynamicFilter(io.trino.sql.DynamicFilters.isDynamicFilter) Double.isNaN(java.lang.Double.isNaN) WarningCollector(io.trino.execution.warnings.WarningCollector) LogicalExpression(io.trino.sql.tree.LogicalExpression) TypeProvider(io.trino.sql.planner.TypeProvider) DisjointSet(io.trino.util.DisjointSet) Set(java.util.Set) ComparisonExpression(io.trino.sql.tree.ComparisonExpression) Expression(io.trino.sql.tree.Expression) InListExpression(io.trino.sql.tree.InListExpression) ExpressionInterpreter.evaluateConstantExpression(io.trino.sql.planner.ExpressionInterpreter.evaluateConstantExpression) NotExpression(io.trino.sql.tree.NotExpression) LogicalExpression(io.trino.sql.tree.LogicalExpression) Symbol(io.trino.sql.planner.Symbol) DisjointSet(io.trino.util.DisjointSet)

Example 2 with DisjointSet

use of io.trino.util.DisjointSet in project trino by trinodb.

the class EqualityInference method newInstance.

public static EqualityInference newInstance(Metadata metadata, Collection<Expression> expressions) {
    DisjointSet<Expression> equalities = new DisjointSet<>();
    expressions.stream().flatMap(expression -> extractConjuncts(expression).stream()).filter(expression -> isInferenceCandidate(metadata, expression)).forEach(expression -> {
        ComparisonExpression comparison = (ComparisonExpression) expression;
        Expression expression1 = comparison.getLeft();
        Expression expression2 = comparison.getRight();
        equalities.findAndUnion(expression1, expression2);
    });
    Collection<Set<Expression>> equivalentClasses = equalities.getEquivalentClasses();
    // Map every expression to the set of equivalent expressions
    Map<Expression, Set<Expression>> byExpression = new HashMap<>();
    for (Set<Expression> equivalence : equivalentClasses) {
        equivalence.forEach(expression -> byExpression.put(expression, equivalence));
    }
    // For every non-derived expression, extract the sub-expressions and see if they can be rewritten as other expressions. If so,
    // use this new information to update the known equalities.
    Set<Expression> derivedExpressions = new LinkedHashSet<>();
    for (Expression expression : byExpression.keySet()) {
        if (derivedExpressions.contains(expression)) {
            continue;
        }
        SubExpressionExtractor.extract(expression).filter(e -> !e.equals(expression)).forEach(subExpression -> {
            byExpression.getOrDefault(subExpression, ImmutableSet.of()).stream().filter(e -> !e.equals(subExpression)).forEach(equivalentSubExpression -> {
                Expression rewritten = replaceExpression(expression, ImmutableMap.of(subExpression, equivalentSubExpression));
                equalities.findAndUnion(expression, rewritten);
                derivedExpressions.add(rewritten);
            });
        });
    }
    Multimap<Expression, Expression> equalitySets = makeEqualitySets(equalities);
    ImmutableMap.Builder<Expression, Expression> canonicalMappings = ImmutableMap.builder();
    for (Map.Entry<Expression, Expression> entry : equalitySets.entries()) {
        Expression canonical = entry.getKey();
        Expression expression = entry.getValue();
        canonicalMappings.put(expression, canonical);
    }
    return new EqualityInference(equalitySets, canonicalMappings.buildOrThrow(), derivedExpressions);
}
Also used : Arrays(java.util.Arrays) HashMap(java.util.HashMap) Multimap(com.google.common.collect.Multimap) ArrayList(java.util.ArrayList) ImmutableList(com.google.common.collect.ImmutableList) DeterminismEvaluator.isDeterministic(io.trino.sql.planner.DeterminismEvaluator.isDeterministic) Map(java.util.Map) Objects.requireNonNull(java.util.Objects.requireNonNull) DisjointSet(io.trino.util.DisjointSet) LinkedHashSet(java.util.LinkedHashSet) ImmutableSetMultimap(com.google.common.collect.ImmutableSetMultimap) ImmutableSet(com.google.common.collect.ImmutableSet) ImmutableMap(com.google.common.collect.ImmutableMap) Predicate(java.util.function.Predicate) Collection(java.util.Collection) NullabilityAnalyzer.mayReturnNullOnNonNullInput(io.trino.sql.planner.NullabilityAnalyzer.mayReturnNullOnNonNullInput) ToIntFunction(java.util.function.ToIntFunction) Set(java.util.Set) Collectors(java.util.stream.Collectors) ComparisonExpression(io.trino.sql.tree.ComparisonExpression) Objects(java.util.Objects) List(java.util.List) Stream(java.util.stream.Stream) SymbolReference(io.trino.sql.tree.SymbolReference) Metadata(io.trino.metadata.Metadata) ExpressionNodeInliner.replaceExpression(io.trino.sql.planner.ExpressionNodeInliner.replaceExpression) Expression(io.trino.sql.tree.Expression) ExpressionUtils.extractConjuncts(io.trino.sql.ExpressionUtils.extractConjuncts) VisibleForTesting(com.google.common.annotations.VisibleForTesting) Comparator(java.util.Comparator) LinkedHashSet(java.util.LinkedHashSet) DisjointSet(io.trino.util.DisjointSet) LinkedHashSet(java.util.LinkedHashSet) ImmutableSet(com.google.common.collect.ImmutableSet) Set(java.util.Set) HashMap(java.util.HashMap) ImmutableMap(com.google.common.collect.ImmutableMap) ComparisonExpression(io.trino.sql.tree.ComparisonExpression) ComparisonExpression(io.trino.sql.tree.ComparisonExpression) ExpressionNodeInliner.replaceExpression(io.trino.sql.planner.ExpressionNodeInliner.replaceExpression) Expression(io.trino.sql.tree.Expression) DisjointSet(io.trino.util.DisjointSet) HashMap(java.util.HashMap) Map(java.util.Map) ImmutableMap(com.google.common.collect.ImmutableMap)

Aggregations

ImmutableList (com.google.common.collect.ImmutableList)2 ImmutableMap (com.google.common.collect.ImmutableMap)2 ComparisonExpression (io.trino.sql.tree.ComparisonExpression)2 Expression (io.trino.sql.tree.Expression)2 SymbolReference (io.trino.sql.tree.SymbolReference)2 DisjointSet (io.trino.util.DisjointSet)2 List (java.util.List)2 Map (java.util.Map)2 Objects.requireNonNull (java.util.Objects.requireNonNull)2 Set (java.util.Set)2 VisibleForTesting (com.google.common.annotations.VisibleForTesting)1 Preconditions.checkArgument (com.google.common.base.Preconditions.checkArgument)1 Preconditions.checkState (com.google.common.base.Preconditions.checkState)1 Verify.verify (com.google.common.base.Verify.verify)1 VerifyException (com.google.common.base.VerifyException)1 ArrayListMultimap (com.google.common.collect.ArrayListMultimap)1 ImmutableList.toImmutableList (com.google.common.collect.ImmutableList.toImmutableList)1 ImmutableSet (com.google.common.collect.ImmutableSet)1 ImmutableSetMultimap (com.google.common.collect.ImmutableSetMultimap)1 ListMultimap (com.google.common.collect.ListMultimap)1