use of org.apache.beam.sdk.extensions.sql.udf.ScalarFn in project beam by apache.
the class SqlCreateFunction method execute.
@Override
public void execute(CalcitePrepare.Context context) {
final Pair<CalciteSchema, String> pair = SqlDdlNodes.schema(context, true, functionName);
SchemaPlus schema = pair.left.plus();
String lastName = pair.right;
if (!schema.getFunctions(lastName).isEmpty()) {
throw SqlUtil.newContextException(functionName.getParserPosition(), RESOURCE.internal(String.format("Function %s is already defined.", lastName)));
}
JavaUdfLoader udfLoader = new JavaUdfLoader();
// TODO(BEAM-12355) Support qualified function names.
List<String> functionPath = ImmutableList.of(lastName);
if (!(jarPath instanceof SqlCharStringLiteral)) {
throw SqlUtil.newContextException(jarPath.getParserPosition(), RESOURCE.internal("Jar path is not instanceof SqlCharStringLiteral."));
}
String unquotedJarPath = ((SqlCharStringLiteral) jarPath).getNlsString().getValue();
if (isAggregate) {
// Try loading the aggregate function just to make sure it exists. LazyAggregateCombineFn will
// need to fetch it again at runtime.
udfLoader.loadAggregateFunction(functionPath, unquotedJarPath);
LazyAggregateCombineFn<?, ?, ?> combineFn = new LazyAggregateCombineFn<>(functionPath, unquotedJarPath);
schema.add(lastName, combineFn.getUdafImpl());
} else {
ScalarFn scalarFn = udfLoader.loadScalarFunction(functionPath, unquotedJarPath);
Method method = ScalarFnReflector.getApplyMethod(scalarFn);
Function function = ScalarFunctionImpl.create(method, unquotedJarPath);
schema.add(lastName, function);
}
}
use of org.apache.beam.sdk.extensions.sql.udf.ScalarFn in project beam by apache.
the class BeamZetaSqlCatalog method addFunction.
void addFunction(ResolvedNodes.ResolvedCreateFunctionStmt createFunctionStmt) {
String functionGroup = getFunctionGroup(createFunctionStmt);
switch(functionGroup) {
case USER_DEFINED_SQL_FUNCTIONS:
sqlScalarUdfs.put(createFunctionStmt.getNamePath(), createFunctionStmt);
break;
case USER_DEFINED_JAVA_SCALAR_FUNCTIONS:
String functionName = String.join(".", createFunctionStmt.getNamePath());
for (FunctionArgumentType argumentType : createFunctionStmt.getSignature().getFunctionArgumentList()) {
Type type = argumentType.getType();
if (type == null) {
throw new UnsupportedOperationException("UDF templated argument types are not supported.");
}
validateJavaUdfZetaSqlType(type, functionName);
}
if (createFunctionStmt.getReturnType() == null) {
throw new IllegalArgumentException("UDF return type must not be null.");
}
validateJavaUdfZetaSqlType(createFunctionStmt.getReturnType(), functionName);
String jarPath = getJarPath(createFunctionStmt);
ScalarFn scalarFn = javaUdfLoader.loadScalarFunction(createFunctionStmt.getNamePath(), jarPath);
Method method = ScalarFnReflector.getApplyMethod(scalarFn);
javaScalarUdfs.put(createFunctionStmt.getNamePath(), UserFunctionDefinitions.JavaScalarFunction.create(method, jarPath));
break;
case USER_DEFINED_JAVA_AGGREGATE_FUNCTIONS:
jarPath = getJarPath(createFunctionStmt);
// Try loading the aggregate function just to make sure it exists. LazyAggregateCombineFn
// will need to fetch it again at runtime.
javaUdfLoader.loadAggregateFunction(createFunctionStmt.getNamePath(), jarPath);
Combine.CombineFn<?, ?, ?> combineFn = new LazyAggregateCombineFn<>(createFunctionStmt.getNamePath(), jarPath);
javaUdafs.put(createFunctionStmt.getNamePath(), combineFn);
break;
default:
throw new IllegalArgumentException(String.format("Encountered unrecognized function group %s.", functionGroup));
}
zetaSqlCatalog.addFunction(new Function(createFunctionStmt.getNamePath(), functionGroup, createFunctionStmt.getIsAggregate() ? ZetaSQLFunctions.FunctionEnums.Mode.AGGREGATE : ZetaSQLFunctions.FunctionEnums.Mode.SCALAR, ImmutableList.of(createFunctionStmt.getSignature())));
}
use of org.apache.beam.sdk.extensions.sql.udf.ScalarFn in project beam by apache.
the class ScalarFnReflector method getApplyMethod.
/**
* Gets the method annotated with {@link
* org.apache.beam.sdk.extensions.sql.udf.ScalarFn.ApplyMethod} from {@code scalarFn}.
*
* <p>There must be exactly one method annotated with {@link
* org.apache.beam.sdk.extensions.sql.udf.ScalarFn.ApplyMethod}, and it must be public.
*/
public static Method getApplyMethod(ScalarFn scalarFn) {
Class<? extends ScalarFn> clazz = scalarFn.getClass();
Collection<Method> matches = ReflectHelpers.declaredMethodsWithAnnotation(ScalarFn.ApplyMethod.class, clazz, ScalarFn.class);
if (matches.isEmpty()) {
throw new IllegalArgumentException(String.format("No method annotated with @%s found in class %s.", ScalarFn.ApplyMethod.class.getSimpleName(), clazz.getName()));
}
// If we have at least one match, then either it should be the only match
// or it should be an extension of the other matches (which came from parent
// classes).
Method first = matches.iterator().next();
for (Method other : matches) {
if (!first.getName().equals(other.getName()) || !Arrays.equals(first.getParameterTypes(), other.getParameterTypes())) {
throw new IllegalArgumentException(String.format("Found multiple methods annotated with @%s. [%s] and [%s]", ScalarFn.ApplyMethod.class.getSimpleName(), ReflectHelpers.formatMethod(first), ReflectHelpers.formatMethod(other)));
}
}
// Method must be public.
if ((first.getModifiers() & Modifier.PUBLIC) == 0) {
throw new IllegalArgumentException(String.format("Method %s is not public.", ReflectHelpers.formatMethod(first)));
}
return first;
}
use of org.apache.beam.sdk.extensions.sql.udf.ScalarFn in project beam by apache.
the class JavaUdfLoader method loadJar.
private FunctionDefinitions loadJar(String jarPath) throws IOException {
if (functionCache.containsKey(jarPath)) {
LOG.debug("Using cached function definitions from {}", jarPath);
return functionCache.get(jarPath);
}
ClassLoader classLoader = createClassLoader(jarPath);
Map<List<String>, ScalarFn> scalarFunctions = new HashMap<>();
Map<List<String>, AggregateFn> aggregateFunctions = new HashMap<>();
Iterator<UdfProvider> providers = getUdfProviders(classLoader);
int providersCount = 0;
while (providers.hasNext()) {
providersCount++;
UdfProvider provider = providers.next();
provider.userDefinedScalarFunctions().forEach((functionName, implementation) -> {
List<String> functionPath = ImmutableList.copyOf(functionName.split("\\."));
if (scalarFunctions.containsKey(functionPath)) {
throw new IllegalArgumentException(String.format("Found multiple definitions of scalar function %s in %s.", functionName, jarPath));
}
scalarFunctions.put(functionPath, implementation);
});
provider.userDefinedAggregateFunctions().forEach((functionName, implementation) -> {
List<String> functionPath = ImmutableList.copyOf(functionName.split("\\."));
if (aggregateFunctions.containsKey(functionPath)) {
throw new IllegalArgumentException(String.format("Found multiple definitions of aggregate function %s in %s.", functionName, jarPath));
}
aggregateFunctions.put(functionPath, implementation);
});
}
if (providersCount == 0) {
throw new ProviderNotFoundException(String.format("No %s implementation found in %s. Create a class implementing %s and annotate it with @AutoService(%s.class).", UdfProvider.class.getSimpleName(), jarPath, UdfProvider.class.getSimpleName(), UdfProvider.class.getSimpleName()));
}
LOG.info("Loaded {} implementations of {} from {} with {} scalar function(s).", providersCount, UdfProvider.class.getSimpleName(), jarPath, scalarFunctions.size());
FunctionDefinitions userFunctionDefinitions = FunctionDefinitions.newBuilder().setScalarFunctions(ImmutableMap.copyOf(scalarFunctions)).setAggregateFunctions(ImmutableMap.copyOf(aggregateFunctions)).build();
functionCache.put(jarPath, userFunctionDefinitions);
return userFunctionDefinitions;
}
Aggregations