Search in sources :

Example 1 with STATE

use of io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.STATE in project trino by trinodb.

the class AggregationFunctionAdapter method normalizeInputMethod.

public static MethodHandle normalizeInputMethod(MethodHandle inputMethod, BoundSignature boundSignature, List<AggregationParameterKind> parameterKinds, int lambdaCount) {
    requireNonNull(inputMethod, "inputMethod is null");
    requireNonNull(parameterKinds, "parameterKinds is null");
    requireNonNull(boundSignature, "boundSignature is null");
    checkArgument(inputMethod.type().parameterCount() - lambdaCount == parameterKinds.size(), "Input method has %s parameters, but parameter kinds only has %s items", inputMethod.type().parameterCount() - lambdaCount, parameterKinds.size());
    List<AggregationParameterKind> stateArgumentKinds = parameterKinds.stream().filter(STATE::equals).collect(toImmutableList());
    List<AggregationParameterKind> inputArgumentKinds = parameterKinds.stream().filter(kind -> kind == INPUT_CHANNEL || kind == BLOCK_INPUT_CHANNEL || kind == NULLABLE_BLOCK_INPUT_CHANNEL).collect(toImmutableList());
    boolean hasInputChannel = parameterKinds.stream().anyMatch(kind -> kind == BLOCK_INPUT_CHANNEL || kind == NULLABLE_BLOCK_INPUT_CHANNEL);
    checkArgument(boundSignature.getArgumentTypes().size() - lambdaCount == inputArgumentKinds.size(), "Bound signature has %s arguments, but parameter kinds only has %s input arguments", boundSignature.getArgumentTypes().size() - lambdaCount, inputArgumentKinds.size());
    List<AggregationParameterKind> expectedInputArgumentKinds = new ArrayList<>();
    expectedInputArgumentKinds.addAll(stateArgumentKinds);
    expectedInputArgumentKinds.addAll(inputArgumentKinds);
    if (hasInputChannel) {
        expectedInputArgumentKinds.add(BLOCK_INDEX);
    }
    checkArgument(expectedInputArgumentKinds.equals(parameterKinds), "Expected input parameter kinds %s, but got %s", expectedInputArgumentKinds, parameterKinds);
    MethodType inputMethodType = inputMethod.type();
    for (int argumentIndex = 0; argumentIndex < inputArgumentKinds.size(); argumentIndex++) {
        int parameterIndex = stateArgumentKinds.size() + argumentIndex;
        AggregationParameterKind inputArgument = inputArgumentKinds.get(argumentIndex);
        if (inputArgument != INPUT_CHANNEL) {
            continue;
        }
        Type argumentType = boundSignature.getArgumentType(argumentIndex);
        // process argument through type value getter
        MethodHandle valueGetter;
        if (argumentType.getJavaType().equals(boolean.class)) {
            valueGetter = BOOLEAN_TYPE_GETTER.bindTo(argumentType);
        } else if (argumentType.getJavaType().equals(long.class)) {
            valueGetter = LONG_TYPE_GETTER.bindTo(argumentType);
        } else if (argumentType.getJavaType().equals(double.class)) {
            valueGetter = DOUBLE_TYPE_GETTER.bindTo(argumentType);
        } else {
            valueGetter = OBJECT_TYPE_GETTER.bindTo(argumentType);
            valueGetter = valueGetter.asType(valueGetter.type().changeReturnType(inputMethodType.parameterType(parameterIndex)));
        }
        inputMethod = collectArguments(inputMethod, parameterIndex, valueGetter);
        // move the position argument to the end (and combine with other existing position argument)
        inputMethodType = inputMethodType.changeParameterType(parameterIndex, Block.class);
        ArrayList<Integer> reorder;
        if (hasInputChannel) {
            reorder = IntStream.range(0, inputMethodType.parameterCount()).boxed().collect(Collectors.toCollection(ArrayList::new));
            reorder.add(parameterIndex + 1, inputMethodType.parameterCount() - 1 - lambdaCount);
        } else {
            inputMethodType = inputMethodType.insertParameterTypes(inputMethodType.parameterCount() - lambdaCount, int.class);
            reorder = IntStream.range(0, inputMethodType.parameterCount()).boxed().collect(Collectors.toCollection(ArrayList::new));
            int positionParameterIndex = inputMethodType.parameterCount() - 1 - lambdaCount;
            reorder.remove(positionParameterIndex);
            reorder.add(parameterIndex + 1, positionParameterIndex);
            hasInputChannel = true;
        }
        inputMethod = permuteArguments(inputMethod, inputMethodType, reorder.stream().mapToInt(Integer::intValue).toArray());
    }
    return inputMethod;
}
Also used : IntStream(java.util.stream.IntStream) MethodHandle(java.lang.invoke.MethodHandle) STATE(io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.STATE) INPUT_CHANNEL(io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.INPUT_CHANNEL) BLOCK_INPUT_CHANNEL(io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.BLOCK_INPUT_CHANNEL) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) Type(io.trino.spi.type.Type) MethodHandles.permuteArguments(java.lang.invoke.MethodHandles.permuteArguments) Collectors(java.util.stream.Collectors) MethodHandles.lookup(java.lang.invoke.MethodHandles.lookup) BLOCK_INDEX(io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.BLOCK_INDEX) ArrayList(java.util.ArrayList) List(java.util.List) Preconditions.checkArgument(com.google.common.base.Preconditions.checkArgument) MethodType(java.lang.invoke.MethodType) ImmutableList(com.google.common.collect.ImmutableList) MethodHandles.collectArguments(java.lang.invoke.MethodHandles.collectArguments) Block(io.trino.spi.block.Block) NULLABLE_BLOCK_INPUT_CHANNEL(io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.NULLABLE_BLOCK_INPUT_CHANNEL) BoundSignature(io.trino.metadata.BoundSignature) Objects.requireNonNull(java.util.Objects.requireNonNull) MethodType(java.lang.invoke.MethodType) Type(io.trino.spi.type.Type) MethodType(java.lang.invoke.MethodType) ArrayList(java.util.ArrayList) Block(io.trino.spi.block.Block) MethodHandle(java.lang.invoke.MethodHandle)

