Search in sources :

Example 21 with Row

use of io.trino.sql.tree.Row in project trino by trinodb.

the class ReplaceJoinOverConstantWithProject method appendProjection.

private ProjectNode appendProjection(PlanNode source, List<Symbol> sourceOutputs, PlanNode constantSource, List<Symbol> constantOutputs, PlanNodeIdAllocator idAllocator) {
    ValuesNode values = (ValuesNode) constantSource;
    Row row = (Row) getOnlyElement(values.getRows().get());
    Map<Symbol, Expression> mapping = new HashMap<>();
    for (int i = 0; i < values.getOutputSymbols().size(); i++) {
        mapping.put(values.getOutputSymbols().get(i), row.getItems().get(i));
    }
    Assignments.Builder assignments = Assignments.builder().putIdentities(sourceOutputs);
    constantOutputs.stream().forEach(symbol -> assignments.put(symbol, mapping.get(symbol)));
    return new ProjectNode(idAllocator.getNextId(), source, assignments.build());
}
Also used : ValuesNode(io.trino.sql.planner.plan.ValuesNode) Expression(io.trino.sql.tree.Expression) HashMap(java.util.HashMap) Symbol(io.trino.sql.planner.Symbol) Assignments(io.trino.sql.planner.plan.Assignments) ProjectNode(io.trino.sql.planner.plan.ProjectNode) Row(io.trino.sql.tree.Row)

Example 22 with Row

use of io.trino.sql.tree.Row in project trino by trinodb.

the class TestMergeProjectWithValues method testProjectWithoutOutputSymbols.

@Test
public void testProjectWithoutOutputSymbols() {
    // ValuesNode has two output symbols and two rows
    tester().assertThat(new MergeProjectWithValues(tester().getMetadata())).on(p -> p.project(Assignments.of(), p.valuesOfExpressions(ImmutableList.of(p.symbol("a"), p.symbol("b")), ImmutableList.of(new Row(ImmutableList.of(new CharLiteral("x"), new BooleanLiteral("true"))), new Row(ImmutableList.of(new CharLiteral("y"), new BooleanLiteral("false"))))))).matches(values(2));
    // ValuesNode has no output symbols and two rows
    tester().assertThat(new MergeProjectWithValues(tester().getMetadata())).on(p -> p.project(Assignments.of(), p.values(ImmutableList.of(), ImmutableList.of(ImmutableList.of(), ImmutableList.of())))).matches(values(2));
    // ValuesNode has two output symbols and no rows
    tester().assertThat(new MergeProjectWithValues(tester().getMetadata())).on(p -> p.project(Assignments.of(), p.values(ImmutableList.of(p.symbol("a"), p.symbol("b")), ImmutableList.of()))).matches(values());
    // ValuesNode has no output symbols and no rows
    tester().assertThat(new MergeProjectWithValues(tester().getMetadata())).on(p -> p.project(Assignments.of(), p.values(ImmutableList.of(), ImmutableList.of()))).matches(values());
}
Also used : IsNullPredicate(io.trino.sql.tree.IsNullPredicate) TypeSignatureProvider.fromTypes(io.trino.sql.analyzer.TypeSignatureProvider.fromTypes) Test(org.testng.annotations.Test) Cast(io.trino.sql.tree.Cast) VARCHAR(io.trino.spi.type.VarcharType.VARCHAR) ImmutableList(com.google.common.collect.ImmutableList) BooleanLiteral(io.trino.sql.tree.BooleanLiteral) ArithmeticUnaryExpression(io.trino.sql.tree.ArithmeticUnaryExpression) LongLiteral(io.trino.sql.tree.LongLiteral) NullLiteral(io.trino.sql.tree.NullLiteral) FunctionCall(io.trino.sql.tree.FunctionCall) ArithmeticBinaryExpression(io.trino.sql.tree.ArithmeticBinaryExpression) Symbol(io.trino.sql.planner.Symbol) RowType(io.trino.spi.type.RowType) StringLiteral(io.trino.sql.tree.StringLiteral) BaseRuleTest(io.trino.sql.planner.iterative.rule.test.BaseRuleTest) TypeSignatureTranslator.toSqlType(io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType) Assignments(io.trino.sql.planner.plan.Assignments) PlanMatchPattern.values(io.trino.sql.planner.assertions.PlanMatchPattern.values) DoubleLiteral(io.trino.sql.tree.DoubleLiteral) QualifiedName(io.trino.sql.tree.QualifiedName) CharLiteral(io.trino.sql.tree.CharLiteral) ADD(io.trino.sql.tree.ArithmeticBinaryExpression.Operator.ADD) BIGINT(io.trino.spi.type.BigintType.BIGINT) SymbolReference(io.trino.sql.tree.SymbolReference) Row(io.trino.sql.tree.Row) PlanBuilder.expression(io.trino.sql.planner.iterative.rule.test.PlanBuilder.expression) MINUS(io.trino.sql.tree.ArithmeticUnaryExpression.Sign.MINUS) CharLiteral(io.trino.sql.tree.CharLiteral) BooleanLiteral(io.trino.sql.tree.BooleanLiteral) Row(io.trino.sql.tree.Row) Test(org.testng.annotations.Test) BaseRuleTest(io.trino.sql.planner.iterative.rule.test.BaseRuleTest)

