use of org.apache.beam.sdk.extensions.sql.impl.UdafImpl in project beam by apache.
the class AggregateScanConverter method convertAggCall.
private AggregateCall convertAggCall(ResolvedComputedColumn computedColumn, int columnRefOff, int groupCount, RelNode input) {
ResolvedAggregateFunctionCall aggregateFunctionCall = (ResolvedAggregateFunctionCall) computedColumn.getExpr();
// Reject AVG(INT64)
if (aggregateFunctionCall.getFunction().getName().equals("avg")) {
FunctionSignature signature = aggregateFunctionCall.getSignature();
if (signature.getFunctionArgumentList().get(0).getType().getKind().equals(TypeKind.TYPE_INT64)) {
throw new UnsupportedOperationException(AVG_ILLEGAL_LONG_INPUT_TYPE);
}
}
// Reject aggregation DISTINCT
if (aggregateFunctionCall.getDistinct()) {
throw new UnsupportedOperationException("Does not support " + aggregateFunctionCall.getFunction().getSqlName() + " DISTINCT. 'SELECT DISTINCT' syntax could be used to deduplicate before" + " aggregation.");
}
final SqlAggFunction sqlAggFunction;
if (aggregateFunctionCall.getFunction().getGroup().equals(BeamZetaSqlCatalog.USER_DEFINED_JAVA_AGGREGATE_FUNCTIONS)) {
// Create a new operator for user-defined functions.
SqlReturnTypeInference typeInference = x -> ZetaSqlCalciteTranslationUtils.toCalciteType(aggregateFunctionCall.getFunction().getSignatureList().get(0).getResultType().getType(), // TODO(BEAM-9514) set nullable=true
false, getCluster().getRexBuilder());
UdafImpl<?, ?, ?> impl = new UdafImpl<>(getExpressionConverter().userFunctionDefinitions.javaAggregateFunctions().get(aggregateFunctionCall.getFunction().getNamePath()));
sqlAggFunction = SqlOperators.createUdafOperator(aggregateFunctionCall.getFunction().getName(), typeInference, impl);
} else {
// Look up builtin functions in SqlOperatorMappingTable.
sqlAggFunction = (SqlAggFunction) SqlOperatorMappingTable.create(aggregateFunctionCall);
if (sqlAggFunction == null) {
throw new UnsupportedOperationException("Does not support ZetaSQL aggregate function: " + aggregateFunctionCall.getFunction().getName());
}
}
List<Integer> argList = new ArrayList<>();
ResolvedAggregateFunctionCall expr = ((ResolvedAggregateFunctionCall) computedColumn.getExpr());
List<ZetaSQLResolvedNodeKind.ResolvedNodeKind> resolvedNodeKinds = Arrays.asList(RESOLVED_CAST, RESOLVED_COLUMN_REF, RESOLVED_GET_STRUCT_FIELD);
for (int i = 0; i < expr.getArgumentList().size(); i++) {
// Throw an error if aggregate function's input isn't either a ColumnRef or a cast(ColumnRef).
// TODO: is there a general way to handle aggregation calls conversion?
ZetaSQLResolvedNodeKind.ResolvedNodeKind resolvedNodeKind = expr.getArgumentList().get(i).nodeKind();
if (i == 0 && resolvedNodeKinds.contains(resolvedNodeKind)) {
argList.add(columnRefOff);
} else if (i > 0 && resolvedNodeKind == RESOLVED_LITERAL) {
continue;
} else {
throw new UnsupportedOperationException("Aggregate function only accepts Column Reference or CAST(Column Reference) as the first argument and " + "Literals as subsequent arguments as its inputs");
}
}
String aggName = getTrait().resolveAlias(computedColumn.getColumn());
return AggregateCall.create(sqlAggFunction, false, false, false, argList, -1, null, RelCollations.EMPTY, groupCount, input, // When we pass null as the return type, Calcite infers it for us.
null, aggName);
}
use of org.apache.beam.sdk.extensions.sql.impl.UdafImpl in project beam by apache.
the class TypedCombineFnDelegateTest method testParameterExtractionFromCombineFn_CombineFnDelegate_WithListInsteadOfArray.
@Test
public void testParameterExtractionFromCombineFn_CombineFnDelegate_WithListInsteadOfArray() {
Combine.BinaryCombineFn<List<List<String>>> max = Max.of((Comparator<List<List<String>>> & Serializable) (a, b) -> Integer.compare(a.get(0).get(0).length(), b.get(0).get(0).length()));
UdafImpl<List<List<String>>, Combine.Holder<List<List<String>>>, List<List<String>>> udaf = new UdafImpl<>(new TypedCombineFnDelegate<List<List<String>>, Combine.Holder<List<List<String>>>, List<List<String>>>(max) {
});
RelDataTypeFactory typeFactory = new JavaTypeFactoryImpl(RelDataTypeSystem.DEFAULT);
List<FunctionParameter> parameters = udaf.getParameters();
assertEquals(1, parameters.size());
assertEquals(SqlTypeName.ARRAY, parameters.get(0).getType(typeFactory).getSqlTypeName());
}
use of org.apache.beam.sdk.extensions.sql.impl.UdafImpl in project beam by apache.
the class TypedCombineFnDelegateTest method testParameterExtractionFromCombineFn_CombineFnDelegate.
@Test
public void testParameterExtractionFromCombineFn_CombineFnDelegate() {
Combine.BinaryCombineFn<String> max = Max.of((Comparator<String> & Serializable) (a, b) -> Integer.compare(a.length(), b.length()));
UdafImpl<String, Combine.Holder<String>, String> udaf = new UdafImpl<>(new TypedCombineFnDelegate<String, Combine.Holder<String>, String>(max) {
});
RelDataTypeFactory typeFactory = new JavaTypeFactoryImpl(RelDataTypeSystem.DEFAULT);
List<FunctionParameter> parameters = udaf.getParameters();
assertEquals(1, parameters.size());
assertEquals(SqlTypeName.VARCHAR, parameters.get(0).getType(typeFactory).getSqlTypeName());
}
use of org.apache.beam.sdk.extensions.sql.impl.UdafImpl in project beam by apache.
the class TypedCombineFnDelegateTest method testParameterExtractionFromCombineFn_CombineFnDelegate_WithGenericArray.
@Test
public void testParameterExtractionFromCombineFn_CombineFnDelegate_WithGenericArray() {
Combine.BinaryCombineFn<List<String>[]> max = Max.of((Comparator<List<String>[]> & Serializable) (a, b) -> Integer.compare(a[0].get(0).length(), b[0].get(0).length()));
UdafImpl<List<String>[], Combine.Holder<List<String>[]>, List<String>[]> udaf = new UdafImpl<>(new TypedCombineFnDelegate<List<String>[], Combine.Holder<List<String>[]>, List<String>[]>(max) {
});
exceptions.expect(IllegalArgumentException.class);
RelDataTypeFactory typeFactory = new JavaTypeFactoryImpl(RelDataTypeSystem.DEFAULT);
udaf.getParameters().get(0).getType(typeFactory);
}
use of org.apache.beam.sdk.extensions.sql.impl.UdafImpl in project beam by apache.
the class BeamZetaSqlCatalog method addUdfsFromSchema.
private void addUdfsFromSchema() {
for (String functionName : calciteSchema.getFunctionNames()) {
Collection<org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.schema.Function> functions = calciteSchema.getFunctions(functionName);
if (functions.size() != 1) {
throw new IllegalArgumentException(String.format("Expected exactly 1 definition for function '%s', but found %d." + " Beam ZetaSQL supports only a single function definition per function name (BEAM-12073).", functionName, functions.size()));
}
for (org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.schema.Function function : functions) {
List<String> path = Arrays.asList(functionName.split("\\."));
if (function instanceof ScalarFunctionImpl) {
ScalarFunctionImpl scalarFunction = (ScalarFunctionImpl) function;
// for unsupported types.
for (FunctionParameter parameter : scalarFunction.getParameters()) {
validateJavaUdfCalciteType(parameter.getType(typeFactory), functionName);
}
validateJavaUdfCalciteType(scalarFunction.getReturnType(typeFactory), functionName);
Method method = scalarFunction.method;
javaScalarUdfs.put(path, UserFunctionDefinitions.JavaScalarFunction.create(method, ""));
FunctionArgumentType resultType = new FunctionArgumentType(ZetaSqlCalciteTranslationUtils.toZetaSqlType(scalarFunction.getReturnType(typeFactory)));
FunctionSignature functionSignature = new FunctionSignature(resultType, getArgumentTypes(scalarFunction), 0L);
zetaSqlCatalog.addFunction(new Function(path, USER_DEFINED_JAVA_SCALAR_FUNCTIONS, ZetaSQLFunctions.FunctionEnums.Mode.SCALAR, ImmutableList.of(functionSignature)));
} else if (function instanceof UdafImpl) {
UdafImpl<?, ?, ?> udaf = (UdafImpl) function;
javaUdafs.put(path, udaf.getCombineFn());
FunctionArgumentType resultType = new FunctionArgumentType(ZetaSqlCalciteTranslationUtils.toZetaSqlType(udaf.getReturnType(typeFactory)));
FunctionSignature functionSignature = new FunctionSignature(resultType, getArgumentTypes(udaf), 0L);
zetaSqlCatalog.addFunction(new Function(path, USER_DEFINED_JAVA_AGGREGATE_FUNCTIONS, ZetaSQLFunctions.FunctionEnums.Mode.AGGREGATE, ImmutableList.of(functionSignature)));
} else {
throw new IllegalArgumentException(String.format("Function %s has unrecognized implementation type %s.", functionName, function.getClass().getName()));
}
}
}
}
Aggregations