use of io.trino.sql.tree.SimpleCaseExpression in project trino by trinodb.
the class SimplifyFilterPredicate method simplifyFilterExpression.
private Optional<Expression> simplifyFilterExpression(Expression expression) {
if (expression instanceof IfExpression) {
IfExpression ifExpression = (IfExpression) expression;
Expression condition = ifExpression.getCondition();
Expression trueValue = ifExpression.getTrueValue();
Optional<Expression> falseValue = ifExpression.getFalseValue();
if (trueValue.equals(TRUE_LITERAL) && (falseValue.isEmpty() || isNotTrue(falseValue.get()))) {
return Optional.of(condition);
}
if (isNotTrue(trueValue) && falseValue.isPresent() && falseValue.get().equals(TRUE_LITERAL)) {
return Optional.of(isFalseOrNullPredicate(condition));
}
if (falseValue.isPresent() && falseValue.get().equals(trueValue) && isDeterministic(trueValue, metadata)) {
return Optional.of(trueValue);
}
if (isNotTrue(trueValue) && (falseValue.isEmpty() || isNotTrue(falseValue.get()))) {
return Optional.of(FALSE_LITERAL);
}
if (condition.equals(TRUE_LITERAL)) {
return Optional.of(trueValue);
}
if (isNotTrue(condition)) {
return Optional.of(falseValue.orElse(FALSE_LITERAL));
}
return Optional.empty();
}
if (expression instanceof NullIfExpression) {
NullIfExpression nullIfExpression = (NullIfExpression) expression;
return Optional.of(LogicalExpression.and(nullIfExpression.getFirst(), isFalseOrNullPredicate(nullIfExpression.getSecond())));
}
if (expression instanceof SearchedCaseExpression) {
SearchedCaseExpression caseExpression = (SearchedCaseExpression) expression;
Optional<Expression> defaultValue = caseExpression.getDefaultValue();
List<Expression> operands = caseExpression.getWhenClauses().stream().map(WhenClause::getOperand).collect(toImmutableList());
List<Expression> results = caseExpression.getWhenClauses().stream().map(WhenClause::getResult).collect(toImmutableList());
long trueResultsCount = results.stream().filter(result -> result.equals(TRUE_LITERAL)).count();
long notTrueResultsCount = results.stream().filter(SimplifyFilterPredicate::isNotTrue).count();
// all results true
if (trueResultsCount == results.size() && defaultValue.isPresent() && defaultValue.get().equals(TRUE_LITERAL)) {
return Optional.of(TRUE_LITERAL);
}
// all results not true
if (notTrueResultsCount == results.size() && (defaultValue.isEmpty() || isNotTrue(defaultValue.get()))) {
return Optional.of(FALSE_LITERAL);
}
// one result true, and remaining results not true
if (trueResultsCount == 1 && notTrueResultsCount == results.size() - 1 && (defaultValue.isEmpty() || isNotTrue(defaultValue.get()))) {
ImmutableList.Builder<Expression> builder = ImmutableList.builder();
for (WhenClause whenClause : caseExpression.getWhenClauses()) {
Expression operand = whenClause.getOperand();
Expression result = whenClause.getResult();
if (isNotTrue(result)) {
builder.add(isFalseOrNullPredicate(operand));
} else {
builder.add(operand);
return Optional.of(combineConjuncts(metadata, builder.build()));
}
}
}
// all results not true, and default true
if (notTrueResultsCount == results.size() && defaultValue.isPresent() && defaultValue.get().equals(TRUE_LITERAL)) {
ImmutableList.Builder<Expression> builder = ImmutableList.builder();
operands.stream().forEach(operand -> builder.add(isFalseOrNullPredicate(operand)));
return Optional.of(combineConjuncts(metadata, builder.build()));
}
// skip clauses with not true conditions
List<WhenClause> whenClauses = new ArrayList<>();
for (WhenClause whenClause : caseExpression.getWhenClauses()) {
Expression operand = whenClause.getOperand();
if (operand.equals(TRUE_LITERAL)) {
if (whenClauses.isEmpty()) {
return Optional.of(whenClause.getResult());
}
return Optional.of(new SearchedCaseExpression(whenClauses, Optional.of(whenClause.getResult())));
}
if (!isNotTrue(operand)) {
whenClauses.add(whenClause);
}
}
if (whenClauses.isEmpty()) {
return Optional.of(defaultValue.orElse(FALSE_LITERAL));
}
if (whenClauses.size() < caseExpression.getWhenClauses().size()) {
return Optional.of(new SearchedCaseExpression(whenClauses, defaultValue));
}
return Optional.empty();
}
if (expression instanceof SimpleCaseExpression) {
SimpleCaseExpression caseExpression = (SimpleCaseExpression) expression;
Optional<Expression> defaultValue = caseExpression.getDefaultValue();
if (caseExpression.getOperand() instanceof NullLiteral) {
return Optional.of(defaultValue.orElse(FALSE_LITERAL));
}
List<Expression> results = caseExpression.getWhenClauses().stream().map(WhenClause::getResult).collect(toImmutableList());
if (results.stream().allMatch(result -> result.equals(TRUE_LITERAL)) && defaultValue.isPresent() && defaultValue.get().equals(TRUE_LITERAL)) {
return Optional.of(TRUE_LITERAL);
}
if (results.stream().allMatch(SimplifyFilterPredicate::isNotTrue) && (defaultValue.isEmpty() || isNotTrue(defaultValue.get()))) {
return Optional.of(FALSE_LITERAL);
}
return Optional.empty();
}
return Optional.empty();
}
use of io.trino.sql.tree.SimpleCaseExpression in project trino by trinodb.
the class TransformCorrelatedScalarSubquery method apply.
@Override
public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Context context) {
PlanNode subquery = context.getLookup().resolve(correlatedJoinNode.getSubquery());
if (!searchFrom(subquery, context.getLookup()).where(EnforceSingleRowNode.class::isInstance).recurseOnlyWhen(ProjectNode.class::isInstance).matches()) {
return Result.empty();
}
PlanNode rewrittenSubquery = searchFrom(subquery, context.getLookup()).where(EnforceSingleRowNode.class::isInstance).recurseOnlyWhen(ProjectNode.class::isInstance).removeFirst();
Range<Long> subqueryCardinality = extractCardinality(rewrittenSubquery, context.getLookup());
boolean producesAtMostOneRow = Range.closed(0L, 1L).encloses(subqueryCardinality);
if (producesAtMostOneRow) {
boolean producesSingleRow = Range.singleton(1L).encloses(subqueryCardinality);
return Result.ofPlanNode(new CorrelatedJoinNode(context.getIdAllocator().getNextId(), correlatedJoinNode.getInput(), rewrittenSubquery, correlatedJoinNode.getCorrelation(), producesSingleRow ? correlatedJoinNode.getType() : LEFT, correlatedJoinNode.getFilter(), correlatedJoinNode.getOriginSubquery()));
}
Symbol unique = context.getSymbolAllocator().newSymbol("unique", BigintType.BIGINT);
CorrelatedJoinNode rewrittenCorrelatedJoinNode = new CorrelatedJoinNode(context.getIdAllocator().getNextId(), new AssignUniqueId(context.getIdAllocator().getNextId(), correlatedJoinNode.getInput(), unique), rewrittenSubquery, correlatedJoinNode.getCorrelation(), LEFT, correlatedJoinNode.getFilter(), correlatedJoinNode.getOriginSubquery());
Symbol isDistinct = context.getSymbolAllocator().newSymbol("is_distinct", BOOLEAN);
MarkDistinctNode markDistinctNode = new MarkDistinctNode(context.getIdAllocator().getNextId(), rewrittenCorrelatedJoinNode, isDistinct, rewrittenCorrelatedJoinNode.getInput().getOutputSymbols(), Optional.empty());
FilterNode filterNode = new FilterNode(context.getIdAllocator().getNextId(), markDistinctNode, new SimpleCaseExpression(isDistinct.toSymbolReference(), ImmutableList.of(new WhenClause(TRUE_LITERAL, TRUE_LITERAL)), Optional.of(new Cast(FunctionCallBuilder.resolve(context.getSession(), metadata).setName(QualifiedName.of("fail")).addArgument(INTEGER, new LongLiteral(Integer.toString(SUBQUERY_MULTIPLE_ROWS.toErrorCode().getCode()))).addArgument(VARCHAR, new StringLiteral("Scalar sub-query has returned multiple rows")).build(), toSqlType(BOOLEAN)))));
return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), filterNode, Assignments.identity(correlatedJoinNode.getOutputSymbols())));
}
use of io.trino.sql.tree.SimpleCaseExpression in project trino by trinodb.
the class TestEqualityInference method testExpressionsThatMayReturnNullOnNonNullInput.
@Test
public void testExpressionsThatMayReturnNullOnNonNullInput() {
List<Expression> candidates = ImmutableList.of(// try_cast
new Cast(nameReference("b"), toSqlType(BIGINT), true), functionResolution.functionCallBuilder(QualifiedName.of(TryFunction.NAME)).addArgument(new FunctionType(ImmutableList.of(), VARCHAR), new LambdaExpression(ImmutableList.of(), nameReference("b"))).build(), new NullIfExpression(nameReference("b"), number(1)), new IfExpression(nameReference("b"), number(1), new NullLiteral()), new InPredicate(nameReference("b"), new InListExpression(ImmutableList.of(new NullLiteral()))), new SearchedCaseExpression(ImmutableList.of(new WhenClause(new IsNotNullPredicate(nameReference("b")), new NullLiteral())), Optional.empty()), new SimpleCaseExpression(nameReference("b"), ImmutableList.of(new WhenClause(number(1), new NullLiteral())), Optional.empty()), new SubscriptExpression(new ArrayConstructor(ImmutableList.of(new NullLiteral())), nameReference("b")));
for (Expression candidate : candidates) {
EqualityInference inference = EqualityInference.newInstance(metadata, equals(nameReference("b"), nameReference("x")), equals(nameReference("a"), candidate));
List<Expression> equalities = inference.generateEqualitiesPartitionedBy(symbols("b")).getScopeStraddlingEqualities();
assertEquals(equalities.size(), 1);
assertTrue(equalities.get(0).equals(equals(nameReference("x"), nameReference("b"))) || equalities.get(0).equals(equals(nameReference("b"), nameReference("x"))));
}
}
Aggregations