Example 2 with STATE

use of io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.STATE in project trino by trinodb.

the class QuantileDigestAggregationFunction method specialize.

@Override
public AggregationMetadata specialize(BoundSignature boundSignature) {
    QuantileDigestType outputType = (QuantileDigestType) boundSignature.getReturnType();
    Type valueType = outputType.getValueType();
    int arity = boundSignature.getArity();
    QuantileDigestStateSerializer stateSerializer = new QuantileDigestStateSerializer(valueType);
    MethodHandle inputFunction = getMethodHandle(valueType, arity);
    inputFunction = normalizeInputMethod(inputFunction, boundSignature, ImmutableList.<AggregationParameterKind>builder().add(STATE).addAll(getInputTypes(valueType, arity).stream().map(ignored -> INPUT_CHANNEL).collect(Collectors.toList())).build());
    return new AggregationMetadata(inputFunction, Optional.empty(), Optional.of(COMBINE_FUNCTION), OUTPUT_FUNCTION.bindTo(stateSerializer), ImmutableList.of(new AccumulatorStateDescriptor<>(QuantileDigestState.class, stateSerializer, new QuantileDigestStateFactory())));
}
Also used : MethodHandle(java.lang.invoke.MethodHandle) QuantileDigest(io.airlift.stats.QuantileDigest) Signature.comparableTypeParameter(io.trino.metadata.Signature.comparableTypeParameter) AccumulatorStateDescriptor(io.trino.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor) MethodHandles.insertArguments(java.lang.invoke.MethodHandles.insertArguments) FunctionNullability(io.trino.metadata.FunctionNullability) Type(io.trino.spi.type.Type) DEFAULT_WEIGHT(io.trino.operator.scalar.QuantileDigestFunctions.DEFAULT_WEIGHT) QuantileDigestStateFactory(io.trino.operator.aggregation.state.QuantileDigestStateFactory) Float.intBitsToFloat(java.lang.Float.intBitsToFloat) QDIGEST(io.trino.spi.type.StandardTypes.QDIGEST) ImmutableList(com.google.common.collect.ImmutableList) TypeSignature.parametricType(io.trino.spi.type.TypeSignature.parametricType) FloatingPointBitsConverterUtil.floatToSortableInt(io.trino.operator.aggregation.FloatingPointBitsConverterUtil.floatToSortableInt) QuantileDigestStateSerializer(io.trino.operator.aggregation.state.QuantileDigestStateSerializer) Signature(io.trino.metadata.Signature) FunctionMetadata(io.trino.metadata.FunctionMetadata) TypeSignature(io.trino.spi.type.TypeSignature) AggregationFunctionAdapter.normalizeInputMethod(io.trino.operator.aggregation.AggregationFunctionAdapter.normalizeInputMethod) QuantileDigestFunctions.verifyWeight(io.trino.operator.scalar.QuantileDigestFunctions.verifyWeight) STATE(io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.STATE) INPUT_CHANNEL(io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.INPUT_CHANNEL) Collections.nCopies(java.util.Collections.nCopies) QuantileDigestType(io.trino.spi.type.QuantileDigestType) StandardTypes(io.trino.spi.type.StandardTypes) QuantileDigestState(io.trino.operator.aggregation.state.QuantileDigestState) FloatingPointBitsConverterUtil.doubleToSortableLong(io.trino.operator.aggregation.FloatingPointBitsConverterUtil.doubleToSortableLong) AGGREGATE(io.trino.metadata.FunctionKind.AGGREGATE) Collectors(java.util.stream.Collectors) DEFAULT_ACCURACY(io.trino.operator.scalar.QuantileDigestFunctions.DEFAULT_ACCURACY) String.format(java.lang.String.format) AggregationParameterKind(io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind) DOUBLE(io.trino.spi.type.DoubleType.DOUBLE) AggregationFunctionMetadata(io.trino.metadata.AggregationFunctionMetadata) List(java.util.List) QuantileDigestFunctions.verifyAccuracy(io.trino.operator.scalar.QuantileDigestFunctions.verifyAccuracy) BIGINT(io.trino.spi.type.BigintType.BIGINT) BoundSignature(io.trino.metadata.BoundSignature) Optional(java.util.Optional) BlockBuilder(io.trino.spi.block.BlockBuilder) SqlAggregationFunction(io.trino.metadata.SqlAggregationFunction) Reflection.methodHandle(io.trino.util.Reflection.methodHandle) Type(io.trino.spi.type.Type) TypeSignature.parametricType(io.trino.spi.type.TypeSignature.parametricType) QuantileDigestType(io.trino.spi.type.QuantileDigestType) QuantileDigestStateSerializer(io.trino.operator.aggregation.state.QuantileDigestStateSerializer) AccumulatorStateDescriptor(io.trino.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor) QuantileDigestStateFactory(io.trino.operator.aggregation.state.QuantileDigestStateFactory) QuantileDigestType(io.trino.spi.type.QuantileDigestType) MethodHandle(java.lang.invoke.MethodHandle)

