Search in sources :

Example 1 with CallArgument

use of io.trino.sql.tree.CallArgument in project trino by trinodb.

the class CallTask method execute.

@Override
public ListenableFuture<Void> execute(Call call, QueryStateMachine stateMachine, List<Expression> parameters, WarningCollector warningCollector) {
    if (!transactionManager.isAutoCommit(stateMachine.getSession().getRequiredTransactionId())) {
        throw new TrinoException(NOT_SUPPORTED, "Procedures cannot be called within a transaction (use autocommit mode)");
    }
    Session session = stateMachine.getSession();
    QualifiedObjectName procedureName = createQualifiedObjectName(session, call, call.getName());
    CatalogName catalogName = plannerContext.getMetadata().getCatalogHandle(stateMachine.getSession(), procedureName.getCatalogName()).orElseThrow(() -> semanticException(CATALOG_NOT_FOUND, call, "Catalog '%s' does not exist", procedureName.getCatalogName()));
    Procedure procedure = procedureRegistry.resolve(catalogName, procedureName.asSchemaTableName());
    // map declared argument names to positions
    Map<String, Integer> positions = new HashMap<>();
    for (int i = 0; i < procedure.getArguments().size(); i++) {
        positions.put(procedure.getArguments().get(i).getName(), i);
    }
    // per specification, do not allow mixing argument types
    Predicate<CallArgument> hasName = argument -> argument.getName().isPresent();
    boolean anyNamed = call.getArguments().stream().anyMatch(hasName);
    boolean allNamed = call.getArguments().stream().allMatch(hasName);
    if (anyNamed && !allNamed) {
        throw semanticException(INVALID_ARGUMENTS, call, "Named and positional arguments cannot be mixed");
    }
    // get the argument names in call order
    Map<String, CallArgument> names = new LinkedHashMap<>();
    for (int i = 0; i < call.getArguments().size(); i++) {
        CallArgument argument = call.getArguments().get(i);
        if (argument.getName().isPresent()) {
            String name = argument.getName().get().getCanonicalValue();
            if (names.put(name, argument) != null) {
                throw semanticException(INVALID_ARGUMENTS, argument, "Duplicate procedure argument: %s", name);
            }
            if (!positions.containsKey(name)) {
                throw semanticException(INVALID_ARGUMENTS, argument, "Unknown argument name: %s", name);
            }
        } else if (i < procedure.getArguments().size()) {
            names.put(procedure.getArguments().get(i).getName(), argument);
        } else {
            throw semanticException(INVALID_ARGUMENTS, call, "Too many arguments for procedure");
        }
    }
    procedure.getArguments().stream().filter(Argument::isRequired).filter(argument -> !names.containsKey(argument.getName())).map(Argument::getName).findFirst().ifPresent(argument -> {
        throw semanticException(INVALID_ARGUMENTS, call, "Required procedure argument '%s' is missing", argument);
    });
    // get argument values
    Object[] values = new Object[procedure.getArguments().size()];
    Map<NodeRef<Parameter>, Expression> parameterLookup = parameterExtractor(call, parameters);
    for (Entry<String, CallArgument> entry : names.entrySet()) {
        CallArgument callArgument = entry.getValue();
        int index = positions.get(entry.getKey());
        Argument argument = procedure.getArguments().get(index);
        Expression expression = ExpressionTreeRewriter.rewriteWith(new ParameterRewriter(parameterLookup), callArgument.getValue());
        Type type = argument.getType();
        Object value = evaluateConstantExpression(expression, type, plannerContext, session, accessControl, parameterLookup);
        values[index] = toTypeObjectValue(session, type, value);
    }
    // fill values with optional arguments defaults
    for (int i = 0; i < procedure.getArguments().size(); i++) {
        Argument argument = procedure.getArguments().get(i);
        if (!names.containsKey(argument.getName())) {
            verify(argument.isOptional());
            values[i] = toTypeObjectValue(session, argument.getType(), argument.getDefaultValue());
        }
    }
    // validate arguments
    MethodType methodType = procedure.getMethodHandle().type();
    for (int i = 0; i < procedure.getArguments().size(); i++) {
        if ((values[i] == null) && methodType.parameterType(i).isPrimitive()) {
            String name = procedure.getArguments().get(i).getName();
            throw new TrinoException(INVALID_PROCEDURE_ARGUMENT, "Procedure argument cannot be null: " + name);
        }
    }
    // insert session argument
    List<Object> arguments = new ArrayList<>();
    Iterator<Object> valuesIterator = asList(values).iterator();
    for (Class<?> type : methodType.parameterList()) {
        if (ConnectorSession.class.equals(type)) {
            arguments.add(session.toConnectorSession(catalogName));
        } else if (ConnectorAccessControl.class.equals(type)) {
            arguments.add(new InjectedConnectorAccessControl(accessControl, session.toSecurityContext(), catalogName.getCatalogName()));
        } else {
            arguments.add(valuesIterator.next());
        }
    }
    accessControl.checkCanExecuteProcedure(session.toSecurityContext(), procedureName);
    stateMachine.setRoutines(ImmutableList.of(new RoutineInfo(procedureName.getObjectName(), session.getUser())));
    try {
        procedure.getMethodHandle().invokeWithArguments(arguments);
    } catch (Throwable t) {
        if (t instanceof InterruptedException) {
            Thread.currentThread().interrupt();
        }
        throwIfInstanceOf(t, TrinoException.class);
        throw new TrinoException(PROCEDURE_CALL_FAILED, t);
    }
    return immediateVoidFuture();
}
Also used : InjectedConnectorAccessControl(io.trino.security.InjectedConnectorAccessControl) TransactionManager(io.trino.transaction.TransactionManager) ParameterUtils.parameterExtractor(io.trino.sql.ParameterUtils.parameterExtractor) NOT_SUPPORTED(io.trino.spi.StandardErrorCode.NOT_SUPPORTED) CatalogName(io.trino.connector.CatalogName) Arrays.asList(java.util.Arrays.asList) Map(java.util.Map) SemanticExceptions.semanticException(io.trino.sql.analyzer.SemanticExceptions.semanticException) Argument(io.trino.spi.procedure.Procedure.Argument) Futures.immediateVoidFuture(com.google.common.util.concurrent.Futures.immediateVoidFuture) INVALID_ARGUMENTS(io.trino.spi.StandardErrorCode.INVALID_ARGUMENTS) Predicate(java.util.function.Predicate) ExpressionTreeRewriter(io.trino.sql.tree.ExpressionTreeRewriter) ConnectorAccessControl(io.trino.spi.connector.ConnectorAccessControl) TrinoException(io.trino.spi.TrinoException) TypeUtils.writeNativeValue(io.trino.spi.type.TypeUtils.writeNativeValue) List(java.util.List) AccessControl(io.trino.security.AccessControl) Parameter(io.trino.sql.tree.Parameter) Entry(java.util.Map.Entry) Expression(io.trino.sql.tree.Expression) PROCEDURE_CALL_FAILED(io.trino.spi.StandardErrorCode.PROCEDURE_CALL_FAILED) Session(io.trino.Session) PlannerContext(io.trino.sql.PlannerContext) ListenableFuture(com.google.common.util.concurrent.ListenableFuture) ExpressionInterpreter.evaluateConstantExpression(io.trino.sql.planner.ExpressionInterpreter.evaluateConstantExpression) RoutineInfo(io.trino.spi.eventlistener.RoutineInfo) Type(io.trino.spi.type.Type) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) Inject(javax.inject.Inject) LinkedHashMap(java.util.LinkedHashMap) INVALID_PROCEDURE_ARGUMENT(io.trino.spi.StandardErrorCode.INVALID_PROCEDURE_ARGUMENT) ImmutableList(com.google.common.collect.ImmutableList) Procedure(io.trino.spi.procedure.Procedure) Verify.verify(com.google.common.base.Verify.verify) MetadataUtil.createQualifiedObjectName(io.trino.metadata.MetadataUtil.createQualifiedObjectName) NodeRef(io.trino.sql.tree.NodeRef) Objects.requireNonNull(java.util.Objects.requireNonNull) Iterator(java.util.Iterator) CATALOG_NOT_FOUND(io.trino.spi.StandardErrorCode.CATALOG_NOT_FOUND) ConnectorSession(io.trino.spi.connector.ConnectorSession) Throwables.throwIfInstanceOf(com.google.common.base.Throwables.throwIfInstanceOf) CallArgument(io.trino.sql.tree.CallArgument) Call(io.trino.sql.tree.Call) MethodType(java.lang.invoke.MethodType) QualifiedObjectName(io.trino.metadata.QualifiedObjectName) ProcedureRegistry(io.trino.metadata.ProcedureRegistry) WarningCollector(io.trino.execution.warnings.WarningCollector) BlockBuilder(io.trino.spi.block.BlockBuilder) ParameterRewriter(io.trino.sql.planner.ParameterRewriter) CallArgument(io.trino.sql.tree.CallArgument) Argument(io.trino.spi.procedure.Procedure.Argument) CallArgument(io.trino.sql.tree.CallArgument) ParameterRewriter(io.trino.sql.planner.ParameterRewriter) HashMap(java.util.HashMap) LinkedHashMap(java.util.LinkedHashMap) ArrayList(java.util.ArrayList) LinkedHashMap(java.util.LinkedHashMap) RoutineInfo(io.trino.spi.eventlistener.RoutineInfo) NodeRef(io.trino.sql.tree.NodeRef) Procedure(io.trino.spi.procedure.Procedure) MethodType(java.lang.invoke.MethodType) InjectedConnectorAccessControl(io.trino.security.InjectedConnectorAccessControl) ConnectorAccessControl(io.trino.spi.connector.ConnectorAccessControl) MetadataUtil.createQualifiedObjectName(io.trino.metadata.MetadataUtil.createQualifiedObjectName) QualifiedObjectName(io.trino.metadata.QualifiedObjectName) InjectedConnectorAccessControl(io.trino.security.InjectedConnectorAccessControl) Type(io.trino.spi.type.Type) MethodType(java.lang.invoke.MethodType) Expression(io.trino.sql.tree.Expression) ExpressionInterpreter.evaluateConstantExpression(io.trino.sql.planner.ExpressionInterpreter.evaluateConstantExpression) TrinoException(io.trino.spi.TrinoException) CatalogName(io.trino.connector.CatalogName) Session(io.trino.Session) ConnectorSession(io.trino.spi.connector.ConnectorSession)

