use of io.trino.sql.planner.Symbol in project trino by trinodb.
the class FilterStatsCalculator method extractCorrelatedGroups.
private static List<List<Expression>> extractCorrelatedGroups(List<Expression> terms, double filterConjunctionIndependenceFactor) {
if (filterConjunctionIndependenceFactor == 1) {
// Allows the filters to be estimated as if there is no correlation between any of the terms
return ImmutableList.of(terms);
}
ListMultimap<Expression, Symbol> expressionUniqueSymbols = ArrayListMultimap.create();
terms.forEach(expression -> expressionUniqueSymbols.putAll(expression, extractUnique(expression)));
// Partition symbols into disjoint sets such that the symbols belonging to different disjoint sets
// do not appear together in any expression.
DisjointSet<Symbol> symbolsPartitioner = new DisjointSet<>();
for (Expression term : terms) {
List<Symbol> expressionSymbols = expressionUniqueSymbols.get(term);
if (expressionSymbols.isEmpty()) {
continue;
}
// Ensure that symbol is added to DisjointSet when there is only one symbol in the list
symbolsPartitioner.find(expressionSymbols.get(0));
for (int i = 1; i < expressionSymbols.size(); i++) {
symbolsPartitioner.findAndUnion(expressionSymbols.get(0), expressionSymbols.get(i));
}
}
// Use disjoint sets of symbols to partition the given list of expressions
List<Set<Symbol>> symbolPartitions = ImmutableList.copyOf(symbolsPartitioner.getEquivalentClasses());
checkState(symbolPartitions.size() <= terms.size(), "symbolPartitions size exceeds number of expressions");
ListMultimap<Integer, Expression> expressionPartitions = ArrayListMultimap.create();
for (Expression term : terms) {
List<Symbol> expressionSymbols = expressionUniqueSymbols.get(term);
int expressionPartitionId;
if (expressionSymbols.isEmpty()) {
// For expressions with no symbols
expressionPartitionId = symbolPartitions.size();
} else {
// Lookup any symbol to find the partition id
Symbol symbol = expressionSymbols.get(0);
expressionPartitionId = IntStream.range(0, symbolPartitions.size()).filter(partition -> symbolPartitions.get(partition).contains(symbol)).findFirst().orElseThrow();
}
expressionPartitions.put(expressionPartitionId, term);
}
return expressionPartitions.keySet().stream().map(expressionPartitions::get).collect(toImmutableList());
}
use of io.trino.sql.planner.Symbol in project trino by trinodb.
the class AggregationStatsRule method groupBy.
public static PlanNodeStatsEstimate groupBy(PlanNodeStatsEstimate sourceStats, Collection<Symbol> groupBySymbols, Map<Symbol, Aggregation> aggregations) {
PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.builder();
for (Symbol groupBySymbol : groupBySymbols) {
SymbolStatsEstimate symbolStatistics = sourceStats.getSymbolStatistics(groupBySymbol);
result.addSymbolStatistics(groupBySymbol, symbolStatistics.mapNullsFraction(nullsFraction -> {
if (nullsFraction == 0.0) {
return 0.0;
}
return 1.0 / (symbolStatistics.getDistinctValuesCount() + 1);
}));
}
double rowsCount = getRowsCount(sourceStats, groupBySymbols);
result.setOutputRowCount(min(rowsCount, sourceStats.getOutputRowCount()));
for (Map.Entry<Symbol, Aggregation> aggregationEntry : aggregations.entrySet()) {
result.addSymbolStatistics(aggregationEntry.getKey(), estimateAggregationStats(aggregationEntry.getValue(), sourceStats));
}
return result.build();
}
use of io.trino.sql.planner.Symbol in project trino by trinodb.
the class JoinStatsRule method addJoinComplementStats.
@VisibleForTesting
PlanNodeStatsEstimate addJoinComplementStats(PlanNodeStatsEstimate sourceStats, PlanNodeStatsEstimate innerJoinStats, PlanNodeStatsEstimate joinComplementStats) {
double innerJoinRowCount = innerJoinStats.getOutputRowCount();
double joinComplementRowCount = joinComplementStats.getOutputRowCount();
if (joinComplementRowCount == 0) {
return innerJoinStats;
}
double outputRowCount = innerJoinRowCount + joinComplementRowCount;
PlanNodeStatsEstimate.Builder outputStats = PlanNodeStatsEstimate.buildFrom(innerJoinStats);
outputStats.setOutputRowCount(outputRowCount);
for (Symbol symbol : joinComplementStats.getSymbolsWithKnownStatistics()) {
SymbolStatsEstimate leftSymbolStats = sourceStats.getSymbolStatistics(symbol);
SymbolStatsEstimate innerJoinSymbolStats = innerJoinStats.getSymbolStatistics(symbol);
SymbolStatsEstimate joinComplementSymbolStats = joinComplementStats.getSymbolStatistics(symbol);
// weighted average
double newNullsFraction = (innerJoinSymbolStats.getNullsFraction() * innerJoinRowCount + joinComplementSymbolStats.getNullsFraction() * joinComplementRowCount) / outputRowCount;
outputStats.addSymbolStatistics(symbol, SymbolStatsEstimate.buildFrom(innerJoinSymbolStats).setLowValue(leftSymbolStats.getLowValue()).setHighValue(leftSymbolStats.getHighValue()).setDistinctValuesCount(leftSymbolStats.getDistinctValuesCount()).setNullsFraction(newNullsFraction).build());
}
// add nulls to columns that don't exist in right stats
for (Symbol symbol : difference(innerJoinStats.getSymbolsWithKnownStatistics(), joinComplementStats.getSymbolsWithKnownStatistics())) {
SymbolStatsEstimate innerJoinSymbolStats = innerJoinStats.getSymbolStatistics(symbol);
double newNullsFraction = (innerJoinSymbolStats.getNullsFraction() * innerJoinRowCount + joinComplementRowCount) / outputRowCount;
outputStats.addSymbolStatistics(symbol, innerJoinSymbolStats.mapNullsFraction(nullsFraction -> newNullsFraction));
}
return outputStats.build();
}
use of io.trino.sql.planner.Symbol in project trino by trinodb.
the class ExtractSpatialJoins method tryCreateSpatialJoin.
private static Result tryCreateSpatialJoin(Context context, JoinNode joinNode, Expression filter, PlanNodeId nodeId, List<Symbol> outputSymbols, FunctionCall spatialFunction, Optional<Expression> radius, PlannerContext plannerContext, SplitManager splitManager, PageSourceManager pageSourceManager, TypeAnalyzer typeAnalyzer) {
// TODO Add support for distributed left spatial joins
Optional<String> spatialPartitioningTableName = joinNode.getType() == INNER ? getSpatialPartitioningTableName(context.getSession()) : Optional.empty();
Optional<KdbTree> kdbTree = spatialPartitioningTableName.map(tableName -> loadKdbTree(tableName, context.getSession(), plannerContext.getMetadata(), splitManager, pageSourceManager));
List<Expression> arguments = spatialFunction.getArguments();
verify(arguments.size() == 2);
Expression firstArgument = arguments.get(0);
Expression secondArgument = arguments.get(1);
Type sphericalGeographyType = plannerContext.getTypeManager().getType(SPHERICAL_GEOGRAPHY_TYPE_SIGNATURE);
if (typeAnalyzer.getType(context.getSession(), context.getSymbolAllocator().getTypes(), firstArgument).equals(sphericalGeographyType) || typeAnalyzer.getType(context.getSession(), context.getSymbolAllocator().getTypes(), secondArgument).equals(sphericalGeographyType)) {
return Result.empty();
}
Set<Symbol> firstSymbols = extractUnique(firstArgument);
Set<Symbol> secondSymbols = extractUnique(secondArgument);
if (firstSymbols.isEmpty() || secondSymbols.isEmpty()) {
return Result.empty();
}
Optional<Symbol> newFirstSymbol = newGeometrySymbol(context, firstArgument, plannerContext.getTypeManager());
Optional<Symbol> newSecondSymbol = newGeometrySymbol(context, secondArgument, plannerContext.getTypeManager());
PlanNode leftNode = joinNode.getLeft();
PlanNode rightNode = joinNode.getRight();
PlanNode newLeftNode;
PlanNode newRightNode;
// Check if the order of arguments of the spatial function matches the order of join sides
int alignment = checkAlignment(joinNode, firstSymbols, secondSymbols);
if (alignment > 0) {
newLeftNode = newFirstSymbol.map(symbol -> addProjection(context, leftNode, symbol, firstArgument)).orElse(leftNode);
newRightNode = newSecondSymbol.map(symbol -> addProjection(context, rightNode, symbol, secondArgument)).orElse(rightNode);
} else if (alignment < 0) {
newLeftNode = newSecondSymbol.map(symbol -> addProjection(context, leftNode, symbol, secondArgument)).orElse(leftNode);
newRightNode = newFirstSymbol.map(symbol -> addProjection(context, rightNode, symbol, firstArgument)).orElse(rightNode);
} else {
return Result.empty();
}
Expression newFirstArgument = toExpression(newFirstSymbol, firstArgument);
Expression newSecondArgument = toExpression(newSecondSymbol, secondArgument);
Optional<Symbol> leftPartitionSymbol = Optional.empty();
Optional<Symbol> rightPartitionSymbol = Optional.empty();
if (kdbTree.isPresent()) {
leftPartitionSymbol = Optional.of(context.getSymbolAllocator().newSymbol("pid", INTEGER));
rightPartitionSymbol = Optional.of(context.getSymbolAllocator().newSymbol("pid", INTEGER));
if (alignment > 0) {
newLeftNode = addPartitioningNodes(plannerContext, context, newLeftNode, leftPartitionSymbol.get(), kdbTree.get(), newFirstArgument, Optional.empty());
newRightNode = addPartitioningNodes(plannerContext, context, newRightNode, rightPartitionSymbol.get(), kdbTree.get(), newSecondArgument, radius);
} else {
newLeftNode = addPartitioningNodes(plannerContext, context, newLeftNode, leftPartitionSymbol.get(), kdbTree.get(), newSecondArgument, Optional.empty());
newRightNode = addPartitioningNodes(plannerContext, context, newRightNode, rightPartitionSymbol.get(), kdbTree.get(), newFirstArgument, radius);
}
}
Expression newSpatialFunction = FunctionCallBuilder.resolve(context.getSession(), plannerContext.getMetadata()).setName(spatialFunction.getName()).addArgument(GEOMETRY_TYPE_SIGNATURE, newFirstArgument).addArgument(GEOMETRY_TYPE_SIGNATURE, newSecondArgument).build();
Expression newFilter = replaceExpression(filter, ImmutableMap.of(spatialFunction, newSpatialFunction));
return Result.ofPlanNode(new SpatialJoinNode(nodeId, SpatialJoinNode.Type.fromJoinNodeType(joinNode.getType()), newLeftNode, newRightNode, outputSymbols, newFilter, leftPartitionSymbol, rightPartitionSymbol, kdbTree.map(KdbTreeUtils::toJson)));
}
use of io.trino.sql.planner.Symbol in project trino by trinodb.
the class ExtractSpatialJoins method addProjection.
private static PlanNode addProjection(Context context, PlanNode node, Symbol symbol, Expression expression) {
Assignments.Builder projections = Assignments.builder();
for (Symbol outputSymbol : node.getOutputSymbols()) {
projections.putIdentity(outputSymbol);
}
projections.put(symbol, expression);
return new ProjectNode(context.getIdAllocator().getNextId(), node, projections.build());
}
Aggregations