Search in sources :

Example 1 with InvocationConvention

use of io.trino.spi.function.InvocationConvention in project trino by trinodb.

the class ChoicesScalarFunctionImplementation method getScalarFunctionInvoker.

@Override
public FunctionInvoker getScalarFunctionInvoker(InvocationConvention invocationConvention) {
    List<ScalarImplementationChoice> choices = new ArrayList<>();
    for (ScalarImplementationChoice choice : this.choices) {
        InvocationConvention callingConvention = choice.getInvocationConvention();
        if (functionAdapter.canAdapt(callingConvention, invocationConvention)) {
            choices.add(choice);
        }
    }
    if (choices.isEmpty()) {
        throw new TrinoException(FUNCTION_NOT_FOUND, format("Function implementation for (%s) cannot be adapted to convention (%s)", boundSignature, invocationConvention));
    }
    ScalarImplementationChoice bestChoice = Collections.max(choices, comparingInt(ScalarImplementationChoice::getScore));
    MethodHandle methodHandle = functionAdapter.adapt(bestChoice.getMethodHandle(), boundSignature.getArgumentTypes(), bestChoice.getInvocationConvention(), invocationConvention);
    return new FunctionInvoker(methodHandle, bestChoice.getInstanceFactory(), bestChoice.getLambdaInterfaces());
}
Also used : ArrayList(java.util.ArrayList) InvocationConvention(io.trino.spi.function.InvocationConvention) TrinoException(io.trino.spi.TrinoException) FunctionInvoker(io.trino.metadata.FunctionInvoker) MethodHandle(java.lang.invoke.MethodHandle)

Example 2 with InvocationConvention

use of io.trino.spi.function.InvocationConvention in project trino by trinodb.

the class TryCastFunction method specialize.

@Override
public ScalarFunctionImplementation specialize(BoundSignature boundSignature, FunctionDependencies functionDependencies) {
    Type fromType = boundSignature.getArgumentType(0);
    Type toType = boundSignature.getReturnType();
    Class<?> returnType = wrap(toType.getJavaType());
    // the resulting method needs to return a boxed type
    InvocationConvention invocationConvention = new InvocationConvention(ImmutableList.of(NEVER_NULL), NULLABLE_RETURN, true, false);
    MethodHandle coercion = functionDependencies.getCastInvoker(fromType, toType, invocationConvention).getMethodHandle();
    coercion = coercion.asType(methodType(returnType, coercion.type()));
    MethodHandle exceptionHandler = dropArguments(constant(returnType, null), 0, RuntimeException.class);
    MethodHandle tryCastHandle = catchException(coercion, RuntimeException.class, exceptionHandler);
    return new ChoicesScalarFunctionImplementation(boundSignature, NULLABLE_RETURN, ImmutableList.of(NEVER_NULL), tryCastHandle);
}
Also used : MethodType.methodType(java.lang.invoke.MethodType.methodType) Type(io.trino.spi.type.Type) InvocationConvention(io.trino.spi.function.InvocationConvention) MethodHandle(java.lang.invoke.MethodHandle)

Example 3 with InvocationConvention

use of io.trino.spi.function.InvocationConvention in project trino by trinodb.

the class BytecodeUtils method generateFullInvocation.

