use of org.apache.beam.runners.flink.translation.functions.FlinkPartialReduceFunction in project beam by apache.
the class FlinkBatchPortablePipelineTranslator method translateGroupByKey.
private static <K, V> void translateGroupByKey(PTransformNode transform, RunnerApi.Pipeline pipeline, BatchTranslationContext context) {
RunnerApi.Components components = pipeline.getComponents();
String inputPCollectionId = Iterables.getOnlyElement(transform.getTransform().getInputsMap().values());
PCollectionNode inputCollection = PipelineNode.pCollection(inputPCollectionId, components.getPcollectionsOrThrow(inputPCollectionId));
DataSet<WindowedValue<KV<K, V>>> inputDataSet = context.getDataSetOrThrow(inputPCollectionId);
RunnerApi.WindowingStrategy windowingStrategyProto = pipeline.getComponents().getWindowingStrategiesOrThrow(pipeline.getComponents().getPcollectionsOrThrow(inputPCollectionId).getWindowingStrategyId());
RehydratedComponents rehydratedComponents = RehydratedComponents.forComponents(pipeline.getComponents());
WindowingStrategy<Object, BoundedWindow> windowingStrategy;
try {
windowingStrategy = (WindowingStrategy<Object, BoundedWindow>) WindowingStrategyTranslation.fromProto(windowingStrategyProto, rehydratedComponents);
} catch (InvalidProtocolBufferException e) {
throw new IllegalStateException(String.format("Unable to hydrate GroupByKey windowing strategy %s.", windowingStrategyProto), e);
}
WindowedValueCoder<KV<K, V>> inputCoder;
try {
inputCoder = (WindowedValueCoder) WireCoders.instantiateRunnerWireCoder(inputCollection, pipeline.getComponents());
} catch (IOException e) {
throw new RuntimeException(e);
}
KvCoder<K, V> inputElementCoder = (KvCoder<K, V>) inputCoder.getValueCoder();
Concatenate<V> combineFn = new Concatenate<>();
Coder<List<V>> accumulatorCoder = combineFn.getAccumulatorCoder(CoderRegistry.createDefault(), inputElementCoder.getValueCoder());
Coder<WindowedValue<KV<K, List<V>>>> outputCoder = WindowedValue.getFullCoder(KvCoder.of(inputElementCoder.getKeyCoder(), accumulatorCoder), windowingStrategy.getWindowFn().windowCoder());
TypeInformation<WindowedValue<KV<K, List<V>>>> partialReduceTypeInfo = new CoderTypeInformation<>(outputCoder, context.getPipelineOptions());
Grouping<WindowedValue<KV<K, V>>> inputGrouping = inputDataSet.groupBy(new KvKeySelector<>(inputElementCoder.getKeyCoder()));
FlinkPartialReduceFunction<K, V, List<V>, ?> partialReduceFunction = new FlinkPartialReduceFunction<>(combineFn, windowingStrategy, Collections.emptyMap(), context.getPipelineOptions());
FlinkReduceFunction<K, List<V>, List<V>, ?> reduceFunction = new FlinkReduceFunction<>(combineFn, windowingStrategy, Collections.emptyMap(), context.getPipelineOptions());
// Partially GroupReduce the values into the intermediate format AccumT (combine)
GroupCombineOperator<WindowedValue<KV<K, V>>, WindowedValue<KV<K, List<V>>>> groupCombine = new GroupCombineOperator<>(inputGrouping, partialReduceTypeInfo, partialReduceFunction, "GroupCombine: " + transform.getTransform().getUniqueName());
Grouping<WindowedValue<KV<K, List<V>>>> intermediateGrouping = groupCombine.groupBy(new KvKeySelector<>(inputElementCoder.getKeyCoder()));
// Fully reduce the values and create output format VO
GroupReduceOperator<WindowedValue<KV<K, List<V>>>, WindowedValue<KV<K, List<V>>>> outputDataSet = new GroupReduceOperator<>(intermediateGrouping, partialReduceTypeInfo, reduceFunction, transform.getTransform().getUniqueName());
context.addDataSet(Iterables.getOnlyElement(transform.getTransform().getOutputsMap().values()), outputDataSet);
}
Aggregations