use of io.trino.metadata.BoundSignature in project trino by trinodb.
the class ParametricAggregation method findMatchingImplementation.
private AggregationImplementation findMatchingImplementation(BoundSignature boundSignature) {
Signature signature = boundSignature.toSignature();
Optional<AggregationImplementation> foundImplementation = Optional.empty();
if (implementations.getExactImplementations().containsKey(signature)) {
foundImplementation = Optional.of(implementations.getExactImplementations().get(signature));
} else {
for (AggregationImplementation candidate : implementations.getGenericImplementations()) {
if (candidate.areTypesAssignable(boundSignature)) {
if (foundImplementation.isPresent()) {
throw new TrinoException(AMBIGUOUS_FUNCTION_CALL, format("Ambiguous function call (%s) for %s", boundSignature, getFunctionMetadata().getSignature()));
}
foundImplementation = Optional.of(candidate);
}
}
}
if (foundImplementation.isEmpty()) {
throw new TrinoException(FUNCTION_IMPLEMENTATION_MISSING, format("Unsupported type parameters (%s) for %s", boundSignature, getFunctionMetadata().getSignature()));
}
return foundImplementation.get();
}
use of io.trino.metadata.BoundSignature in project trino by trinodb.
the class ParametricScalarImplementation method specialize.
public Optional<ScalarFunctionImplementation> specialize(FunctionBinding functionBinding, FunctionDependencies functionDependencies) {
List<ScalarImplementationChoice> implementationChoices = new ArrayList<>();
for (Map.Entry<String, Class<?>> entry : specializedTypeParameters.entrySet()) {
if (!entry.getValue().isAssignableFrom(functionBinding.getTypeVariable(entry.getKey()).getJavaType())) {
return Optional.empty();
}
}
BoundSignature boundSignature = functionBinding.getBoundSignature();
if (returnNativeContainerType != Object.class && returnNativeContainerType != boundSignature.getReturnType().getJavaType()) {
return Optional.empty();
}
for (int i = 0; i < boundSignature.getArgumentTypes().size(); i++) {
if (boundSignature.getArgumentTypes().get(i) instanceof FunctionType) {
if (argumentNativeContainerTypes.get(i).isPresent()) {
return Optional.empty();
}
} else {
if (argumentNativeContainerTypes.get(i).isEmpty()) {
return Optional.empty();
}
Class<?> argumentType = boundSignature.getArgumentTypes().get(i).getJavaType();
Class<?> argumentNativeContainerType = argumentNativeContainerTypes.get(i).get();
if (argumentNativeContainerType != Object.class && argumentNativeContainerType != argumentType) {
return Optional.empty();
}
}
}
for (ParametricScalarImplementationChoice choice : choices) {
MethodHandle boundMethodHandle = bindDependencies(choice.getMethodHandle(), choice.getDependencies(), functionBinding, functionDependencies);
Optional<MethodHandle> boundConstructor = choice.getConstructor().map(constructor -> {
MethodHandle result = bindDependencies(constructor, choice.getConstructorDependencies(), functionBinding, functionDependencies);
checkCondition(result.type().parameterList().isEmpty(), FUNCTION_IMPLEMENTATION_ERROR, "All parameters of a constructor in a function definition class must be Dependencies. Signature: %s", boundSignature);
return result;
});
implementationChoices.add(new ScalarImplementationChoice(choice.getReturnConvention(), choice.getArgumentConventions(), choice.getLambdaInterfaces(), boundMethodHandle.asType(javaMethodType(choice, boundSignature)), boundConstructor));
}
return Optional.of(new ChoicesScalarFunctionImplementation(boundSignature, implementationChoices));
}
use of io.trino.metadata.BoundSignature in project trino by trinodb.
the class AggregationFunctionAdapter method normalizeInputMethod.
public static MethodHandle normalizeInputMethod(MethodHandle inputMethod, BoundSignature boundSignature, List<AggregationParameterKind> parameterKinds, int lambdaCount) {
requireNonNull(inputMethod, "inputMethod is null");
requireNonNull(parameterKinds, "parameterKinds is null");
requireNonNull(boundSignature, "boundSignature is null");
checkArgument(inputMethod.type().parameterCount() - lambdaCount == parameterKinds.size(), "Input method has %s parameters, but parameter kinds only has %s items", inputMethod.type().parameterCount() - lambdaCount, parameterKinds.size());
List<AggregationParameterKind> stateArgumentKinds = parameterKinds.stream().filter(STATE::equals).collect(toImmutableList());
List<AggregationParameterKind> inputArgumentKinds = parameterKinds.stream().filter(kind -> kind == INPUT_CHANNEL || kind == BLOCK_INPUT_CHANNEL || kind == NULLABLE_BLOCK_INPUT_CHANNEL).collect(toImmutableList());
boolean hasInputChannel = parameterKinds.stream().anyMatch(kind -> kind == BLOCK_INPUT_CHANNEL || kind == NULLABLE_BLOCK_INPUT_CHANNEL);
checkArgument(boundSignature.getArgumentTypes().size() - lambdaCount == inputArgumentKinds.size(), "Bound signature has %s arguments, but parameter kinds only has %s input arguments", boundSignature.getArgumentTypes().size() - lambdaCount, inputArgumentKinds.size());
List<AggregationParameterKind> expectedInputArgumentKinds = new ArrayList<>();
expectedInputArgumentKinds.addAll(stateArgumentKinds);
expectedInputArgumentKinds.addAll(inputArgumentKinds);
if (hasInputChannel) {
expectedInputArgumentKinds.add(BLOCK_INDEX);
}
checkArgument(expectedInputArgumentKinds.equals(parameterKinds), "Expected input parameter kinds %s, but got %s", expectedInputArgumentKinds, parameterKinds);
MethodType inputMethodType = inputMethod.type();
for (int argumentIndex = 0; argumentIndex < inputArgumentKinds.size(); argumentIndex++) {
int parameterIndex = stateArgumentKinds.size() + argumentIndex;
AggregationParameterKind inputArgument = inputArgumentKinds.get(argumentIndex);
if (inputArgument != INPUT_CHANNEL) {
continue;
}
Type argumentType = boundSignature.getArgumentType(argumentIndex);
// process argument through type value getter
MethodHandle valueGetter;
if (argumentType.getJavaType().equals(boolean.class)) {
valueGetter = BOOLEAN_TYPE_GETTER.bindTo(argumentType);
} else if (argumentType.getJavaType().equals(long.class)) {
valueGetter = LONG_TYPE_GETTER.bindTo(argumentType);
} else if (argumentType.getJavaType().equals(double.class)) {
valueGetter = DOUBLE_TYPE_GETTER.bindTo(argumentType);
} else {
valueGetter = OBJECT_TYPE_GETTER.bindTo(argumentType);
valueGetter = valueGetter.asType(valueGetter.type().changeReturnType(inputMethodType.parameterType(parameterIndex)));
}
inputMethod = collectArguments(inputMethod, parameterIndex, valueGetter);
// move the position argument to the end (and combine with other existing position argument)
inputMethodType = inputMethodType.changeParameterType(parameterIndex, Block.class);
ArrayList<Integer> reorder;
if (hasInputChannel) {
reorder = IntStream.range(0, inputMethodType.parameterCount()).boxed().collect(Collectors.toCollection(ArrayList::new));
reorder.add(parameterIndex + 1, inputMethodType.parameterCount() - 1 - lambdaCount);
} else {
inputMethodType = inputMethodType.insertParameterTypes(inputMethodType.parameterCount() - lambdaCount, int.class);
reorder = IntStream.range(0, inputMethodType.parameterCount()).boxed().collect(Collectors.toCollection(ArrayList::new));
int positionParameterIndex = inputMethodType.parameterCount() - 1 - lambdaCount;
reorder.remove(positionParameterIndex);
reorder.add(parameterIndex + 1, positionParameterIndex);
hasInputChannel = true;
}
inputMethod = permuteArguments(inputMethod, inputMethodType, reorder.stream().mapToInt(Integer::intValue).toArray());
}
return inputMethod;
}
use of io.trino.metadata.BoundSignature in project trino by trinodb.
the class TestAnnotationEngineForAggregates method testNotDecomposableAggregationParse.
@Test
public void testNotDecomposableAggregationParse() {
Signature expectedSignature = new Signature("custom_decomposable_aggregate", DoubleType.DOUBLE.getTypeSignature(), ImmutableList.of(DoubleType.DOUBLE.getTypeSignature()));
ParametricAggregation aggregation = getOnlyElement(parseFunctionDefinitions(NotDecomposableAggregationFunction.class));
assertEquals(aggregation.getFunctionMetadata().getDescription(), "Aggregate with Decomposable=false");
assertTrue(aggregation.getFunctionMetadata().isDeterministic());
assertEquals(aggregation.getFunctionMetadata().getSignature(), expectedSignature);
BoundSignature boundSignature = new BoundSignature(aggregation.getFunctionMetadata().getSignature().getName(), DoubleType.DOUBLE, ImmutableList.of(DoubleType.DOUBLE));
AggregationFunctionMetadata aggregationMetadata = aggregation.getAggregationMetadata();
assertFalse(aggregationMetadata.isOrderSensitive());
assertTrue(aggregationMetadata.getIntermediateTypes().isEmpty());
aggregation.specialize(boundSignature, NO_FUNCTION_DEPENDENCIES);
}
use of io.trino.metadata.BoundSignature in project trino by trinodb.
the class TestAnnotationEngineForAggregates method testSimpleExplicitSpecializedAggregationParse.
// TODO this is not yet supported
@Test(enabled = false)
public void testSimpleExplicitSpecializedAggregationParse() {
Signature expectedSignature = new Signature("explicit_specialized_aggregate", ImmutableList.of(typeVariable("T")), ImmutableList.of(), new TypeSignature("T"), ImmutableList.of(new TypeSignature(ARRAY, TypeSignatureParameter.typeParameter(new TypeSignature("T")))), false);
ParametricAggregation aggregation = getOnlyElement(parseFunctionDefinitions(ExplicitSpecializedAggregationFunction.class));
assertEquals(aggregation.getFunctionMetadata().getDescription(), "Simple explicit specialized aggregate");
assertTrue(aggregation.getFunctionMetadata().isDeterministic());
assertEquals(aggregation.getFunctionMetadata().getSignature(), expectedSignature);
ParametricImplementationsGroup<AggregationImplementation> implementations = aggregation.getImplementations();
assertImplementationCount(implementations, 0, 1, 1);
AggregationImplementation implementation1 = implementations.getSpecializedImplementations().get(0);
assertTrue(implementation1.hasSpecializedTypeParameters());
assertFalse(implementation1.hasSpecializedTypeParameters());
assertEquals(implementation1.getInputParameterKinds(), ImmutableList.of(STATE, INPUT_CHANNEL));
AggregationImplementation implementation2 = implementations.getSpecializedImplementations().get(1);
assertTrue(implementation2.hasSpecializedTypeParameters());
assertFalse(implementation2.hasSpecializedTypeParameters());
assertEquals(implementation2.getInputParameterKinds(), ImmutableList.of(STATE, INPUT_CHANNEL));
BoundSignature boundSignature = new BoundSignature(aggregation.getFunctionMetadata().getSignature().getName(), DoubleType.DOUBLE, ImmutableList.of(new ArrayType(DoubleType.DOUBLE)));
AggregationFunctionMetadata aggregationMetadata = aggregation.getAggregationMetadata();
assertFalse(aggregationMetadata.isOrderSensitive());
assertFalse(aggregationMetadata.getIntermediateTypes().isEmpty());
aggregation.specialize(boundSignature, NO_FUNCTION_DEPENDENCIES);
}
Aggregations