Search in sources :

Example 1 with BoundSignature

use of io.trino.metadata.BoundSignature in project trino by trinodb.

the class ParametricAggregation method findMatchingImplementation.

private AggregationImplementation findMatchingImplementation(BoundSignature boundSignature) {
    Signature signature = boundSignature.toSignature();
    Optional<AggregationImplementation> foundImplementation = Optional.empty();
    if (implementations.getExactImplementations().containsKey(signature)) {
        foundImplementation = Optional.of(implementations.getExactImplementations().get(signature));
    } else {
        for (AggregationImplementation candidate : implementations.getGenericImplementations()) {
            if (candidate.areTypesAssignable(boundSignature)) {
                if (foundImplementation.isPresent()) {
                    throw new TrinoException(AMBIGUOUS_FUNCTION_CALL, format("Ambiguous function call (%s) for %s", boundSignature, getFunctionMetadata().getSignature()));
                }
                foundImplementation = Optional.of(candidate);
            }
        }
    }
    if (foundImplementation.isEmpty()) {
        throw new TrinoException(FUNCTION_IMPLEMENTATION_MISSING, format("Unsupported type parameters (%s) for %s", boundSignature, getFunctionMetadata().getSignature()));
    }
    return foundImplementation.get();
}
Also used : Signature(io.trino.metadata.Signature) BoundSignature(io.trino.metadata.BoundSignature) TrinoException(io.trino.spi.TrinoException)

Example 2 with BoundSignature

use of io.trino.metadata.BoundSignature in project trino by trinodb.

the class ParametricScalarImplementation method specialize.

public Optional<ScalarFunctionImplementation> specialize(FunctionBinding functionBinding, FunctionDependencies functionDependencies) {
    List<ScalarImplementationChoice> implementationChoices = new ArrayList<>();
    for (Map.Entry<String, Class<?>> entry : specializedTypeParameters.entrySet()) {
        if (!entry.getValue().isAssignableFrom(functionBinding.getTypeVariable(entry.getKey()).getJavaType())) {
            return Optional.empty();
        }
    }
    BoundSignature boundSignature = functionBinding.getBoundSignature();
    if (returnNativeContainerType != Object.class && returnNativeContainerType != boundSignature.getReturnType().getJavaType()) {
        return Optional.empty();
    }
    for (int i = 0; i < boundSignature.getArgumentTypes().size(); i++) {
        if (boundSignature.getArgumentTypes().get(i) instanceof FunctionType) {
            if (argumentNativeContainerTypes.get(i).isPresent()) {
                return Optional.empty();
            }
        } else {
            if (argumentNativeContainerTypes.get(i).isEmpty()) {
                return Optional.empty();
            }
            Class<?> argumentType = boundSignature.getArgumentTypes().get(i).getJavaType();
            Class<?> argumentNativeContainerType = argumentNativeContainerTypes.get(i).get();
            if (argumentNativeContainerType != Object.class && argumentNativeContainerType != argumentType) {
                return Optional.empty();
            }
        }
    }
    for (ParametricScalarImplementationChoice choice : choices) {
        MethodHandle boundMethodHandle = bindDependencies(choice.getMethodHandle(), choice.getDependencies(), functionBinding, functionDependencies);
        Optional<MethodHandle> boundConstructor = choice.getConstructor().map(constructor -> {
            MethodHandle result = bindDependencies(constructor, choice.getConstructorDependencies(), functionBinding, functionDependencies);
            checkCondition(result.type().parameterList().isEmpty(), FUNCTION_IMPLEMENTATION_ERROR, "All parameters of a constructor in a function definition class must be Dependencies. Signature: %s", boundSignature);
            return result;
        });
        implementationChoices.add(new ScalarImplementationChoice(choice.getReturnConvention(), choice.getArgumentConventions(), choice.getLambdaInterfaces(), boundMethodHandle.asType(javaMethodType(choice, boundSignature)), boundConstructor));
    }
    return Optional.of(new ChoicesScalarFunctionImplementation(boundSignature, implementationChoices));
}
Also used : FunctionType(io.trino.type.FunctionType) ArrayList(java.util.ArrayList) ScalarImplementationChoice(io.trino.operator.scalar.ChoicesScalarFunctionImplementation.ScalarImplementationChoice) BoundSignature(io.trino.metadata.BoundSignature) LongVariableConstraint(io.trino.metadata.LongVariableConstraint) Map(java.util.Map) ImmutableMap(com.google.common.collect.ImmutableMap) ChoicesScalarFunctionImplementation(io.trino.operator.scalar.ChoicesScalarFunctionImplementation) MethodHandle(java.lang.invoke.MethodHandle) Reflection.constructorMethodHandle(io.trino.util.Reflection.constructorMethodHandle)

