use of io.prestosql.spi.plan.JoinNode in project hetu-core by openlookeng.
the class SqlStageExecution method traverseNodesForDynamicFiltering.
private void traverseNodesForDynamicFiltering(List<PlanNode> nodes) {
for (PlanNode node : nodes) {
if (node instanceof JoinNode) {
JoinNode joinNode = (JoinNode) node;
dynamicFilterService.registerTasks(joinNode, allTasks, getScheduledNodes(), stateMachine);
} else if (node instanceof SemiJoinNode) {
SemiJoinNode semiJoinNode = (SemiJoinNode) node;
dynamicFilterService.registerTasks(semiJoinNode, allTasks, getScheduledNodes(), stateMachine);
}
traverseNodesForDynamicFiltering(node.getSources());
}
}
use of io.prestosql.spi.plan.JoinNode in project hetu-core by openlookeng.
the class RelationPlanner method visitJoin.
@Override
protected RelationPlan visitJoin(Join node, Void context) {
// TODO: translate the RIGHT join into a mirrored LEFT join when we refactor (@martint)
RelationPlan leftPlan = process(node.getLeft(), context);
Optional<Unnest> unnest = getUnnest(node.getRight());
if (unnest.isPresent()) {
if (node.getType() != Join.Type.CROSS && node.getType() != Join.Type.IMPLICIT) {
throw notSupportedException(unnest.get(), "UNNEST on other than the right side of CROSS JOIN");
}
return planCrossJoinUnnest(leftPlan, node, unnest.get());
}
Optional<Lateral> lateral = getLateral(node.getRight());
if (lateral.isPresent()) {
return planLateralJoin(node, leftPlan, lateral.get());
}
RelationPlan rightPlan = process(node.getRight(), context);
if (node.getCriteria().isPresent() && node.getCriteria().get() instanceof JoinUsing) {
return planJoinUsing(node, leftPlan, rightPlan);
}
PlanBuilder leftPlanBuilder = initializePlanBuilder(leftPlan);
PlanBuilder rightPlanBuilder = initializePlanBuilder(rightPlan);
// NOTE: symbols must be in the same order as the outputDescriptor
List<Symbol> outputSymbols = ImmutableList.<Symbol>builder().addAll(leftPlan.getFieldMappings()).addAll(rightPlan.getFieldMappings()).build();
ImmutableList.Builder<JoinNode.EquiJoinClause> equiClauses = ImmutableList.builder();
List<Expression> complexJoinExpressions = new ArrayList<>();
List<Expression> postInnerJoinConditions = new ArrayList<>();
if (node.getType() != Join.Type.CROSS && node.getType() != Join.Type.IMPLICIT) {
Expression criteria = analysis.getJoinCriteria(node);
RelationType left = analysis.getOutputDescriptor(node.getLeft());
RelationType right = analysis.getOutputDescriptor(node.getRight());
List<Expression> leftComparisonExpressions = new ArrayList<>();
List<Expression> rightComparisonExpressions = new ArrayList<>();
List<ComparisonExpression.Operator> joinConditionComparisonOperators = new ArrayList<>();
for (Expression conjunct : ExpressionUtils.extractConjuncts(criteria)) {
conjunct = ExpressionUtils.normalize(conjunct);
if (!isEqualComparisonExpression(conjunct) && node.getType() != INNER) {
complexJoinExpressions.add(conjunct);
continue;
}
Set<QualifiedName> dependencies = SymbolsExtractor.extractNames(conjunct, analysis.getColumnReferences());
if (dependencies.stream().allMatch(left::canResolve) || dependencies.stream().allMatch(right::canResolve)) {
// If the conjunct can be evaluated entirely with the inputs on either side of the join, add
// it to the list complex expressions and let the optimizers figure out how to push it down later.
complexJoinExpressions.add(conjunct);
} else if (conjunct instanceof ComparisonExpression) {
Expression firstExpression = ((ComparisonExpression) conjunct).getLeft();
Expression secondExpression = ((ComparisonExpression) conjunct).getRight();
ComparisonExpression.Operator comparisonOperator = ((ComparisonExpression) conjunct).getOperator();
Set<QualifiedName> firstDependencies = SymbolsExtractor.extractNames(firstExpression, analysis.getColumnReferences());
Set<QualifiedName> secondDependencies = SymbolsExtractor.extractNames(secondExpression, analysis.getColumnReferences());
if (firstDependencies.stream().allMatch(left::canResolve) && secondDependencies.stream().allMatch(right::canResolve)) {
leftComparisonExpressions.add(firstExpression);
rightComparisonExpressions.add(secondExpression);
joinConditionComparisonOperators.add(comparisonOperator);
} else if (firstDependencies.stream().allMatch(right::canResolve) && secondDependencies.stream().allMatch(left::canResolve)) {
leftComparisonExpressions.add(secondExpression);
rightComparisonExpressions.add(firstExpression);
joinConditionComparisonOperators.add(comparisonOperator.flip());
} else {
// the case when we mix symbols from both left and right join side on either side of condition.
complexJoinExpressions.add(conjunct);
}
} else {
complexJoinExpressions.add(conjunct);
}
}
leftPlanBuilder = subqueryPlanner.handleSubqueries(leftPlanBuilder, leftComparisonExpressions, node);
rightPlanBuilder = subqueryPlanner.handleSubqueries(rightPlanBuilder, rightComparisonExpressions, node);
// Add projections for join criteria
leftPlanBuilder = leftPlanBuilder.appendProjections(leftComparisonExpressions, planSymbolAllocator, idAllocator);
rightPlanBuilder = rightPlanBuilder.appendProjections(rightComparisonExpressions, planSymbolAllocator, idAllocator);
for (int i = 0; i < leftComparisonExpressions.size(); i++) {
if (joinConditionComparisonOperators.get(i) == ComparisonExpression.Operator.EQUAL) {
Symbol leftSymbol = leftPlanBuilder.translate(leftComparisonExpressions.get(i));
Symbol rightSymbol = rightPlanBuilder.translate(rightComparisonExpressions.get(i));
equiClauses.add(new JoinNode.EquiJoinClause(leftSymbol, rightSymbol));
} else {
Expression leftExpression = leftPlanBuilder.rewrite(leftComparisonExpressions.get(i));
Expression rightExpression = rightPlanBuilder.rewrite(rightComparisonExpressions.get(i));
postInnerJoinConditions.add(new ComparisonExpression(joinConditionComparisonOperators.get(i), leftExpression, rightExpression));
}
}
}
PlanNode root = new JoinNode(idAllocator.getNextId(), JoinNodeUtils.typeConvert(node.getType()), leftPlanBuilder.getRoot(), rightPlanBuilder.getRoot(), equiClauses.build(), ImmutableList.<Symbol>builder().addAll(leftPlanBuilder.getRoot().getOutputSymbols()).addAll(rightPlanBuilder.getRoot().getOutputSymbols()).build(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), ImmutableMap.of());
if (node.getType() != INNER) {
for (Expression complexExpression : complexJoinExpressions) {
Set<InPredicate> inPredicates = subqueryPlanner.collectInPredicateSubqueries(complexExpression, node);
if (!inPredicates.isEmpty()) {
InPredicate inPredicate = Iterables.getLast(inPredicates);
throw notSupportedException(inPredicate, "IN with subquery predicate in join condition");
}
}
// subqueries can be applied only to one side of join - left side is selected in arbitrary way
leftPlanBuilder = subqueryPlanner.handleUncorrelatedSubqueries(leftPlanBuilder, complexJoinExpressions, node);
}
RelationPlan intermediateRootRelationPlan = new RelationPlan(root, analysis.getScope(node), outputSymbols);
TranslationMap translationMap = new TranslationMap(intermediateRootRelationPlan, analysis, lambdaDeclarationToSymbolMap);
translationMap.setFieldMappings(outputSymbols);
translationMap.putExpressionMappingsFrom(leftPlanBuilder.getTranslations());
translationMap.putExpressionMappingsFrom(rightPlanBuilder.getTranslations());
if (node.getType() != INNER && !complexJoinExpressions.isEmpty()) {
Expression joinedFilterCondition = ExpressionUtils.and(complexJoinExpressions);
Expression rewrittenFilterCondition = translationMap.rewrite(joinedFilterCondition);
root = new JoinNode(idAllocator.getNextId(), JoinNodeUtils.typeConvert(node.getType()), leftPlanBuilder.getRoot(), rightPlanBuilder.getRoot(), equiClauses.build(), ImmutableList.<Symbol>builder().addAll(leftPlanBuilder.getRoot().getOutputSymbols()).addAll(rightPlanBuilder.getRoot().getOutputSymbols()).build(), Optional.of(castToRowExpression(rewrittenFilterCondition)), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), ImmutableMap.of());
}
if (node.getType() == INNER) {
// rewrite all the other conditions using output symbols from left + right plan node.
PlanBuilder rootPlanBuilder = new PlanBuilder(translationMap, root);
rootPlanBuilder = subqueryPlanner.handleSubqueries(rootPlanBuilder, complexJoinExpressions, node);
for (Expression expression : complexJoinExpressions) {
postInnerJoinConditions.add(rootPlanBuilder.rewrite(expression));
}
root = rootPlanBuilder.getRoot();
Expression postInnerJoinCriteria;
if (!postInnerJoinConditions.isEmpty()) {
postInnerJoinCriteria = ExpressionUtils.and(postInnerJoinConditions);
root = new FilterNode(idAllocator.getNextId(), root, castToRowExpression(postInnerJoinCriteria));
}
}
return new RelationPlan(root, analysis.getScope(node), outputSymbols);
}
use of io.prestosql.spi.plan.JoinNode in project hetu-core by openlookeng.
the class LocalDynamicFilter method create.
public static Optional<LocalDynamicFilter> create(JoinNode planNode, int partitionCount, Session session, TaskId taskId, StateStoreProvider stateStoreProvider) {
Set<String> joinDynamicFilters = planNode.getDynamicFilters().keySet();
// Mapping from probe-side dynamic filters' IDs to their matching probe symbols.
Multimap<String, Symbol> localProbeSymbols = MultimapBuilder.treeKeys().arrayListValues().build();
PlanNode buildNode;
DynamicFilter.Type localType = DynamicFilter.Type.LOCAL;
List<FilterNode> filterNodes = findFilterNodeInStage(planNode);
if (filterNodes.isEmpty()) {
buildNode = planNode.getRight();
mapProbeSymbolsFromCriteria(planNode.getDynamicFilters(), localProbeSymbols, planNode.getCriteria());
localType = DynamicFilter.Type.GLOBAL;
} else {
buildNode = planNode.getRight();
for (FilterNode filterNode : filterNodes) {
mapProbeSymbols(filterNode.getPredicate(), joinDynamicFilters, localProbeSymbols);
}
}
final List<Symbol> buildSideSymbols = buildNode.getOutputSymbols();
Map<String, Integer> localBuildChannels = planNode.getDynamicFilters().entrySet().stream().filter(entry -> localProbeSymbols.containsKey(entry.getKey())).collect(toMap(// Dynamic filter ID
entry -> entry.getKey(), // Build-side channel index
entry -> {
Symbol buildSymbol = entry.getValue();
int buildChannelIndex = buildSideSymbols.indexOf(buildSymbol);
verify(buildChannelIndex >= 0);
return buildChannelIndex;
}));
if (localBuildChannels.isEmpty()) {
return Optional.empty();
}
return Optional.of(new LocalDynamicFilter(localProbeSymbols, localBuildChannels, partitionCount, localType, session, taskId, stateStoreProvider));
}
use of io.prestosql.spi.plan.JoinNode in project hetu-core by openlookeng.
the class ExtractSpatialJoins method tryCreateSpatialJoin.
private static Result tryCreateSpatialJoin(Context context, JoinNode joinNode, RowExpression filter, PlanNodeId nodeId, List<Symbol> outputSymbols, CallExpression spatialComparison, Metadata metadata, SplitManager splitManager, PageSourceManager pageSourceManager, TypeAnalyzer typeAnalyzer) {
FunctionMetadata spatialComparisonMetadata = metadata.getFunctionAndTypeManager().getFunctionMetadata(spatialComparison.getFunctionHandle());
checkArgument(spatialComparison.getArguments().size() == 2 && spatialComparisonMetadata.getOperatorType().isPresent());
PlanNode leftNode = joinNode.getLeft();
PlanNode rightNode = joinNode.getRight();
List<Symbol> leftSymbols = leftNode.getOutputSymbols();
List<Symbol> rightSymbols = rightNode.getOutputSymbols();
RowExpression radius;
Optional<Symbol> newRadiusSymbol;
CallExpression newComparison;
OperatorType operatorType = spatialComparisonMetadata.getOperatorType().get();
if (operatorType.equals(OperatorType.LESS_THAN) || operatorType.equals(OperatorType.LESS_THAN_OR_EQUAL)) {
// ST_Distance(a, b) <= r
radius = spatialComparison.getArguments().get(1);
Set<Symbol> radiusSymbols = extractUnique(radius);
if (radiusSymbols.isEmpty() || (rightSymbols.containsAll(radiusSymbols) && containsNone(leftSymbols, radiusSymbols))) {
newRadiusSymbol = newRadiusSymbol(context, radius);
newComparison = new CallExpression(spatialComparison.getDisplayName(), spatialComparison.getFunctionHandle(), spatialComparison.getType(), ImmutableList.of(spatialComparison.getArguments().get(0), mapToExpression(newRadiusSymbol, radius, context)), Optional.empty());
} else {
return Result.empty();
}
} else {
// r >= ST_Distance(a, b)
radius = spatialComparison.getArguments().get(0);
Set<Symbol> radiusSymbols = extractUnique(radius);
if (radiusSymbols.isEmpty() || (rightSymbols.containsAll(radiusSymbols) && containsNone(leftSymbols, radiusSymbols))) {
OperatorType newOperatorType = SpatialJoinUtils.flip(operatorType);
FunctionHandle flippedHandle = getFlippedFunctionHandle(spatialComparison, metadata.getFunctionAndTypeManager());
newRadiusSymbol = newRadiusSymbol(context, radius);
newComparison = new CallExpression(newOperatorType.name(), flippedHandle, spatialComparison.getType(), ImmutableList.of(spatialComparison.getArguments().get(1), mapToExpression(newRadiusSymbol, radius, context)), Optional.empty());
} else {
return Result.empty();
}
}
RowExpression newFilter = replaceExpression(filter, ImmutableMap.of(spatialComparison, newComparison));
PlanNode newRightNode = newRadiusSymbol.map(symbol -> addProjection(context, rightNode, symbol, radius)).orElse(rightNode);
JoinNode newJoinNode = new JoinNode(joinNode.getId(), joinNode.getType(), leftNode, newRightNode, joinNode.getCriteria(), joinNode.getOutputSymbols(), Optional.of(newFilter), joinNode.getLeftHashSymbol(), joinNode.getRightHashSymbol(), joinNode.getDistributionType(), joinNode.isSpillable(), joinNode.getDynamicFilters());
return tryCreateSpatialJoin(context, newJoinNode, newFilter, nodeId, outputSymbols, (CallExpression) newComparison.getArguments().get(0), Optional.of(newComparison.getArguments().get(1)), metadata, splitManager, pageSourceManager, typeAnalyzer);
}
use of io.prestosql.spi.plan.JoinNode in project hetu-core by openlookeng.
the class PruneCrossJoinColumns method pushDownProjectOff.
@Override
protected Optional<PlanNode> pushDownProjectOff(PlanNodeIdAllocator idAllocator, JoinNode joinNode, Set<Symbol> referencedOutputs) {
Optional<PlanNode> newLeft = restrictOutputs(idAllocator, joinNode.getLeft(), referencedOutputs, false, null);
Optional<PlanNode> newRight = restrictOutputs(idAllocator, joinNode.getRight(), referencedOutputs, false, null);
if (!newLeft.isPresent() && !newRight.isPresent()) {
return Optional.empty();
}
ImmutableList.Builder<Symbol> outputSymbolBuilder = ImmutableList.builder();
outputSymbolBuilder.addAll(newLeft.orElse(joinNode.getLeft()).getOutputSymbols());
outputSymbolBuilder.addAll(newRight.orElse(joinNode.getRight()).getOutputSymbols());
return Optional.of(new JoinNode(idAllocator.getNextId(), joinNode.getType(), newLeft.orElse(joinNode.getLeft()), newRight.orElse(joinNode.getRight()), joinNode.getCriteria(), outputSymbolBuilder.build(), joinNode.getFilter(), joinNode.getLeftHashSymbol(), joinNode.getRightHashSymbol(), joinNode.getDistributionType(), joinNode.isSpillable(), joinNode.getDynamicFilters()));
}
Aggregations