Search in sources :

Example 1 with WhenClause

use of io.prestosql.sql.tree.WhenClause in project hetu-core by openlookeng.

the class TestEqualityInference method testExpressionsThatMayReturnNullOnNonNullInput.

@Test
public void testExpressionsThatMayReturnNullOnNonNullInput() {
    List<Expression> candidates = ImmutableList.of(// try_cast
    new Cast(nameReference("b"), "BIGINT", true), new FunctionCallBuilder(metadata).setName(QualifiedName.of(TryFunction.NAME)).addArgument(new FunctionType(ImmutableList.of(), VARCHAR), new LambdaExpression(ImmutableList.of(), nameReference("b"))).build(), new NullIfExpression(nameReference("b"), number(1)), new IfExpression(nameReference("b"), number(1), new NullLiteral()), new DereferenceExpression(nameReference("b"), identifier("x")), new InPredicate(nameReference("b"), new InListExpression(ImmutableList.of(new NullLiteral()))), new SearchedCaseExpression(ImmutableList.of(new WhenClause(new IsNotNullPredicate(nameReference("b")), new NullLiteral())), Optional.empty()), new SimpleCaseExpression(nameReference("b"), ImmutableList.of(new WhenClause(number(1), new NullLiteral())), Optional.empty()), new SubscriptExpression(new ArrayConstructor(ImmutableList.of(new NullLiteral())), nameReference("b")));
    for (Expression candidate : candidates) {
        EqualityInference.Builder builder = new EqualityInference.Builder();
        builder.extractInferenceCandidates(equals(nameReference("b"), nameReference("x")));
        builder.extractInferenceCandidates(equals(nameReference("a"), candidate));
        EqualityInference inference = builder.build();
        List<Expression> equalities = inference.generateEqualitiesPartitionedBy(matchesSymbols("b")).getScopeStraddlingEqualities();
        assertEquals(equalities.size(), 1);
        assertTrue(equalities.get(0).equals(equals(nameReference("x"), nameReference("b"))) || equalities.get(0).equals(equals(nameReference("b"), nameReference("x"))));
    }
}
Also used : Cast(io.prestosql.sql.tree.Cast) NullIfExpression(io.prestosql.sql.tree.NullIfExpression) IfExpression(io.prestosql.sql.tree.IfExpression) DereferenceExpression(io.prestosql.sql.tree.DereferenceExpression) FunctionType(io.prestosql.spi.type.FunctionType) InListExpression(io.prestosql.sql.tree.InListExpression) InPredicate(io.prestosql.sql.tree.InPredicate) IsNotNullPredicate(io.prestosql.sql.tree.IsNotNullPredicate) SimpleCaseExpression(io.prestosql.sql.tree.SimpleCaseExpression) WhenClause(io.prestosql.sql.tree.WhenClause) SearchedCaseExpression(io.prestosql.sql.tree.SearchedCaseExpression) InListExpression(io.prestosql.sql.tree.InListExpression) ArithmeticBinaryExpression(io.prestosql.sql.tree.ArithmeticBinaryExpression) SearchedCaseExpression(io.prestosql.sql.tree.SearchedCaseExpression) SimpleCaseExpression(io.prestosql.sql.tree.SimpleCaseExpression) SubscriptExpression(io.prestosql.sql.tree.SubscriptExpression) DereferenceExpression(io.prestosql.sql.tree.DereferenceExpression) NullIfExpression(io.prestosql.sql.tree.NullIfExpression) ComparisonExpression(io.prestosql.sql.tree.ComparisonExpression) LambdaExpression(io.prestosql.sql.tree.LambdaExpression) IfExpression(io.prestosql.sql.tree.IfExpression) Expression(io.prestosql.sql.tree.Expression) NullIfExpression(io.prestosql.sql.tree.NullIfExpression) SubscriptExpression(io.prestosql.sql.tree.SubscriptExpression) ArrayConstructor(io.prestosql.sql.tree.ArrayConstructor) LambdaExpression(io.prestosql.sql.tree.LambdaExpression) NullLiteral(io.prestosql.sql.tree.NullLiteral) Test(org.testng.annotations.Test)

