use of io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.NULLABLE_BLOCK_INPUT_CHANNEL in project trino by trinodb.
the class AggregationFunctionAdapter method normalizeInputMethod.
public static MethodHandle normalizeInputMethod(MethodHandle inputMethod, BoundSignature boundSignature, List<AggregationParameterKind> parameterKinds, int lambdaCount) {
requireNonNull(inputMethod, "inputMethod is null");
requireNonNull(parameterKinds, "parameterKinds is null");
requireNonNull(boundSignature, "boundSignature is null");
checkArgument(inputMethod.type().parameterCount() - lambdaCount == parameterKinds.size(), "Input method has %s parameters, but parameter kinds only has %s items", inputMethod.type().parameterCount() - lambdaCount, parameterKinds.size());
List<AggregationParameterKind> stateArgumentKinds = parameterKinds.stream().filter(STATE::equals).collect(toImmutableList());
List<AggregationParameterKind> inputArgumentKinds = parameterKinds.stream().filter(kind -> kind == INPUT_CHANNEL || kind == BLOCK_INPUT_CHANNEL || kind == NULLABLE_BLOCK_INPUT_CHANNEL).collect(toImmutableList());
boolean hasInputChannel = parameterKinds.stream().anyMatch(kind -> kind == BLOCK_INPUT_CHANNEL || kind == NULLABLE_BLOCK_INPUT_CHANNEL);
checkArgument(boundSignature.getArgumentTypes().size() - lambdaCount == inputArgumentKinds.size(), "Bound signature has %s arguments, but parameter kinds only has %s input arguments", boundSignature.getArgumentTypes().size() - lambdaCount, inputArgumentKinds.size());
List<AggregationParameterKind> expectedInputArgumentKinds = new ArrayList<>();
expectedInputArgumentKinds.addAll(stateArgumentKinds);
expectedInputArgumentKinds.addAll(inputArgumentKinds);
if (hasInputChannel) {
expectedInputArgumentKinds.add(BLOCK_INDEX);
}
checkArgument(expectedInputArgumentKinds.equals(parameterKinds), "Expected input parameter kinds %s, but got %s", expectedInputArgumentKinds, parameterKinds);
MethodType inputMethodType = inputMethod.type();
for (int argumentIndex = 0; argumentIndex < inputArgumentKinds.size(); argumentIndex++) {
int parameterIndex = stateArgumentKinds.size() + argumentIndex;
AggregationParameterKind inputArgument = inputArgumentKinds.get(argumentIndex);
if (inputArgument != INPUT_CHANNEL) {
continue;
}
Type argumentType = boundSignature.getArgumentType(argumentIndex);
// process argument through type value getter
MethodHandle valueGetter;
if (argumentType.getJavaType().equals(boolean.class)) {
valueGetter = BOOLEAN_TYPE_GETTER.bindTo(argumentType);
} else if (argumentType.getJavaType().equals(long.class)) {
valueGetter = LONG_TYPE_GETTER.bindTo(argumentType);
} else if (argumentType.getJavaType().equals(double.class)) {
valueGetter = DOUBLE_TYPE_GETTER.bindTo(argumentType);
} else {
valueGetter = OBJECT_TYPE_GETTER.bindTo(argumentType);
valueGetter = valueGetter.asType(valueGetter.type().changeReturnType(inputMethodType.parameterType(parameterIndex)));
}
inputMethod = collectArguments(inputMethod, parameterIndex, valueGetter);
// move the position argument to the end (and combine with other existing position argument)
inputMethodType = inputMethodType.changeParameterType(parameterIndex, Block.class);
ArrayList<Integer> reorder;
if (hasInputChannel) {
reorder = IntStream.range(0, inputMethodType.parameterCount()).boxed().collect(Collectors.toCollection(ArrayList::new));
reorder.add(parameterIndex + 1, inputMethodType.parameterCount() - 1 - lambdaCount);
} else {
inputMethodType = inputMethodType.insertParameterTypes(inputMethodType.parameterCount() - lambdaCount, int.class);
reorder = IntStream.range(0, inputMethodType.parameterCount()).boxed().collect(Collectors.toCollection(ArrayList::new));
int positionParameterIndex = inputMethodType.parameterCount() - 1 - lambdaCount;
reorder.remove(positionParameterIndex);
reorder.add(parameterIndex + 1, positionParameterIndex);
hasInputChannel = true;
}
inputMethod = permuteArguments(inputMethod, inputMethodType, reorder.stream().mapToInt(Integer::intValue).toArray());
}
return inputMethod;
}
Aggregations