Example 23 with Row

use of io.trino.sql.tree.Row in project trino by trinodb.

the class TestMergeProjectWithValues method testCorrelation.

@Test
public void testCorrelation() {
    // correlation symbol in projection (note: the resulting plan is not yet supported in execution)
    tester().assertThat(new MergeProjectWithValues(tester().getMetadata())).on(p -> p.project(Assignments.of(p.symbol("x"), expression("a + corr")), p.valuesOfExpressions(ImmutableList.of(p.symbol("a")), ImmutableList.of(new Row(ImmutableList.of(new LongLiteral("1"))))))).matches(values(ImmutableList.of("x"), ImmutableList.of(ImmutableList.of(new ArithmeticBinaryExpression(ADD, new LongLiteral("1"), new SymbolReference("corr"))))));
    // correlation symbol in values (note: the resulting plan is not yet supported in execution)
    tester().assertThat(new MergeProjectWithValues(tester().getMetadata())).on(p -> p.project(Assignments.of(p.symbol("x"), expression("a")), p.valuesOfExpressions(ImmutableList.of(p.symbol("a")), ImmutableList.of(new Row(ImmutableList.of(new SymbolReference("corr"))))))).matches(values(ImmutableList.of("x"), ImmutableList.of(ImmutableList.of(new SymbolReference("corr")))));
    // correlation symbol is not present in the resulting expression
    tester().assertThat(new MergeProjectWithValues(tester().getMetadata())).on(p -> p.project(Assignments.of(p.symbol("x"), expression("1")), p.valuesOfExpressions(ImmutableList.of(p.symbol("a")), ImmutableList.of(new Row(ImmutableList.of(new SymbolReference("corr"))))))).matches(values(ImmutableList.of("x"), ImmutableList.of(ImmutableList.of(new LongLiteral("1")))));
}
Also used : IsNullPredicate(io.trino.sql.tree.IsNullPredicate) TypeSignatureProvider.fromTypes(io.trino.sql.analyzer.TypeSignatureProvider.fromTypes) Test(org.testng.annotations.Test) Cast(io.trino.sql.tree.Cast) VARCHAR(io.trino.spi.type.VarcharType.VARCHAR) ImmutableList(com.google.common.collect.ImmutableList) BooleanLiteral(io.trino.sql.tree.BooleanLiteral) ArithmeticUnaryExpression(io.trino.sql.tree.ArithmeticUnaryExpression) LongLiteral(io.trino.sql.tree.LongLiteral) NullLiteral(io.trino.sql.tree.NullLiteral) FunctionCall(io.trino.sql.tree.FunctionCall) ArithmeticBinaryExpression(io.trino.sql.tree.ArithmeticBinaryExpression) Symbol(io.trino.sql.planner.Symbol) RowType(io.trino.spi.type.RowType) StringLiteral(io.trino.sql.tree.StringLiteral) BaseRuleTest(io.trino.sql.planner.iterative.rule.test.BaseRuleTest) TypeSignatureTranslator.toSqlType(io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType) Assignments(io.trino.sql.planner.plan.Assignments) PlanMatchPattern.values(io.trino.sql.planner.assertions.PlanMatchPattern.values) DoubleLiteral(io.trino.sql.tree.DoubleLiteral) QualifiedName(io.trino.sql.tree.QualifiedName) CharLiteral(io.trino.sql.tree.CharLiteral) ADD(io.trino.sql.tree.ArithmeticBinaryExpression.Operator.ADD) BIGINT(io.trino.spi.type.BigintType.BIGINT) SymbolReference(io.trino.sql.tree.SymbolReference) Row(io.trino.sql.tree.Row) PlanBuilder.expression(io.trino.sql.planner.iterative.rule.test.PlanBuilder.expression) MINUS(io.trino.sql.tree.ArithmeticUnaryExpression.Sign.MINUS) ArithmeticBinaryExpression(io.trino.sql.tree.ArithmeticBinaryExpression) LongLiteral(io.trino.sql.tree.LongLiteral) SymbolReference(io.trino.sql.tree.SymbolReference) Row(io.trino.sql.tree.Row) Test(org.testng.annotations.Test) BaseRuleTest(io.trino.sql.planner.iterative.rule.test.BaseRuleTest)

