use of org.apache.flink.table.functions.python.PythonFunction in project flink by apache.
the class CommonPythonUtil method extractPythonAggregateFunctionInfosFromAggregateCall.
public static Tuple2<int[], PythonFunctionInfo[]> extractPythonAggregateFunctionInfosFromAggregateCall(AggregateCall[] aggCalls) {
Map<Integer, Integer> inputNodes = new LinkedHashMap<>();
List<PythonFunctionInfo> pythonFunctionInfos = new ArrayList<>();
for (AggregateCall aggregateCall : aggCalls) {
List<Integer> inputs = new ArrayList<>();
List<Integer> argList = aggregateCall.getArgList();
for (Integer arg : argList) {
if (inputNodes.containsKey(arg)) {
inputs.add(inputNodes.get(arg));
} else {
Integer inputOffset = inputNodes.size();
inputs.add(inputOffset);
inputNodes.put(arg, inputOffset);
}
}
PythonFunction pythonFunction = null;
SqlAggFunction aggregateFunction = aggregateCall.getAggregation();
if (aggregateFunction instanceof AggSqlFunction) {
pythonFunction = (PythonFunction) ((AggSqlFunction) aggregateFunction).aggregateFunction();
} else if (aggregateFunction instanceof BridgingSqlAggFunction) {
pythonFunction = (PythonFunction) ((BridgingSqlAggFunction) aggregateFunction).getDefinition();
}
PythonFunctionInfo pythonFunctionInfo = new PythonAggregateFunctionInfo(pythonFunction, inputs.toArray(), aggregateCall.filterArg, aggregateCall.isDistinct());
pythonFunctionInfos.add(pythonFunctionInfo);
}
int[] udafInputOffsets = inputNodes.keySet().stream().mapToInt(i -> i).toArray();
return Tuple2.of(udafInputOffsets, pythonFunctionInfos.toArray(new PythonFunctionInfo[0]));
}
use of org.apache.flink.table.functions.python.PythonFunction in project flink by apache.
the class CommonPythonUtil method extractPythonAggregateFunctionInfos.
public static Tuple2<PythonAggregateFunctionInfo[], DataViewSpec[][]> extractPythonAggregateFunctionInfos(AggregateInfoList pythonAggregateInfoList, AggregateCall[] aggCalls) {
List<PythonAggregateFunctionInfo> pythonAggregateFunctionInfoList = new ArrayList<>();
List<DataViewSpec[]> dataViewSpecList = new ArrayList<>();
AggregateInfo[] aggInfos = pythonAggregateInfoList.aggInfos();
for (int i = 0; i < aggInfos.length; i++) {
AggregateInfo aggInfo = aggInfos[i];
UserDefinedFunction function = aggInfo.function();
if (function instanceof PythonFunction) {
pythonAggregateFunctionInfoList.add(new PythonAggregateFunctionInfo((PythonFunction) function, Arrays.stream(aggInfo.argIndexes()).boxed().toArray(), aggCalls[i].filterArg, aggCalls[i].isDistinct()));
TypeInference typeInference = function.getTypeInference(null);
dataViewSpecList.add(extractDataViewSpecs(i, typeInference.getAccumulatorTypeStrategy().get().inferType(null).get()));
} else {
int filterArg = -1;
boolean distinct = false;
if (i < aggCalls.length) {
filterArg = aggCalls[i].filterArg;
distinct = aggCalls[i].isDistinct();
}
pythonAggregateFunctionInfoList.add(new PythonAggregateFunctionInfo(getBuiltInPythonAggregateFunction(function), Arrays.stream(aggInfo.argIndexes()).boxed().toArray(), filterArg, distinct));
// The data views of the built in Python Aggregate Function are different from Java
// side, we will create the spec at Python side.
dataViewSpecList.add(new DataViewSpec[0]);
}
}
return Tuple2.of(pythonAggregateFunctionInfoList.toArray(new PythonAggregateFunctionInfo[0]), dataViewSpecList.toArray(new DataViewSpec[0][0]));
}
Aggregations