Search in sources :

Example 1 with NullIfExpression

use of io.trino.sql.tree.NullIfExpression in project trino by trinodb.

the class SimplifyFilterPredicate method simplifyFilterExpression.

private Optional<Expression> simplifyFilterExpression(Expression expression) {
    if (expression instanceof IfExpression) {
        IfExpression ifExpression = (IfExpression) expression;
        Expression condition = ifExpression.getCondition();
        Expression trueValue = ifExpression.getTrueValue();
        Optional<Expression> falseValue = ifExpression.getFalseValue();
        if (trueValue.equals(TRUE_LITERAL) && (falseValue.isEmpty() || isNotTrue(falseValue.get()))) {
            return Optional.of(condition);
        }
        if (isNotTrue(trueValue) && falseValue.isPresent() && falseValue.get().equals(TRUE_LITERAL)) {
            return Optional.of(isFalseOrNullPredicate(condition));
        }
        if (falseValue.isPresent() && falseValue.get().equals(trueValue) && isDeterministic(trueValue, metadata)) {
            return Optional.of(trueValue);
        }
        if (isNotTrue(trueValue) && (falseValue.isEmpty() || isNotTrue(falseValue.get()))) {
            return Optional.of(FALSE_LITERAL);
        }
        if (condition.equals(TRUE_LITERAL)) {
            return Optional.of(trueValue);
        }
        if (isNotTrue(condition)) {
            return Optional.of(falseValue.orElse(FALSE_LITERAL));
        }
        return Optional.empty();
    }
    if (expression instanceof NullIfExpression) {
        NullIfExpression nullIfExpression = (NullIfExpression) expression;
        return Optional.of(LogicalExpression.and(nullIfExpression.getFirst(), isFalseOrNullPredicate(nullIfExpression.getSecond())));
    }
    if (expression instanceof SearchedCaseExpression) {
        SearchedCaseExpression caseExpression = (SearchedCaseExpression) expression;
        Optional<Expression> defaultValue = caseExpression.getDefaultValue();
        List<Expression> operands = caseExpression.getWhenClauses().stream().map(WhenClause::getOperand).collect(toImmutableList());
        List<Expression> results = caseExpression.getWhenClauses().stream().map(WhenClause::getResult).collect(toImmutableList());
        long trueResultsCount = results.stream().filter(result -> result.equals(TRUE_LITERAL)).count();
        long notTrueResultsCount = results.stream().filter(SimplifyFilterPredicate::isNotTrue).count();
        // all results true
        if (trueResultsCount == results.size() && defaultValue.isPresent() && defaultValue.get().equals(TRUE_LITERAL)) {
            return Optional.of(TRUE_LITERAL);
        }
        // all results not true
        if (notTrueResultsCount == results.size() && (defaultValue.isEmpty() || isNotTrue(defaultValue.get()))) {
            return Optional.of(FALSE_LITERAL);
        }
        // one result true, and remaining results not true
        if (trueResultsCount == 1 && notTrueResultsCount == results.size() - 1 && (defaultValue.isEmpty() || isNotTrue(defaultValue.get()))) {
            ImmutableList.Builder<Expression> builder = ImmutableList.builder();
            for (WhenClause whenClause : caseExpression.getWhenClauses()) {
                Expression operand = whenClause.getOperand();
                Expression result = whenClause.getResult();
                if (isNotTrue(result)) {
                    builder.add(isFalseOrNullPredicate(operand));
                } else {
                    builder.add(operand);
                    return Optional.of(combineConjuncts(metadata, builder.build()));
                }
            }
        }
        // all results not true, and default true
        if (notTrueResultsCount == results.size() && defaultValue.isPresent() && defaultValue.get().equals(TRUE_LITERAL)) {
            ImmutableList.Builder<Expression> builder = ImmutableList.builder();
            operands.stream().forEach(operand -> builder.add(isFalseOrNullPredicate(operand)));
            return Optional.of(combineConjuncts(metadata, builder.build()));
        }
        // skip clauses with not true conditions
        List<WhenClause> whenClauses = new ArrayList<>();
        for (WhenClause whenClause : caseExpression.getWhenClauses()) {
            Expression operand = whenClause.getOperand();
            if (operand.equals(TRUE_LITERAL)) {
                if (whenClauses.isEmpty()) {
                    return Optional.of(whenClause.getResult());
                }
                return Optional.of(new SearchedCaseExpression(whenClauses, Optional.of(whenClause.getResult())));
            }
            if (!isNotTrue(operand)) {
                whenClauses.add(whenClause);
            }
        }
        if (whenClauses.isEmpty()) {
            return Optional.of(defaultValue.orElse(FALSE_LITERAL));
        }
        if (whenClauses.size() < caseExpression.getWhenClauses().size()) {
            return Optional.of(new SearchedCaseExpression(whenClauses, defaultValue));
        }
        return Optional.empty();
    }
    if (expression instanceof SimpleCaseExpression) {
        SimpleCaseExpression caseExpression = (SimpleCaseExpression) expression;
        Optional<Expression> defaultValue = caseExpression.getDefaultValue();
        if (caseExpression.getOperand() instanceof NullLiteral) {
            return Optional.of(defaultValue.orElse(FALSE_LITERAL));
        }
        List<Expression> results = caseExpression.getWhenClauses().stream().map(WhenClause::getResult).collect(toImmutableList());
        if (results.stream().allMatch(result -> result.equals(TRUE_LITERAL)) && defaultValue.isPresent() && defaultValue.get().equals(TRUE_LITERAL)) {
            return Optional.of(TRUE_LITERAL);
        }
        if (results.stream().allMatch(SimplifyFilterPredicate::isNotTrue) && (defaultValue.isEmpty() || isNotTrue(defaultValue.get()))) {
            return Optional.of(FALSE_LITERAL);
        }
        return Optional.empty();
    }
    return Optional.empty();
}
Also used : IsNullPredicate(io.trino.sql.tree.IsNullPredicate) SimpleCaseExpression(io.trino.sql.tree.SimpleCaseExpression) Patterns.filter(io.trino.sql.planner.plan.Patterns.filter) NullIfExpression(io.trino.sql.tree.NullIfExpression) FilterNode(io.trino.sql.planner.plan.FilterNode) ArrayList(java.util.ArrayList) Cast(io.trino.sql.tree.Cast) ImmutableList(com.google.common.collect.ImmutableList) SearchedCaseExpression(io.trino.sql.tree.SearchedCaseExpression) DeterminismEvaluator.isDeterministic(io.trino.sql.planner.DeterminismEvaluator.isDeterministic) NotExpression(io.trino.sql.tree.NotExpression) NullLiteral(io.trino.sql.tree.NullLiteral) Rule(io.trino.sql.planner.iterative.Rule) WhenClause(io.trino.sql.tree.WhenClause) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) FALSE_LITERAL(io.trino.sql.tree.BooleanLiteral.FALSE_LITERAL) TRUE_LITERAL(io.trino.sql.tree.BooleanLiteral.TRUE_LITERAL) List(java.util.List) Pattern(io.trino.matching.Pattern) IfExpression(io.trino.sql.tree.IfExpression) Captures(io.trino.matching.Captures) LogicalExpression(io.trino.sql.tree.LogicalExpression) Metadata(io.trino.metadata.Metadata) Optional(java.util.Optional) Expression(io.trino.sql.tree.Expression) ExpressionUtils.extractConjuncts(io.trino.sql.ExpressionUtils.extractConjuncts) ExpressionUtils.combineConjuncts(io.trino.sql.ExpressionUtils.combineConjuncts) NullIfExpression(io.trino.sql.tree.NullIfExpression) IfExpression(io.trino.sql.tree.IfExpression) ImmutableList(com.google.common.collect.ImmutableList) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) ArrayList(java.util.ArrayList) SimpleCaseExpression(io.trino.sql.tree.SimpleCaseExpression) WhenClause(io.trino.sql.tree.WhenClause) SearchedCaseExpression(io.trino.sql.tree.SearchedCaseExpression) SimpleCaseExpression(io.trino.sql.tree.SimpleCaseExpression) NullIfExpression(io.trino.sql.tree.NullIfExpression) SearchedCaseExpression(io.trino.sql.tree.SearchedCaseExpression) NotExpression(io.trino.sql.tree.NotExpression) IfExpression(io.trino.sql.tree.IfExpression) LogicalExpression(io.trino.sql.tree.LogicalExpression) Expression(io.trino.sql.tree.Expression) NullIfExpression(io.trino.sql.tree.NullIfExpression) NullLiteral(io.trino.sql.tree.NullLiteral)

