use of org.apache.beam.runners.flink.translation.types.CoderTypeInformation in project beam by apache.
the class FlinkBatchPortablePipelineTranslator method translateExecutableStage.
private static <InputT> void translateExecutableStage(PTransformNode transform, RunnerApi.Pipeline pipeline, BatchTranslationContext context) {
// TODO: Fail on splittable DoFns.
// TODO: Special-case single outputs to avoid multiplexing PCollections.
RunnerApi.Components components = pipeline.getComponents();
Map<String, String> outputs = transform.getTransform().getOutputsMap();
// Mapping from PCollection id to coder tag id.
BiMap<String, Integer> outputMap = createOutputMap(outputs.values());
// Collect all output Coders and create a UnionCoder for our tagged outputs.
List<Coder<?>> unionCoders = Lists.newArrayList();
// Enforce tuple tag sorting by union tag index.
Map<String, Coder<WindowedValue<?>>> outputCoders = Maps.newHashMap();
for (String collectionId : new TreeMap<>(outputMap.inverse()).values()) {
PCollectionNode collectionNode = PipelineNode.pCollection(collectionId, components.getPcollectionsOrThrow(collectionId));
Coder<WindowedValue<?>> coder;
try {
coder = (Coder) WireCoders.instantiateRunnerWireCoder(collectionNode, components);
} catch (IOException e) {
throw new RuntimeException(e);
}
outputCoders.put(collectionId, coder);
unionCoders.add(coder);
}
UnionCoder unionCoder = UnionCoder.of(unionCoders);
TypeInformation<RawUnionValue> typeInformation = new CoderTypeInformation<>(unionCoder, context.getPipelineOptions());
RunnerApi.ExecutableStagePayload stagePayload;
try {
stagePayload = RunnerApi.ExecutableStagePayload.parseFrom(transform.getTransform().getSpec().getPayload());
} catch (IOException e) {
throw new RuntimeException(e);
}
String inputPCollectionId = stagePayload.getInput();
Coder<WindowedValue<InputT>> windowedInputCoder = instantiateCoder(inputPCollectionId, components);
DataSet<WindowedValue<InputT>> inputDataSet = context.getDataSetOrThrow(inputPCollectionId);
final FlinkExecutableStageFunction<InputT> function = new FlinkExecutableStageFunction<>(transform.getTransform().getUniqueName(), context.getPipelineOptions(), stagePayload, context.getJobInfo(), outputMap, FlinkExecutableStageContextFactory.getInstance(), getWindowingStrategy(inputPCollectionId, components).getWindowFn().windowCoder(), windowedInputCoder);
final String operatorName = generateNameFromStagePayload(stagePayload);
final SingleInputUdfOperator taggedDataset;
if (stagePayload.getUserStatesCount() > 0 || stagePayload.getTimersCount() > 0) {
Coder valueCoder = ((WindowedValue.FullWindowedValueCoder) windowedInputCoder).getValueCoder();
// Stateful stages are only allowed of KV input to be able to group on the key
if (!(valueCoder instanceof KvCoder)) {
throw new IllegalStateException(String.format(Locale.ENGLISH, "The element coder for stateful DoFn '%s' must be KvCoder but is: %s", inputPCollectionId, valueCoder.getClass().getSimpleName()));
}
Coder keyCoder = ((KvCoder) valueCoder).getKeyCoder();
Grouping<WindowedValue<InputT>> groupedInput = inputDataSet.groupBy(new KvKeySelector<>(keyCoder));
boolean requiresTimeSortedInput = requiresTimeSortedInput(stagePayload, false);
if (requiresTimeSortedInput) {
groupedInput = ((UnsortedGrouping<WindowedValue<InputT>>) groupedInput).sortGroup(WindowedValue::getTimestamp, Order.ASCENDING);
}
taggedDataset = new GroupReduceOperator<>(groupedInput, typeInformation, function, operatorName);
} else {
taggedDataset = new MapPartitionOperator<>(inputDataSet, typeInformation, function, operatorName);
}
for (SideInputId sideInputId : stagePayload.getSideInputsList()) {
String collectionId = stagePayload.getComponents().getTransformsOrThrow(sideInputId.getTransformId()).getInputsOrThrow(sideInputId.getLocalName());
// Register under the global PCollection name. Only ExecutableStageFunction needs to know the
// mapping from local name to global name and how to translate the broadcast data to a state
// API view.
taggedDataset.withBroadcastSet(context.getDataSetOrThrow(collectionId), collectionId);
}
for (String collectionId : outputs.values()) {
pruneOutput(taggedDataset, context, outputMap.get(collectionId), outputCoders.get(collectionId), collectionId);
}
if (outputs.isEmpty()) {
// NOTE: After pipeline translation, we traverse the set of unconsumed PCollections and add a
// no-op sink to each to make sure they are materialized by Flink. However, some SDK-executed
// stages have no runner-visible output after fusion. We handle this case by adding a sink
// here.
taggedDataset.output(new DiscardingOutputFormat<>()).name("DiscardingOutput");
}
}
use of org.apache.beam.runners.flink.translation.types.CoderTypeInformation in project beam by apache.
the class FlinkBatchPortablePipelineTranslator method translateGroupByKey.
private static <K, V> void translateGroupByKey(PTransformNode transform, RunnerApi.Pipeline pipeline, BatchTranslationContext context) {
RunnerApi.Components components = pipeline.getComponents();
String inputPCollectionId = Iterables.getOnlyElement(transform.getTransform().getInputsMap().values());
PCollectionNode inputCollection = PipelineNode.pCollection(inputPCollectionId, components.getPcollectionsOrThrow(inputPCollectionId));
DataSet<WindowedValue<KV<K, V>>> inputDataSet = context.getDataSetOrThrow(inputPCollectionId);
RunnerApi.WindowingStrategy windowingStrategyProto = pipeline.getComponents().getWindowingStrategiesOrThrow(pipeline.getComponents().getPcollectionsOrThrow(inputPCollectionId).getWindowingStrategyId());
RehydratedComponents rehydratedComponents = RehydratedComponents.forComponents(pipeline.getComponents());
WindowingStrategy<Object, BoundedWindow> windowingStrategy;
try {
windowingStrategy = (WindowingStrategy<Object, BoundedWindow>) WindowingStrategyTranslation.fromProto(windowingStrategyProto, rehydratedComponents);
} catch (InvalidProtocolBufferException e) {
throw new IllegalStateException(String.format("Unable to hydrate GroupByKey windowing strategy %s.", windowingStrategyProto), e);
}
WindowedValueCoder<KV<K, V>> inputCoder;
try {
inputCoder = (WindowedValueCoder) WireCoders.instantiateRunnerWireCoder(inputCollection, pipeline.getComponents());
} catch (IOException e) {
throw new RuntimeException(e);
}
KvCoder<K, V> inputElementCoder = (KvCoder<K, V>) inputCoder.getValueCoder();
Concatenate<V> combineFn = new Concatenate<>();
Coder<List<V>> accumulatorCoder = combineFn.getAccumulatorCoder(CoderRegistry.createDefault(), inputElementCoder.getValueCoder());
Coder<WindowedValue<KV<K, List<V>>>> outputCoder = WindowedValue.getFullCoder(KvCoder.of(inputElementCoder.getKeyCoder(), accumulatorCoder), windowingStrategy.getWindowFn().windowCoder());
TypeInformation<WindowedValue<KV<K, List<V>>>> partialReduceTypeInfo = new CoderTypeInformation<>(outputCoder, context.getPipelineOptions());
Grouping<WindowedValue<KV<K, V>>> inputGrouping = inputDataSet.groupBy(new KvKeySelector<>(inputElementCoder.getKeyCoder()));
FlinkPartialReduceFunction<K, V, List<V>, ?> partialReduceFunction = new FlinkPartialReduceFunction<>(combineFn, windowingStrategy, Collections.emptyMap(), context.getPipelineOptions());
FlinkReduceFunction<K, List<V>, List<V>, ?> reduceFunction = new FlinkReduceFunction<>(combineFn, windowingStrategy, Collections.emptyMap(), context.getPipelineOptions());
// Partially GroupReduce the values into the intermediate format AccumT (combine)
GroupCombineOperator<WindowedValue<KV<K, V>>, WindowedValue<KV<K, List<V>>>> groupCombine = new GroupCombineOperator<>(inputGrouping, partialReduceTypeInfo, partialReduceFunction, "GroupCombine: " + transform.getTransform().getUniqueName());
Grouping<WindowedValue<KV<K, List<V>>>> intermediateGrouping = groupCombine.groupBy(new KvKeySelector<>(inputElementCoder.getKeyCoder()));
// Fully reduce the values and create output format VO
GroupReduceOperator<WindowedValue<KV<K, List<V>>>, WindowedValue<KV<K, List<V>>>> outputDataSet = new GroupReduceOperator<>(intermediateGrouping, partialReduceTypeInfo, reduceFunction, transform.getTransform().getUniqueName());
context.addDataSet(Iterables.getOnlyElement(transform.getTransform().getOutputsMap().values()), outputDataSet);
}
use of org.apache.beam.runners.flink.translation.types.CoderTypeInformation in project beam by apache.
the class FlinkStreamingPortablePipelineTranslator method addGBK.
private <K, V> SingleOutputStreamOperator<WindowedValue<KV<K, Iterable<V>>>> addGBK(DataStream<WindowedValue<KV<K, V>>> inputDataStream, WindowingStrategy<?, ?> windowingStrategy, WindowedValueCoder<KV<K, V>> windowedInputCoder, String operatorName, StreamingTranslationContext context) {
KvCoder<K, V> inputElementCoder = (KvCoder<K, V>) windowedInputCoder.getValueCoder();
SingletonKeyedWorkItemCoder<K, V> workItemCoder = SingletonKeyedWorkItemCoder.of(inputElementCoder.getKeyCoder(), inputElementCoder.getValueCoder(), windowingStrategy.getWindowFn().windowCoder());
WindowedValue.FullWindowedValueCoder<KeyedWorkItem<K, V>> windowedWorkItemCoder = WindowedValue.getFullCoder(workItemCoder, windowingStrategy.getWindowFn().windowCoder());
CoderTypeInformation<WindowedValue<KeyedWorkItem<K, V>>> workItemTypeInfo = new CoderTypeInformation<>(windowedWorkItemCoder, context.getPipelineOptions());
DataStream<WindowedValue<KeyedWorkItem<K, V>>> workItemStream = inputDataStream.flatMap(new FlinkStreamingTransformTranslators.ToKeyedWorkItem<>(context.getPipelineOptions())).returns(workItemTypeInfo).name("ToKeyedWorkItem");
WorkItemKeySelector<K, V> keySelector = new WorkItemKeySelector<>(inputElementCoder.getKeyCoder(), new SerializablePipelineOptions(context.getPipelineOptions()));
KeyedStream<WindowedValue<KeyedWorkItem<K, V>>, ByteBuffer> keyedWorkItemStream = workItemStream.keyBy(keySelector);
SystemReduceFn<K, V, Iterable<V>, Iterable<V>, BoundedWindow> reduceFn = SystemReduceFn.buffering(inputElementCoder.getValueCoder());
Coder<Iterable<V>> accumulatorCoder = IterableCoder.of(inputElementCoder.getValueCoder());
Coder<WindowedValue<KV<K, Iterable<V>>>> outputCoder = WindowedValue.getFullCoder(KvCoder.of(inputElementCoder.getKeyCoder(), accumulatorCoder), windowingStrategy.getWindowFn().windowCoder());
TypeInformation<WindowedValue<KV<K, Iterable<V>>>> outputTypeInfo = new CoderTypeInformation<>(outputCoder, context.getPipelineOptions());
TupleTag<KV<K, Iterable<V>>> mainTag = new TupleTag<>("main output");
WindowDoFnOperator<K, V, Iterable<V>> doFnOperator = new WindowDoFnOperator<>(reduceFn, operatorName, windowedWorkItemCoder, mainTag, Collections.emptyList(), new DoFnOperator.MultiOutputOutputManagerFactory<>(mainTag, outputCoder, new SerializablePipelineOptions(context.getPipelineOptions())), windowingStrategy, new HashMap<>(), /* side-input mapping */
Collections.emptyList(), /* side inputs */
context.getPipelineOptions(), inputElementCoder.getKeyCoder(), keySelector);
return keyedWorkItemStream.transform(operatorName, outputTypeInfo, doFnOperator);
}
use of org.apache.beam.runners.flink.translation.types.CoderTypeInformation in project beam by apache.
the class FlinkStreamingPortablePipelineTranslator method transformSideInputs.
private TransformedSideInputs transformSideInputs(RunnerApi.ExecutableStagePayload stagePayload, RunnerApi.Components components, StreamingTranslationContext context) {
LinkedHashMap<RunnerApi.ExecutableStagePayload.SideInputId, PCollectionView<?>> sideInputs = getSideInputIdToPCollectionViewMap(stagePayload, components);
Map<TupleTag<?>, Integer> tagToIntMapping = new HashMap<>();
Map<Integer, PCollectionView<?>> intToViewMapping = new HashMap<>();
List<WindowedValueCoder<KV<Void, Object>>> kvCoders = new ArrayList<>();
List<Coder<?>> viewCoders = new ArrayList<>();
int count = 0;
for (Map.Entry<RunnerApi.ExecutableStagePayload.SideInputId, PCollectionView<?>> sideInput : sideInputs.entrySet()) {
TupleTag<?> tag = sideInput.getValue().getTagInternal();
intToViewMapping.put(count, sideInput.getValue());
tagToIntMapping.put(tag, count);
count++;
String collectionId = components.getTransformsOrThrow(sideInput.getKey().getTransformId()).getInputsOrThrow(sideInput.getKey().getLocalName());
DataStream<Object> sideInputStream = context.getDataStreamOrThrow(collectionId);
TypeInformation<Object> tpe = sideInputStream.getType();
if (!(tpe instanceof CoderTypeInformation)) {
throw new IllegalStateException("Input Stream TypeInformation is no CoderTypeInformation.");
}
WindowedValueCoder<Object> coder = (WindowedValueCoder) ((CoderTypeInformation) tpe).getCoder();
Coder<KV<Void, Object>> kvCoder = KvCoder.of(VoidCoder.of(), coder.getValueCoder());
kvCoders.add(coder.withValueCoder(kvCoder));
// coder for materialized view matching GBK below
WindowedValueCoder<KV<Void, Iterable<Object>>> viewCoder = coder.withValueCoder(KvCoder.of(VoidCoder.of(), IterableCoder.of(coder.getValueCoder())));
viewCoders.add(viewCoder);
}
// second pass, now that we gathered the input coders
UnionCoder unionCoder = UnionCoder.of(viewCoders);
CoderTypeInformation<RawUnionValue> unionTypeInformation = new CoderTypeInformation<>(unionCoder, context.getPipelineOptions());
// transform each side input to RawUnionValue and union them
DataStream<RawUnionValue> sideInputUnion = null;
for (Map.Entry<RunnerApi.ExecutableStagePayload.SideInputId, PCollectionView<?>> sideInput : sideInputs.entrySet()) {
TupleTag<?> tag = sideInput.getValue().getTagInternal();
final int intTag = tagToIntMapping.get(tag);
RunnerApi.PTransform pTransform = components.getTransformsOrThrow(sideInput.getKey().getTransformId());
String collectionId = pTransform.getInputsOrThrow(sideInput.getKey().getLocalName());
DataStream<WindowedValue<?>> sideInputStream = context.getDataStreamOrThrow(collectionId);
// insert GBK to materialize side input view
String viewName = sideInput.getKey().getTransformId() + "-" + sideInput.getKey().getLocalName();
WindowedValueCoder<KV<Void, Object>> kvCoder = kvCoders.get(intTag);
DataStream<WindowedValue<KV<Void, Object>>> keyedSideInputStream = sideInputStream.map(new ToVoidKeyValue(context.getPipelineOptions()));
SingleOutputStreamOperator<WindowedValue<KV<Void, Iterable<Object>>>> viewStream = addGBK(keyedSideInputStream, sideInput.getValue().getWindowingStrategyInternal(), kvCoder, viewName, context);
// Assign a unique but consistent id to re-map operator state
viewStream.uid(pTransform.getUniqueName() + "-" + sideInput.getKey().getLocalName());
DataStream<RawUnionValue> unionValueStream = viewStream.map(new FlinkStreamingTransformTranslators.ToRawUnion<>(intTag, context.getPipelineOptions())).returns(unionTypeInformation);
if (sideInputUnion == null) {
sideInputUnion = unionValueStream;
} else {
sideInputUnion = sideInputUnion.union(unionValueStream);
}
}
return new TransformedSideInputs(intToViewMapping, sideInputUnion);
}
Aggregations