use of com.facebook.presto.spi.relation.CallExpression in project presto by prestodb.
the class RewriteAggregationIfToFilter method shouldRewriteAggregation.
private boolean shouldRewriteAggregation(Aggregation aggregation, ProjectNode sourceProject) {
if (functionAndTypeManager.getFunctionMetadata(aggregation.getFunctionHandle()).isCalledOnNullInput()) {
// This rewrite will filter out the null values. It could change the behavior if the aggregation is also applied on NULLs.
return false;
}
if (!(aggregation.getArguments().size() == 1 && aggregation.getArguments().get(0) instanceof VariableReferenceExpression)) {
// Currently we only handle aggregation with a single VariableReferenceExpression. The detailed expressions are in a project node below this aggregation.
return false;
}
if (aggregation.getFilter().isPresent() || aggregation.getMask().isPresent()) {
// Do not rewrite the aggregation if it already has a filter or mask.
return false;
}
RowExpression sourceExpression = sourceProject.getAssignments().get((VariableReferenceExpression) aggregation.getArguments().get(0));
if (sourceExpression instanceof CallExpression) {
CallExpression callExpression = (CallExpression) sourceExpression;
if (callExpression.getArguments().size() == 1 && standardFunctionResolution.isCastFunction(callExpression.getFunctionHandle())) {
// If the expression is CAST(), check the expression inside.
sourceExpression = callExpression.getArguments().get(0);
}
}
if (!(sourceExpression instanceof SpecialFormExpression) || !rowExpressionDeterminismEvaluator.isDeterministic(sourceExpression)) {
return false;
}
SpecialFormExpression expression = (SpecialFormExpression) sourceExpression;
// Only rewrite the aggregation if the else branch is not present or the else result is NULL.
return expression.getForm() == IF && Expressions.isNull(expression.getArguments().get(2));
}
use of com.facebook.presto.spi.relation.CallExpression in project presto by prestodb.
the class RewriteAggregationIfToFilter method canUnwrapIf.
private boolean canUnwrapIf(SpecialFormExpression ifExpression, AggregationIfToFilterRewriteStrategy rewriteStrategy) {
if (rewriteStrategy == FILTER_WITH_IF) {
return false;
}
// Some use cases use IF expression to avoid returning errors when evaluating the true branch. For example, IF(CARDINALITY(array) > 0, array[1])).
// We shouldn't unwrap the IF for those cases.
// But if the condition expression doesn't reference any variables referenced in the true branch, unwrapping the if should not cause exceptions for the true branch.
Set<VariableReferenceExpression> ifConditionReferences = VariablesExtractor.extractUnique(ifExpression.getArguments().get(0));
Set<VariableReferenceExpression> ifResultReferences = VariablesExtractor.extractUnique(ifExpression.getArguments().get(1));
if (ifConditionReferences.stream().noneMatch(ifResultReferences::contains)) {
return true;
}
if (rewriteStrategy != UNWRAP_IF) {
return false;
}
AtomicBoolean result = new AtomicBoolean(true);
ifExpression.getArguments().get(1).accept(new DefaultRowExpressionTraversalVisitor<AtomicBoolean>() {
@Override
public Void visitLambda(LambdaDefinitionExpression lambda, AtomicBoolean result) {
// Unwrapping the IF expression in the aggregate might cause issues if the true branch return errors for rows not matching the filters.
// To be safe, we don't unwrap the IF expressions when the true branch has lambdas.
result.set(false);
return null;
}
@Override
public Void visitCall(CallExpression call, AtomicBoolean result) {
Optional<OperatorType> operatorType = functionAndTypeManager.getFunctionMetadata(call.getFunctionHandle()).getOperatorType();
// For example, array[1] could return out of bound error and a / b could return DIVISION_BY_ZERO error. So we doesn't unwrap the IF expression in these cases.
if (operatorType.isPresent() && (operatorType.get() == OperatorType.DIVIDE || operatorType.get() == OperatorType.SUBSCRIPT)) {
result.set(false);
return null;
}
return super.visitCall(call, result);
}
}, result);
return result.get();
}
use of com.facebook.presto.spi.relation.CallExpression in project presto by prestodb.
the class TestTypeValidator method testInvalidAggregationFunctionCall.
@Test(expectedExceptions = IllegalArgumentException.class, expectedExceptionsMessageRegExp = "Expected input types are \\[double\\] but getting \\[bigint\\]")
public void testInvalidAggregationFunctionCall() {
VariableReferenceExpression aggregationVariable = variableAllocator.newVariable("sum", DOUBLE);
PlanNode node = new AggregationNode(Optional.empty(), newId(), baseTableScan, ImmutableMap.of(aggregationVariable, new Aggregation(new CallExpression("sum", SUM, DOUBLE, ImmutableList.of(variableA)), Optional.empty(), Optional.empty(), false, Optional.empty())), singleGroupingSet(ImmutableList.of(variableA, variableB)), ImmutableList.of(), SINGLE, Optional.empty(), Optional.empty());
assertTypesValid(node);
}
use of com.facebook.presto.spi.relation.CallExpression in project presto by prestodb.
the class TestTypeValidator method testValidAggregation.
@Test
public void testValidAggregation() {
VariableReferenceExpression aggregationVariable = variableAllocator.newVariable("sum", DOUBLE);
PlanNode node = new AggregationNode(Optional.empty(), newId(), baseTableScan, ImmutableMap.of(aggregationVariable, new Aggregation(new CallExpression("sum", SUM, DOUBLE, ImmutableList.of(variableC)), Optional.empty(), Optional.empty(), false, Optional.empty())), singleGroupingSet(ImmutableList.of(variableA, variableB)), ImmutableList.of(), SINGLE, Optional.empty(), Optional.empty());
assertTypesValid(node);
}
use of com.facebook.presto.spi.relation.CallExpression in project presto by prestodb.
the class TestRowExpressionRewriter method testSimple.
@Test
public void testSimple() {
// successful rewrite
RowExpression predicate = call(GREATER_THAN.name(), functionAndTypeManager.resolveOperator(GREATER_THAN, fromTypes(BIGINT, BIGINT)), BOOLEAN, constant(1L, BIGINT), constant(2L, BIGINT));
RowExpression negatedPredicate = rewrite(predicate);
assertEquals(negatedPredicate.getType(), BOOLEAN);
assertTrue(negatedPredicate instanceof CallExpression);
assertTrue(((CallExpression) negatedPredicate).getArguments().get(0) instanceof CallExpression);
assertEquals(((CallExpression) negatedPredicate).getDisplayName(), "not");
assertEquals(((CallExpression) ((CallExpression) negatedPredicate).getArguments().get(0)).getDisplayName(), GREATER_THAN.name());
// no rewrite
RowExpression nonPredicate = call(ADD.name(), functionAndTypeManager.resolveOperator(ADD, fromTypes(BIGINT, BIGINT)), BIGINT, constant(1L, BIGINT), constant(2L, BIGINT));
RowExpression samePredicate = rewrite(nonPredicate);
assertEquals(samePredicate, nonPredicate);
}
Aggregations