Search in sources :

Example 1 with NULLABLE_BLOCK_INPUT_CHANNEL

use of io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.NULLABLE_BLOCK_INPUT_CHANNEL in project trino by trinodb.

the class AggregationFunctionAdapter method normalizeInputMethod.

public static MethodHandle normalizeInputMethod(MethodHandle inputMethod, BoundSignature boundSignature, List<AggregationParameterKind> parameterKinds, int lambdaCount) {
    requireNonNull(inputMethod, "inputMethod is null");
    requireNonNull(parameterKinds, "parameterKinds is null");
    requireNonNull(boundSignature, "boundSignature is null");
    checkArgument(inputMethod.type().parameterCount() - lambdaCount == parameterKinds.size(), "Input method has %s parameters, but parameter kinds only has %s items", inputMethod.type().parameterCount() - lambdaCount, parameterKinds.size());
    List<AggregationParameterKind> stateArgumentKinds = parameterKinds.stream().filter(STATE::equals).collect(toImmutableList());
    List<AggregationParameterKind> inputArgumentKinds = parameterKinds.stream().filter(kind -> kind == INPUT_CHANNEL || kind == BLOCK_INPUT_CHANNEL || kind == NULLABLE_BLOCK_INPUT_CHANNEL).collect(toImmutableList());
    boolean hasInputChannel = parameterKinds.stream().anyMatch(kind -> kind == BLOCK_INPUT_CHANNEL || kind == NULLABLE_BLOCK_INPUT_CHANNEL);
    checkArgument(boundSignature.getArgumentTypes().size() - lambdaCount == inputArgumentKinds.size(), "Bound signature has %s arguments, but parameter kinds only has %s input arguments", boundSignature.getArgumentTypes().size() - lambdaCount, inputArgumentKinds.size());
    List<AggregationParameterKind> expectedInputArgumentKinds = new ArrayList<>();
    expectedInputArgumentKinds.addAll(stateArgumentKinds);
    expectedInputArgumentKinds.addAll(inputArgumentKinds);
    if (hasInputChannel) {
        expectedInputArgumentKinds.add(BLOCK_INDEX);
    }
    checkArgument(expectedInputArgumentKinds.equals(parameterKinds), "Expected input parameter kinds %s, but got %s", expectedInputArgumentKinds, parameterKinds);
    MethodType inputMethodType = inputMethod.type();
    for (int argumentIndex = 0; argumentIndex < inputArgumentKinds.size(); argumentIndex++) {
        int parameterIndex = stateArgumentKinds.size() + argumentIndex;
        AggregationParameterKind inputArgument = inputArgumentKinds.get(argumentIndex);
        if (inputArgument != INPUT_CHANNEL) {
            continue;
        }
        Type argumentType = boundSignature.getArgumentType(argumentIndex);
        // process argument through type value getter
        MethodHandle valueGetter;
        if (argumentType.getJavaType().equals(boolean.class)) {
            valueGetter = BOOLEAN_TYPE_GETTER.bindTo(argumentType);
        } else if (argumentType.getJavaType().equals(long.class)) {
            valueGetter = LONG_TYPE_GETTER.bindTo(argumentType);
        } else if (argumentType.getJavaType().equals(double.class)) {
            valueGetter = DOUBLE_TYPE_GETTER.bindTo(argumentType);
        } else {
            valueGetter = OBJECT_TYPE_GETTER.bindTo(argumentType);
            valueGetter = valueGetter.asType(valueGetter.type().changeReturnType(inputMethodType.parameterType(parameterIndex)));
        }
        inputMethod = collectArguments(inputMethod, parameterIndex, valueGetter);
        // move the position argument to the end (and combine with other existing position argument)
        inputMethodType = inputMethodType.changeParameterType(parameterIndex, Block.class);
        ArrayList<Integer> reorder;
        if (hasInputChannel) {
            reorder = IntStream.range(0, inputMethodType.parameterCount()).boxed().collect(Collectors.toCollection(ArrayList::new));
            reorder.add(parameterIndex + 1, inputMethodType.parameterCount() - 1 - lambdaCount);
        } else {
            inputMethodType = inputMethodType.insertParameterTypes(inputMethodType.parameterCount() - lambdaCount, int.class);
            reorder = IntStream.range(0, inputMethodType.parameterCount()).boxed().collect(Collectors.toCollection(ArrayList::new));
            int positionParameterIndex = inputMethodType.parameterCount() - 1 - lambdaCount;
            reorder.remove(positionParameterIndex);
            reorder.add(parameterIndex + 1, positionParameterIndex);
            hasInputChannel = true;
        }
        inputMethod = permuteArguments(inputMethod, inputMethodType, reorder.stream().mapToInt(Integer::intValue).toArray());
    }
    return inputMethod;
}
Also used : IntStream(java.util.stream.IntStream) MethodHandle(java.lang.invoke.MethodHandle) STATE(io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.STATE) INPUT_CHANNEL(io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.INPUT_CHANNEL) BLOCK_INPUT_CHANNEL(io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.BLOCK_INPUT_CHANNEL) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) Type(io.trino.spi.type.Type) MethodHandles.permuteArguments(java.lang.invoke.MethodHandles.permuteArguments) Collectors(java.util.stream.Collectors) MethodHandles.lookup(java.lang.invoke.MethodHandles.lookup) BLOCK_INDEX(io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.BLOCK_INDEX) ArrayList(java.util.ArrayList) List(java.util.List) Preconditions.checkArgument(com.google.common.base.Preconditions.checkArgument) MethodType(java.lang.invoke.MethodType) ImmutableList(com.google.common.collect.ImmutableList) MethodHandles.collectArguments(java.lang.invoke.MethodHandles.collectArguments) Block(io.trino.spi.block.Block) NULLABLE_BLOCK_INPUT_CHANNEL(io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.NULLABLE_BLOCK_INPUT_CHANNEL) BoundSignature(io.trino.metadata.BoundSignature) Objects.requireNonNull(java.util.Objects.requireNonNull) MethodType(java.lang.invoke.MethodType) Type(io.trino.spi.type.Type) MethodType(java.lang.invoke.MethodType) ArrayList(java.util.ArrayList) Block(io.trino.spi.block.Block) MethodHandle(java.lang.invoke.MethodHandle)

Aggregations

Preconditions.checkArgument (com.google.common.base.Preconditions.checkArgument)1 ImmutableList (com.google.common.collect.ImmutableList)1 ImmutableList.toImmutableList (com.google.common.collect.ImmutableList.toImmutableList)1 BoundSignature (io.trino.metadata.BoundSignature)1 BLOCK_INDEX (io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.BLOCK_INDEX)1 BLOCK_INPUT_CHANNEL (io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.BLOCK_INPUT_CHANNEL)1 INPUT_CHANNEL (io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.INPUT_CHANNEL)1 NULLABLE_BLOCK_INPUT_CHANNEL (io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.NULLABLE_BLOCK_INPUT_CHANNEL)1 STATE (io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.STATE)1 Block (io.trino.spi.block.Block)1 Type (io.trino.spi.type.Type)1 MethodHandle (java.lang.invoke.MethodHandle)1 MethodHandles.collectArguments (java.lang.invoke.MethodHandles.collectArguments)1 MethodHandles.lookup (java.lang.invoke.MethodHandles.lookup)1 MethodHandles.permuteArguments (java.lang.invoke.MethodHandles.permuteArguments)1 MethodType (java.lang.invoke.MethodType)1 ArrayList (java.util.ArrayList)1 List (java.util.List)1 Objects.requireNonNull (java.util.Objects.requireNonNull)1 Collectors (java.util.stream.Collectors)1