use of io.trino.cost.StatsProvider in project trino by trinodb.
the class RuleAssert method formatPlan.
private String formatPlan(PlanNode plan, TypeProvider types) {
return inTransaction(session -> {
StatsProvider statsProvider = new CachingStatsProvider(statsCalculator, session, types);
CostProvider costProvider = new CachingCostProvider(costCalculator, statsProvider, session, types);
return textLogicalPlan(plan, types, metadata, functionManager, StatsAndCosts.create(plan, statsProvider, costProvider), session, 2, false);
});
}
use of io.trino.cost.StatsProvider in project trino by trinodb.
the class DetermineJoinDistributionType method getJoinNodeWithCost.
private PlanNodeWithCost getJoinNodeWithCost(Context context, JoinNode possibleJoinNode) {
TypeProvider types = context.getSymbolAllocator().getTypes();
StatsProvider stats = context.getStatsProvider();
boolean replicated = possibleJoinNode.getDistributionType().get() == REPLICATED;
/*
* HACK!
*
* Currently cost model always has to compute the total cost of an operation.
* For JOIN the total cost consist of 4 parts:
* - Cost of exchanges that have to be introduced to execute a JOIN
* - Cost of building a hash table
* - Cost of probing a hash table
* - Cost of building an output for matched rows
*
* When output size for a JOIN cannot be estimated the cost model returns
* UNKNOWN cost for the join.
*
* However assuming the cost of JOIN output is always the same, we can still make
* cost based decisions based on the input cost for different types of JOINs.
*
* Although the side flipping can be made purely based on stats (smaller side
* always goes to the right), determining JOIN type is not that simple. As when
* choosing REPLICATED over REPARTITIONED join the cost of exchanging and building
* the hash table scales with the number of nodes where the build side is replicated.
*
* TODO Decision about the distribution should be based on LocalCostEstimate only when PlanCostEstimate cannot be calculated. Otherwise cost comparator cannot take query.max-memory into account.
*/
int estimatedSourceDistributedTaskCount = taskCountEstimator.estimateSourceDistributedTaskCount(context.getSession());
LocalCostEstimate cost = calculateJoinCostWithoutOutput(possibleJoinNode.getLeft(), possibleJoinNode.getRight(), stats, types, replicated, estimatedSourceDistributedTaskCount);
return new PlanNodeWithCost(cost.toPlanCost(), possibleJoinNode);
}
use of io.trino.cost.StatsProvider in project trino by trinodb.
the class IterativeOptimizer method ruleContext.
private Rule.Context ruleContext(Context context) {
StatsProvider statsProvider = new CachingStatsProvider(statsCalculator, Optional.of(context.memo), context.lookup, context.session, context.symbolAllocator.getTypes());
CostProvider costProvider = new CachingCostProvider(costCalculator, statsProvider, Optional.of(context.memo), context.session, context.symbolAllocator.getTypes());
return new Rule.Context() {
@Override
public Lookup getLookup() {
return context.lookup;
}
@Override
public PlanNodeIdAllocator getIdAllocator() {
return context.idAllocator;
}
@Override
public SymbolAllocator getSymbolAllocator() {
return context.symbolAllocator;
}
@Override
public Session getSession() {
return context.session;
}
@Override
public StatsProvider getStatsProvider() {
return statsProvider;
}
@Override
public CostProvider getCostProvider() {
return costProvider;
}
@Override
public void checkTimeoutNotExhausted() {
context.checkTimeoutNotExhausted();
}
@Override
public WarningCollector getWarningCollector() {
return context.warningCollector;
}
};
}
use of io.trino.cost.StatsProvider in project trino by trinodb.
the class RuleAssert method ruleContext.
private Rule.Context ruleContext(StatsCalculator statsCalculator, CostCalculator costCalculator, SymbolAllocator symbolAllocator, Memo memo, Lookup lookup, Session session) {
StatsProvider statsProvider = new CachingStatsProvider(statsCalculator, Optional.of(memo), lookup, session, symbolAllocator.getTypes());
CostProvider costProvider = new CachingCostProvider(costCalculator, statsProvider, Optional.of(memo), session, symbolAllocator.getTypes());
return new Rule.Context() {
@Override
public Lookup getLookup() {
return lookup;
}
@Override
public PlanNodeIdAllocator getIdAllocator() {
return idAllocator;
}
@Override
public SymbolAllocator getSymbolAllocator() {
return symbolAllocator;
}
@Override
public Session getSession() {
return session;
}
@Override
public StatsProvider getStatsProvider() {
return statsProvider;
}
@Override
public CostProvider getCostProvider() {
return costProvider;
}
@Override
public void checkTimeoutNotExhausted() {
}
@Override
public WarningCollector getWarningCollector() {
return WarningCollector.NOOP;
}
};
}
use of io.trino.cost.StatsProvider in project trino by trinodb.
the class AggregationMatcher method detailMatches.
@Override
public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session session, Metadata metadata, SymbolAliases symbolAliases) {
checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName());
AggregationNode aggregationNode = (AggregationNode) node;
if (groupId.isPresent() != aggregationNode.getGroupIdSymbol().isPresent()) {
return NO_MATCH;
}
if (!matches(groupingSets.getGroupingKeys(), aggregationNode.getGroupingKeys(), symbolAliases)) {
return NO_MATCH;
}
if (groupingSets.getGroupingSetCount() != aggregationNode.getGroupingSetCount()) {
return NO_MATCH;
}
if (!groupingSets.getGlobalGroupingSets().equals(aggregationNode.getGlobalGroupingSets())) {
return NO_MATCH;
}
Set<Symbol> actualMasks = aggregationNode.getAggregations().values().stream().filter(aggregation -> aggregation.getMask().isPresent()).map(aggregation -> aggregation.getMask().get()).collect(toImmutableSet());
Set<Symbol> expectedMasks = masks.stream().map(name -> new Symbol(symbolAliases.get(name).getName())).collect(toImmutableSet());
if (!actualMasks.equals(expectedMasks)) {
return NO_MATCH;
}
if (step != aggregationNode.getStep()) {
return NO_MATCH;
}
if (!matches(preGroupedSymbols, aggregationNode.getPreGroupedSymbols(), symbolAliases)) {
return NO_MATCH;
}
if (!preGroupedSymbols.isEmpty() && !aggregationNode.isStreamable()) {
return NO_MATCH;
}
return match();
}
Aggregations