Example 2 with WhenClause

use of io.prestosql.sql.tree.WhenClause in project hetu-core by openlookeng.

the class TransformCorrelatedScalarSubquery method apply.

@Override
public Result apply(LateralJoinNode lateralJoinNode, Captures captures, Context context) {
    PlanNode subquery = context.getLookup().resolve(lateralJoinNode.getSubquery());
    if (!searchFrom(subquery, context.getLookup()).where(EnforceSingleRowNode.class::isInstance).recurseOnlyWhen(ProjectNode.class::isInstance).matches()) {
        return Result.empty();
    }
    PlanNode rewrittenSubquery = searchFrom(subquery, context.getLookup()).where(EnforceSingleRowNode.class::isInstance).recurseOnlyWhen(ProjectNode.class::isInstance).removeFirst();
    Range<Long> subqueryCardinality = extractCardinality(rewrittenSubquery, context.getLookup());
    boolean producesAtMostOneRow = Range.closed(0L, 1L).encloses(subqueryCardinality);
    if (producesAtMostOneRow) {
        boolean producesSingleRow = Range.singleton(1L).encloses(subqueryCardinality);
        return Result.ofPlanNode(new LateralJoinNode(context.getIdAllocator().getNextId(), lateralJoinNode.getInput(), rewrittenSubquery, lateralJoinNode.getCorrelation(), producesSingleRow ? lateralJoinNode.getType() : LEFT, lateralJoinNode.getFilter(), lateralJoinNode.getOriginSubquery()));
    }
    Symbol unique = context.getSymbolAllocator().newSymbol("unique", BigintType.BIGINT);
    LateralJoinNode rewrittenLateralJoinNode = new LateralJoinNode(context.getIdAllocator().getNextId(), new AssignUniqueId(context.getIdAllocator().getNextId(), lateralJoinNode.getInput(), unique), rewrittenSubquery, lateralJoinNode.getCorrelation(), LEFT, lateralJoinNode.getFilter(), lateralJoinNode.getOriginSubquery());
    Symbol isDistinct = context.getSymbolAllocator().newSymbol("is_distinct", BooleanType.BOOLEAN);
    MarkDistinctNode markDistinctNode = new MarkDistinctNode(context.getIdAllocator().getNextId(), rewrittenLateralJoinNode, isDistinct, rewrittenLateralJoinNode.getInput().getOutputSymbols(), Optional.empty());
    FilterNode filterNode = new FilterNode(context.getIdAllocator().getNextId(), markDistinctNode, castToRowExpression(new SimpleCaseExpression(toSymbolReference(isDistinct), ImmutableList.of(new WhenClause(TRUE_LITERAL, TRUE_LITERAL)), Optional.of(new Cast(new FunctionCallBuilder(metadata).setName(QualifiedName.of("fail")).addArgument(INTEGER, new LongLiteral(Integer.toString(SUBQUERY_MULTIPLE_ROWS.toErrorCode().getCode()))).addArgument(VARCHAR, new StringLiteral("Scalar sub-query has returned multiple rows")).build(), BOOLEAN)))));
    return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), filterNode, AssignmentUtils.identityAsSymbolReferences((lateralJoinNode.getOutputSymbols()))));
}
Also used : Cast(io.prestosql.sql.tree.Cast) MarkDistinctNode(io.prestosql.spi.plan.MarkDistinctNode) LongLiteral(io.prestosql.sql.tree.LongLiteral) Symbol(io.prestosql.spi.plan.Symbol) FilterNode(io.prestosql.spi.plan.FilterNode) SimpleCaseExpression(io.prestosql.sql.tree.SimpleCaseExpression) WhenClause(io.prestosql.sql.tree.WhenClause) PlanNode(io.prestosql.spi.plan.PlanNode) AssignUniqueId(io.prestosql.sql.planner.plan.AssignUniqueId) LateralJoinNode(io.prestosql.sql.planner.plan.LateralJoinNode) StringLiteral(io.prestosql.sql.tree.StringLiteral) EnforceSingleRowNode(io.prestosql.sql.planner.plan.EnforceSingleRowNode) ProjectNode(io.prestosql.spi.plan.ProjectNode) FunctionCallBuilder(io.prestosql.sql.planner.FunctionCallBuilder)

