use of org.apache.beam.runners.flink.translation.functions.FlinkExecutableStageFunction in project beam by apache.
the class FlinkBatchPortablePipelineTranslator method translateExecutableStage.
private static <InputT> void translateExecutableStage(PTransformNode transform, RunnerApi.Pipeline pipeline, BatchTranslationContext context) {
// TODO: Fail on splittable DoFns.
// TODO: Special-case single outputs to avoid multiplexing PCollections.
RunnerApi.Components components = pipeline.getComponents();
Map<String, String> outputs = transform.getTransform().getOutputsMap();
// Mapping from PCollection id to coder tag id.
BiMap<String, Integer> outputMap = createOutputMap(outputs.values());
// Collect all output Coders and create a UnionCoder for our tagged outputs.
List<Coder<?>> unionCoders = Lists.newArrayList();
// Enforce tuple tag sorting by union tag index.
Map<String, Coder<WindowedValue<?>>> outputCoders = Maps.newHashMap();
for (String collectionId : new TreeMap<>(outputMap.inverse()).values()) {
PCollectionNode collectionNode = PipelineNode.pCollection(collectionId, components.getPcollectionsOrThrow(collectionId));
Coder<WindowedValue<?>> coder;
try {
coder = (Coder) WireCoders.instantiateRunnerWireCoder(collectionNode, components);
} catch (IOException e) {
throw new RuntimeException(e);
}
outputCoders.put(collectionId, coder);
unionCoders.add(coder);
}
UnionCoder unionCoder = UnionCoder.of(unionCoders);
TypeInformation<RawUnionValue> typeInformation = new CoderTypeInformation<>(unionCoder, context.getPipelineOptions());
RunnerApi.ExecutableStagePayload stagePayload;
try {
stagePayload = RunnerApi.ExecutableStagePayload.parseFrom(transform.getTransform().getSpec().getPayload());
} catch (IOException e) {
throw new RuntimeException(e);
}
String inputPCollectionId = stagePayload.getInput();
Coder<WindowedValue<InputT>> windowedInputCoder = instantiateCoder(inputPCollectionId, components);
DataSet<WindowedValue<InputT>> inputDataSet = context.getDataSetOrThrow(inputPCollectionId);
final FlinkExecutableStageFunction<InputT> function = new FlinkExecutableStageFunction<>(transform.getTransform().getUniqueName(), context.getPipelineOptions(), stagePayload, context.getJobInfo(), outputMap, FlinkExecutableStageContextFactory.getInstance(), getWindowingStrategy(inputPCollectionId, components).getWindowFn().windowCoder(), windowedInputCoder);
final String operatorName = generateNameFromStagePayload(stagePayload);
final SingleInputUdfOperator taggedDataset;
if (stagePayload.getUserStatesCount() > 0 || stagePayload.getTimersCount() > 0) {
Coder valueCoder = ((WindowedValue.FullWindowedValueCoder) windowedInputCoder).getValueCoder();
// Stateful stages are only allowed of KV input to be able to group on the key
if (!(valueCoder instanceof KvCoder)) {
throw new IllegalStateException(String.format(Locale.ENGLISH, "The element coder for stateful DoFn '%s' must be KvCoder but is: %s", inputPCollectionId, valueCoder.getClass().getSimpleName()));
}
Coder keyCoder = ((KvCoder) valueCoder).getKeyCoder();
Grouping<WindowedValue<InputT>> groupedInput = inputDataSet.groupBy(new KvKeySelector<>(keyCoder));
boolean requiresTimeSortedInput = requiresTimeSortedInput(stagePayload, false);
if (requiresTimeSortedInput) {
groupedInput = ((UnsortedGrouping<WindowedValue<InputT>>) groupedInput).sortGroup(WindowedValue::getTimestamp, Order.ASCENDING);
}
taggedDataset = new GroupReduceOperator<>(groupedInput, typeInformation, function, operatorName);
} else {
taggedDataset = new MapPartitionOperator<>(inputDataSet, typeInformation, function, operatorName);
}
for (SideInputId sideInputId : stagePayload.getSideInputsList()) {
String collectionId = stagePayload.getComponents().getTransformsOrThrow(sideInputId.getTransformId()).getInputsOrThrow(sideInputId.getLocalName());
// Register under the global PCollection name. Only ExecutableStageFunction needs to know the
// mapping from local name to global name and how to translate the broadcast data to a state
// API view.
taggedDataset.withBroadcastSet(context.getDataSetOrThrow(collectionId), collectionId);
}
for (String collectionId : outputs.values()) {
pruneOutput(taggedDataset, context, outputMap.get(collectionId), outputCoders.get(collectionId), collectionId);
}
if (outputs.isEmpty()) {
// NOTE: After pipeline translation, we traverse the set of unconsumed PCollections and add a
// no-op sink to each to make sure they are materialized by Flink. However, some SDK-executed
// stages have no runner-visible output after fusion. We handle this case by adding a sink
// here.
taggedDataset.output(new DiscardingOutputFormat<>()).name("DiscardingOutput");
}
}
Aggregations