use of io.trino.spi.function.InvocationConvention in project trino by trinodb.
the class ChoicesScalarFunctionImplementation method getScalarFunctionInvoker.
@Override
public FunctionInvoker getScalarFunctionInvoker(InvocationConvention invocationConvention) {
List<ScalarImplementationChoice> choices = new ArrayList<>();
for (ScalarImplementationChoice choice : this.choices) {
InvocationConvention callingConvention = choice.getInvocationConvention();
if (functionAdapter.canAdapt(callingConvention, invocationConvention)) {
choices.add(choice);
}
}
if (choices.isEmpty()) {
throw new TrinoException(FUNCTION_NOT_FOUND, format("Function implementation for (%s) cannot be adapted to convention (%s)", boundSignature, invocationConvention));
}
ScalarImplementationChoice bestChoice = Collections.max(choices, comparingInt(ScalarImplementationChoice::getScore));
MethodHandle methodHandle = functionAdapter.adapt(bestChoice.getMethodHandle(), boundSignature.getArgumentTypes(), bestChoice.getInvocationConvention(), invocationConvention);
return new FunctionInvoker(methodHandle, bestChoice.getInstanceFactory(), bestChoice.getLambdaInterfaces());
}
use of io.trino.spi.function.InvocationConvention in project trino by trinodb.
the class TryCastFunction method specialize.
@Override
public ScalarFunctionImplementation specialize(BoundSignature boundSignature, FunctionDependencies functionDependencies) {
Type fromType = boundSignature.getArgumentType(0);
Type toType = boundSignature.getReturnType();
Class<?> returnType = wrap(toType.getJavaType());
// the resulting method needs to return a boxed type
InvocationConvention invocationConvention = new InvocationConvention(ImmutableList.of(NEVER_NULL), NULLABLE_RETURN, true, false);
MethodHandle coercion = functionDependencies.getCastInvoker(fromType, toType, invocationConvention).getMethodHandle();
coercion = coercion.asType(methodType(returnType, coercion.type()));
MethodHandle exceptionHandler = dropArguments(constant(returnType, null), 0, RuntimeException.class);
MethodHandle tryCastHandle = catchException(coercion, RuntimeException.class, exceptionHandler);
return new ChoicesScalarFunctionImplementation(boundSignature, NULLABLE_RETURN, ImmutableList.of(NEVER_NULL), tryCastHandle);
}
use of io.trino.spi.function.InvocationConvention in project trino by trinodb.
the class BytecodeUtils method generateFullInvocation.
private static BytecodeNode generateFullInvocation(Scope scope, String functionName, FunctionNullability functionNullability, List<Boolean> argumentIsFunctionType, Function<InvocationConvention, FunctionInvoker> functionInvokerProvider, Function<MethodHandle, BytecodeNode> instanceFactory, List<Function<Optional<Class<?>>, BytecodeNode>> argumentCompilers, CallSiteBinder binder) {
verify(argumentIsFunctionType.size() == argumentCompilers.size());
List<InvocationArgumentConvention> argumentConventions = new ArrayList<>();
List<BytecodeNode> arguments = new ArrayList<>();
for (int i = 0; i < argumentIsFunctionType.size(); i++) {
if (argumentIsFunctionType.get(i)) {
argumentConventions.add(FUNCTION);
arguments.add(null);
} else {
BytecodeNode argument = argumentCompilers.get(i).apply(Optional.empty());
argumentConventions.add(getPreferredArgumentConvention(argument, argumentCompilers.size(), functionNullability.isArgumentNullable(i)));
arguments.add(argument);
}
}
InvocationConvention invocationConvention = new InvocationConvention(argumentConventions, functionNullability.isReturnNullable() ? NULLABLE_RETURN : FAIL_ON_NULL, true, true);
FunctionInvoker functionInvoker = functionInvokerProvider.apply(invocationConvention);
Binding binding = binder.bind(functionInvoker.getMethodHandle());
LabelNode end = new LabelNode("end");
BytecodeBlock block = new BytecodeBlock().setDescription("invoke " + functionName);
Optional<BytecodeNode> instance = functionInvoker.getInstanceFactory().map(instanceFactory);
// Index of current parameter in the MethodHandle
int currentParameterIndex = 0;
// Index of parameter (without @IsNull) in Trino function
int realParameterIndex = 0;
// Index of function argument types
int lambdaArgumentIndex = 0;
MethodType methodType = binding.getType();
Class<?> returnType = methodType.returnType();
Class<?> unboxedReturnType = Primitives.unwrap(returnType);
List<Class<?>> stackTypes = new ArrayList<>();
boolean instanceIsBound = false;
while (currentParameterIndex < methodType.parameterArray().length) {
Class<?> type = methodType.parameterArray()[currentParameterIndex];
stackTypes.add(type);
if (instance.isPresent() && !instanceIsBound) {
checkState(type.equals(functionInvoker.getInstanceFactory().get().type().returnType()), "Mismatched type for instance parameter");
block.append(instance.get());
instanceIsBound = true;
} else if (type == ConnectorSession.class) {
block.append(scope.getVariable("session"));
} else {
switch(invocationConvention.getArgumentConvention(realParameterIndex)) {
case NEVER_NULL:
block.append(arguments.get(realParameterIndex));
checkArgument(!Primitives.isWrapperType(type), "Non-nullable argument must not be primitive wrapper type");
block.append(ifWasNullPopAndGoto(scope, end, unboxedReturnType, Lists.reverse(stackTypes)));
break;
case NULL_FLAG:
block.append(arguments.get(realParameterIndex));
block.append(scope.getVariable("wasNull"));
block.append(scope.getVariable("wasNull").set(constantFalse()));
stackTypes.add(boolean.class);
currentParameterIndex++;
break;
case BOXED_NULLABLE:
block.append(arguments.get(realParameterIndex));
block.append(boxPrimitiveIfNecessary(scope, type));
block.append(scope.getVariable("wasNull").set(constantFalse()));
break;
case BLOCK_POSITION:
InputReferenceNode inputReferenceNode = (InputReferenceNode) arguments.get(realParameterIndex);
block.append(inputReferenceNode.produceBlockAndPosition());
stackTypes.add(int.class);
if (!functionNullability.isArgumentNullable(realParameterIndex)) {
block.append(scope.getVariable("wasNull").set(inputReferenceNode.blockAndPositionIsNull()));
block.append(ifWasNullPopAndGoto(scope, end, unboxedReturnType, Lists.reverse(stackTypes)));
}
currentParameterIndex++;
break;
case FUNCTION:
Class<?> lambdaInterface = functionInvoker.getLambdaInterfaces().get(lambdaArgumentIndex);
block.append(argumentCompilers.get(realParameterIndex).apply(Optional.of(lambdaInterface)));
lambdaArgumentIndex++;
break;
default:
throw new UnsupportedOperationException(format("Unsupported argument conventsion type: %s", invocationConvention.getArgumentConvention(realParameterIndex)));
}
realParameterIndex++;
}
currentParameterIndex++;
}
block.append(invoke(binding, functionName));
if (functionNullability.isReturnNullable()) {
block.append(unboxPrimitiveIfNecessary(scope, returnType));
}
block.visitLabel(end);
return block;
}
use of io.trino.spi.function.InvocationConvention in project trino by trinodb.
the class MapToMapCast method buildProcessor.
/**
* The signature of the returned MethodHandle is (Block fromMap, int position, ConnectorSession session, BlockBuilder mapBlockBuilder)void.
* The processor will get the value from fromMap, cast it and write to toBlock.
*/
private MethodHandle buildProcessor(FunctionDependencies functionDependencies, Type fromType, Type toType, boolean isKey) {
// Get block position cast, with optional connector session
FunctionNullability functionNullability = functionDependencies.getCastNullability(fromType, toType);
InvocationConvention invocationConvention = new InvocationConvention(ImmutableList.of(BLOCK_POSITION), functionNullability.isReturnNullable() ? NULLABLE_RETURN : FAIL_ON_NULL, true, false);
MethodHandle cast = functionDependencies.getCastInvoker(fromType, toType, invocationConvention).getMethodHandle();
// Normalize cast to have connector session as first argument
if (cast.type().parameterArray()[0] != ConnectorSession.class) {
cast = MethodHandles.dropArguments(cast, 0, ConnectorSession.class);
}
// Change cast signature to (Block.class, int.class, ConnectorSession.class):T
cast = permuteArguments(cast, methodType(cast.type().returnType(), Block.class, int.class, ConnectorSession.class), 2, 0, 1);
// If the key cast function is nullable, check the result is not null
if (isKey && functionNullability.isReturnNullable()) {
cast = compose(nullChecker(cast.type().returnType()), cast);
}
// get write method with signature: (T, BlockBuilder.class):void
MethodHandle writer = nativeValueWriter(toType);
writer = permuteArguments(writer, methodType(void.class, writer.type().parameterArray()[1], BlockBuilder.class), 1, 0);
// ensure cast returns type expected by the writer
cast = cast.asType(methodType(writer.type().parameterType(0), cast.type().parameterArray()));
return compose(writer, cast);
}
use of io.trino.spi.function.InvocationConvention in project trino by trinodb.
the class TestPolymorphicScalarFunction method testSelectsMultipleChoiceWithBlockPosition.
@Test
public void testSelectsMultipleChoiceWithBlockPosition() throws Throwable {
Signature signature = Signature.builder().operatorType(IS_DISTINCT_FROM).argumentTypes(DECIMAL_SIGNATURE, DECIMAL_SIGNATURE).returnType(BOOLEAN.getTypeSignature()).build();
SqlScalarFunction function = new PolymorphicScalarFunctionBuilder(TestMethods.class).signature(signature).argumentNullability(true, true).deterministic(true).choice(choice -> choice.argumentProperties(NULL_FLAG, NULL_FLAG).implementation(methodsGroup -> methodsGroup.methods("shortShort", "longLong"))).choice(choice -> choice.argumentProperties(BLOCK_POSITION, BLOCK_POSITION).implementation(methodsGroup -> methodsGroup.methodWithExplicitJavaTypes("blockPositionLongLong", asList(Optional.of(Int128.class), Optional.of(Int128.class))).methodWithExplicitJavaTypes("blockPositionShortShort", asList(Optional.of(long.class), Optional.of(long.class))))).build();
BoundSignature shortDecimalBoundSignature = new BoundSignature(signature.getName(), BOOLEAN, ImmutableList.of(SHORT_DECIMAL_BOUND_TYPE, SHORT_DECIMAL_BOUND_TYPE));
ChoicesScalarFunctionImplementation functionImplementation = (ChoicesScalarFunctionImplementation) function.specialize(shortDecimalBoundSignature, new FunctionDependencies(FUNCTION_MANAGER::getScalarFunctionInvoker, ImmutableMap.of(), ImmutableSet.of()));
assertEquals(functionImplementation.getChoices().size(), 2);
assertEquals(functionImplementation.getChoices().get(0).getInvocationConvention(), new InvocationConvention(ImmutableList.of(NULL_FLAG, NULL_FLAG), FAIL_ON_NULL, false, false));
assertEquals(functionImplementation.getChoices().get(1).getInvocationConvention(), new InvocationConvention(ImmutableList.of(BLOCK_POSITION, BLOCK_POSITION), FAIL_ON_NULL, false, false));
Block block1 = new LongArrayBlock(0, Optional.empty(), new long[0]);
Block block2 = new LongArrayBlock(0, Optional.empty(), new long[0]);
assertFalse((boolean) functionImplementation.getChoices().get(1).getMethodHandle().invoke(block1, 0, block2, 0));
BoundSignature longDecimalBoundSignature = new BoundSignature(signature.getName(), BOOLEAN, ImmutableList.of(LONG_DECIMAL_BOUND_TYPE, LONG_DECIMAL_BOUND_TYPE));
functionImplementation = (ChoicesScalarFunctionImplementation) function.specialize(longDecimalBoundSignature, new FunctionDependencies(FUNCTION_MANAGER::getScalarFunctionInvoker, ImmutableMap.of(), ImmutableSet.of()));
assertTrue((boolean) functionImplementation.getChoices().get(1).getMethodHandle().invoke(block1, 0, block2, 0));
}
Aggregations