Example 2 with CallArgument

use of io.trino.sql.tree.CallArgument in project trino by trinodb.

the class TestSqlParser method testTableExecute.

@Test
public void testTableExecute() {
    Table table = new Table(QualifiedName.of("foo"));
    Identifier procedure = new Identifier("bar");
    assertStatement("ALTER TABLE foo EXECUTE bar", new TableExecute(table, procedure, ImmutableList.of(), Optional.empty()));
    assertStatement("ALTER TABLE foo EXECUTE bar(bah => 1, wuh => 'clap') WHERE age > 17", new TableExecute(table, procedure, ImmutableList.of(new CallArgument(identifier("bah"), new LongLiteral("1")), new CallArgument(identifier("wuh"), new StringLiteral("clap"))), Optional.of(new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN, new Identifier("age"), new LongLiteral("17")))));
    assertStatement("ALTER TABLE foo EXECUTE bar(1, 'clap') WHERE age > 17", new TableExecute(table, procedure, ImmutableList.of(new CallArgument(new LongLiteral("1")), new CallArgument(new StringLiteral("clap"))), Optional.of(new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN, new Identifier("age"), new LongLiteral("17")))));
}
Also used : TableExecute(io.trino.sql.tree.TableExecute) CallArgument(io.trino.sql.tree.CallArgument) QuantifiedComparisonExpression(io.trino.sql.tree.QuantifiedComparisonExpression) ComparisonExpression(io.trino.sql.tree.ComparisonExpression) CreateTable(io.trino.sql.tree.CreateTable) DropTable(io.trino.sql.tree.DropTable) Table(io.trino.sql.tree.Table) TruncateTable(io.trino.sql.tree.TruncateTable) RenameTable(io.trino.sql.tree.RenameTable) QueryUtil.quotedIdentifier(io.trino.sql.QueryUtil.quotedIdentifier) Identifier(io.trino.sql.tree.Identifier) StringLiteral(io.trino.sql.tree.StringLiteral) LongLiteral(io.trino.sql.tree.LongLiteral) Test(org.junit.jupiter.api.Test)