Example 24 with Row

use of io.trino.sql.tree.Row in project trino by trinodb.

the class TreePrinter method print.

public void print(Node root) {
    AstVisitor<Void, Integer> printer = new DefaultTraversalVisitor<Integer>() {

        @Override
        protected Void visitNode(Node node, Integer indentLevel) {
            throw new UnsupportedOperationException("not yet implemented: " + node);
        }

        @Override
        protected Void visitQuery(Query node, Integer indentLevel) {
            print(indentLevel, "Query ");
            indentLevel++;
            print(indentLevel, "QueryBody");
            process(node.getQueryBody(), indentLevel);
            if (node.getOrderBy().isPresent()) {
                print(indentLevel, "OrderBy");
                process(node.getOrderBy().get(), indentLevel + 1);
            }
            if (node.getLimit().isPresent()) {
                print(indentLevel, "Limit: " + node.getLimit().get());
            }
            return null;
        }

        @Override
        protected Void visitQuerySpecification(QuerySpecification node, Integer indentLevel) {
            print(indentLevel, "QuerySpecification ");
            indentLevel++;
            process(node.getSelect(), indentLevel);
            if (node.getFrom().isPresent()) {
                print(indentLevel, "From");
                process(node.getFrom().get(), indentLevel + 1);
            }
            if (node.getWhere().isPresent()) {
                print(indentLevel, "Where");
                process(node.getWhere().get(), indentLevel + 1);
            }
            if (node.getGroupBy().isPresent()) {
                String distinct = "";
                if (node.getGroupBy().get().isDistinct()) {
                    distinct = "[DISTINCT]";
                }
                print(indentLevel, "GroupBy" + distinct);
                for (GroupingElement groupingElement : node.getGroupBy().get().getGroupingElements()) {
                    print(indentLevel, "SimpleGroupBy");
                    if (groupingElement instanceof SimpleGroupBy) {
                        for (Expression column : groupingElement.getExpressions()) {
                            process(column, indentLevel + 1);
                        }
                    } else if (groupingElement instanceof GroupingSets) {
                        print(indentLevel + 1, "GroupingSets");
                        for (List<Expression> set : ((GroupingSets) groupingElement).getSets()) {
                            print(indentLevel + 2, "GroupingSet[");
                            for (Expression expression : set) {
                                process(expression, indentLevel + 3);
                            }
                            print(indentLevel + 2, "]");
                        }
                    } else if (groupingElement instanceof Cube) {
                        print(indentLevel + 1, "Cube");
                        for (Expression column : groupingElement.getExpressions()) {
                            process(column, indentLevel + 1);
                        }
                    } else if (groupingElement instanceof Rollup) {
                        print(indentLevel + 1, "Rollup");
                        for (Expression column : groupingElement.getExpressions()) {
                            process(column, indentLevel + 1);
                        }
                    }
                }
            }
            if (node.getHaving().isPresent()) {
                print(indentLevel, "Having");
                process(node.getHaving().get(), indentLevel + 1);
            }
            if (!node.getWindows().isEmpty()) {
                print(indentLevel, "Window");
                for (WindowDefinition windowDefinition : node.getWindows()) {
                    process(windowDefinition, indentLevel + 1);
                }
            }
            if (node.getOrderBy().isPresent()) {
                print(indentLevel, "OrderBy");
                process(node.getOrderBy().get(), indentLevel + 1);
            }
            if (node.getLimit().isPresent()) {
                print(indentLevel, "Limit: " + node.getLimit().get());
            }
            return null;
        }

        @Override
        protected Void visitOrderBy(OrderBy node, Integer indentLevel) {
            for (SortItem sortItem : node.getSortItems()) {
                process(sortItem, indentLevel);
            }
            return null;
        }

        @Override
        protected Void visitWindowDefinition(WindowDefinition node, Integer indentLevel) {
            print(indentLevel, "WindowDefinition[" + node.getName() + "]");
            process(node.getWindow(), indentLevel + 1);
            return null;
        }

        @Override
        protected Void visitWindowReference(WindowReference node, Integer indentLevel) {
            print(indentLevel, "WindowReference[" + node.getName() + "]");
            return null;
        }

        @Override
        public Void visitWindowSpecification(WindowSpecification node, Integer indentLevel) {
            if (node.getExistingWindowName().isPresent()) {
                print(indentLevel, "ExistingWindowName " + node.getExistingWindowName().get());
            }
            if (!node.getPartitionBy().isEmpty()) {
                print(indentLevel, "PartitionBy");
                for (Expression expression : node.getPartitionBy()) {
                    process(expression, indentLevel + 1);
                }
            }
            if (node.getOrderBy().isPresent()) {
                print(indentLevel, "OrderBy");
                process(node.getOrderBy().get(), indentLevel + 1);
            }
            if (node.getFrame().isPresent()) {
                print(indentLevel, "Frame");
                process(node.getFrame().get(), indentLevel + 1);
            }
            return null;
        }

        @Override
        protected Void visitSelect(Select node, Integer indentLevel) {
            String distinct = "";
            if (node.isDistinct()) {
                distinct = "[DISTINCT]";
            }
            print(indentLevel, "Select" + distinct);
            // visit children
            super.visitSelect(node, indentLevel + 1);
            return null;
        }

        @Override
        protected Void visitAllColumns(AllColumns node, Integer indent) {
            StringBuilder aliases = new StringBuilder();
            if (!node.getAliases().isEmpty()) {
                aliases.append(" [Aliases: ");
                Joiner.on(", ").appendTo(aliases, node.getAliases());
                aliases.append("]");
            }
            print(indent, "All columns" + aliases.toString());
            if (node.getTarget().isPresent()) {
                // visit child
                super.visitAllColumns(node, indent + 1);
            }
            return null;
        }

        @Override
        protected Void visitSingleColumn(SingleColumn node, Integer indent) {
            if (node.getAlias().isPresent()) {
                print(indent, "Alias: " + node.getAlias().get());
            }
            // visit children
            super.visitSingleColumn(node, indent + 1);
            return null;
        }

        @Override
        protected Void visitComparisonExpression(ComparisonExpression node, Integer indentLevel) {
            print(indentLevel, node.getOperator().toString());
            super.visitComparisonExpression(node, indentLevel + 1);
            return null;
        }

        @Override
        protected Void visitArithmeticBinary(ArithmeticBinaryExpression node, Integer indentLevel) {
            print(indentLevel, node.getOperator().toString());
            super.visitArithmeticBinary(node, indentLevel + 1);
            return null;
        }

        @Override
        protected Void visitLogicalExpression(LogicalExpression node, Integer indentLevel) {
            print(indentLevel, node.getOperator().toString());
            super.visitLogicalExpression(node, indentLevel + 1);
            return null;
        }

        @Override
        protected Void visitStringLiteral(StringLiteral node, Integer indentLevel) {
            print(indentLevel, "String[" + node.getValue() + "]");
            return null;
        }

        @Override
        protected Void visitBinaryLiteral(BinaryLiteral node, Integer indentLevel) {
            print(indentLevel, "Binary[" + node.toHexString() + "]");
            return null;
        }

        @Override
        protected Void visitBooleanLiteral(BooleanLiteral node, Integer indentLevel) {
            print(indentLevel, "Boolean[" + node.getValue() + "]");
            return null;
        }

        @Override
        protected Void visitLongLiteral(LongLiteral node, Integer indentLevel) {
            print(indentLevel, "Long[" + node.getValue() + "]");
            return null;
        }

        @Override
        protected Void visitLikePredicate(LikePredicate node, Integer indentLevel) {
            print(indentLevel, "LIKE");
            super.visitLikePredicate(node, indentLevel + 1);
            return null;
        }

        @Override
        protected Void visitIdentifier(Identifier node, Integer indentLevel) {
            QualifiedName resolved = resolvedNameReferences.get(node);
            String resolvedName = "";
            if (resolved != null) {
                resolvedName = "=>" + resolved.toString();
            }
            print(indentLevel, "Identifier[" + node.getValue() + resolvedName + "]");
            return null;
        }

        @Override
        protected Void visitDereferenceExpression(DereferenceExpression node, Integer indentLevel) {
            QualifiedName resolved = resolvedNameReferences.get(node);
            String resolvedName = "";
            if (resolved != null) {
                resolvedName = "=>" + resolved.toString();
            }
            print(indentLevel, "DereferenceExpression[" + node + resolvedName + "]");
            return null;
        }

        @Override
        protected Void visitFunctionCall(FunctionCall node, Integer indentLevel) {
            String name = Joiner.on('.').join(node.getName().getParts());
            print(indentLevel, "FunctionCall[" + name + "]");
            super.visitFunctionCall(node, indentLevel + 1);
            return null;
        }

        @Override
        protected Void visitTable(Table node, Integer indentLevel) {
            String name = Joiner.on('.').join(node.getName().getParts());
            print(indentLevel, "Table[" + name + "]");
            return null;
        }

        @Override
        protected Void visitValues(Values node, Integer indentLevel) {
            print(indentLevel, "Values");
            super.visitValues(node, indentLevel + 1);
            return null;
        }

        @Override
        protected Void visitRow(Row node, Integer indentLevel) {
            print(indentLevel, "Row");
            super.visitRow(node, indentLevel + 1);
            return null;
        }

        @Override
        protected Void visitAliasedRelation(AliasedRelation node, Integer indentLevel) {
            print(indentLevel, "Alias[" + node.getAlias() + "]");
            super.visitAliasedRelation(node, indentLevel + 1);
            return null;
        }

        @Override
        protected Void visitSampledRelation(SampledRelation node, Integer indentLevel) {
            print(indentLevel, "TABLESAMPLE[" + node.getType() + " (" + node.getSamplePercentage() + ")]");
            super.visitSampledRelation(node, indentLevel + 1);
            return null;
        }

        @Override
        protected Void visitTableSubquery(TableSubquery node, Integer indentLevel) {
            print(indentLevel, "SubQuery");
            super.visitTableSubquery(node, indentLevel + 1);
            return null;
        }

        @Override
        protected Void visitInPredicate(InPredicate node, Integer indentLevel) {
            print(indentLevel, "IN");
            super.visitInPredicate(node, indentLevel + 1);
            return null;
        }

        @Override
        protected Void visitSubqueryExpression(SubqueryExpression node, Integer indentLevel) {
            print(indentLevel, "SubQuery");
            super.visitSubqueryExpression(node, indentLevel + 1);
            return null;
        }
    };
    printer.process(root, 0);
}
Also used : ArithmeticBinaryExpression(io.trino.sql.tree.ArithmeticBinaryExpression) SimpleGroupBy(io.trino.sql.tree.SimpleGroupBy) Query(io.trino.sql.tree.Query) Rollup(io.trino.sql.tree.Rollup) BooleanLiteral(io.trino.sql.tree.BooleanLiteral) Node(io.trino.sql.tree.Node) Values(io.trino.sql.tree.Values) AllColumns(io.trino.sql.tree.AllColumns) WindowReference(io.trino.sql.tree.WindowReference) SubqueryExpression(io.trino.sql.tree.SubqueryExpression) QuerySpecification(io.trino.sql.tree.QuerySpecification) LogicalExpression(io.trino.sql.tree.LogicalExpression) SortItem(io.trino.sql.tree.SortItem) Identifier(io.trino.sql.tree.Identifier) List(java.util.List) FunctionCall(io.trino.sql.tree.FunctionCall) SampledRelation(io.trino.sql.tree.SampledRelation) GroupingSets(io.trino.sql.tree.GroupingSets) OrderBy(io.trino.sql.tree.OrderBy) DereferenceExpression(io.trino.sql.tree.DereferenceExpression) Table(io.trino.sql.tree.Table) LongLiteral(io.trino.sql.tree.LongLiteral) QualifiedName(io.trino.sql.tree.QualifiedName) WindowSpecification(io.trino.sql.tree.WindowSpecification) DefaultTraversalVisitor(io.trino.sql.tree.DefaultTraversalVisitor) SingleColumn(io.trino.sql.tree.SingleColumn) LikePredicate(io.trino.sql.tree.LikePredicate) TableSubquery(io.trino.sql.tree.TableSubquery) InPredicate(io.trino.sql.tree.InPredicate) GroupingElement(io.trino.sql.tree.GroupingElement) ComparisonExpression(io.trino.sql.tree.ComparisonExpression) BinaryLiteral(io.trino.sql.tree.BinaryLiteral) StringLiteral(io.trino.sql.tree.StringLiteral) SubqueryExpression(io.trino.sql.tree.SubqueryExpression) ArithmeticBinaryExpression(io.trino.sql.tree.ArithmeticBinaryExpression) ComparisonExpression(io.trino.sql.tree.ComparisonExpression) DereferenceExpression(io.trino.sql.tree.DereferenceExpression) LogicalExpression(io.trino.sql.tree.LogicalExpression) Expression(io.trino.sql.tree.Expression) Cube(io.trino.sql.tree.Cube) Select(io.trino.sql.tree.Select) Row(io.trino.sql.tree.Row) WindowDefinition(io.trino.sql.tree.WindowDefinition) AliasedRelation(io.trino.sql.tree.AliasedRelation)