Example 2 with NullIfExpression

use of io.trino.sql.tree.NullIfExpression in project trino by trinodb.

the class AstBuilder method visitFunctionCall.

@Override
public Node visitFunctionCall(SqlBaseParser.FunctionCallContext context) {
    Optional<Expression> filter = visitIfPresent(context.filter(), Expression.class);
    Optional<Window> window = visitIfPresent(context.over(), Window.class);
    Optional<OrderBy> orderBy = Optional.empty();
    if (context.ORDER() != null) {
        orderBy = Optional.of(new OrderBy(visit(context.sortItem(), SortItem.class)));
    }
    QualifiedName name = getQualifiedName(context.qualifiedName());
    boolean distinct = isDistinct(context.setQuantifier());
    SqlBaseParser.NullTreatmentContext nullTreatment = context.nullTreatment();
    SqlBaseParser.ProcessingModeContext processingMode = context.processingMode();
    if (name.toString().equalsIgnoreCase("if")) {
        check(context.expression().size() == 2 || context.expression().size() == 3, "Invalid number of arguments for 'if' function", context);
        check(!window.isPresent(), "OVER clause not valid for 'if' function", context);
        check(!distinct, "DISTINCT not valid for 'if' function", context);
        check(nullTreatment == null, "Null treatment clause not valid for 'if' function", context);
        check(processingMode == null, "Running or final semantics not valid for 'if' function", context);
        check(!filter.isPresent(), "FILTER not valid for 'if' function", context);
        Expression elseExpression = null;
        if (context.expression().size() == 3) {
            elseExpression = (Expression) visit(context.expression(2));
        }
        return new IfExpression(getLocation(context), (Expression) visit(context.expression(0)), (Expression) visit(context.expression(1)), elseExpression);
    }
    if (name.toString().equalsIgnoreCase("nullif")) {
        check(context.expression().size() == 2, "Invalid number of arguments for 'nullif' function", context);
        check(!window.isPresent(), "OVER clause not valid for 'nullif' function", context);
        check(!distinct, "DISTINCT not valid for 'nullif' function", context);
        check(nullTreatment == null, "Null treatment clause not valid for 'nullif' function", context);
        check(processingMode == null, "Running or final semantics not valid for 'nullif' function", context);
        check(!filter.isPresent(), "FILTER not valid for 'nullif' function", context);
        return new NullIfExpression(getLocation(context), (Expression) visit(context.expression(0)), (Expression) visit(context.expression(1)));
    }
    if (name.toString().equalsIgnoreCase("coalesce")) {
        check(context.expression().size() >= 2, "The 'coalesce' function must have at least two arguments", context);
        check(!window.isPresent(), "OVER clause not valid for 'coalesce' function", context);
        check(!distinct, "DISTINCT not valid for 'coalesce' function", context);
        check(nullTreatment == null, "Null treatment clause not valid for 'coalesce' function", context);
        check(processingMode == null, "Running or final semantics not valid for 'coalesce' function", context);
        check(!filter.isPresent(), "FILTER not valid for 'coalesce' function", context);
        return new CoalesceExpression(getLocation(context), visit(context.expression(), Expression.class));
    }
    if (name.toString().equalsIgnoreCase("try")) {
        check(context.expression().size() == 1, "The 'try' function must have exactly one argument", context);
        check(!window.isPresent(), "OVER clause not valid for 'try' function", context);
        check(!distinct, "DISTINCT not valid for 'try' function", context);
        check(nullTreatment == null, "Null treatment clause not valid for 'try' function", context);
        check(processingMode == null, "Running or final semantics not valid for 'try' function", context);
        check(!filter.isPresent(), "FILTER not valid for 'try' function", context);
        return new TryExpression(getLocation(context), (Expression) visit(getOnlyElement(context.expression())));
    }
    if (name.toString().equalsIgnoreCase("format")) {
        check(context.expression().size() >= 2, "The 'format' function must have at least two arguments", context);
        check(!window.isPresent(), "OVER clause not valid for 'format' function", context);
        check(!distinct, "DISTINCT not valid for 'format' function", context);
        check(nullTreatment == null, "Null treatment clause not valid for 'format' function", context);
        check(processingMode == null, "Running or final semantics not valid for 'format' function", context);
        check(!filter.isPresent(), "FILTER not valid for 'format' function", context);
        return new Format(getLocation(context), visit(context.expression(), Expression.class));
    }
    if (name.toString().equalsIgnoreCase("$internal$bind")) {
        check(context.expression().size() >= 1, "The '$internal$bind' function must have at least one arguments", context);
        check(!window.isPresent(), "OVER clause not valid for '$internal$bind' function", context);
        check(!distinct, "DISTINCT not valid for '$internal$bind' function", context);
        check(nullTreatment == null, "Null treatment clause not valid for '$internal$bind' function", context);
        check(processingMode == null, "Running or final semantics not valid for '$internal$bind' function", context);
        check(!filter.isPresent(), "FILTER not valid for '$internal$bind' function", context);
        int numValues = context.expression().size() - 1;
        List<Expression> arguments = context.expression().stream().map(this::visit).map(Expression.class::cast).collect(toImmutableList());
        return new BindExpression(getLocation(context), arguments.subList(0, numValues), arguments.get(numValues));
    }
    Optional<NullTreatment> nulls = Optional.empty();
    if (nullTreatment != null) {
        if (nullTreatment.IGNORE() != null) {
            nulls = Optional.of(NullTreatment.IGNORE);
        } else if (nullTreatment.RESPECT() != null) {
            nulls = Optional.of(NullTreatment.RESPECT);
        }
    }
    Optional<ProcessingMode> mode = Optional.empty();
    if (processingMode != null) {
        if (processingMode.RUNNING() != null) {
            mode = Optional.of(new ProcessingMode(getLocation(processingMode), RUNNING));
        } else if (processingMode.FINAL() != null) {
            mode = Optional.of(new ProcessingMode(getLocation(processingMode), FINAL));
        }
    }
    List<Expression> arguments = visit(context.expression(), Expression.class);
    if (context.label != null) {
        arguments = ImmutableList.of(new DereferenceExpression(getLocation(context.label), (Identifier) visit(context.label)));
    }
    return new FunctionCall(Optional.of(getLocation(context)), name, window, filter, orderBy, distinct, nulls, mode, arguments);
}
Also used : ProcessingMode(io.trino.sql.tree.ProcessingMode) SortItem(io.trino.sql.tree.SortItem) Format(io.trino.sql.tree.Format) ExplainFormat(io.trino.sql.tree.ExplainFormat) FunctionCall(io.trino.sql.tree.FunctionCall) Window(io.trino.sql.tree.Window) OrderBy(io.trino.sql.tree.OrderBy) NullIfExpression(io.trino.sql.tree.NullIfExpression) IfExpression(io.trino.sql.tree.IfExpression) DereferenceExpression(io.trino.sql.tree.DereferenceExpression) QualifiedName(io.trino.sql.tree.QualifiedName) TryExpression(io.trino.sql.tree.TryExpression) NullTreatment(io.trino.sql.tree.FunctionCall.NullTreatment) DereferenceExpression(io.trino.sql.tree.DereferenceExpression) LogicalExpression(io.trino.sql.tree.LogicalExpression) SearchedCaseExpression(io.trino.sql.tree.SearchedCaseExpression) BindExpression(io.trino.sql.tree.BindExpression) CoalesceExpression(io.trino.sql.tree.CoalesceExpression) QuantifiedComparisonExpression(io.trino.sql.tree.QuantifiedComparisonExpression) SimpleCaseExpression(io.trino.sql.tree.SimpleCaseExpression) SubqueryExpression(io.trino.sql.tree.SubqueryExpression) LambdaExpression(io.trino.sql.tree.LambdaExpression) SubscriptExpression(io.trino.sql.tree.SubscriptExpression) NullIfExpression(io.trino.sql.tree.NullIfExpression) ArithmeticUnaryExpression(io.trino.sql.tree.ArithmeticUnaryExpression) InListExpression(io.trino.sql.tree.InListExpression) NotExpression(io.trino.sql.tree.NotExpression) ArithmeticBinaryExpression(io.trino.sql.tree.ArithmeticBinaryExpression) TryExpression(io.trino.sql.tree.TryExpression) ComparisonExpression(io.trino.sql.tree.ComparisonExpression) IfExpression(io.trino.sql.tree.IfExpression) Expression(io.trino.sql.tree.Expression) BindExpression(io.trino.sql.tree.BindExpression) NullIfExpression(io.trino.sql.tree.NullIfExpression) CoalesceExpression(io.trino.sql.tree.CoalesceExpression)