private static BytecodeNode generateFullInvocation(Scope scope, String functionName, FunctionNullability functionNullability, List<Boolean> argumentIsFunctionType, Function<InvocationConvention, FunctionInvoker> functionInvokerProvider, Function<MethodHandle, BytecodeNode> instanceFactory, List<Function<Optional<Class<?>>, BytecodeNode>> argumentCompilers, CallSiteBinder binder) {
    verify(argumentIsFunctionType.size() == argumentCompilers.size());
    List<InvocationArgumentConvention> argumentConventions = new ArrayList<>();
    List<BytecodeNode> arguments = new ArrayList<>();
    for (int i = 0; i < argumentIsFunctionType.size(); i++) {
        if (argumentIsFunctionType.get(i)) {
            argumentConventions.add(FUNCTION);
            arguments.add(null);
        } else {
            BytecodeNode argument = argumentCompilers.get(i).apply(Optional.empty());
            argumentConventions.add(getPreferredArgumentConvention(argument, argumentCompilers.size(), functionNullability.isArgumentNullable(i)));
            arguments.add(argument);
        }
    }
    InvocationConvention invocationConvention = new InvocationConvention(argumentConventions, functionNullability.isReturnNullable() ? NULLABLE_RETURN : FAIL_ON_NULL, true, true);
    FunctionInvoker functionInvoker = functionInvokerProvider.apply(invocationConvention);
    Binding binding = binder.bind(functionInvoker.getMethodHandle());
    LabelNode end = new LabelNode("end");
    BytecodeBlock block = new BytecodeBlock().setDescription("invoke " + functionName);
    Optional<BytecodeNode> instance = functionInvoker.getInstanceFactory().map(instanceFactory);
    // Index of current parameter in the MethodHandle
    int currentParameterIndex = 0;
    // Index of parameter (without @IsNull) in Trino function
    int realParameterIndex = 0;
    // Index of function argument types
    int lambdaArgumentIndex = 0;
    MethodType methodType = binding.getType();
    Class<?> returnType = methodType.returnType();
    Class<?> unboxedReturnType = Primitives.unwrap(returnType);
    List<Class<?>> stackTypes = new ArrayList<>();
    boolean instanceIsBound = false;
    while (currentParameterIndex < methodType.parameterArray().length) {
        Class<?> type = methodType.parameterArray()[currentParameterIndex];
        stackTypes.add(type);
        if (instance.isPresent() && !instanceIsBound) {
            checkState(type.equals(functionInvoker.getInstanceFactory().get().type().returnType()), "Mismatched type for instance parameter");
            block.append(instance.get());
            instanceIsBound = true;
        } else if (type == ConnectorSession.class) {
            block.append(scope.getVariable("session"));
        } else {
            switch(invocationConvention.getArgumentConvention(realParameterIndex)) {
                case NEVER_NULL:
                    block.append(arguments.get(realParameterIndex));
                    checkArgument(!Primitives.isWrapperType(type), "Non-nullable argument must not be primitive wrapper type");
                    block.append(ifWasNullPopAndGoto(scope, end, unboxedReturnType, Lists.reverse(stackTypes)));
                    break;
                case NULL_FLAG:
                    block.append(arguments.get(realParameterIndex));
                    block.append(scope.getVariable("wasNull"));
                    block.append(scope.getVariable("wasNull").set(constantFalse()));
                    stackTypes.add(boolean.class);
                    currentParameterIndex++;
                    break;
                case BOXED_NULLABLE:
                    block.append(arguments.get(realParameterIndex));
                    block.append(boxPrimitiveIfNecessary(scope, type));
                    block.append(scope.getVariable("wasNull").set(constantFalse()));
                    break;
                case BLOCK_POSITION:
                    InputReferenceNode inputReferenceNode = (InputReferenceNode) arguments.get(realParameterIndex);
                    block.append(inputReferenceNode.produceBlockAndPosition());
                    stackTypes.add(int.class);
                    if (!functionNullability.isArgumentNullable(realParameterIndex)) {
                        block.append(scope.getVariable("wasNull").set(inputReferenceNode.blockAndPositionIsNull()));
                        block.append(ifWasNullPopAndGoto(scope, end, unboxedReturnType, Lists.reverse(stackTypes)));
                    }
                    currentParameterIndex++;
                    break;
                case FUNCTION:
                    Class<?> lambdaInterface = functionInvoker.getLambdaInterfaces().get(lambdaArgumentIndex);
                    block.append(argumentCompilers.get(realParameterIndex).apply(Optional.of(lambdaInterface)));
                    lambdaArgumentIndex++;
                    break;
                default:
                    throw new UnsupportedOperationException(format("Unsupported argument conventsion type: %s", invocationConvention.getArgumentConvention(realParameterIndex)));
            }
            realParameterIndex++;
        }
        currentParameterIndex++;
    }
    block.append(invoke(binding, functionName));
    if (functionNullability.isReturnNullable()) {
        block.append(unboxPrimitiveIfNecessary(scope, returnType));
    }
    block.visitLabel(end);
    return block;
}
Also used : LabelNode(io.airlift.bytecode.instruction.LabelNode) MethodType(java.lang.invoke.MethodType) ArrayList(java.util.ArrayList) BytecodeBlock(io.airlift.bytecode.BytecodeBlock) InputReferenceNode(io.trino.sql.gen.InputReferenceCompiler.InputReferenceNode) InvocationArgumentConvention(io.trino.spi.function.InvocationConvention.InvocationArgumentConvention) InvocationConvention(io.trino.spi.function.InvocationConvention) BytecodeNode(io.airlift.bytecode.BytecodeNode) FunctionInvoker(io.trino.metadata.FunctionInvoker) ConnectorSession(io.trino.spi.connector.ConnectorSession)

Example 4 with InvocationConvention

use of io.trino.spi.function.InvocationConvention in project trino by trinodb.

the class MapToMapCast method buildProcessor.

/**
 * The signature of the returned MethodHandle is (Block fromMap, int position, ConnectorSession session, BlockBuilder mapBlockBuilder)void.
 * The processor will get the value from fromMap, cast it and write to toBlock.
 */
