use of com.facebook.presto.spi.relation.SpecialFormExpression in project presto by prestodb.
the class Partitioning method translateToCoalesce.
// Maps VariableReferenceExpression in both partitions to an COALESCE expression, keeps constant arguments unchanged.
public Optional<Partitioning> translateToCoalesce(Partitioning other, Metadata metadata, Session session) {
checkArgument(arePartitionHandlesCompatibleForCoalesce(this.handle, other.handle, metadata, session), "incompatible partitioning handles: cannot coalesce %s and %s", this.handle, other.handle);
checkArgument(this.arguments.size() == other.arguments.size(), "incompatible number of partitioning arguments: %s != %s", this.arguments.size(), other.arguments.size());
ImmutableList.Builder<RowExpression> arguments = ImmutableList.builder();
for (int i = 0; i < this.arguments.size(); i++) {
RowExpression leftArgument = this.arguments.get(i);
RowExpression rightArgument = other.arguments.get(i);
if (leftArgument instanceof ConstantExpression) {
arguments.add(leftArgument);
} else if (rightArgument instanceof ConstantExpression) {
arguments.add(rightArgument);
} else if (leftArgument instanceof VariableReferenceExpression && rightArgument instanceof VariableReferenceExpression) {
VariableReferenceExpression leftVariable = (VariableReferenceExpression) leftArgument;
VariableReferenceExpression rightVariable = (VariableReferenceExpression) rightArgument;
checkArgument(leftVariable.getType().equals(rightVariable.getType()), "incompatible types: %s != %s", leftVariable.getType(), rightVariable.getType());
arguments.add(new SpecialFormExpression(COALESCE, leftVariable.getType(), ImmutableList.of(leftVariable, rightVariable)));
} else {
return Optional.empty();
}
}
return Optional.of(new Partitioning(metadata.isRefinedPartitioningOver(session, other.handle, this.handle) ? this.handle : other.handle, arguments.build()));
}
use of com.facebook.presto.spi.relation.SpecialFormExpression in project presto by prestodb.
the class TestCursorProcessorCompiler method testCompilerWithCSE.
@Test
public void testCompilerWithCSE() {
PageFunctionCompiler functionCompiler = new PageFunctionCompiler(METADATA, 0);
ExpressionCompiler expressionCompiler = new ExpressionCompiler(METADATA, functionCompiler);
RowExpression filter = new SpecialFormExpression(AND, BIGINT, ADD_X_Y_GREATER_THAN_2, ADD_X_Y_LESS_THAN_10);
List<? extends RowExpression> projections = createIfProjectionList(5);
Supplier<CursorProcessor> cseCursorProcessorSupplier = expressionCompiler.compileCursorProcessor(SESSION.getSqlFunctionProperties(), Optional.of(filter), projections, "key", true);
Supplier<CursorProcessor> noCseSECursorProcessorSupplier = expressionCompiler.compileCursorProcessor(SESSION.getSqlFunctionProperties(), Optional.of(filter), projections, "key", false);
Page input = createLongBlockPage(2, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9);
List<Type> types = ImmutableList.of(BIGINT, BIGINT);
PageBuilder pageBuilder = new PageBuilder(projections.stream().map(RowExpression::getType).collect(toList()));
RecordSet recordSet = new PageRecordSet(types, input);
cseCursorProcessorSupplier.get().process(SESSION.getSqlFunctionProperties(), new DriverYieldSignal(), recordSet.cursor(), pageBuilder);
Page pageFromCSE = pageBuilder.build();
pageBuilder.reset();
noCseSECursorProcessorSupplier.get().process(SESSION.getSqlFunctionProperties(), new DriverYieldSignal(), recordSet.cursor(), pageBuilder);
Page pageFromNoCSE = pageBuilder.build();
checkPageEqual(pageFromCSE, pageFromNoCSE);
}
use of com.facebook.presto.spi.relation.SpecialFormExpression in project presto by prestodb.
the class RowExpressionVerifier method visitSimpleCaseExpression.
@Override
protected Boolean visitSimpleCaseExpression(SimpleCaseExpression expected, RowExpression actual) {
if (!(actual instanceof SpecialFormExpression && ((SpecialFormExpression) actual).getForm().equals(SWITCH))) {
return false;
}
SpecialFormExpression actualCase = (SpecialFormExpression) actual;
if (!process(expected.getOperand(), actualCase.getArguments().get(0))) {
return false;
}
List<RowExpression> whenClauses;
Optional<RowExpression> elseValue;
RowExpression last = actualCase.getArguments().get(actualCase.getArguments().size() - 1);
if (last instanceof SpecialFormExpression && ((SpecialFormExpression) last).getForm().equals(WHEN)) {
whenClauses = actualCase.getArguments().subList(1, actualCase.getArguments().size());
elseValue = Optional.empty();
} else {
whenClauses = actualCase.getArguments().subList(1, actualCase.getArguments().size() - 1);
elseValue = Optional.of(last);
}
if (!process(expected.getWhenClauses(), whenClauses)) {
return false;
}
return process(expected.getDefaultValue(), elseValue);
}
use of com.facebook.presto.spi.relation.SpecialFormExpression in project presto by prestodb.
the class RowExpressionVerifier method visitDereferenceExpression.
@Override
protected Boolean visitDereferenceExpression(DereferenceExpression expected, RowExpression actual) {
if (!(actual instanceof SpecialFormExpression) || !(((SpecialFormExpression) actual).getForm().equals(DEREFERENCE))) {
return false;
}
SpecialFormExpression actualDereference = (SpecialFormExpression) actual;
if (actualDereference.getArguments().size() == 2 && actualDereference.getArguments().get(0).getType() instanceof RowType && actualDereference.getArguments().get(1) instanceof ConstantExpression) {
RowType rowType = (RowType) actualDereference.getArguments().get(0).getType();
Object value = LiteralInterpreter.evaluate(TEST_SESSION.toConnectorSession(), (ConstantExpression) actualDereference.getArguments().get(1));
checkState(value instanceof Long);
long index = (Long) value;
checkState(index >= 0 && index < rowType.getFields().size());
RowType.Field field = rowType.getFields().get(toIntExact(index));
checkState(field.getName().isPresent());
return expected.getField().getValue().equals(field.getName().get()) && process(expected.getBase(), actualDereference.getArguments().get(0));
}
return false;
}
use of com.facebook.presto.spi.relation.SpecialFormExpression in project urban-eureka by errir503.
the class RewriteAggregationIfToFilter method apply.
@Override
public Result apply(AggregationNode aggregationNode, Captures captures, Context context) {
ProjectNode sourceProject = captures.get(CHILD);
Set<Aggregation> aggregationsToRewrite = aggregationNode.getAggregations().values().stream().filter(aggregation -> shouldRewriteAggregation(aggregation, sourceProject)).collect(toImmutableSet());
if (aggregationsToRewrite.isEmpty()) {
return Result.empty();
}
context.getSession().getRuntimeStats().addMetricValue(REWRITE_AGGREGATION_IF_TO_FILTER_APPLIED, 1);
// Get the corresponding assignments in the input project.
// The aggregationReferences only has the aggregations to rewrite, thus the sourceAssignments only has IF/CAST(IF) expressions with NULL false results.
// Multiple aggregations may reference the same input. We use a map to dedup them based on the VariableReferenceExpression, so that we only do the rewrite once per input
// IF expression.
// The order of sourceAssignments determines the order of generating the new variables for the IF conditions and results. We use a sorted map to get a deterministic
// order based on the name of the VariableReferenceExpressions.
Map<VariableReferenceExpression, RowExpression> sourceAssignments = aggregationsToRewrite.stream().map(aggregation -> (VariableReferenceExpression) aggregation.getArguments().get(0)).collect(toImmutableSortedMap(VariableReferenceExpression::compareTo, identity(), variable -> sourceProject.getAssignments().get(variable), (left, right) -> left));
Assignments.Builder newAssignments = Assignments.builder();
newAssignments.putAll(sourceProject.getAssignments());
// Map from the aggregation reference to the IF condition reference which will be put in the mask.
Map<VariableReferenceExpression, VariableReferenceExpression> aggregationReferenceToConditionReference = new HashMap<>();
// Map from the aggregation reference to the IF result reference. This only contains the aggregates where the IF can be safely unwrapped.
// E.g., SUM(IF(CARDINALITY(array) > 0, array[1])) will not be included in this map as array[1] can return errors if we unwrap the IF.
Map<VariableReferenceExpression, VariableReferenceExpression> aggregationReferenceToIfResultReference = new HashMap<>();
AggregationIfToFilterRewriteStrategy rewriteStrategy = getAggregationIfToFilterRewriteStrategy(context.getSession());
for (Map.Entry<VariableReferenceExpression, RowExpression> entry : sourceAssignments.entrySet()) {
VariableReferenceExpression outputVariable = entry.getKey();
RowExpression rowExpression = entry.getValue();
SpecialFormExpression ifExpression = (SpecialFormExpression) ((rowExpression instanceof CallExpression) ? ((CallExpression) rowExpression).getArguments().get(0) : rowExpression);
RowExpression condition = ifExpression.getArguments().get(0);
VariableReferenceExpression conditionReference = context.getVariableAllocator().newVariable(condition);
newAssignments.put(conditionReference, condition);
aggregationReferenceToConditionReference.put(outputVariable, conditionReference);
if (canUnwrapIf(ifExpression, rewriteStrategy)) {
RowExpression trueResult = ifExpression.getArguments().get(1);
if (rowExpression instanceof CallExpression) {
// Wrap the result with CAST().
trueResult = new CallExpression(((CallExpression) rowExpression).getDisplayName(), ((CallExpression) rowExpression).getFunctionHandle(), rowExpression.getType(), ImmutableList.of(trueResult));
}
VariableReferenceExpression ifResultReference = context.getVariableAllocator().newVariable(trueResult);
newAssignments.put(ifResultReference, trueResult);
aggregationReferenceToIfResultReference.put(outputVariable, ifResultReference);
}
}
// Build new aggregations.
ImmutableMap.Builder<VariableReferenceExpression, Aggregation> aggregations = ImmutableMap.builder();
// Stores the masks used to build the filter predicates. Use set to dedup the predicates.
ImmutableSortedSet.Builder<VariableReferenceExpression> masks = ImmutableSortedSet.naturalOrder();
for (Map.Entry<VariableReferenceExpression, Aggregation> entry : aggregationNode.getAggregations().entrySet()) {
VariableReferenceExpression output = entry.getKey();
Aggregation aggregation = entry.getValue();
if (!aggregationsToRewrite.contains(aggregation)) {
aggregations.put(output, aggregation);
continue;
}
VariableReferenceExpression aggregationReference = (VariableReferenceExpression) aggregation.getArguments().get(0);
CallExpression callExpression = aggregation.getCall();
VariableReferenceExpression ifResultReference = aggregationReferenceToIfResultReference.get(aggregationReference);
if (ifResultReference != null) {
callExpression = new CallExpression(callExpression.getSourceLocation(), callExpression.getDisplayName(), callExpression.getFunctionHandle(), callExpression.getType(), ImmutableList.of(ifResultReference));
}
VariableReferenceExpression mask = aggregationReferenceToConditionReference.get(aggregationReference);
aggregations.put(output, new Aggregation(callExpression, Optional.empty(), aggregation.getOrderBy(), aggregation.isDistinct(), Optional.of(aggregationReferenceToConditionReference.get(aggregationReference))));
masks.add(mask);
}
RowExpression predicate = TRUE_CONSTANT;
if (!aggregationNode.hasNonEmptyGroupingSet() && aggregationsToRewrite.size() == aggregationNode.getAggregations().size()) {
// All aggregations are rewritten by this rule. We can add a filter with all the masks to make the query more efficient.
predicate = or(masks.build());
}
return Result.ofPlanNode(new AggregationNode(aggregationNode.getSourceLocation(), context.getIdAllocator().getNextId(), new FilterNode(aggregationNode.getSourceLocation(), context.getIdAllocator().getNextId(), new ProjectNode(context.getIdAllocator().getNextId(), sourceProject.getSource(), newAssignments.build()), predicate), aggregations.build(), aggregationNode.getGroupingSets(), aggregationNode.getPreGroupedVariables(), aggregationNode.getStep(), aggregationNode.getHashVariable(), aggregationNode.getGroupIdVariable()));
}
Aggregations