Example 3 with NullIfExpression

use of io.trino.sql.tree.NullIfExpression in project trino by trinodb.

the class TestEqualityInference method testExpressionsThatMayReturnNullOnNonNullInput.

@Test
public void testExpressionsThatMayReturnNullOnNonNullInput() {
    List<Expression> candidates = ImmutableList.of(// try_cast
    new Cast(nameReference("b"), toSqlType(BIGINT), true), functionResolution.functionCallBuilder(QualifiedName.of(TryFunction.NAME)).addArgument(new FunctionType(ImmutableList.of(), VARCHAR), new LambdaExpression(ImmutableList.of(), nameReference("b"))).build(), new NullIfExpression(nameReference("b"), number(1)), new IfExpression(nameReference("b"), number(1), new NullLiteral()), new InPredicate(nameReference("b"), new InListExpression(ImmutableList.of(new NullLiteral()))), new SearchedCaseExpression(ImmutableList.of(new WhenClause(new IsNotNullPredicate(nameReference("b")), new NullLiteral())), Optional.empty()), new SimpleCaseExpression(nameReference("b"), ImmutableList.of(new WhenClause(number(1), new NullLiteral())), Optional.empty()), new SubscriptExpression(new ArrayConstructor(ImmutableList.of(new NullLiteral())), nameReference("b")));
    for (Expression candidate : candidates) {
        EqualityInference inference = EqualityInference.newInstance(metadata, equals(nameReference("b"), nameReference("x")), equals(nameReference("a"), candidate));
        List<Expression> equalities = inference.generateEqualitiesPartitionedBy(symbols("b")).getScopeStraddlingEqualities();
        assertEquals(equalities.size(), 1);
        assertTrue(equalities.get(0).equals(equals(nameReference("x"), nameReference("b"))) || equalities.get(0).equals(equals(nameReference("b"), nameReference("x"))));
    }
}
Also used : Cast(io.trino.sql.tree.Cast) NullIfExpression(io.trino.sql.tree.NullIfExpression) IfExpression(io.trino.sql.tree.IfExpression) FunctionType(io.trino.type.FunctionType) InListExpression(io.trino.sql.tree.InListExpression) InPredicate(io.trino.sql.tree.InPredicate) IsNotNullPredicate(io.trino.sql.tree.IsNotNullPredicate) SimpleCaseExpression(io.trino.sql.tree.SimpleCaseExpression) WhenClause(io.trino.sql.tree.WhenClause) SearchedCaseExpression(io.trino.sql.tree.SearchedCaseExpression) SimpleCaseExpression(io.trino.sql.tree.SimpleCaseExpression) LambdaExpression(io.trino.sql.tree.LambdaExpression) NullIfExpression(io.trino.sql.tree.NullIfExpression) SubscriptExpression(io.trino.sql.tree.SubscriptExpression) SearchedCaseExpression(io.trino.sql.tree.SearchedCaseExpression) ArithmeticBinaryExpression(io.trino.sql.tree.ArithmeticBinaryExpression) ComparisonExpression(io.trino.sql.tree.ComparisonExpression) IfExpression(io.trino.sql.tree.IfExpression) Expression(io.trino.sql.tree.Expression) InListExpression(io.trino.sql.tree.InListExpression) NullIfExpression(io.trino.sql.tree.NullIfExpression) SubscriptExpression(io.trino.sql.tree.SubscriptExpression) ArrayConstructor(io.trino.sql.tree.ArrayConstructor) LambdaExpression(io.trino.sql.tree.LambdaExpression) NullLiteral(io.trino.sql.tree.NullLiteral) Test(org.testng.annotations.Test)

