use of io.prestosql.sql.planner.TypeProvider in project hetu-core by openlookeng.
the class SimpleFilterProjectSemiJoinStatsRule method calculate.
private Optional<PlanNodeStatsEstimate> calculate(FilterNode filterNode, SemiJoinNode semiJoinNode, StatsProvider statsProvider, Session session, TypeProvider types) {
PlanNodeStatsEstimate sourceStats = statsProvider.getStats(semiJoinNode.getSource());
PlanNodeStatsEstimate filteringSourceStats = statsProvider.getStats(semiJoinNode.getFilteringSource());
Symbol filteringSourceJoinSymbol = semiJoinNode.getFilteringSourceJoinSymbol();
Symbol sourceJoinSymbol = semiJoinNode.getSourceJoinSymbol();
Optional<SemiJoinOutputFilter> semiJoinOutputFilter;
if (isExpression(filterNode.getPredicate())) {
semiJoinOutputFilter = extractSemiJoinOutputFilter(castToExpression(filterNode.getPredicate()), semiJoinNode.getSemiJoinOutput());
} else {
VariableReferenceExpression semiJoinOutput = new VariableReferenceExpression(semiJoinNode.getSemiJoinOutput().getName(), types.get(semiJoinNode.getSemiJoinOutput()));
semiJoinOutputFilter = extractSemiJoinOutputFilter(filterNode.getPredicate(), semiJoinOutput);
}
if (!semiJoinOutputFilter.isPresent()) {
return Optional.empty();
}
PlanNodeStatsEstimate semiJoinStats;
if (semiJoinOutputFilter.get().isNegated()) {
semiJoinStats = computeAntiJoin(sourceStats, filteringSourceStats, sourceJoinSymbol, filteringSourceJoinSymbol);
} else {
semiJoinStats = computeSemiJoin(sourceStats, filteringSourceStats, sourceJoinSymbol, filteringSourceJoinSymbol);
}
if (semiJoinStats.isOutputRowCountUnknown()) {
return Optional.of(PlanNodeStatsEstimate.unknown());
}
// apply remaining predicate
PlanNodeStatsEstimate filteredStats;
if (isExpression(filterNode.getPredicate())) {
filteredStats = filterStatsCalculator.filterStats(semiJoinStats, castToExpression(semiJoinOutputFilter.get().getRemainingPredicate()), session, types);
} else {
Map<Integer, Symbol> layout = SymbolUtils.toLayOut(filterNode.getOutputSymbols());
filteredStats = filterStatsCalculator.filterStats(semiJoinStats, semiJoinOutputFilter.get().getRemainingPredicate(), session, types, layout);
}
if (filteredStats.isOutputRowCountUnknown()) {
return Optional.of(semiJoinStats.mapOutputRowCount(rowCount -> rowCount * UNKNOWN_FILTER_COEFFICIENT));
}
return Optional.of(filteredStats);
}
Aggregations