Example 3 with BoundSignature

use of io.trino.metadata.BoundSignature in project trino by trinodb.

the class AggregationFunctionAdapter method normalizeInputMethod.

public static MethodHandle normalizeInputMethod(MethodHandle inputMethod, BoundSignature boundSignature, List<AggregationParameterKind> parameterKinds, int lambdaCount) {
    requireNonNull(inputMethod, "inputMethod is null");
    requireNonNull(parameterKinds, "parameterKinds is null");
    requireNonNull(boundSignature, "boundSignature is null");
    checkArgument(inputMethod.type().parameterCount() - lambdaCount == parameterKinds.size(), "Input method has %s parameters, but parameter kinds only has %s items", inputMethod.type().parameterCount() - lambdaCount, parameterKinds.size());
    List<AggregationParameterKind> stateArgumentKinds = parameterKinds.stream().filter(STATE::equals).collect(toImmutableList());
    List<AggregationParameterKind> inputArgumentKinds = parameterKinds.stream().filter(kind -> kind == INPUT_CHANNEL || kind == BLOCK_INPUT_CHANNEL || kind == NULLABLE_BLOCK_INPUT_CHANNEL).collect(toImmutableList());
    boolean hasInputChannel = parameterKinds.stream().anyMatch(kind -> kind == BLOCK_INPUT_CHANNEL || kind == NULLABLE_BLOCK_INPUT_CHANNEL);
    checkArgument(boundSignature.getArgumentTypes().size() - lambdaCount == inputArgumentKinds.size(), "Bound signature has %s arguments, but parameter kinds only has %s input arguments", boundSignature.getArgumentTypes().size() - lambdaCount, inputArgumentKinds.size());
    List<AggregationParameterKind> expectedInputArgumentKinds = new ArrayList<>();
    expectedInputArgumentKinds.addAll(stateArgumentKinds);
    expectedInputArgumentKinds.addAll(inputArgumentKinds);
    if (hasInputChannel) {
        expectedInputArgumentKinds.add(BLOCK_INDEX);
    }
    checkArgument(expectedInputArgumentKinds.equals(parameterKinds), "Expected input parameter kinds %s, but got %s", expectedInputArgumentKinds, parameterKinds);
    MethodType inputMethodType = inputMethod.type();
    for (int argumentIndex = 0; argumentIndex < inputArgumentKinds.size(); argumentIndex++) {
        int parameterIndex = stateArgumentKinds.size() + argumentIndex;
        AggregationParameterKind inputArgument = inputArgumentKinds.get(argumentIndex);
        if (inputArgument != INPUT_CHANNEL) {
            continue;
        }
        Type argumentType = boundSignature.getArgumentType(argumentIndex);
        // process argument through type value getter
        MethodHandle valueGetter;
        if (argumentType.getJavaType().equals(boolean.class)) {
            valueGetter = BOOLEAN_TYPE_GETTER.bindTo(argumentType);
        } else if (argumentType.getJavaType().equals(long.class)) {
            valueGetter = LONG_TYPE_GETTER.bindTo(argumentType);
        } else if (argumentType.getJavaType().equals(double.class)) {
            valueGetter = DOUBLE_TYPE_GETTER.bindTo(argumentType);
        } else {
            valueGetter = OBJECT_TYPE_GETTER.bindTo(argumentType);
            valueGetter = valueGetter.asType(valueGetter.type().changeReturnType(inputMethodType.parameterType(parameterIndex)));
        }
        inputMethod = collectArguments(inputMethod, parameterIndex, valueGetter);
        // move the position argument to the end (and combine with other existing position argument)
        inputMethodType = inputMethodType.changeParameterType(parameterIndex, Block.class);
        ArrayList<Integer> reorder;
        if (hasInputChannel) {
            reorder = IntStream.range(0, inputMethodType.parameterCount()).boxed().collect(Collectors.toCollection(ArrayList::new));
            reorder.add(parameterIndex + 1, inputMethodType.parameterCount() - 1 - lambdaCount);
        } else {
            inputMethodType = inputMethodType.insertParameterTypes(inputMethodType.parameterCount() - lambdaCount, int.class);
            reorder = IntStream.range(0, inputMethodType.parameterCount()).boxed().collect(Collectors.toCollection(ArrayList::new));
            int positionParameterIndex = inputMethodType.parameterCount() - 1 - lambdaCount;
            reorder.remove(positionParameterIndex);
            reorder.add(parameterIndex + 1, positionParameterIndex);
            hasInputChannel = true;
        }
        inputMethod = permuteArguments(inputMethod, inputMethodType, reorder.stream().mapToInt(Integer::intValue).toArray());
    }
    return inputMethod;
}
Also used : IntStream(java.util.stream.IntStream) MethodHandle(java.lang.invoke.MethodHandle) STATE(io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.STATE) INPUT_CHANNEL(io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.INPUT_CHANNEL) BLOCK_INPUT_CHANNEL(io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.BLOCK_INPUT_CHANNEL) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) Type(io.trino.spi.type.Type) MethodHandles.permuteArguments(java.lang.invoke.MethodHandles.permuteArguments) Collectors(java.util.stream.Collectors) MethodHandles.lookup(java.lang.invoke.MethodHandles.lookup) BLOCK_INDEX(io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.BLOCK_INDEX) ArrayList(java.util.ArrayList) List(java.util.List) Preconditions.checkArgument(com.google.common.base.Preconditions.checkArgument) MethodType(java.lang.invoke.MethodType) ImmutableList(com.google.common.collect.ImmutableList) MethodHandles.collectArguments(java.lang.invoke.MethodHandles.collectArguments) Block(io.trino.spi.block.Block) NULLABLE_BLOCK_INPUT_CHANNEL(io.trino.operator.aggregation.AggregationFunctionAdapter.AggregationParameterKind.NULLABLE_BLOCK_INPUT_CHANNEL) BoundSignature(io.trino.metadata.BoundSignature) Objects.requireNonNull(java.util.Objects.requireNonNull) MethodType(java.lang.invoke.MethodType) Type(io.trino.spi.type.Type) MethodType(java.lang.invoke.MethodType) ArrayList(java.util.ArrayList) Block(io.trino.spi.block.Block) MethodHandle(java.lang.invoke.MethodHandle)