private MethodHandle buildProcessor(FunctionDependencies functionDependencies, Type fromType, Type toType, boolean isKey) {
    // Get block position cast, with optional connector session
    FunctionNullability functionNullability = functionDependencies.getCastNullability(fromType, toType);
    InvocationConvention invocationConvention = new InvocationConvention(ImmutableList.of(BLOCK_POSITION), functionNullability.isReturnNullable() ? NULLABLE_RETURN : FAIL_ON_NULL, true, false);
    MethodHandle cast = functionDependencies.getCastInvoker(fromType, toType, invocationConvention).getMethodHandle();
    // Normalize cast to have connector session as first argument
    if (cast.type().parameterArray()[0] != ConnectorSession.class) {
        cast = MethodHandles.dropArguments(cast, 0, ConnectorSession.class);
    }
    // Change cast signature to (Block.class, int.class, ConnectorSession.class):T
    cast = permuteArguments(cast, methodType(cast.type().returnType(), Block.class, int.class, ConnectorSession.class), 2, 0, 1);
    // If the key cast function is nullable, check the result is not null
    if (isKey && functionNullability.isReturnNullable()) {
        cast = compose(nullChecker(cast.type().returnType()), cast);
    }
    // get write method with signature: (T, BlockBuilder.class):void
    MethodHandle writer = nativeValueWriter(toType);
    writer = permuteArguments(writer, methodType(void.class, writer.type().parameterArray()[1], BlockBuilder.class), 1, 0);
    // ensure cast returns type expected by the writer
    cast = cast.asType(methodType(writer.type().parameterType(0), cast.type().parameterArray()));
    return compose(writer, cast);
}
Also used : InvocationConvention(io.trino.spi.function.InvocationConvention) ConnectorSession(io.trino.spi.connector.ConnectorSession) FunctionNullability(io.trino.metadata.FunctionNullability) MethodHandle(java.lang.invoke.MethodHandle)

Example 5 with InvocationConvention

use of io.trino.spi.function.InvocationConvention in project trino by trinodb.

the class TestPolymorphicScalarFunction method testSelectsMultipleChoiceWithBlockPosition.

