Search in sources :

Example 1 with FunctionInvoker

use of io.trino.metadata.FunctionInvoker in project trino by trinodb.

the class ChoicesScalarFunctionImplementation method getScalarFunctionInvoker.

@Override
public FunctionInvoker getScalarFunctionInvoker(InvocationConvention invocationConvention) {
    List<ScalarImplementationChoice> choices = new ArrayList<>();
    for (ScalarImplementationChoice choice : this.choices) {
        InvocationConvention callingConvention = choice.getInvocationConvention();
        if (functionAdapter.canAdapt(callingConvention, invocationConvention)) {
            choices.add(choice);
        }
    }
    if (choices.isEmpty()) {
        throw new TrinoException(FUNCTION_NOT_FOUND, format("Function implementation for (%s) cannot be adapted to convention (%s)", boundSignature, invocationConvention));
    }
    ScalarImplementationChoice bestChoice = Collections.max(choices, comparingInt(ScalarImplementationChoice::getScore));
    MethodHandle methodHandle = functionAdapter.adapt(bestChoice.getMethodHandle(), boundSignature.getArgumentTypes(), bestChoice.getInvocationConvention(), invocationConvention);
    return new FunctionInvoker(methodHandle, bestChoice.getInstanceFactory(), bestChoice.getLambdaInterfaces());
}
Also used : ArrayList(java.util.ArrayList) InvocationConvention(io.trino.spi.function.InvocationConvention) TrinoException(io.trino.spi.TrinoException) FunctionInvoker(io.trino.metadata.FunctionInvoker) MethodHandle(java.lang.invoke.MethodHandle)

Example 2 with FunctionInvoker

use of io.trino.metadata.FunctionInvoker in project trino by trinodb.

the class BytecodeUtils method generateFullInvocation.

private static BytecodeNode generateFullInvocation(Scope scope, String functionName, FunctionNullability functionNullability, List<Boolean> argumentIsFunctionType, Function<InvocationConvention, FunctionInvoker> functionInvokerProvider, Function<MethodHandle, BytecodeNode> instanceFactory, List<Function<Optional<Class<?>>, BytecodeNode>> argumentCompilers, CallSiteBinder binder) {
    verify(argumentIsFunctionType.size() == argumentCompilers.size());
    List<InvocationArgumentConvention> argumentConventions = new ArrayList<>();
    List<BytecodeNode> arguments = new ArrayList<>();
    for (int i = 0; i < argumentIsFunctionType.size(); i++) {
        if (argumentIsFunctionType.get(i)) {
            argumentConventions.add(FUNCTION);
            arguments.add(null);
        } else {
            BytecodeNode argument = argumentCompilers.get(i).apply(Optional.empty());
            argumentConventions.add(getPreferredArgumentConvention(argument, argumentCompilers.size(), functionNullability.isArgumentNullable(i)));
            arguments.add(argument);
        }
    }
    InvocationConvention invocationConvention = new InvocationConvention(argumentConventions, functionNullability.isReturnNullable() ? NULLABLE_RETURN : FAIL_ON_NULL, true, true);
    FunctionInvoker functionInvoker = functionInvokerProvider.apply(invocationConvention);
    Binding binding = binder.bind(functionInvoker.getMethodHandle());
    LabelNode end = new LabelNode("end");
    BytecodeBlock block = new BytecodeBlock().setDescription("invoke " + functionName);
    Optional<BytecodeNode> instance = functionInvoker.getInstanceFactory().map(instanceFactory);
    // Index of current parameter in the MethodHandle
    int currentParameterIndex = 0;
    // Index of parameter (without @IsNull) in Trino function
    int realParameterIndex = 0;
    // Index of function argument types
    int lambdaArgumentIndex = 0;
    MethodType methodType = binding.getType();
    Class<?> returnType = methodType.returnType();
    Class<?> unboxedReturnType = Primitives.unwrap(returnType);
    List<Class<?>> stackTypes = new ArrayList<>();
    boolean instanceIsBound = false;
    while (currentParameterIndex < methodType.parameterArray().length) {
        Class<?> type = methodType.parameterArray()[currentParameterIndex];
        stackTypes.add(type);
        if (instance.isPresent() && !instanceIsBound) {
            checkState(type.equals(functionInvoker.getInstanceFactory().get().type().returnType()), "Mismatched type for instance parameter");
            block.append(instance.get());
            instanceIsBound = true;
        } else if (type == ConnectorSession.class) {
            block.append(scope.getVariable("session"));
        } else {
            switch(invocationConvention.getArgumentConvention(realParameterIndex)) {
                case NEVER_NULL:
                    block.append(arguments.get(realParameterIndex));
                    checkArgument(!Primitives.isWrapperType(type), "Non-nullable argument must not be primitive wrapper type");
                    block.append(ifWasNullPopAndGoto(scope, end, unboxedReturnType, Lists.reverse(stackTypes)));
                    break;
                case NULL_FLAG:
                    block.append(arguments.get(realParameterIndex));
                    block.append(scope.getVariable("wasNull"));
                    block.append(scope.getVariable("wasNull").set(constantFalse()));
                    stackTypes.add(boolean.class);
                    currentParameterIndex++;
                    break;
                case BOXED_NULLABLE:
                    block.append(arguments.get(realParameterIndex));
                    block.append(boxPrimitiveIfNecessary(scope, type));
                    block.append(scope.getVariable("wasNull").set(constantFalse()));
                    break;
                case BLOCK_POSITION:
                    InputReferenceNode inputReferenceNode = (InputReferenceNode) arguments.get(realParameterIndex);
                    block.append(inputReferenceNode.produceBlockAndPosition());
                    stackTypes.add(int.class);
                    if (!functionNullability.isArgumentNullable(realParameterIndex)) {
                        block.append(scope.getVariable("wasNull").set(inputReferenceNode.blockAndPositionIsNull()));
                        block.append(ifWasNullPopAndGoto(scope, end, unboxedReturnType, Lists.reverse(stackTypes)));
                    }
                    currentParameterIndex++;
                    break;
                case FUNCTION:
                    Class<?> lambdaInterface = functionInvoker.getLambdaInterfaces().get(lambdaArgumentIndex);
                    block.append(argumentCompilers.get(realParameterIndex).apply(Optional.of(lambdaInterface)));
                    lambdaArgumentIndex++;
                    break;
                default:
                    throw new UnsupportedOperationException(format("Unsupported argument conventsion type: %s", invocationConvention.getArgumentConvention(realParameterIndex)));
            }
            realParameterIndex++;
        }
        currentParameterIndex++;
    }
    block.append(invoke(binding, functionName));
    if (functionNullability.isReturnNullable()) {
        block.append(unboxPrimitiveIfNecessary(scope, returnType));
    }
    block.visitLabel(end);
    return block;
}
Also used : LabelNode(io.airlift.bytecode.instruction.LabelNode) MethodType(java.lang.invoke.MethodType) ArrayList(java.util.ArrayList) BytecodeBlock(io.airlift.bytecode.BytecodeBlock) InputReferenceNode(io.trino.sql.gen.InputReferenceCompiler.InputReferenceNode) InvocationArgumentConvention(io.trino.spi.function.InvocationConvention.InvocationArgumentConvention) InvocationConvention(io.trino.spi.function.InvocationConvention) BytecodeNode(io.airlift.bytecode.BytecodeNode) FunctionInvoker(io.trino.metadata.FunctionInvoker) ConnectorSession(io.trino.spi.connector.ConnectorSession)

