use of io.trino.sql.planner.TypeAnalyzer in project trino by trinodb.
the class TestEliminateSorts method assertUnitPlan.
private void assertUnitPlan(@Language("SQL") String sql, PlanMatchPattern pattern) {
TypeAnalyzer typeAnalyzer = createTestingTypeAnalyzer(getQueryRunner().getPlannerContext());
List<PlanOptimizer> optimizers = ImmutableList.of(new IterativeOptimizer(getQueryRunner().getPlannerContext(), new RuleStatsRecorder(), getQueryRunner().getStatsCalculator(), getQueryRunner().getCostCalculator(), ImmutableSet.of(new RemoveRedundantIdentityProjections(), new DetermineTableScanNodePartitioning(getQueryRunner().getMetadata(), getQueryRunner().getNodePartitioningManager(), new TaskCountEstimator(() -> 10)))), new AddExchanges(getQueryRunner().getPlannerContext(), typeAnalyzer, getQueryRunner().getStatsCalculator()));
assertPlan(sql, pattern, optimizers);
}
use of io.trino.sql.planner.TypeAnalyzer in project trino by trinodb.
the class TestValidateLimitWithPresortedInput method validatePlan.
private void validatePlan(Function<PlanBuilder, PlanNode> planProvider) {
LocalQueryRunner queryRunner = getQueryRunner();
PlanBuilder builder = new PlanBuilder(idAllocator, queryRunner.getMetadata(), queryRunner.getDefaultSession());
PlanNode planNode = planProvider.apply(builder);
TypeProvider types = builder.getTypes();
queryRunner.inTransaction(session -> {
// metadata.getCatalogHandle() registers the catalog for the transaction
session.getCatalog().ifPresent(catalog -> queryRunner.getMetadata().getCatalogHandle(session, catalog));
TypeAnalyzer typeAnalyzer = createTestingTypeAnalyzer(queryRunner.getPlannerContext());
new ValidateLimitWithPresortedInput().validate(planNode, session, queryRunner.getPlannerContext(), typeAnalyzer, types, WarningCollector.NOOP);
return null;
});
}
use of io.trino.sql.planner.TypeAnalyzer in project trino by trinodb.
the class ExtractDereferencesFromFilterAboveScan method apply.
@Override
public Result apply(FilterNode node, Captures captures, Context context) {
Set<SubscriptExpression> dereferences = extractRowSubscripts(ImmutableList.of(node.getPredicate()), true, context.getSession(), typeAnalyzer, context.getSymbolAllocator().getTypes());
if (dereferences.isEmpty()) {
return Result.empty();
}
Assignments assignments = Assignments.of(dereferences, context.getSession(), context.getSymbolAllocator(), typeAnalyzer);
Map<Expression, SymbolReference> mappings = HashBiMap.create(assignments.getMap()).inverse().entrySet().stream().collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference()));
PlanNode source = node.getSource();
return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), new FilterNode(context.getIdAllocator().getNextId(), new ProjectNode(context.getIdAllocator().getNextId(), source, Assignments.builder().putIdentities(source.getOutputSymbols()).putAll(assignments).build()), replaceExpression(node.getPredicate(), mappings)), Assignments.identity(node.getOutputSymbols())));
}
use of io.trino.sql.planner.TypeAnalyzer in project trino by trinodb.
the class ExtractSpatialJoins method tryCreateSpatialJoin.
private static Result tryCreateSpatialJoin(Context context, JoinNode joinNode, Expression filter, PlanNodeId nodeId, List<Symbol> outputSymbols, ComparisonExpression spatialComparison, PlannerContext plannerContext, SplitManager splitManager, PageSourceManager pageSourceManager, TypeAnalyzer typeAnalyzer) {
PlanNode leftNode = joinNode.getLeft();
PlanNode rightNode = joinNode.getRight();
List<Symbol> leftSymbols = leftNode.getOutputSymbols();
List<Symbol> rightSymbols = rightNode.getOutputSymbols();
Expression radius;
Optional<Symbol> newRadiusSymbol;
ComparisonExpression newComparison;
if (spatialComparison.getOperator() == LESS_THAN || spatialComparison.getOperator() == LESS_THAN_OR_EQUAL) {
// ST_Distance(a, b) <= r
radius = spatialComparison.getRight();
Set<Symbol> radiusSymbols = extractUnique(radius);
if (radiusSymbols.isEmpty() || (rightSymbols.containsAll(radiusSymbols) && containsNone(leftSymbols, radiusSymbols))) {
newRadiusSymbol = newRadiusSymbol(context, radius);
newComparison = new ComparisonExpression(spatialComparison.getOperator(), spatialComparison.getLeft(), toExpression(newRadiusSymbol, radius));
} else {
return Result.empty();
}
} else {
// r >= ST_Distance(a, b)
radius = spatialComparison.getLeft();
Set<Symbol> radiusSymbols = extractUnique(radius);
if (radiusSymbols.isEmpty() || (rightSymbols.containsAll(radiusSymbols) && containsNone(leftSymbols, radiusSymbols))) {
newRadiusSymbol = newRadiusSymbol(context, radius);
newComparison = new ComparisonExpression(spatialComparison.getOperator().flip(), spatialComparison.getRight(), toExpression(newRadiusSymbol, radius));
} else {
return Result.empty();
}
}
Expression newFilter = replaceExpression(filter, ImmutableMap.of(spatialComparison, newComparison));
PlanNode newRightNode = newRadiusSymbol.map(symbol -> addProjection(context, rightNode, symbol, radius)).orElse(rightNode);
JoinNode newJoinNode = new JoinNode(joinNode.getId(), joinNode.getType(), leftNode, newRightNode, joinNode.getCriteria(), joinNode.getLeftOutputSymbols(), joinNode.getRightOutputSymbols(), joinNode.isMaySkipOutputDuplicates(), Optional.of(newFilter), joinNode.getLeftHashSymbol(), joinNode.getRightHashSymbol(), joinNode.getDistributionType(), joinNode.isSpillable(), joinNode.getDynamicFilters(), joinNode.getReorderJoinStatsAndCost());
return tryCreateSpatialJoin(context, newJoinNode, newFilter, nodeId, outputSymbols, (FunctionCall) newComparison.getLeft(), Optional.of(newComparison.getRight()), plannerContext, splitManager, pageSourceManager, typeAnalyzer);
}
use of io.trino.sql.planner.TypeAnalyzer in project trino by trinodb.
the class PushDownDereferenceThroughJoin method apply.
@Override
public Result apply(ProjectNode projectNode, Captures captures, Context context) {
JoinNode joinNode = captures.get(CHILD);
// Consider dereferences in projections and join filter for pushdown
ImmutableList.Builder<Expression> expressionsBuilder = ImmutableList.builder();
expressionsBuilder.addAll(projectNode.getAssignments().getExpressions());
joinNode.getFilter().ifPresent(expressionsBuilder::add);
Set<SubscriptExpression> dereferences = extractRowSubscripts(expressionsBuilder.build(), false, context.getSession(), typeAnalyzer, context.getSymbolAllocator().getTypes());
// Exclude criteria symbols
ImmutableSet.Builder<Symbol> criteriaSymbolsBuilder = ImmutableSet.builder();
joinNode.getCriteria().forEach(criteria -> {
criteriaSymbolsBuilder.add(criteria.getLeft());
criteriaSymbolsBuilder.add(criteria.getRight());
});
Set<Symbol> excludeSymbols = criteriaSymbolsBuilder.build();
dereferences = dereferences.stream().filter(expression -> !excludeSymbols.contains(getBase(expression))).collect(toImmutableSet());
if (dereferences.isEmpty()) {
return Result.empty();
}
// Create new symbols for dereference expressions
Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSession(), context.getSymbolAllocator(), typeAnalyzer);
// Rewrite project node assignments using new symbols for dereference expressions
Map<Expression, SymbolReference> mappings = HashBiMap.create(dereferenceAssignments.getMap()).inverse().entrySet().stream().collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference()));
Assignments newAssignments = projectNode.getAssignments().rewrite(expression -> replaceExpression(expression, mappings));
Assignments.Builder leftAssignmentsBuilder = Assignments.builder();
Assignments.Builder rightAssignmentsBuilder = Assignments.builder();
// Separate dereferences coming from left and right nodes
dereferenceAssignments.entrySet().stream().forEach(entry -> {
Symbol baseSymbol = getOnlyElement(extractAll(entry.getValue()));
if (joinNode.getLeft().getOutputSymbols().contains(baseSymbol)) {
leftAssignmentsBuilder.put(entry.getKey(), entry.getValue());
} else if (joinNode.getRight().getOutputSymbols().contains(baseSymbol)) {
rightAssignmentsBuilder.put(entry.getKey(), entry.getValue());
} else {
throw new IllegalArgumentException(format("Unexpected symbol %s in projectNode", baseSymbol));
}
});
Assignments leftAssignments = leftAssignmentsBuilder.build();
Assignments rightAssignments = rightAssignmentsBuilder.build();
PlanNode leftNode = createProjectNodeIfRequired(joinNode.getLeft(), leftAssignments, context.getIdAllocator());
PlanNode rightNode = createProjectNodeIfRequired(joinNode.getRight(), rightAssignments, context.getIdAllocator());
// Prepare new output symbols for join node
List<Symbol> referredSymbolsInAssignments = newAssignments.getExpressions().stream().flatMap(expression -> extractAll(expression).stream()).collect(toList());
List<Symbol> newLeftOutputSymbols = referredSymbolsInAssignments.stream().filter(symbol -> leftNode.getOutputSymbols().contains(symbol)).collect(toList());
List<Symbol> newRightOutputSymbols = referredSymbolsInAssignments.stream().filter(symbol -> rightNode.getOutputSymbols().contains(symbol)).collect(toList());
JoinNode newJoinNode = new JoinNode(context.getIdAllocator().getNextId(), joinNode.getType(), leftNode, rightNode, joinNode.getCriteria(), newLeftOutputSymbols, newRightOutputSymbols, joinNode.isMaySkipOutputDuplicates(), // Use newly created symbols in filter
joinNode.getFilter().map(expression -> replaceExpression(expression, mappings)), joinNode.getLeftHashSymbol(), joinNode.getRightHashSymbol(), joinNode.getDistributionType(), joinNode.isSpillable(), joinNode.getDynamicFilters(), joinNode.getReorderJoinStatsAndCost());
return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), newJoinNode, newAssignments));
}
Aggregations