Example 4 with BoundSignature

use of io.trino.metadata.BoundSignature in project trino by trinodb.

the class TestAnnotationEngineForAggregates method testNotDecomposableAggregationParse.

@Test
public void testNotDecomposableAggregationParse() {
    Signature expectedSignature = new Signature("custom_decomposable_aggregate", DoubleType.DOUBLE.getTypeSignature(), ImmutableList.of(DoubleType.DOUBLE.getTypeSignature()));
    ParametricAggregation aggregation = getOnlyElement(parseFunctionDefinitions(NotDecomposableAggregationFunction.class));
    assertEquals(aggregation.getFunctionMetadata().getDescription(), "Aggregate with Decomposable=false");
    assertTrue(aggregation.getFunctionMetadata().isDeterministic());
    assertEquals(aggregation.getFunctionMetadata().getSignature(), expectedSignature);
    BoundSignature boundSignature = new BoundSignature(aggregation.getFunctionMetadata().getSignature().getName(), DoubleType.DOUBLE, ImmutableList.of(DoubleType.DOUBLE));
    AggregationFunctionMetadata aggregationMetadata = aggregation.getAggregationMetadata();
    assertFalse(aggregationMetadata.isOrderSensitive());
    assertTrue(aggregationMetadata.getIntermediateTypes().isEmpty());
    aggregation.specialize(boundSignature, NO_FUNCTION_DEPENDENCIES);
}
Also used : TypeSignature(io.trino.spi.type.TypeSignature) Signature(io.trino.metadata.Signature) BoundSignature(io.trino.metadata.BoundSignature) BoundSignature(io.trino.metadata.BoundSignature) ParametricAggregation(io.trino.operator.aggregation.ParametricAggregation) AggregationFunctionMetadata(io.trino.metadata.AggregationFunctionMetadata) Test(org.testng.annotations.Test)

Example 5 with BoundSignature

use of io.trino.metadata.BoundSignature in project trino by trinodb.

the class TestAnnotationEngineForAggregates method testSimpleExplicitSpecializedAggregationParse.

