Search in sources :

Example 1 with AggregateFn

use of org.apache.beam.sdk.extensions.sql.udf.AggregateFn in project beam by apache.

the class JavaUdfLoader method loadJar.

private FunctionDefinitions loadJar(String jarPath) throws IOException {
    if (functionCache.containsKey(jarPath)) {
        LOG.debug("Using cached function definitions from {}", jarPath);
        return functionCache.get(jarPath);
    }
    ClassLoader classLoader = createClassLoader(jarPath);
    Map<List<String>, ScalarFn> scalarFunctions = new HashMap<>();
    Map<List<String>, AggregateFn> aggregateFunctions = new HashMap<>();
    Iterator<UdfProvider> providers = getUdfProviders(classLoader);
    int providersCount = 0;
    while (providers.hasNext()) {
        providersCount++;
        UdfProvider provider = providers.next();
        provider.userDefinedScalarFunctions().forEach((functionName, implementation) -> {
            List<String> functionPath = ImmutableList.copyOf(functionName.split("\\."));
            if (scalarFunctions.containsKey(functionPath)) {
                throw new IllegalArgumentException(String.format("Found multiple definitions of scalar function %s in %s.", functionName, jarPath));
            }
            scalarFunctions.put(functionPath, implementation);
        });
        provider.userDefinedAggregateFunctions().forEach((functionName, implementation) -> {
            List<String> functionPath = ImmutableList.copyOf(functionName.split("\\."));
            if (aggregateFunctions.containsKey(functionPath)) {
                throw new IllegalArgumentException(String.format("Found multiple definitions of aggregate function %s in %s.", functionName, jarPath));
            }
            aggregateFunctions.put(functionPath, implementation);
        });
    }
    if (providersCount == 0) {
        throw new ProviderNotFoundException(String.format("No %s implementation found in %s. Create a class implementing %s and annotate it with @AutoService(%s.class).", UdfProvider.class.getSimpleName(), jarPath, UdfProvider.class.getSimpleName(), UdfProvider.class.getSimpleName()));
    }
    LOG.info("Loaded {} implementations of {} from {} with {} scalar function(s).", providersCount, UdfProvider.class.getSimpleName(), jarPath, scalarFunctions.size());
    FunctionDefinitions userFunctionDefinitions = FunctionDefinitions.newBuilder().setScalarFunctions(ImmutableMap.copyOf(scalarFunctions)).setAggregateFunctions(ImmutableMap.copyOf(aggregateFunctions)).build();
    functionCache.put(jarPath, userFunctionDefinitions);
    return userFunctionDefinitions;
}
Also used : ScalarFn(org.apache.beam.sdk.extensions.sql.udf.ScalarFn) HashMap(java.util.HashMap) UdfProvider(org.apache.beam.sdk.extensions.sql.udf.UdfProvider) ProviderNotFoundException(java.nio.file.ProviderNotFoundException) AggregateFn(org.apache.beam.sdk.extensions.sql.udf.AggregateFn) URLClassLoader(java.net.URLClassLoader) ArrayList(java.util.ArrayList) List(java.util.List) ImmutableList(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList)

Aggregations

URLClassLoader (java.net.URLClassLoader)1 ProviderNotFoundException (java.nio.file.ProviderNotFoundException)1 ArrayList (java.util.ArrayList)1 HashMap (java.util.HashMap)1 List (java.util.List)1 AggregateFn (org.apache.beam.sdk.extensions.sql.udf.AggregateFn)1 ScalarFn (org.apache.beam.sdk.extensions.sql.udf.ScalarFn)1 UdfProvider (org.apache.beam.sdk.extensions.sql.udf.UdfProvider)1 ImmutableList (org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList)1