Example 25 with Row

use of io.trino.sql.tree.Row in project trino by trinodb.

the class TestEffectivePredicateExtractor method testValues.

@Test
public void testValues() {
    TypeProvider types = TypeProvider.copyOf(ImmutableMap.<Symbol, Type>builder().put(A, BIGINT).put(B, BIGINT).put(D, DOUBLE).put(R, RowType.anonymous(ImmutableList.of(BIGINT, BIGINT))).buildOrThrow());
    // one column
    assertEquals(effectivePredicateExtractor.extract(SESSION, new ValuesNode(newId(), ImmutableList.of(A), ImmutableList.of(new Row(ImmutableList.of(bigintLiteral(1))), new Row(ImmutableList.of(bigintLiteral(2))))), types, typeAnalyzer), new InPredicate(AE, new InListExpression(ImmutableList.of(bigintLiteral(1), bigintLiteral(2)))));
    // one column with null
    assertEquals(effectivePredicateExtractor.extract(SESSION, new ValuesNode(newId(), ImmutableList.of(A), ImmutableList.of(new Row(ImmutableList.of(bigintLiteral(1))), new Row(ImmutableList.of(bigintLiteral(2))), new Row(ImmutableList.of(new Cast(new NullLiteral(), toSqlType(BIGINT)))))), types, typeAnalyzer), or(new InPredicate(AE, new InListExpression(ImmutableList.of(bigintLiteral(1), bigintLiteral(2)))), new IsNullPredicate(AE)));
    // all nulls
    assertEquals(effectivePredicateExtractor.extract(SESSION, new ValuesNode(newId(), ImmutableList.of(A), ImmutableList.of(new Row(ImmutableList.of(new Cast(new NullLiteral(), toSqlType(BIGINT)))))), types, typeAnalyzer), new IsNullPredicate(AE));
    // nested row
    assertEquals(effectivePredicateExtractor.extract(SESSION, new ValuesNode(newId(), ImmutableList.of(R), ImmutableList.of(new Row(ImmutableList.of(new Row(ImmutableList.of(bigintLiteral(1), new NullLiteral())))))), types, typeAnalyzer), TRUE_LITERAL);
    // many rows
    List<Expression> rows = IntStream.range(0, 500).mapToObj(TestEffectivePredicateExtractor::bigintLiteral).map(ImmutableList::of).map(Row::new).collect(toImmutableList());
    assertEquals(effectivePredicateExtractor.extract(SESSION, new ValuesNode(newId(), ImmutableList.of(A), rows), types, typeAnalyzer), new BetweenPredicate(AE, bigintLiteral(0), bigintLiteral(499)));
    // NaN
    assertEquals(effectivePredicateExtractor.extract(SESSION, new ValuesNode(newId(), ImmutableList.of(D), ImmutableList.of(new Row(ImmutableList.of(doubleLiteral(Double.NaN))))), types, typeAnalyzer), new NotExpression(new IsNullPredicate(DE)));
    // NaN and NULL
    assertEquals(effectivePredicateExtractor.extract(SESSION, new ValuesNode(newId(), ImmutableList.of(D), ImmutableList.of(new Row(ImmutableList.of(new Cast(new NullLiteral(), toSqlType(DOUBLE)))), new Row(ImmutableList.of(doubleLiteral(Double.NaN))))), types, typeAnalyzer), TRUE_LITERAL);
    // NaN and value
    assertEquals(effectivePredicateExtractor.extract(SESSION, new ValuesNode(newId(), ImmutableList.of(D), ImmutableList.of(new Row(ImmutableList.of(doubleLiteral(42.))), new Row(ImmutableList.of(doubleLiteral(Double.NaN))))), types, typeAnalyzer), new NotExpression(new IsNullPredicate(DE)));
    // Real NaN
    assertEquals(effectivePredicateExtractor.extract(SESSION, new ValuesNode(newId(), ImmutableList.of(D), ImmutableList.of(new Row(ImmutableList.of(new Cast(doubleLiteral(Double.NaN), toSqlType(REAL)))))), TypeProvider.copyOf(ImmutableMap.of(D, REAL)), typeAnalyzer), new NotExpression(new IsNullPredicate(DE)));
    // multiple columns
    assertEquals(effectivePredicateExtractor.extract(SESSION, new ValuesNode(newId(), ImmutableList.of(A, B), ImmutableList.of(new Row(ImmutableList.of(bigintLiteral(1), bigintLiteral(100))), new Row(ImmutableList.of(bigintLiteral(2), bigintLiteral(200))))), types, typeAnalyzer), and(new InPredicate(AE, new InListExpression(ImmutableList.of(bigintLiteral(1), bigintLiteral(2)))), new InPredicate(BE, new InListExpression(ImmutableList.of(bigintLiteral(100), bigintLiteral(200))))));
    // multiple columns with null
    assertEquals(effectivePredicateExtractor.extract(SESSION, new ValuesNode(newId(), ImmutableList.of(A, B), ImmutableList.of(new Row(ImmutableList.of(bigintLiteral(1), new Cast(new NullLiteral(), toSqlType(BIGINT)))), new Row(ImmutableList.of(new Cast(new NullLiteral(), toSqlType(BIGINT)), bigintLiteral(200))))), types, typeAnalyzer), and(or(new ComparisonExpression(EQUAL, AE, bigintLiteral(1)), new IsNullPredicate(AE)), or(new ComparisonExpression(EQUAL, BE, bigintLiteral(200)), new IsNullPredicate(BE))));
    // non-deterministic
    ResolvedFunction rand = functionResolution.resolveFunction(QualifiedName.of("rand"), ImmutableList.of());
    ValuesNode node = new ValuesNode(newId(), ImmutableList.of(A, B), ImmutableList.of(new Row(ImmutableList.of(bigintLiteral(1), new FunctionCall(rand.toQualifiedName(), ImmutableList.of())))));
    assertEquals(extract(types, node), new ComparisonExpression(EQUAL, AE, bigintLiteral(1)));
    // non-constant
    assertEquals(effectivePredicateExtractor.extract(SESSION, new ValuesNode(newId(), ImmutableList.of(A), ImmutableList.of(new Row(ImmutableList.of(bigintLiteral(1))), new Row(ImmutableList.of(BE)))), types, typeAnalyzer), TRUE_LITERAL);
}
Also used : Cast(io.trino.sql.tree.Cast) ValuesNode(io.trino.sql.planner.plan.ValuesNode) BetweenPredicate(io.trino.sql.tree.BetweenPredicate) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) ImmutableList(com.google.common.collect.ImmutableList) ResolvedFunction(io.trino.metadata.ResolvedFunction) InListExpression(io.trino.sql.tree.InListExpression) NotExpression(io.trino.sql.tree.NotExpression) InPredicate(io.trino.sql.tree.InPredicate) ComparisonExpression(io.trino.sql.tree.ComparisonExpression) RowType(io.trino.spi.type.RowType) TypeSignatureTranslator.toSqlType(io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType) Type(io.trino.spi.type.Type) InListExpression(io.trino.sql.tree.InListExpression) NotExpression(io.trino.sql.tree.NotExpression) ComparisonExpression(io.trino.sql.tree.ComparisonExpression) Expression(io.trino.sql.tree.Expression) IsNullPredicate(io.trino.sql.tree.IsNullPredicate) Row(io.trino.sql.tree.Row) FunctionCall(io.trino.sql.tree.FunctionCall) NullLiteral(io.trino.sql.tree.NullLiteral) Test(org.testng.annotations.Test)

