use of io.trino.sql.planner.plan.SemiJoinNode in project trino by trinodb.
the class TransformFilteringSemiJoinToInnerJoin method apply.
@Override
public Result apply(FilterNode filterNode, Captures captures, Context context) {
SemiJoinNode semiJoin = captures.get(SEMI_JOIN);
// Do not transform semi-join in context of DELETE
if (PlanNodeSearcher.searchFrom(semiJoin.getSource(), context.getLookup()).where(node -> node instanceof TableScanNode && ((TableScanNode) node).isUpdateTarget()).matches()) {
return Result.empty();
}
Symbol semiJoinSymbol = semiJoin.getSemiJoinOutput();
Predicate<Expression> isSemiJoinSymbol = expression -> expression.equals(semiJoinSymbol.toSymbolReference());
List<Expression> conjuncts = extractConjuncts(filterNode.getPredicate());
if (conjuncts.stream().noneMatch(isSemiJoinSymbol)) {
return Result.empty();
}
Expression filteredPredicate = and(conjuncts.stream().filter(not(isSemiJoinSymbol)).collect(toImmutableList()));
Expression simplifiedPredicate = inlineSymbols(symbol -> {
if (symbol.equals(semiJoinSymbol)) {
return TRUE_LITERAL;
}
return symbol.toSymbolReference();
}, filteredPredicate);
Optional<Expression> joinFilter = simplifiedPredicate.equals(TRUE_LITERAL) ? Optional.empty() : Optional.of(simplifiedPredicate);
PlanNode filteringSourceDistinct = new AggregationNode(context.getIdAllocator().getNextId(), semiJoin.getFilteringSource(), ImmutableMap.of(), singleGroupingSet(ImmutableList.of(semiJoin.getFilteringSourceJoinSymbol())), ImmutableList.of(), SINGLE, Optional.empty(), Optional.empty());
JoinNode innerJoin = new JoinNode(semiJoin.getId(), INNER, semiJoin.getSource(), filteringSourceDistinct, ImmutableList.of(new EquiJoinClause(semiJoin.getSourceJoinSymbol(), semiJoin.getFilteringSourceJoinSymbol())), semiJoin.getSource().getOutputSymbols(), ImmutableList.of(), false, joinFilter, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), semiJoin.getDynamicFilterId().map(id -> ImmutableMap.of(id, semiJoin.getFilteringSourceJoinSymbol())).orElse(ImmutableMap.of()), Optional.empty());
ProjectNode project = new ProjectNode(context.getIdAllocator().getNextId(), innerJoin, Assignments.builder().putIdentities(innerJoin.getOutputSymbols()).put(semiJoinSymbol, TRUE_LITERAL).build());
return Result.ofPlanNode(project);
}
use of io.trino.sql.planner.plan.SemiJoinNode in project trino by trinodb.
the class TestEffectivePredicateExtractor method testSemiJoin.
@Test
public void testSemiJoin() {
PlanNode node = new SemiJoinNode(newId(), filter(baseTableScan, and(greaterThan(AE, bigintLiteral(10)), lessThan(AE, bigintLiteral(100)))), filter(baseTableScan, greaterThan(AE, bigintLiteral(5))), A, B, C, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());
Expression effectivePredicate = effectivePredicateExtractor.extract(SESSION, node, TypeProvider.empty(), typeAnalyzer);
// Currently, only pull predicates through the source plan
assertEquals(normalizeConjuncts(effectivePredicate), normalizeConjuncts(and(greaterThan(AE, bigintLiteral(10)), lessThan(AE, bigintLiteral(100)))));
}
use of io.trino.sql.planner.plan.SemiJoinNode in project trino by trinodb.
the class SimpleFilterProjectSemiJoinStatsRule method calculate.
private Optional<PlanNodeStatsEstimate> calculate(FilterNode filterNode, SemiJoinNode semiJoinNode, StatsProvider statsProvider, Session session, TypeProvider types) {
PlanNodeStatsEstimate sourceStats = statsProvider.getStats(semiJoinNode.getSource());
PlanNodeStatsEstimate filteringSourceStats = statsProvider.getStats(semiJoinNode.getFilteringSource());
Symbol filteringSourceJoinSymbol = semiJoinNode.getFilteringSourceJoinSymbol();
Symbol sourceJoinSymbol = semiJoinNode.getSourceJoinSymbol();
Optional<SemiJoinOutputFilter> semiJoinOutputFilter = extractSemiJoinOutputFilter(filterNode.getPredicate(), semiJoinNode.getSemiJoinOutput());
if (semiJoinOutputFilter.isEmpty()) {
return Optional.empty();
}
PlanNodeStatsEstimate semiJoinStats;
if (semiJoinOutputFilter.get().isNegated()) {
semiJoinStats = computeAntiJoin(sourceStats, filteringSourceStats, sourceJoinSymbol, filteringSourceJoinSymbol);
} else {
semiJoinStats = computeSemiJoin(sourceStats, filteringSourceStats, sourceJoinSymbol, filteringSourceJoinSymbol);
}
if (semiJoinStats.isOutputRowCountUnknown()) {
return Optional.of(PlanNodeStatsEstimate.unknown());
}
// apply remaining predicate
PlanNodeStatsEstimate filteredStats = filterStatsCalculator.filterStats(semiJoinStats, semiJoinOutputFilter.get().getRemainingPredicate(), session, types);
if (filteredStats.isOutputRowCountUnknown()) {
return Optional.of(semiJoinStats.mapOutputRowCount(rowCount -> rowCount * UNKNOWN_FILTER_COEFFICIENT));
}
return Optional.of(filteredStats);
}
use of io.trino.sql.planner.plan.SemiJoinNode in project trino by trinodb.
the class SemiJoinMatcher method detailMatches.
@Override
public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session session, Metadata metadata, SymbolAliases symbolAliases) {
checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName());
SemiJoinNode semiJoinNode = (SemiJoinNode) node;
if (!(symbolAliases.get(sourceSymbolAlias).equals(semiJoinNode.getSourceJoinSymbol().toSymbolReference()) && symbolAliases.get(filteringSymbolAlias).equals(semiJoinNode.getFilteringSourceJoinSymbol().toSymbolReference()))) {
return NO_MATCH;
}
if (distributionType.isPresent() && !distributionType.equals(semiJoinNode.getDistributionType())) {
return NO_MATCH;
}
if (hasDynamicFilter.isPresent()) {
if (hasDynamicFilter.get()) {
if (semiJoinNode.getDynamicFilterId().isEmpty()) {
return NO_MATCH;
}
DynamicFilterId dynamicFilterId = semiJoinNode.getDynamicFilterId().get();
List<DynamicFilters.Descriptor> matchingDescriptors = searchFrom(semiJoinNode.getSource()).where(FilterNode.class::isInstance).findAll().stream().flatMap(filterNode -> extractExpressions(filterNode).stream()).flatMap(expression -> extractDynamicFilters(expression).getDynamicConjuncts().stream()).filter(descriptor -> descriptor.getId().equals(dynamicFilterId)).collect(toImmutableList());
boolean sourceSymbolsMatch = matchingDescriptors.stream().map(descriptor -> Symbol.from(descriptor.getInput())).allMatch(sourceSymbol -> symbolAliases.get(sourceSymbolAlias).equals(sourceSymbol.toSymbolReference()));
if (!matchingDescriptors.isEmpty() && sourceSymbolsMatch) {
return match(outputAlias, semiJoinNode.getSemiJoinOutput().toSymbolReference());
}
return NO_MATCH;
}
if (semiJoinNode.getDynamicFilterId().isPresent()) {
return NO_MATCH;
}
}
return match(outputAlias, semiJoinNode.getSemiJoinOutput().toSymbolReference());
}
use of io.trino.sql.planner.plan.SemiJoinNode in project trino by trinodb.
the class SimpleFilterProjectSemiJoinStatsRule method doCalculate.
@Override
protected Optional<PlanNodeStatsEstimate> doCalculate(FilterNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types) {
PlanNode nodeSource = lookup.resolve(node.getSource());
SemiJoinNode semiJoinNode;
if (nodeSource instanceof ProjectNode) {
ProjectNode projectNode = (ProjectNode) nodeSource;
if (!projectNode.isIdentity()) {
return Optional.empty();
}
PlanNode projectNodeSource = lookup.resolve(projectNode.getSource());
if (!(projectNodeSource instanceof SemiJoinNode)) {
return Optional.empty();
}
semiJoinNode = (SemiJoinNode) projectNodeSource;
} else if (nodeSource instanceof SemiJoinNode) {
semiJoinNode = (SemiJoinNode) nodeSource;
} else {
return Optional.empty();
}
return calculate(node, semiJoinNode, sourceStats, session, types);
}
Aggregations