Example 3 with CallArgument

use of io.trino.sql.tree.CallArgument in project trino by trinodb.

the class TestSqlParser method testCall.

@Test
public void testCall() {
    assertStatement("CALL foo()", new Call(QualifiedName.of("foo"), ImmutableList.of()));
    assertStatement("CALL foo(123, a => 1, b => 'go', 456)", new Call(QualifiedName.of("foo"), ImmutableList.of(new CallArgument(new LongLiteral("123")), new CallArgument(identifier("a"), new LongLiteral("1")), new CallArgument(identifier("b"), new StringLiteral("go")), new CallArgument(new LongLiteral("456")))));
}
Also used : FunctionCall(io.trino.sql.tree.FunctionCall) Call(io.trino.sql.tree.Call) CallArgument(io.trino.sql.tree.CallArgument) StringLiteral(io.trino.sql.tree.StringLiteral) LongLiteral(io.trino.sql.tree.LongLiteral) Test(org.junit.jupiter.api.Test)

Aggregations

CallArgument (io.trino.sql.tree.CallArgument)3 Call (io.trino.sql.tree.Call)2 Throwables.throwIfInstanceOf (com.google.common.base.Throwables.throwIfInstanceOf)1 Verify.verify (com.google.common.base.Verify.verify)1 ImmutableList (com.google.common.collect.ImmutableList)1 Futures.immediateVoidFuture (com.google.common.util.concurrent.Futures.immediateVoidFuture)1 ListenableFuture (com.google.common.util.concurrent.ListenableFuture)1 Session (io.trino.Session)1 CatalogName (io.trino.connector.CatalogName)1 WarningCollector (io.trino.execution.warnings.WarningCollector)1 MetadataUtil.createQualifiedObjectName (io.trino.metadata.MetadataUtil.createQualifiedObjectName)1 ProcedureRegistry (io.trino.metadata.ProcedureRegistry)1 QualifiedObjectName (io.trino.metadata.QualifiedObjectName)1 AccessControl (io.trino.security.AccessControl)1 InjectedConnectorAccessControl (io.trino.security.InjectedConnectorAccessControl)1 CATALOG_NOT_FOUND (io.trino.spi.StandardErrorCode.CATALOG_NOT_FOUND)1 INVALID_ARGUMENTS (io.trino.spi.StandardErrorCode.INVALID_ARGUMENTS)1 INVALID_PROCEDURE_ARGUMENT (io.trino.spi.StandardErrorCode.INVALID_PROCEDURE_ARGUMENT)1 NOT_SUPPORTED (io.trino.spi.StandardErrorCode.NOT_SUPPORTED)1 PROCEDURE_CALL_FAILED (io.trino.spi.StandardErrorCode.PROCEDURE_CALL_FAILED)1