use of io.prestosql.spi.plan.Assignments in project hetu-core by openlookeng.
the class InlineProjections method apply.
@Override
public Result apply(ProjectNode parent, Captures captures, Context context) {
ProjectNode child = captures.get(CHILD);
Sets.SetView<Symbol> targets = extractInliningTargets(parent, child);
if (targets.isEmpty()) {
return Result.empty();
}
// inline the expressions
Assignments assignments = child.getAssignments().filter(targets::contains);
Map<Symbol, RowExpression> parentAssignments = parent.getAssignments().entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> (inlineReferences(entry.getValue(), assignments))));
// Synthesize identity assignments for the inputs of expressions that were inlined
// to place in the child projection.
// If all assignments end up becoming identity assignments, they'll get pruned by
// other rules
Set<Symbol> inputs = child.getAssignments().entrySet().stream().filter(entry -> targets.contains(entry.getKey())).map(Map.Entry::getValue).flatMap(entry -> extractInputs(entry).stream()).collect(toSet());
Assignments.Builder childAssignments = Assignments.builder();
for (Map.Entry<Symbol, RowExpression> assignment : child.getAssignments().entrySet()) {
if (!targets.contains(assignment.getKey())) {
childAssignments.put(assignment);
}
}
boolean allTranslated = child.getAssignments().entrySet().stream().map(Map.Entry::getValue).noneMatch(OriginalExpressionUtils::isExpression);
for (Symbol input : inputs) {
if (allTranslated) {
Type inputType = context.getSymbolAllocator().getSymbols().get(input);
childAssignments.put(input, new VariableReferenceExpression(input.getName(), inputType));
} else {
childAssignments.put(input, castToRowExpression(toSymbolReference(input)));
}
}
return Result.ofPlanNode(new ProjectNode(parent.getId(), new ProjectNode(child.getId(), child.getSource(), childAssignments.build()), Assignments.copyOf(parentAssignments)));
}
use of io.prestosql.spi.plan.Assignments in project hetu-core by openlookeng.
the class PushProjectionIntoTableScan method apply.
@Override
public Result apply(ProjectNode project, Captures captures, Context context) {
TableScanNode tableScan = captures.get(TABLE_SCAN);
List<ConnectorExpression> projections;
try {
projections = project.getAssignments().getExpressions().stream().map(expression -> ConnectorExpressionTranslator.translate(expression)).collect(toImmutableList());
} catch (UnsupportedOperationException e) {
// rewritten in terms of the new assignments for the columns passed in #2
return Result.empty();
}
Map<String, ColumnHandle> assignments = tableScan.getAssignments().entrySet().stream().collect(toImmutableMap(entry -> entry.getKey().getName(), Map.Entry::getValue));
Optional<ProjectionApplicationResult<TableHandle>> result = metadata.applyProjection(context.getSession(), tableScan.getTable(), projections, assignments);
if (!result.isPresent()) {
return Result.empty();
}
List<Symbol> newScanOutputs = new ArrayList<>();
Map<Symbol, ColumnHandle> newScanAssignments = new HashMap<>();
Map<String, Symbol> variableMappings = new HashMap<>();
for (ProjectionApplicationResult.Assignment assignment : result.get().getAssignments()) {
Symbol symbol = context.getSymbolAllocator().newSymbol(assignment.getVariable(), assignment.getType());
newScanOutputs.add(symbol);
newScanAssignments.put(symbol, assignment.getColumn());
variableMappings.put(assignment.getVariable(), symbol);
}
// TODO: ensure newProjections.size == original projections.size
List<RowExpression> newProjections = result.get().getProjections().stream().map(expression -> ConnectorExpressionTranslator.translate(expression, variableMappings, new LiteralEncoder(metadata))).collect(toImmutableList());
Assignments.Builder newProjectionAssignments = Assignments.builder();
for (int i = 0; i < project.getOutputSymbols().size(); i++) {
newProjectionAssignments.put(project.getOutputSymbols().get(i), newProjections.get(i));
}
return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), TableScanNode.newInstance(tableScan.getId(), result.get().getHandle(), newScanOutputs, newScanAssignments, tableScan.getStrategy(), tableScan.getReuseTableScanMappingId(), tableScan.getConsumerTableScanNodeCount(), tableScan.isForDelete()), newProjectionAssignments.build()));
}
use of io.prestosql.spi.plan.Assignments in project hetu-core by openlookeng.
the class PushProjectionThroughUnion method apply.
@Override
public Result apply(ProjectNode parent, Captures captures, Context context) {
UnionNode source = captures.get(CHILD);
// OutputLayout of the resultant Union, will be same as the layout of the Project
List<Symbol> outputLayout = parent.getOutputSymbols();
// Mapping from the output symbol to ordered list of symbols from each of the sources
ImmutableListMultimap.Builder<Symbol, Symbol> mappings = ImmutableListMultimap.builder();
// sources for the resultant UnionNode
ImmutableList.Builder<PlanNode> outputSources = ImmutableList.builder();
for (int i = 0; i < source.getSources().size(); i++) {
// Map: output of union -> input of this source to the union
Map<Symbol, SymbolReference> outputToInput = sourceSymbolMap(source, i);
// assignments for the new ProjectNode
Assignments.Builder assignments = Assignments.builder();
// mapping from current ProjectNode to new ProjectNode, used to identify the output layout
Map<Symbol, Symbol> projectSymbolMapping = new HashMap<>();
// Translate the assignments in the ProjectNode using symbols of the source of the UnionNode
for (Map.Entry<Symbol, RowExpression> entry : parent.getAssignments().entrySet()) {
RowExpression translatedExpression;
Map<VariableReferenceExpression, VariableReferenceExpression> variable = new HashMap<>();
Map<Symbol, Type> symbols = context.getSymbolAllocator().getSymbols();
for (Symbol symbolMap : source.getSymbolMapping().keySet()) {
Symbol symboli = source.getSymbolMapping().get(symbolMap).get(i);
variable.put(new VariableReferenceExpression(symbolMap.getName(), symbols.get(symbolMap)), new VariableReferenceExpression(symboli.getName(), symbols.get(symboli)));
}
translatedExpression = RowExpressionVariableInliner.inlineVariables(variable, entry.getValue());
Symbol symbol = context.getSymbolAllocator().newSymbol(translatedExpression);
assignments.put(symbol, translatedExpression);
projectSymbolMapping.put(entry.getKey(), symbol);
}
outputSources.add(new ProjectNode(context.getIdAllocator().getNextId(), source.getSources().get(i), assignments.build()));
outputLayout.forEach(symbol -> mappings.put(symbol, projectSymbolMapping.get(symbol)));
}
return Result.ofPlanNode(new UnionNode(parent.getId(), outputSources.build(), mappings.build(), ImmutableList.copyOf(mappings.build().keySet())));
}
use of io.prestosql.spi.plan.Assignments in project hetu-core by openlookeng.
the class TransformCorrelatedInPredicateToJoin method apply.
@Override
public Result apply(ApplyNode apply, Captures captures, Context context) {
Assignments subqueryAssignments = apply.getSubqueryAssignments();
if (subqueryAssignments.size() != 1) {
return Result.empty();
}
Expression assignmentExpression = castToExpression(getOnlyElement(subqueryAssignments.getExpressions()));
if (!(assignmentExpression instanceof InPredicate)) {
return Result.empty();
}
InPredicate inPredicate = (InPredicate) assignmentExpression;
Symbol inPredicateOutputSymbol = getOnlyElement(subqueryAssignments.getSymbols());
return apply(apply, inPredicate, inPredicateOutputSymbol, context.getLookup(), context.getIdAllocator(), context.getSymbolAllocator());
}
use of io.prestosql.spi.plan.Assignments in project hetu-core by openlookeng.
the class TransformUnCorrelatedInPredicateSubQuerySelfJoinToAggregate method transformProjectNode.
private Optional<ProjectNode> transformProjectNode(Context context, ProjectNode projectNode) {
// IN predicate requires only one projection
if (projectNode.getOutputSymbols().size() > 1) {
return Optional.empty();
}
PlanNode source = context.getLookup().resolve(projectNode.getSource());
if (source instanceof CTEScanNode) {
source = getChildFilterNode(context, context.getLookup().resolve(((CTEScanNode) source).getSource()));
}
if (!(source instanceof FilterNode && context.getLookup().resolve(((FilterNode) source).getSource()) instanceof JoinNode)) {
return Optional.empty();
}
FilterNode filter = (FilterNode) source;
Expression predicate = OriginalExpressionUtils.castToExpression(filter.getPredicate());
List<SymbolReference> allPredicateSymbols = new ArrayList<>();
getAllSymbols(predicate, allPredicateSymbols);
JoinNode joinNode = (JoinNode) context.getLookup().resolve(((FilterNode) source).getSource());
if (!isSelfJoin(projectNode, predicate, joinNode, context.getLookup())) {
// Check next level for Self Join
PlanNode left = context.getLookup().resolve(joinNode.getLeft());
boolean changed = false;
if (left instanceof ProjectNode) {
Optional<ProjectNode> transformResult = transformProjectNode(context, (ProjectNode) left);
if (transformResult.isPresent()) {
joinNode = new JoinNode(joinNode.getId(), joinNode.getType(), transformResult.get(), joinNode.getRight(), joinNode.getCriteria(), joinNode.getOutputSymbols(), joinNode.getFilter(), joinNode.getLeftHashSymbol(), joinNode.getRightHashSymbol(), joinNode.getDistributionType(), joinNode.isSpillable(), joinNode.getDynamicFilters());
changed = true;
}
}
PlanNode right = context.getLookup().resolve(joinNode.getRight());
if (right instanceof ProjectNode) {
Optional<ProjectNode> transformResult = transformProjectNode(context, (ProjectNode) right);
if (transformResult.isPresent()) {
joinNode = new JoinNode(joinNode.getId(), joinNode.getType(), joinNode.getLeft(), transformResult.get(), joinNode.getCriteria(), joinNode.getOutputSymbols(), joinNode.getFilter(), joinNode.getLeftHashSymbol(), joinNode.getRightHashSymbol(), joinNode.getDistributionType(), joinNode.isSpillable(), joinNode.getDynamicFilters());
changed = true;
}
}
if (changed) {
FilterNode transformedFilter = new FilterNode(filter.getId(), joinNode, filter.getPredicate());
ProjectNode transformedProject = new ProjectNode(projectNode.getId(), transformedFilter, projectNode.getAssignments());
return Optional.of(transformedProject);
}
return Optional.empty();
}
// Choose the table to use based on projected output.
TableScanNode leftTable = (TableScanNode) context.getLookup().resolve(joinNode.getLeft());
TableScanNode rightTable = (TableScanNode) context.getLookup().resolve(joinNode.getRight());
TableScanNode tableToUse;
List<RowExpression> aggregationSymbols;
AggregationNode.GroupingSetDescriptor groupingSetDescriptor;
Assignments projectionsForCTE = null;
Assignments projectionsFromFilter = null;
// Use non-projected column for aggregation
if (context.getLookup().resolve(projectNode.getSource()) instanceof CTEScanNode) {
CTEScanNode cteScanNode = (CTEScanNode) context.getLookup().resolve(projectNode.getSource());
ProjectNode childProjectOfCte = (ProjectNode) context.getLookup().resolve(cteScanNode.getSource());
List<Symbol> completeOutputSymbols = new ArrayList<>();
rightTable.getOutputSymbols().forEach((s) -> completeOutputSymbols.add(s));
leftTable.getOutputSymbols().forEach((s) -> completeOutputSymbols.add(s));
List<Symbol> outputSymbols = new ArrayList<>();
for (int i = 0; i < completeOutputSymbols.size(); i++) {
Symbol outputSymbol = completeOutputSymbols.get(i);
for (Symbol symbol : projectNode.getOutputSymbols()) {
if (childProjectOfCte.getAssignments().getMap().containsKey(symbol)) {
if (((SymbolReference) OriginalExpressionUtils.castToExpression(childProjectOfCte.getAssignments().getMap().get(symbol))).getName().equals(outputSymbol.getName())) {
outputSymbols.add(outputSymbol);
}
}
}
}
Map<Symbol, RowExpression> projectionsForCTEMap = new HashMap<>();
Map<Symbol, RowExpression> projectionsFromFilterMap = new HashMap<>();
for (Map.Entry entry : childProjectOfCte.getAssignments().getMap().entrySet()) {
if (entry.getKey().equals(getOnlyElement(projectNode.getOutputSymbols()))) {
projectionsForCTEMap.put((Symbol) entry.getKey(), (RowExpression) entry.getValue());
}
if (entry.getKey().equals(getOnlyElement(projectNode.getOutputSymbols()))) {
projectionsFromFilterMap.put(getOnlyElement(outputSymbols), (RowExpression) entry.getValue());
}
}
projectionsForCTE = new Assignments(projectionsForCTEMap);
projectionsFromFilter = new Assignments(projectionsFromFilterMap);
tableToUse = leftTable.getOutputSymbols().contains(getOnlyElement(outputSymbols)) ? leftTable : rightTable;
aggregationSymbols = allPredicateSymbols.stream().filter(s -> tableToUse.getOutputSymbols().contains(SymbolUtils.from(s))).filter(s -> !outputSymbols.contains(SymbolUtils.from(s))).map(OriginalExpressionUtils::castToRowExpression).collect(Collectors.toList());
// Create aggregation
groupingSetDescriptor = new AggregationNode.GroupingSetDescriptor(ImmutableList.copyOf(outputSymbols), 1, ImmutableSet.of());
} else {
tableToUse = leftTable.getOutputSymbols().contains(getOnlyElement(projectNode.getOutputSymbols())) ? leftTable : rightTable;
aggregationSymbols = allPredicateSymbols.stream().filter(s -> tableToUse.getOutputSymbols().contains(SymbolUtils.from(s))).filter(s -> !projectNode.getOutputSymbols().contains(SymbolUtils.from(s))).map(OriginalExpressionUtils::castToRowExpression).collect(Collectors.toList());
// Create aggregation
groupingSetDescriptor = new AggregationNode.GroupingSetDescriptor(ImmutableList.copyOf(projectNode.getOutputSymbols()), 1, ImmutableSet.of());
}
AggregationNode.Aggregation aggregation = new AggregationNode.Aggregation(new CallExpression("count", functionResolution.countFunction(), BIGINT, aggregationSymbols), aggregationSymbols, // mark DISTINCT since NOT_EQUALS predicate
true, Optional.empty(), Optional.empty(), Optional.empty());
ImmutableMap.Builder<Symbol, AggregationNode.Aggregation> aggregationsBuilder = ImmutableMap.builder();
Symbol countSymbol = context.getSymbolAllocator().newSymbol(aggregation.getFunctionCall().getDisplayName(), BIGINT);
aggregationsBuilder.put(countSymbol, aggregation);
AggregationNode aggregationNode = new AggregationNode(context.getIdAllocator().getNextId(), tableToUse, aggregationsBuilder.build(), groupingSetDescriptor, ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty(), AggregationNode.AggregationType.HASH, Optional.empty());
// Filter rows with count < 1 from aggregation results to match the NOT_EQUALS clause in original query.
FilterNode filterNode = new FilterNode(context.getIdAllocator().getNextId(), aggregationNode, OriginalExpressionUtils.castToRowExpression(new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN, SymbolUtils.toSymbolReference(countSymbol), new GenericLiteral("BIGINT", "1"))));
// Project the aggregated+filtered rows.
ProjectNode transformedSubquery = new ProjectNode(projectNode.getId(), filterNode, projectNode.getAssignments());
if (context.getLookup().resolve(projectNode.getSource()) instanceof CTEScanNode) {
CTEScanNode cteScanNode = (CTEScanNode) context.getLookup().resolve(projectNode.getSource());
PlanNode projectNodeForCTE = context.getLookup().resolve(cteScanNode.getSource());
PlanNode projectNodeFromFilter = context.getLookup().resolve(projectNodeForCTE);
projectNodeFromFilter = new ProjectNode(projectNodeFromFilter.getId(), filterNode, projectionsFromFilter);
projectNodeForCTE = new ProjectNode(projectNodeForCTE.getId(), projectNodeFromFilter, projectionsForCTE);
cteScanNode = (CTEScanNode) cteScanNode.replaceChildren(ImmutableList.of(projectNodeForCTE));
cteScanNode.setOutputSymbols(projectNode.getOutputSymbols());
transformedSubquery = new ProjectNode(projectNode.getId(), cteScanNode, projectNode.getAssignments());
}
return Optional.of(transformedSubquery);
}
Aggregations