// TODO this is not yet supported
@Test(enabled = false)
public void testSimpleExplicitSpecializedAggregationParse() {
    Signature expectedSignature = new Signature("explicit_specialized_aggregate", ImmutableList.of(typeVariable("T")), ImmutableList.of(), new TypeSignature("T"), ImmutableList.of(new TypeSignature(ARRAY, TypeSignatureParameter.typeParameter(new TypeSignature("T")))), false);
    ParametricAggregation aggregation = getOnlyElement(parseFunctionDefinitions(ExplicitSpecializedAggregationFunction.class));
    assertEquals(aggregation.getFunctionMetadata().getDescription(), "Simple explicit specialized aggregate");
    assertTrue(aggregation.getFunctionMetadata().isDeterministic());
    assertEquals(aggregation.getFunctionMetadata().getSignature(), expectedSignature);
    ParametricImplementationsGroup<AggregationImplementation> implementations = aggregation.getImplementations();
    assertImplementationCount(implementations, 0, 1, 1);
    AggregationImplementation implementation1 = implementations.getSpecializedImplementations().get(0);
    assertTrue(implementation1.hasSpecializedTypeParameters());
    assertFalse(implementation1.hasSpecializedTypeParameters());
    assertEquals(implementation1.getInputParameterKinds(), ImmutableList.of(STATE, INPUT_CHANNEL));
    AggregationImplementation implementation2 = implementations.getSpecializedImplementations().get(1);
    assertTrue(implementation2.hasSpecializedTypeParameters());
    assertFalse(implementation2.hasSpecializedTypeParameters());
    assertEquals(implementation2.getInputParameterKinds(), ImmutableList.of(STATE, INPUT_CHANNEL));
    BoundSignature boundSignature = new BoundSignature(aggregation.getFunctionMetadata().getSignature().getName(), DoubleType.DOUBLE, ImmutableList.of(new ArrayType(DoubleType.DOUBLE)));
    AggregationFunctionMetadata aggregationMetadata = aggregation.getAggregationMetadata();
    assertFalse(aggregationMetadata.isOrderSensitive());
    assertFalse(aggregationMetadata.getIntermediateTypes().isEmpty());
    aggregation.specialize(boundSignature, NO_FUNCTION_DEPENDENCIES);
}
Also used : AggregationImplementation(io.trino.operator.aggregation.AggregationImplementation) ArrayType(io.trino.spi.type.ArrayType) TypeSignature(io.trino.spi.type.TypeSignature) TypeSignature(io.trino.spi.type.TypeSignature) Signature(io.trino.metadata.Signature) BoundSignature(io.trino.metadata.BoundSignature) BoundSignature(io.trino.metadata.BoundSignature) ParametricAggregation(io.trino.operator.aggregation.ParametricAggregation) AggregationFunctionMetadata(io.trino.metadata.AggregationFunctionMetadata) Test(org.testng.annotations.Test)

Aggregations

BoundSignature (io.trino.metadata.BoundSignature)27 Signature (io.trino.metadata.Signature)20 TypeSignature (io.trino.spi.type.TypeSignature)18 Test (org.testng.annotations.Test)17 ParametricAggregation (io.trino.operator.aggregation.ParametricAggregation)13 AggregationFunctionMetadata (io.trino.metadata.AggregationFunctionMetadata)12 AggregationImplementation (io.trino.operator.aggregation.AggregationImplementation)12 MethodHandle (java.lang.invoke.MethodHandle)8 FunctionMetadata (io.trino.metadata.FunctionMetadata)7 ImmutableList (com.google.common.collect.ImmutableList)6 FunctionDependencies (io.trino.metadata.FunctionDependencies)6 List (java.util.List)6 SqlScalarFunction (io.trino.metadata.SqlScalarFunction)5 ImmutableList.toImmutableList (com.google.common.collect.ImmutableList.toImmutableList)4 FunctionNullability (io.trino.metadata.FunctionNullability)4 Type (io.trino.spi.type.Type)4 Preconditions.checkArgument (com.google.common.base.Preconditions.checkArgument)3 FunctionDependencyDeclaration (io.trino.metadata.FunctionDependencyDeclaration)3 ChoicesScalarFunctionImplementation (io.trino.operator.scalar.ChoicesScalarFunctionImplementation)3 TrinoException (io.trino.spi.TrinoException)3