use of io.prestosql.spi.plan.AggregationNode in project hetu-core by openlookeng.
the class PushAggregationThroughOuterJoin method coalesceWithNullAggregation.
// When the aggregation is done after the join, there will be a null value that gets aggregated over
// where rows did not exist in the inner table. For some aggregate functions, such as count, the result
// of an aggregation over a single null row is one or zero rather than null. In order to ensure correct results,
// we add a coalesce function with the output of the new outer join and the agggregation performed over a single
// null row.
private Optional<PlanNode> coalesceWithNullAggregation(AggregationNode aggregationNode, PlanNode outerJoin, PlanSymbolAllocator planSymbolAllocator, PlanNodeIdAllocator idAllocator, Lookup lookup) {
// Create an aggregation node over a row of nulls.
Optional<MappedAggregationInfo> aggregationOverNullInfoResultNode = createAggregationOverNull(aggregationNode, planSymbolAllocator, idAllocator, lookup);
if (!aggregationOverNullInfoResultNode.isPresent()) {
return Optional.empty();
}
MappedAggregationInfo aggregationOverNullInfo = aggregationOverNullInfoResultNode.get();
AggregationNode aggregationOverNull = aggregationOverNullInfo.getAggregation();
Map<Symbol, Symbol> sourceAggregationToOverNullMapping = aggregationOverNullInfo.getSymbolMapping();
// Do a cross join with the aggregation over null
JoinNode crossJoin = new JoinNode(idAllocator.getNextId(), JoinNode.Type.INNER, outerJoin, aggregationOverNull, ImmutableList.of(), ImmutableList.<Symbol>builder().addAll(outerJoin.getOutputSymbols()).addAll(aggregationOverNull.getOutputSymbols()).build(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), ImmutableMap.of());
// Add coalesce expressions for all aggregation functions
Assignments.Builder assignmentsBuilder = Assignments.builder();
for (Symbol symbol : outerJoin.getOutputSymbols()) {
if (aggregationNode.getAggregations().containsKey(symbol)) {
assignmentsBuilder.put(symbol, new SpecialForm(COALESCE, planSymbolAllocator.getTypes().get(symbol), ImmutableList.of(toVariableReference(symbol, planSymbolAllocator.getTypes()), toVariableReference(sourceAggregationToOverNullMapping.get(symbol), planSymbolAllocator.getTypes()))));
} else {
assignmentsBuilder.put(symbol, toVariableReference(symbol, planSymbolAllocator.getTypes()));
}
}
return Optional.of(new ProjectNode(idAllocator.getNextId(), crossJoin, assignmentsBuilder.build()));
}
use of io.prestosql.spi.plan.AggregationNode in project hetu-core by openlookeng.
the class PushAggregationThroughOuterJoin method apply.
@Override
public Result apply(AggregationNode aggregation, Captures captures, Context context) {
JoinNode join = captures.get(JOIN);
if (join.getFilter().isPresent() || !(join.getType() == JoinNode.Type.LEFT || join.getType() == JoinNode.Type.RIGHT) || !groupsOnAllColumns(aggregation, getOuterTable(join).getOutputSymbols()) || !isDistinct(context.getLookup().resolve(getOuterTable(join)), context.getLookup()::resolve)) {
return Result.empty();
}
List<Symbol> groupingKeys = join.getCriteria().stream().map(join.getType() == JoinNode.Type.RIGHT ? JoinNode.EquiJoinClause::getLeft : JoinNode.EquiJoinClause::getRight).collect(toImmutableList());
AggregationNode rewrittenAggregation = new AggregationNode(aggregation.getId(), getInnerTable(join), aggregation.getAggregations(), singleGroupingSet(groupingKeys), ImmutableList.of(), aggregation.getStep(), aggregation.getHashSymbol(), aggregation.getGroupIdSymbol(), aggregation.getAggregationType(), aggregation.getFinalizeSymbol());
JoinNode rewrittenJoin;
if (join.getType() == JoinNode.Type.LEFT) {
rewrittenJoin = new JoinNode(join.getId(), join.getType(), join.getLeft(), rewrittenAggregation, join.getCriteria(), ImmutableList.<Symbol>builder().addAll(join.getLeft().getOutputSymbols()).addAll(rewrittenAggregation.getAggregations().keySet()).build(), join.getFilter(), join.getLeftHashSymbol(), join.getRightHashSymbol(), join.getDistributionType(), join.isSpillable(), join.getDynamicFilters());
} else {
rewrittenJoin = new JoinNode(join.getId(), join.getType(), rewrittenAggregation, join.getRight(), join.getCriteria(), ImmutableList.<Symbol>builder().addAll(rewrittenAggregation.getAggregations().keySet()).addAll(join.getRight().getOutputSymbols()).build(), join.getFilter(), join.getLeftHashSymbol(), join.getRightHashSymbol(), join.getDistributionType(), join.isSpillable(), join.getDynamicFilters());
}
Optional<PlanNode> resultNode = coalesceWithNullAggregation(rewrittenAggregation, rewrittenJoin, context.getSymbolAllocator(), context.getIdAllocator(), context.getLookup());
if (!resultNode.isPresent()) {
return Result.empty();
}
return Result.ofPlanNode(resultNode.get());
}
use of io.prestosql.spi.plan.AggregationNode in project hetu-core by openlookeng.
the class SetOperationNodeTranslator method computeCounts.
private AggregationNode computeCounts(UnionNode sourceNode, List<Symbol> originalColumns, List<Symbol> markers, List<Symbol> aggregationOutputs) {
ImmutableMap.Builder<Symbol, AggregationNode.Aggregation> aggregations = ImmutableMap.builder();
for (int i = 0; i < markers.size(); i++) {
Symbol output = aggregationOutputs.get(i);
aggregations.put(output, new AggregationNode.Aggregation(new CallExpression(COUNT_AGGREGATION_NAME.getObjectName(), metadata.getFunctionAndTypeManager().lookupFunction(COUNT_AGGREGATION_NAME.getObjectName(), fromTypes(BOOLEAN)), BIGINT, ImmutableList.of(castToRowExpression(toSymbolReference(markers.get(i)))), Optional.empty()), ImmutableList.of(castToRowExpression(toSymbolReference(markers.get(i)))), false, Optional.empty(), Optional.empty(), Optional.empty()));
}
return new AggregationNode(idAllocator.getNextId(), sourceNode, aggregations.build(), singleGroupingSet(originalColumns), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty(), AggregationNode.AggregationType.HASH, Optional.empty());
}
use of io.prestosql.spi.plan.AggregationNode in project hetu-core by openlookeng.
the class SetOperationNodeTranslator method makeSetContainmentPlan.
public TranslationResult makeSetContainmentPlan(SetOperationNode node) {
checkArgument(!(node instanceof UnionNode), "Cannot simplify a UnionNode");
List<Symbol> markers = allocateSymbols(node.getSources().size(), MARKER, BOOLEAN);
// identity projection for all the fields in each of the sources plus marker columns
List<PlanNode> withMarkers = appendMarkers(markers, node.getSources(), node);
// add a union over all the rewritten sources. The outputs of the union have the same name as the
// original intersect node
List<Symbol> outputs = node.getOutputSymbols();
UnionNode union = union(withMarkers, ImmutableList.copyOf(concat(outputs, markers)));
// add count aggregations and filter rows where any of the counts is >= 1
List<Symbol> aggregationOutputs = allocateSymbols(markers.size(), "count", BIGINT);
AggregationNode aggregation = computeCounts(union, outputs, markers, aggregationOutputs);
List<Expression> presentExpression = aggregationOutputs.stream().map(symbol -> new ComparisonExpression(GREATER_THAN_OR_EQUAL, toSymbolReference(symbol), GENERIC_LITERAL)).collect(toImmutableList());
return new TranslationResult(aggregation, presentExpression);
}
use of io.prestosql.spi.plan.AggregationNode in project hetu-core by openlookeng.
the class TransformCorrelatedInPredicateToJoin method buildInPredicateEquivalent.
private PlanNode buildInPredicateEquivalent(ApplyNode apply, InPredicate inPredicate, Symbol inPredicateOutputSymbol, Decorrelated decorrelated, PlanNodeIdAllocator idAllocator, PlanSymbolAllocator planSymbolAllocator) {
Expression correlationCondition = and(decorrelated.getCorrelatedPredicates());
PlanNode decorrelatedBuildSource = decorrelated.getDecorrelatedNode();
AssignUniqueId probeSide = new AssignUniqueId(idAllocator.getNextId(), apply.getInput(), planSymbolAllocator.newSymbol("unique", BIGINT));
Symbol buildSideKnownNonNull = planSymbolAllocator.newSymbol("buildSideKnownNonNull", BIGINT);
ProjectNode buildSide = new ProjectNode(idAllocator.getNextId(), decorrelatedBuildSource, Assignments.builder().putAll(identityAsSymbolReferences(decorrelatedBuildSource.getOutputSymbols())).put(buildSideKnownNonNull, castToRowExpression(bigint(0))).build());
Symbol probeSideSymbol = SymbolUtils.from(inPredicate.getValue());
Symbol buildSideSymbol = SymbolUtils.from(inPredicate.getValueList());
Expression joinExpression = and(or(new IsNullPredicate(toSymbolReference(probeSideSymbol)), new ComparisonExpression(ComparisonExpression.Operator.EQUAL, toSymbolReference(probeSideSymbol), toSymbolReference(buildSideSymbol)), new IsNullPredicate(toSymbolReference(buildSideSymbol))), correlationCondition);
JoinNode leftOuterJoin = leftOuterJoin(idAllocator, probeSide, buildSide, joinExpression);
Symbol matchConditionSymbol = planSymbolAllocator.newSymbol("matchConditionSymbol", BOOLEAN);
Expression matchCondition = and(isNotNull(probeSideSymbol), isNotNull(buildSideSymbol));
Symbol nullMatchConditionSymbol = planSymbolAllocator.newSymbol("nullMatchConditionSymbol", BOOLEAN);
Expression nullMatchCondition = and(isNotNull(buildSideKnownNonNull), not(matchCondition));
ProjectNode preProjection = new ProjectNode(idAllocator.getNextId(), leftOuterJoin, Assignments.builder().putAll(AssignmentUtils.identityAsSymbolReferences(leftOuterJoin.getOutputSymbols())).put(matchConditionSymbol, castToRowExpression(matchCondition)).put(nullMatchConditionSymbol, castToRowExpression(nullMatchCondition)).build());
Symbol countMatchesSymbol = planSymbolAllocator.newSymbol("countMatches", BIGINT);
Symbol countNullMatchesSymbol = planSymbolAllocator.newSymbol("countNullMatches", BIGINT);
AggregationNode aggregation = new AggregationNode(idAllocator.getNextId(), preProjection, ImmutableMap.<Symbol, AggregationNode.Aggregation>builder().put(countMatchesSymbol, countWithFilter(matchConditionSymbol)).put(countNullMatchesSymbol, countWithFilter(nullMatchConditionSymbol)).build(), singleGroupingSet(probeSide.getOutputSymbols()), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty(), AggregationNode.AggregationType.HASH, Optional.empty());
// TODO since we care only about "some count > 0", we could have specialized node instead of leftOuterJoin that does the job without materializing join results
SearchedCaseExpression inPredicateEquivalent = new SearchedCaseExpression(ImmutableList.of(new WhenClause(isGreaterThan(countMatchesSymbol, 0), booleanConstant(true)), new WhenClause(isGreaterThan(countNullMatchesSymbol, 0), booleanConstant(null))), Optional.of(booleanConstant(false)));
return new ProjectNode(idAllocator.getNextId(), aggregation, Assignments.builder().putAll(identityAsSymbolReferences(apply.getInput().getOutputSymbols())).put(inPredicateOutputSymbol, castToRowExpression(inPredicateEquivalent)).build());
}
Aggregations