use of io.trino.sql.tree.SearchedCaseExpression in project trino by trinodb.
the class SimplifyFilterPredicate method simplifyFilterExpression.
private Optional<Expression> simplifyFilterExpression(Expression expression) {
if (expression instanceof IfExpression) {
IfExpression ifExpression = (IfExpression) expression;
Expression condition = ifExpression.getCondition();
Expression trueValue = ifExpression.getTrueValue();
Optional<Expression> falseValue = ifExpression.getFalseValue();
if (trueValue.equals(TRUE_LITERAL) && (falseValue.isEmpty() || isNotTrue(falseValue.get()))) {
return Optional.of(condition);
}
if (isNotTrue(trueValue) && falseValue.isPresent() && falseValue.get().equals(TRUE_LITERAL)) {
return Optional.of(isFalseOrNullPredicate(condition));
}
if (falseValue.isPresent() && falseValue.get().equals(trueValue) && isDeterministic(trueValue, metadata)) {
return Optional.of(trueValue);
}
if (isNotTrue(trueValue) && (falseValue.isEmpty() || isNotTrue(falseValue.get()))) {
return Optional.of(FALSE_LITERAL);
}
if (condition.equals(TRUE_LITERAL)) {
return Optional.of(trueValue);
}
if (isNotTrue(condition)) {
return Optional.of(falseValue.orElse(FALSE_LITERAL));
}
return Optional.empty();
}
if (expression instanceof NullIfExpression) {
NullIfExpression nullIfExpression = (NullIfExpression) expression;
return Optional.of(LogicalExpression.and(nullIfExpression.getFirst(), isFalseOrNullPredicate(nullIfExpression.getSecond())));
}
if (expression instanceof SearchedCaseExpression) {
SearchedCaseExpression caseExpression = (SearchedCaseExpression) expression;
Optional<Expression> defaultValue = caseExpression.getDefaultValue();
List<Expression> operands = caseExpression.getWhenClauses().stream().map(WhenClause::getOperand).collect(toImmutableList());
List<Expression> results = caseExpression.getWhenClauses().stream().map(WhenClause::getResult).collect(toImmutableList());
long trueResultsCount = results.stream().filter(result -> result.equals(TRUE_LITERAL)).count();
long notTrueResultsCount = results.stream().filter(SimplifyFilterPredicate::isNotTrue).count();
// all results true
if (trueResultsCount == results.size() && defaultValue.isPresent() && defaultValue.get().equals(TRUE_LITERAL)) {
return Optional.of(TRUE_LITERAL);
}
// all results not true
if (notTrueResultsCount == results.size() && (defaultValue.isEmpty() || isNotTrue(defaultValue.get()))) {
return Optional.of(FALSE_LITERAL);
}
// one result true, and remaining results not true
if (trueResultsCount == 1 && notTrueResultsCount == results.size() - 1 && (defaultValue.isEmpty() || isNotTrue(defaultValue.get()))) {
ImmutableList.Builder<Expression> builder = ImmutableList.builder();
for (WhenClause whenClause : caseExpression.getWhenClauses()) {
Expression operand = whenClause.getOperand();
Expression result = whenClause.getResult();
if (isNotTrue(result)) {
builder.add(isFalseOrNullPredicate(operand));
} else {
builder.add(operand);
return Optional.of(combineConjuncts(metadata, builder.build()));
}
}
}
// all results not true, and default true
if (notTrueResultsCount == results.size() && defaultValue.isPresent() && defaultValue.get().equals(TRUE_LITERAL)) {
ImmutableList.Builder<Expression> builder = ImmutableList.builder();
operands.stream().forEach(operand -> builder.add(isFalseOrNullPredicate(operand)));
return Optional.of(combineConjuncts(metadata, builder.build()));
}
// skip clauses with not true conditions
List<WhenClause> whenClauses = new ArrayList<>();
for (WhenClause whenClause : caseExpression.getWhenClauses()) {
Expression operand = whenClause.getOperand();
if (operand.equals(TRUE_LITERAL)) {
if (whenClauses.isEmpty()) {
return Optional.of(whenClause.getResult());
}
return Optional.of(new SearchedCaseExpression(whenClauses, Optional.of(whenClause.getResult())));
}
if (!isNotTrue(operand)) {
whenClauses.add(whenClause);
}
}
if (whenClauses.isEmpty()) {
return Optional.of(defaultValue.orElse(FALSE_LITERAL));
}
if (whenClauses.size() < caseExpression.getWhenClauses().size()) {
return Optional.of(new SearchedCaseExpression(whenClauses, defaultValue));
}
return Optional.empty();
}
if (expression instanceof SimpleCaseExpression) {
SimpleCaseExpression caseExpression = (SimpleCaseExpression) expression;
Optional<Expression> defaultValue = caseExpression.getDefaultValue();
if (caseExpression.getOperand() instanceof NullLiteral) {
return Optional.of(defaultValue.orElse(FALSE_LITERAL));
}
List<Expression> results = caseExpression.getWhenClauses().stream().map(WhenClause::getResult).collect(toImmutableList());
if (results.stream().allMatch(result -> result.equals(TRUE_LITERAL)) && defaultValue.isPresent() && defaultValue.get().equals(TRUE_LITERAL)) {
return Optional.of(TRUE_LITERAL);
}
if (results.stream().allMatch(SimplifyFilterPredicate::isNotTrue) && (defaultValue.isEmpty() || isNotTrue(defaultValue.get()))) {
return Optional.of(FALSE_LITERAL);
}
return Optional.empty();
}
return Optional.empty();
}
use of io.trino.sql.tree.SearchedCaseExpression in project trino by trinodb.
the class TestSqlParser method testArraySubscript.
@Test
public void testArraySubscript() {
assertExpression("ARRAY [1, 2][1]", new SubscriptExpression(new ArrayConstructor(ImmutableList.of(new LongLiteral("1"), new LongLiteral("2"))), new LongLiteral("1")));
assertExpression("CASE WHEN TRUE THEN ARRAY[1,2] END[1]", new SubscriptExpression(new SearchedCaseExpression(ImmutableList.of(new WhenClause(new BooleanLiteral("true"), new ArrayConstructor(ImmutableList.of(new LongLiteral("1"), new LongLiteral("2"))))), Optional.empty()), new LongLiteral("1")));
}
use of io.trino.sql.tree.SearchedCaseExpression in project trino by trinodb.
the class TransformCorrelatedInPredicateToJoin method buildInPredicateEquivalent.
private PlanNode buildInPredicateEquivalent(ApplyNode apply, InPredicate inPredicate, Symbol inPredicateOutputSymbol, Decorrelated decorrelated, PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) {
Expression correlationCondition = and(decorrelated.getCorrelatedPredicates());
PlanNode decorrelatedBuildSource = decorrelated.getDecorrelatedNode();
AssignUniqueId probeSide = new AssignUniqueId(idAllocator.getNextId(), apply.getInput(), symbolAllocator.newSymbol("unique", BIGINT));
Symbol buildSideKnownNonNull = symbolAllocator.newSymbol("buildSideKnownNonNull", BIGINT);
ProjectNode buildSide = new ProjectNode(idAllocator.getNextId(), decorrelatedBuildSource, Assignments.builder().putIdentities(decorrelatedBuildSource.getOutputSymbols()).put(buildSideKnownNonNull, bigint(0)).build());
Symbol probeSideSymbol = Symbol.from(inPredicate.getValue());
Symbol buildSideSymbol = Symbol.from(inPredicate.getValueList());
Expression joinExpression = and(or(new IsNullPredicate(probeSideSymbol.toSymbolReference()), new ComparisonExpression(ComparisonExpression.Operator.EQUAL, probeSideSymbol.toSymbolReference(), buildSideSymbol.toSymbolReference()), new IsNullPredicate(buildSideSymbol.toSymbolReference())), correlationCondition);
JoinNode leftOuterJoin = leftOuterJoin(idAllocator, probeSide, buildSide, joinExpression);
Symbol matchConditionSymbol = symbolAllocator.newSymbol("matchConditionSymbol", BOOLEAN);
Expression matchCondition = and(isNotNull(probeSideSymbol), isNotNull(buildSideSymbol));
Symbol nullMatchConditionSymbol = symbolAllocator.newSymbol("nullMatchConditionSymbol", BOOLEAN);
Expression nullMatchCondition = and(isNotNull(buildSideKnownNonNull), not(matchCondition));
ProjectNode preProjection = new ProjectNode(idAllocator.getNextId(), leftOuterJoin, Assignments.builder().putIdentities(leftOuterJoin.getOutputSymbols()).put(matchConditionSymbol, matchCondition).put(nullMatchConditionSymbol, nullMatchCondition).build());
Symbol countMatchesSymbol = symbolAllocator.newSymbol("countMatches", BIGINT);
Symbol countNullMatchesSymbol = symbolAllocator.newSymbol("countNullMatches", BIGINT);
AggregationNode aggregation = new AggregationNode(idAllocator.getNextId(), preProjection, ImmutableMap.<Symbol, AggregationNode.Aggregation>builder().put(countMatchesSymbol, countWithFilter(session, matchConditionSymbol)).put(countNullMatchesSymbol, countWithFilter(session, nullMatchConditionSymbol)).buildOrThrow(), singleGroupingSet(probeSide.getOutputSymbols()), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty());
// TODO since we care only about "some count > 0", we could have specialized node instead of leftOuterJoin that does the job without materializing join results
SearchedCaseExpression inPredicateEquivalent = new SearchedCaseExpression(ImmutableList.of(new WhenClause(isGreaterThan(countMatchesSymbol, 0), booleanConstant(true)), new WhenClause(isGreaterThan(countNullMatchesSymbol, 0), booleanConstant(null))), Optional.of(booleanConstant(false)));
return new ProjectNode(idAllocator.getNextId(), aggregation, Assignments.builder().putIdentities(apply.getInput().getOutputSymbols()).put(inPredicateOutputSymbol, inPredicateEquivalent).build());
}
use of io.trino.sql.tree.SearchedCaseExpression in project trino by trinodb.
the class TestEqualityInference method testExpressionsThatMayReturnNullOnNonNullInput.
@Test
public void testExpressionsThatMayReturnNullOnNonNullInput() {
List<Expression> candidates = ImmutableList.of(// try_cast
new Cast(nameReference("b"), toSqlType(BIGINT), true), functionResolution.functionCallBuilder(QualifiedName.of(TryFunction.NAME)).addArgument(new FunctionType(ImmutableList.of(), VARCHAR), new LambdaExpression(ImmutableList.of(), nameReference("b"))).build(), new NullIfExpression(nameReference("b"), number(1)), new IfExpression(nameReference("b"), number(1), new NullLiteral()), new InPredicate(nameReference("b"), new InListExpression(ImmutableList.of(new NullLiteral()))), new SearchedCaseExpression(ImmutableList.of(new WhenClause(new IsNotNullPredicate(nameReference("b")), new NullLiteral())), Optional.empty()), new SimpleCaseExpression(nameReference("b"), ImmutableList.of(new WhenClause(number(1), new NullLiteral())), Optional.empty()), new SubscriptExpression(new ArrayConstructor(ImmutableList.of(new NullLiteral())), nameReference("b")));
for (Expression candidate : candidates) {
EqualityInference inference = EqualityInference.newInstance(metadata, equals(nameReference("b"), nameReference("x")), equals(nameReference("a"), candidate));
List<Expression> equalities = inference.generateEqualitiesPartitionedBy(symbols("b")).getScopeStraddlingEqualities();
assertEquals(equalities.size(), 1);
assertTrue(equalities.get(0).equals(equals(nameReference("x"), nameReference("b"))) || equalities.get(0).equals(equals(nameReference("b"), nameReference("x"))));
}
}
Aggregations