Aggregations

Row (io.trino.sql.tree.Row)26 ImmutableList (com.google.common.collect.ImmutableList)15 Cast (io.trino.sql.tree.Cast)15 TypeSignatureTranslator.toSqlType (io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType)14 ValuesNode (io.trino.sql.planner.plan.ValuesNode)13 Symbol (io.trino.sql.planner.Symbol)12 Expression (io.trino.sql.tree.Expression)11 FunctionCall (io.trino.sql.tree.FunctionCall)11 LongLiteral (io.trino.sql.tree.LongLiteral)11 Assignments (io.trino.sql.planner.plan.Assignments)10 NullLiteral (io.trino.sql.tree.NullLiteral)10 QualifiedName (io.trino.sql.tree.QualifiedName)10 StringLiteral (io.trino.sql.tree.StringLiteral)10 Test (org.testng.annotations.Test)10 RowType (io.trino.spi.type.RowType)9 ImmutableList.toImmutableList (com.google.common.collect.ImmutableList.toImmutableList)8 BIGINT (io.trino.spi.type.BigintType.BIGINT)8 VARCHAR (io.trino.spi.type.VarcharType.VARCHAR)8 Type (io.trino.spi.type.Type)7 TypeSignatureProvider.fromTypes (io.trino.sql.analyzer.TypeSignatureProvider.fromTypes)7