Example 3 with FunctionInvoker

use of io.trino.metadata.FunctionInvoker in project trino by trinodb.

the class InterpretedFunctionInvoker method invoke.

/**
 * Arguments must be the native container type for the corresponding SQL types.
 * <p>
 * Returns a value in the native container type corresponding to the declared SQL return type
 */
public Object invoke(ResolvedFunction function, ConnectorSession session, List<Object> arguments) {
    FunctionInvoker invoker = functionManager.getScalarFunctionInvoker(function, getInvocationConvention(function.getSignature(), function.getFunctionNullability()));
    MethodHandle method = invoker.getMethodHandle();
    List<Object> actualArguments = new ArrayList<>();
    // handle function on instance method, to allow use of fields
    if (invoker.getInstanceFactory().isPresent()) {
        try {
            actualArguments.add(invoker.getInstanceFactory().get().invoke());
        } catch (Throwable throwable) {
            throw propagate(throwable);
        }
    }
    // add session
    if (method.type().parameterCount() > actualArguments.size() && method.type().parameterType(actualArguments.size()) == ConnectorSession.class) {
        actualArguments.add(session);
    }
    int lambdaArgumentIndex = 0;
    for (int i = 0; i < arguments.size(); i++) {
        Object argument = arguments.get(i);
        // if argument is null and function does not handle nulls, result is null
        if (argument == null && !function.getFunctionNullability().isArgumentNullable(i)) {
            return null;
        }
        if (function.getSignature().getArgumentTypes().get(i) instanceof FunctionType) {
            argument = asInterfaceInstance(invoker.getLambdaInterfaces().get(lambdaArgumentIndex), (MethodHandle) argument);
            lambdaArgumentIndex++;
        }
        actualArguments.add(argument);
    }
    try {
        return method.invokeWithArguments(actualArguments);
    } catch (Throwable throwable) {
        throw propagate(throwable);
    }
}
Also used : FunctionType(io.trino.type.FunctionType) ArrayList(java.util.ArrayList) FunctionInvoker(io.trino.metadata.FunctionInvoker) ConnectorSession(io.trino.spi.connector.ConnectorSession) MethodHandle(java.lang.invoke.MethodHandle)

Aggregations

FunctionInvoker (io.trino.metadata.FunctionInvoker)3 ArrayList (java.util.ArrayList)3 ConnectorSession (io.trino.spi.connector.ConnectorSession)2 InvocationConvention (io.trino.spi.function.InvocationConvention)2 MethodHandle (java.lang.invoke.MethodHandle)2 BytecodeBlock (io.airlift.bytecode.BytecodeBlock)1 BytecodeNode (io.airlift.bytecode.BytecodeNode)1 LabelNode (io.airlift.bytecode.instruction.LabelNode)1 TrinoException (io.trino.spi.TrinoException)1 InvocationArgumentConvention (io.trino.spi.function.InvocationConvention.InvocationArgumentConvention)1 InputReferenceNode (io.trino.sql.gen.InputReferenceCompiler.InputReferenceNode)1 FunctionType (io.trino.type.FunctionType)1 MethodType (java.lang.invoke.MethodType)1