Aggregations

ImmutableList (com.google.common.collect.ImmutableList)2 BoundSignature (io.trino.metadata.BoundSignature)2 INPUT_CHANNEL (io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.INPUT_CHANNEL)2 STATE (io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.STATE)2 Type (io.trino.spi.type.Type)2 MethodHandle (java.lang.invoke.MethodHandle)2 List (java.util.List)2 Preconditions.checkArgument (com.google.common.base.Preconditions.checkArgument)1 ImmutableList.toImmutableList (com.google.common.collect.ImmutableList.toImmutableList)1 QuantileDigest (io.airlift.stats.QuantileDigest)1 AggregationFunctionMetadata (io.trino.metadata.AggregationFunctionMetadata)1 AGGREGATE (io.trino.metadata.FunctionKind.AGGREGATE)1 FunctionMetadata (io.trino.metadata.FunctionMetadata)1 FunctionNullability (io.trino.metadata.FunctionNullability)1 Signature (io.trino.metadata.Signature)1 Signature.comparableTypeParameter (io.trino.metadata.Signature.comparableTypeParameter)1 SqlAggregationFunction (io.trino.metadata.SqlAggregationFunction)1 AggregationParameterKind (io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind)1 BLOCK_INDEX (io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.BLOCK_INDEX)1 BLOCK_INPUT_CHANNEL (io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.BLOCK_INPUT_CHANNEL)1