@Test
public void testSelectsMultipleChoiceWithBlockPosition() throws Throwable {
    Signature signature = Signature.builder().operatorType(IS_DISTINCT_FROM).argumentTypes(DECIMAL_SIGNATURE, DECIMAL_SIGNATURE).returnType(BOOLEAN.getTypeSignature()).build();
    SqlScalarFunction function = new PolymorphicScalarFunctionBuilder(TestMethods.class).signature(signature).argumentNullability(true, true).deterministic(true).choice(choice -> choice.argumentProperties(NULL_FLAG, NULL_FLAG).implementation(methodsGroup -> methodsGroup.methods("shortShort", "longLong"))).choice(choice -> choice.argumentProperties(BLOCK_POSITION, BLOCK_POSITION).implementation(methodsGroup -> methodsGroup.methodWithExplicitJavaTypes("blockPositionLongLong", asList(Optional.of(Int128.class), Optional.of(Int128.class))).methodWithExplicitJavaTypes("blockPositionShortShort", asList(Optional.of(long.class), Optional.of(long.class))))).build();
    BoundSignature shortDecimalBoundSignature = new BoundSignature(signature.getName(), BOOLEAN, ImmutableList.of(SHORT_DECIMAL_BOUND_TYPE, SHORT_DECIMAL_BOUND_TYPE));
    ChoicesScalarFunctionImplementation functionImplementation = (ChoicesScalarFunctionImplementation) function.specialize(shortDecimalBoundSignature, new FunctionDependencies(FUNCTION_MANAGER::getScalarFunctionInvoker, ImmutableMap.of(), ImmutableSet.of()));
    assertEquals(functionImplementation.getChoices().size(), 2);
    assertEquals(functionImplementation.getChoices().get(0).getInvocationConvention(), new InvocationConvention(ImmutableList.of(NULL_FLAG, NULL_FLAG), FAIL_ON_NULL, false, false));
    assertEquals(functionImplementation.getChoices().get(1).getInvocationConvention(), new InvocationConvention(ImmutableList.of(BLOCK_POSITION, BLOCK_POSITION), FAIL_ON_NULL, false, false));
    Block block1 = new LongArrayBlock(0, Optional.empty(), new long[0]);
    Block block2 = new LongArrayBlock(0, Optional.empty(), new long[0]);
    assertFalse((boolean) functionImplementation.getChoices().get(1).getMethodHandle().invoke(block1, 0, block2, 0));
    BoundSignature longDecimalBoundSignature = new BoundSignature(signature.getName(), BOOLEAN, ImmutableList.of(LONG_DECIMAL_BOUND_TYPE, LONG_DECIMAL_BOUND_TYPE));
    functionImplementation = (ChoicesScalarFunctionImplementation) function.specialize(longDecimalBoundSignature, new FunctionDependencies(FUNCTION_MANAGER::getScalarFunctionInvoker, ImmutableMap.of(), ImmutableSet.of()));
    assertTrue((boolean) functionImplementation.getChoices().get(1).getMethodHandle().invoke(block1, 0, block2, 0));
}
Also used : VARCHAR_TO_VARCHAR_RETURN_VALUE(io.trino.metadata.TestPolymorphicScalarFunction.TestMethods.VARCHAR_TO_VARCHAR_RETURN_VALUE) BLOCK_POSITION(io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION) Slice(io.airlift.slice.Slice) FAIL_ON_NULL(io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL) MAX_SHORT_PRECISION(io.trino.spi.type.Decimals.MAX_SHORT_PRECISION) TypeSignatureParameter.typeVariable(io.trino.spi.type.TypeSignatureParameter.typeVariable) BOOLEAN(io.trino.spi.type.BooleanType.BOOLEAN) Assert.assertEquals(org.testng.Assert.assertEquals) ChoicesScalarFunctionImplementation(io.trino.operator.scalar.ChoicesScalarFunctionImplementation) Test(org.testng.annotations.Test) IS_DISTINCT_FROM(io.trino.spi.function.OperatorType.IS_DISTINCT_FROM) FunctionManager.createTestingFunctionManager(io.trino.metadata.FunctionManager.createTestingFunctionManager) VARCHAR(io.trino.spi.type.VarcharType.VARCHAR) ImmutableList(com.google.common.collect.ImmutableList) NULL_FLAG(io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NULL_FLAG) Assertions.assertThatThrownBy(org.assertj.core.api.Assertions.assertThatThrownBy) Block(io.trino.spi.block.Block) Arrays.asList(java.util.Arrays.asList) Slices(io.airlift.slice.Slices) LongArrayBlock(io.trino.spi.block.LongArrayBlock) Assert.assertFalse(org.testng.Assert.assertFalse) TypeSignature(io.trino.spi.type.TypeSignature) Int128(io.trino.spi.type.Int128) ImmutableSet(com.google.common.collect.ImmutableSet) ImmutableMap(com.google.common.collect.ImmutableMap) Signature.comparableWithVariadicBound(io.trino.metadata.Signature.comparableWithVariadicBound) VARCHAR_TO_BIGINT_RETURN_VALUE(io.trino.metadata.TestPolymorphicScalarFunction.TestMethods.VARCHAR_TO_BIGINT_RETURN_VALUE) InvocationConvention(io.trino.spi.function.InvocationConvention) ADD(io.trino.spi.function.OperatorType.ADD) BIGINT(io.trino.spi.type.BigintType.BIGINT) Optional(java.util.Optional) Assert.assertTrue(org.testng.Assert.assertTrue) VarcharType.createVarcharType(io.trino.spi.type.VarcharType.createVarcharType) DecimalType(io.trino.spi.type.DecimalType) LongArrayBlock(io.trino.spi.block.LongArrayBlock) TypeSignature(io.trino.spi.type.TypeSignature) InvocationConvention(io.trino.spi.function.InvocationConvention) Block(io.trino.spi.block.Block) LongArrayBlock(io.trino.spi.block.LongArrayBlock) ChoicesScalarFunctionImplementation(io.trino.operator.scalar.ChoicesScalarFunctionImplementation) Test(org.testng.annotations.Test)

Aggregations

InvocationConvention (io.trino.spi.function.InvocationConvention)9 MethodHandle (java.lang.invoke.MethodHandle)7 ConnectorSession (io.trino.spi.connector.ConnectorSession)5 TrinoException (io.trino.spi.TrinoException)3 ArrayList (java.util.ArrayList)3 FunctionInvoker (io.trino.metadata.FunctionInvoker)2 Block (io.trino.spi.block.Block)2 InvocationArgumentConvention (io.trino.spi.function.InvocationConvention.InvocationArgumentConvention)2 BLOCK_POSITION (io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION)2 FAIL_ON_NULL (io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL)2 Type (io.trino.spi.type.Type)2 MethodType.methodType (java.lang.invoke.MethodType.methodType)2 Optional (java.util.Optional)2 ImmutableList (com.google.common.collect.ImmutableList)1 ImmutableMap (com.google.common.collect.ImmutableMap)1 ImmutableSet (com.google.common.collect.ImmutableSet)1 BytecodeBlock (io.airlift.bytecode.BytecodeBlock)1 BytecodeNode (io.airlift.bytecode.BytecodeNode)1 LabelNode (io.airlift.bytecode.instruction.LabelNode)1 Slice (io.airlift.slice.Slice)1