Example 3 with WhenClause

use of io.prestosql.sql.tree.WhenClause in project hetu-core by openlookeng.

the class TransformCorrelatedInPredicateToJoin method buildInPredicateEquivalent.

private PlanNode buildInPredicateEquivalent(ApplyNode apply, InPredicate inPredicate, Symbol inPredicateOutputSymbol, Decorrelated decorrelated, PlanNodeIdAllocator idAllocator, PlanSymbolAllocator planSymbolAllocator) {
    Expression correlationCondition = and(decorrelated.getCorrelatedPredicates());
    PlanNode decorrelatedBuildSource = decorrelated.getDecorrelatedNode();
    AssignUniqueId probeSide = new AssignUniqueId(idAllocator.getNextId(), apply.getInput(), planSymbolAllocator.newSymbol("unique", BIGINT));
    Symbol buildSideKnownNonNull = planSymbolAllocator.newSymbol("buildSideKnownNonNull", BIGINT);
    ProjectNode buildSide = new ProjectNode(idAllocator.getNextId(), decorrelatedBuildSource, Assignments.builder().putAll(identityAsSymbolReferences(decorrelatedBuildSource.getOutputSymbols())).put(buildSideKnownNonNull, castToRowExpression(bigint(0))).build());
    Symbol probeSideSymbol = SymbolUtils.from(inPredicate.getValue());
    Symbol buildSideSymbol = SymbolUtils.from(inPredicate.getValueList());
    Expression joinExpression = and(or(new IsNullPredicate(toSymbolReference(probeSideSymbol)), new ComparisonExpression(ComparisonExpression.Operator.EQUAL, toSymbolReference(probeSideSymbol), toSymbolReference(buildSideSymbol)), new IsNullPredicate(toSymbolReference(buildSideSymbol))), correlationCondition);
    JoinNode leftOuterJoin = leftOuterJoin(idAllocator, probeSide, buildSide, joinExpression);
    Symbol matchConditionSymbol = planSymbolAllocator.newSymbol("matchConditionSymbol", BOOLEAN);
    Expression matchCondition = and(isNotNull(probeSideSymbol), isNotNull(buildSideSymbol));
    Symbol nullMatchConditionSymbol = planSymbolAllocator.newSymbol("nullMatchConditionSymbol", BOOLEAN);
    Expression nullMatchCondition = and(isNotNull(buildSideKnownNonNull), not(matchCondition));
    ProjectNode preProjection = new ProjectNode(idAllocator.getNextId(), leftOuterJoin, Assignments.builder().putAll(AssignmentUtils.identityAsSymbolReferences(leftOuterJoin.getOutputSymbols())).put(matchConditionSymbol, castToRowExpression(matchCondition)).put(nullMatchConditionSymbol, castToRowExpression(nullMatchCondition)).build());
    Symbol countMatchesSymbol = planSymbolAllocator.newSymbol("countMatches", BIGINT);
    Symbol countNullMatchesSymbol = planSymbolAllocator.newSymbol("countNullMatches", BIGINT);
    AggregationNode aggregation = new AggregationNode(idAllocator.getNextId(), preProjection, ImmutableMap.<Symbol, AggregationNode.Aggregation>builder().put(countMatchesSymbol, countWithFilter(matchConditionSymbol)).put(countNullMatchesSymbol, countWithFilter(nullMatchConditionSymbol)).build(), singleGroupingSet(probeSide.getOutputSymbols()), ImmutableList.of(), AggregationNode.Step.SINGLE, Optional.empty(), Optional.empty(), AggregationNode.AggregationType.HASH, Optional.empty());
    // TODO since we care only about "some count > 0", we could have specialized node instead of leftOuterJoin that does the job without materializing join results
    SearchedCaseExpression inPredicateEquivalent = new SearchedCaseExpression(ImmutableList.of(new WhenClause(isGreaterThan(countMatchesSymbol, 0), booleanConstant(true)), new WhenClause(isGreaterThan(countNullMatchesSymbol, 0), booleanConstant(null))), Optional.of(booleanConstant(false)));
    return new ProjectNode(idAllocator.getNextId(), aggregation, Assignments.builder().putAll(identityAsSymbolReferences(apply.getInput().getOutputSymbols())).put(inPredicateOutputSymbol, castToRowExpression(inPredicateEquivalent)).build());
}
Also used : ComparisonExpression(io.prestosql.sql.tree.ComparisonExpression) WhenClause(io.prestosql.sql.tree.WhenClause) PlanNode(io.prestosql.spi.plan.PlanNode) AssignUniqueId(io.prestosql.sql.planner.plan.AssignUniqueId) SearchedCaseExpression(io.prestosql.sql.tree.SearchedCaseExpression) CallExpression(io.prestosql.spi.relation.CallExpression) OriginalExpressionUtils.castToRowExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToRowExpression) OriginalExpressionUtils.castToExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToExpression) SearchedCaseExpression(io.prestosql.sql.tree.SearchedCaseExpression) NotExpression(io.prestosql.sql.tree.NotExpression) ComparisonExpression(io.prestosql.sql.tree.ComparisonExpression) Expression(io.prestosql.sql.tree.Expression) Symbol(io.prestosql.spi.plan.Symbol) JoinNode(io.prestosql.spi.plan.JoinNode) IsNullPredicate(io.prestosql.sql.tree.IsNullPredicate) ProjectNode(io.prestosql.spi.plan.ProjectNode) AggregationNode(io.prestosql.spi.plan.AggregationNode)

