use of org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload.SideInputId in project beam by apache.
the class ExecutableStage method fromPayload.
/**
* Return an {@link ExecutableStage} constructed from the provided {@link FunctionSpec}
* representation.
*
* <p>See {@link #toPTransform} for how the payload is constructed.
*
* <p>Note: The payload contains some information redundant with the {@link PTransform} it is the
* payload of. The {@link ExecutableStagePayload} should be sufficiently rich to construct a
* {@code ProcessBundleDescriptor} using only the payload.
*/
static ExecutableStage fromPayload(ExecutableStagePayload payload) {
Components components = payload.getComponents();
Environment environment = payload.getEnvironment();
Collection<WireCoderSetting> wireCoderSettings = payload.getWireCoderSettingsList();
PCollectionNode input = PipelineNode.pCollection(payload.getInput(), components.getPcollectionsOrThrow(payload.getInput()));
List<SideInputReference> sideInputs = payload.getSideInputsList().stream().map(sideInputId -> SideInputReference.fromSideInputId(sideInputId, components)).collect(Collectors.toList());
List<UserStateReference> userStates = payload.getUserStatesList().stream().map(userStateId -> UserStateReference.fromUserStateId(userStateId, components)).collect(Collectors.toList());
List<TimerReference> timers = payload.getTimersList().stream().map(timerId -> TimerReference.fromTimerId(timerId, components)).collect(Collectors.toList());
List<PTransformNode> transforms = payload.getTransformsList().stream().map(id -> PipelineNode.pTransform(id, components.getTransformsOrThrow(id))).collect(Collectors.toList());
List<PCollectionNode> outputs = payload.getOutputsList().stream().map(id -> PipelineNode.pCollection(id, components.getPcollectionsOrThrow(id))).collect(Collectors.toList());
return ImmutableExecutableStage.of(components, environment, input, sideInputs, userStates, timers, transforms, outputs, wireCoderSettings);
}
use of org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload.SideInputId in project beam by apache.
the class PipelineValidator method validateParDo.
private static void validateParDo(String id, PTransform transform, Components components, Set<String> requirements) throws Exception {
ParDoPayload payload = ParDoPayload.parseFrom(transform.getSpec().getPayload());
// side_inputs
for (String sideInputId : payload.getSideInputsMap().keySet()) {
checkArgument(transform.containsInputs(sideInputId), "Transform %s side input %s is not listed in the transform's inputs", id, sideInputId);
}
if (payload.getStateSpecsCount() > 0 || payload.getTimerFamilySpecsCount() > 0) {
checkArgument(requirements.contains(ParDoTranslation.REQUIRES_STATEFUL_PROCESSING_URN));
// TODO: Validate state_specs and timer_specs
}
if (!payload.getRestrictionCoderId().isEmpty()) {
checkArgument(components.containsCoders(payload.getRestrictionCoderId()));
checkArgument(requirements.contains(ParDoTranslation.REQUIRES_SPLITTABLE_DOFN_URN));
}
if (payload.getRequestsFinalization()) {
checkArgument(requirements.contains(ParDoTranslation.REQUIRES_BUNDLE_FINALIZATION_URN));
}
if (payload.getRequiresStableInput()) {
checkArgument(requirements.contains(ParDoTranslation.REQUIRES_STABLE_INPUT_URN));
}
if (payload.getRequiresTimeSortedInput()) {
checkArgument(requirements.contains(ParDoTranslation.REQUIRES_TIME_SORTED_INPUT_URN));
}
}
use of org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload.SideInputId in project beam by apache.
the class SideInputReference method fromSideInputId.
/**
* Create a side input reference from a SideInputId proto and components.
*/
public static SideInputReference fromSideInputId(SideInputId sideInputId, RunnerApi.Components components) {
String transformId = sideInputId.getTransformId();
String localName = sideInputId.getLocalName();
String collectionId = components.getTransformsOrThrow(transformId).getInputsOrThrow(localName);
PTransform transform = components.getTransformsOrThrow(transformId);
PCollection collection = components.getPcollectionsOrThrow(collectionId);
return SideInputReference.of(PipelineNode.pTransform(transformId, transform), localName, PipelineNode.pCollection(collectionId, collection));
}
use of org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload.SideInputId in project beam by apache.
the class FlinkBatchPortablePipelineTranslator method translateExecutableStage.
private static <InputT> void translateExecutableStage(PTransformNode transform, RunnerApi.Pipeline pipeline, BatchTranslationContext context) {
// TODO: Fail on splittable DoFns.
// TODO: Special-case single outputs to avoid multiplexing PCollections.
RunnerApi.Components components = pipeline.getComponents();
Map<String, String> outputs = transform.getTransform().getOutputsMap();
// Mapping from PCollection id to coder tag id.
BiMap<String, Integer> outputMap = createOutputMap(outputs.values());
// Collect all output Coders and create a UnionCoder for our tagged outputs.
List<Coder<?>> unionCoders = Lists.newArrayList();
// Enforce tuple tag sorting by union tag index.
Map<String, Coder<WindowedValue<?>>> outputCoders = Maps.newHashMap();
for (String collectionId : new TreeMap<>(outputMap.inverse()).values()) {
PCollectionNode collectionNode = PipelineNode.pCollection(collectionId, components.getPcollectionsOrThrow(collectionId));
Coder<WindowedValue<?>> coder;
try {
coder = (Coder) WireCoders.instantiateRunnerWireCoder(collectionNode, components);
} catch (IOException e) {
throw new RuntimeException(e);
}
outputCoders.put(collectionId, coder);
unionCoders.add(coder);
}
UnionCoder unionCoder = UnionCoder.of(unionCoders);
TypeInformation<RawUnionValue> typeInformation = new CoderTypeInformation<>(unionCoder, context.getPipelineOptions());
RunnerApi.ExecutableStagePayload stagePayload;
try {
stagePayload = RunnerApi.ExecutableStagePayload.parseFrom(transform.getTransform().getSpec().getPayload());
} catch (IOException e) {
throw new RuntimeException(e);
}
String inputPCollectionId = stagePayload.getInput();
Coder<WindowedValue<InputT>> windowedInputCoder = instantiateCoder(inputPCollectionId, components);
DataSet<WindowedValue<InputT>> inputDataSet = context.getDataSetOrThrow(inputPCollectionId);
final FlinkExecutableStageFunction<InputT> function = new FlinkExecutableStageFunction<>(transform.getTransform().getUniqueName(), context.getPipelineOptions(), stagePayload, context.getJobInfo(), outputMap, FlinkExecutableStageContextFactory.getInstance(), getWindowingStrategy(inputPCollectionId, components).getWindowFn().windowCoder(), windowedInputCoder);
final String operatorName = generateNameFromStagePayload(stagePayload);
final SingleInputUdfOperator taggedDataset;
if (stagePayload.getUserStatesCount() > 0 || stagePayload.getTimersCount() > 0) {
Coder valueCoder = ((WindowedValue.FullWindowedValueCoder) windowedInputCoder).getValueCoder();
// Stateful stages are only allowed of KV input to be able to group on the key
if (!(valueCoder instanceof KvCoder)) {
throw new IllegalStateException(String.format(Locale.ENGLISH, "The element coder for stateful DoFn '%s' must be KvCoder but is: %s", inputPCollectionId, valueCoder.getClass().getSimpleName()));
}
Coder keyCoder = ((KvCoder) valueCoder).getKeyCoder();
Grouping<WindowedValue<InputT>> groupedInput = inputDataSet.groupBy(new KvKeySelector<>(keyCoder));
boolean requiresTimeSortedInput = requiresTimeSortedInput(stagePayload, false);
if (requiresTimeSortedInput) {
groupedInput = ((UnsortedGrouping<WindowedValue<InputT>>) groupedInput).sortGroup(WindowedValue::getTimestamp, Order.ASCENDING);
}
taggedDataset = new GroupReduceOperator<>(groupedInput, typeInformation, function, operatorName);
} else {
taggedDataset = new MapPartitionOperator<>(inputDataSet, typeInformation, function, operatorName);
}
for (SideInputId sideInputId : stagePayload.getSideInputsList()) {
String collectionId = stagePayload.getComponents().getTransformsOrThrow(sideInputId.getTransformId()).getInputsOrThrow(sideInputId.getLocalName());
// Register under the global PCollection name. Only ExecutableStageFunction needs to know the
// mapping from local name to global name and how to translate the broadcast data to a state
// API view.
taggedDataset.withBroadcastSet(context.getDataSetOrThrow(collectionId), collectionId);
}
for (String collectionId : outputs.values()) {
pruneOutput(taggedDataset, context, outputMap.get(collectionId), outputCoders.get(collectionId), collectionId);
}
if (outputs.isEmpty()) {
// NOTE: After pipeline translation, we traverse the set of unconsumed PCollections and add a
// no-op sink to each to make sure they are materialized by Flink. However, some SDK-executed
// stages have no runner-visible output after fusion. We handle this case by adding a sink
// here.
taggedDataset.output(new DiscardingOutputFormat<>()).name("DiscardingOutput");
}
}
use of org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload.SideInputId in project beam by apache.
the class FlinkStreamingPortablePipelineTranslator method getSideInputIdToPCollectionViewMap.
private static LinkedHashMap<RunnerApi.ExecutableStagePayload.SideInputId, PCollectionView<?>> getSideInputIdToPCollectionViewMap(RunnerApi.ExecutableStagePayload stagePayload, RunnerApi.Components components) {
RehydratedComponents rehydratedComponents = RehydratedComponents.forComponents(components);
LinkedHashMap<RunnerApi.ExecutableStagePayload.SideInputId, PCollectionView<?>> sideInputs = new LinkedHashMap<>();
// for PCollectionView compatibility, not used to transform materialization
ViewFn<Iterable<WindowedValue<?>>, ?> viewFn = (ViewFn) new PCollectionViews.MultimapViewFn<>((PCollectionViews.TypeDescriptorSupplier<Iterable<WindowedValue<Void>>>) () -> TypeDescriptors.iterables(new TypeDescriptor<WindowedValue<Void>>() {
}), (PCollectionViews.TypeDescriptorSupplier<Void>) TypeDescriptors::voids);
for (RunnerApi.ExecutableStagePayload.SideInputId sideInputId : stagePayload.getSideInputsList()) {
// TODO: local name is unique as long as only one transform with side input can be within a
// stage
String sideInputTag = sideInputId.getLocalName();
String collectionId = components.getTransformsOrThrow(sideInputId.getTransformId()).getInputsOrThrow(sideInputId.getLocalName());
RunnerApi.WindowingStrategy windowingStrategyProto = components.getWindowingStrategiesOrThrow(components.getPcollectionsOrThrow(collectionId).getWindowingStrategyId());
final WindowingStrategy<?, ?> windowingStrategy;
try {
windowingStrategy = WindowingStrategyTranslation.fromProto(windowingStrategyProto, rehydratedComponents);
} catch (InvalidProtocolBufferException e) {
throw new IllegalStateException(String.format("Unable to hydrate side input windowing strategy %s.", windowingStrategyProto), e);
}
Coder<WindowedValue<Object>> coder = instantiateCoder(collectionId, components);
// side input materialization via GBK (T -> Iterable<T>)
WindowedValueCoder wvCoder = (WindowedValueCoder) coder;
coder = wvCoder.withValueCoder(IterableCoder.of(wvCoder.getValueCoder()));
sideInputs.put(sideInputId, new RunnerPCollectionView<>(null, new TupleTag<>(sideInputTag), viewFn, // TODO: support custom mapping fn
windowingStrategy.getWindowFn().getDefaultWindowMappingFn(), windowingStrategy, coder));
}
return sideInputs;
}
Aggregations