Example 4 with NullIfExpression

use of io.trino.sql.tree.NullIfExpression in project trino by trinodb.

the class TestSqlParser method testNullIf.

@Test
public void testNullIf() {
    assertExpression("nullif(42, 87)", new NullIfExpression(new LongLiteral("42"), new LongLiteral("87")));
    assertExpression("nullif(42, null)", new NullIfExpression(new LongLiteral("42"), new NullLiteral()));
    assertExpression("nullif(null, null)", new NullIfExpression(new NullLiteral(), new NullLiteral()));
    assertInvalidExpression("nullif(1)", "Invalid number of arguments for 'nullif' function");
    assertInvalidExpression("nullif(1, 2, 3)", "Invalid number of arguments for 'nullif' function");
    assertInvalidExpression("nullif(42, 87) filter (where true)", "FILTER not valid for 'nullif' function");
    assertInvalidExpression("nullif(42, 87) OVER ()", "OVER clause not valid for 'nullif' function");
}
Also used : LongLiteral(io.trino.sql.tree.LongLiteral) NullIfExpression(io.trino.sql.tree.NullIfExpression) NullLiteral(io.trino.sql.tree.NullLiteral) Test(org.junit.jupiter.api.Test)

Aggregations

NullIfExpression (io.trino.sql.tree.NullIfExpression)4 Expression (io.trino.sql.tree.Expression)3 IfExpression (io.trino.sql.tree.IfExpression)3 NullLiteral (io.trino.sql.tree.NullLiteral)3 SearchedCaseExpression (io.trino.sql.tree.SearchedCaseExpression)3 SimpleCaseExpression (io.trino.sql.tree.SimpleCaseExpression)3 ArithmeticBinaryExpression (io.trino.sql.tree.ArithmeticBinaryExpression)2 Cast (io.trino.sql.tree.Cast)2 ComparisonExpression (io.trino.sql.tree.ComparisonExpression)2 InListExpression (io.trino.sql.tree.InListExpression)2 LambdaExpression (io.trino.sql.tree.LambdaExpression)2 LogicalExpression (io.trino.sql.tree.LogicalExpression)2 NotExpression (io.trino.sql.tree.NotExpression)2 SubscriptExpression (io.trino.sql.tree.SubscriptExpression)2 WhenClause (io.trino.sql.tree.WhenClause)2 ImmutableList (com.google.common.collect.ImmutableList)1 ImmutableList.toImmutableList (com.google.common.collect.ImmutableList.toImmutableList)1 Captures (io.trino.matching.Captures)1 Pattern (io.trino.matching.Pattern)1 Metadata (io.trino.metadata.Metadata)1