Aggregations

WhenClause (io.prestosql.sql.tree.WhenClause)3 PlanNode (io.prestosql.spi.plan.PlanNode)2 ProjectNode (io.prestosql.spi.plan.ProjectNode)2 Symbol (io.prestosql.spi.plan.Symbol)2 AssignUniqueId (io.prestosql.sql.planner.plan.AssignUniqueId)2 Cast (io.prestosql.sql.tree.Cast)2 ComparisonExpression (io.prestosql.sql.tree.ComparisonExpression)2 Expression (io.prestosql.sql.tree.Expression)2 SearchedCaseExpression (io.prestosql.sql.tree.SearchedCaseExpression)2 SimpleCaseExpression (io.prestosql.sql.tree.SimpleCaseExpression)2 AggregationNode (io.prestosql.spi.plan.AggregationNode)1 FilterNode (io.prestosql.spi.plan.FilterNode)1 JoinNode (io.prestosql.spi.plan.JoinNode)1 MarkDistinctNode (io.prestosql.spi.plan.MarkDistinctNode)1 CallExpression (io.prestosql.spi.relation.CallExpression)1 FunctionType (io.prestosql.spi.type.FunctionType)1 FunctionCallBuilder (io.prestosql.sql.planner.FunctionCallBuilder)1 EnforceSingleRowNode (io.prestosql.sql.planner.plan.EnforceSingleRowNode)1 LateralJoinNode (io.prestosql.sql.planner.plan.LateralJoinNode)1 OriginalExpressionUtils.castToExpression (io.prestosql.sql.relational.OriginalExpressionUtils.castToExpression)1