use of org.apache.beam.sdk.transforms.join.RawUnionValue in project beam by apache.
the class SparkStreamingPortablePipelineTranslator method translateExecutableStage.
private static <InputT, OutputT, SideInputT> void translateExecutableStage(PTransformNode transformNode, RunnerApi.Pipeline pipeline, SparkStreamingTranslationContext context) {
RunnerApi.ExecutableStagePayload stagePayload;
try {
stagePayload = RunnerApi.ExecutableStagePayload.parseFrom(transformNode.getTransform().getSpec().getPayload());
} catch (IOException e) {
throw new RuntimeException(e);
}
String inputPCollectionId = stagePayload.getInput();
UnboundedDataset<InputT> inputDataset = (UnboundedDataset<InputT>) context.popDataset(inputPCollectionId);
List<Integer> streamSources = inputDataset.getStreamSources();
JavaDStream<WindowedValue<InputT>> inputDStream = inputDataset.getDStream();
Map<String, String> outputs = transformNode.getTransform().getOutputsMap();
BiMap<String, Integer> outputMap = createOutputMap(outputs.values());
RunnerApi.Components components = pipeline.getComponents();
Coder windowCoder = getWindowingStrategy(inputPCollectionId, components).getWindowFn().windowCoder();
// TODO (BEAM-10712): handle side inputs.
if (stagePayload.getSideInputsCount() > 0) {
throw new UnsupportedOperationException("Side inputs to executable stage are currently unsupported.");
}
ImmutableMap<String, Tuple2<Broadcast<List<byte[]>>, WindowedValue.WindowedValueCoder<SideInputT>>> broadcastVariables = ImmutableMap.copyOf(new HashMap<>());
SparkExecutableStageFunction<InputT, SideInputT> function = new SparkExecutableStageFunction<>(context.getSerializableOptions(), stagePayload, context.jobInfo, outputMap, SparkExecutableStageContextFactory.getInstance(), broadcastVariables, MetricsAccumulator.getInstance(), windowCoder);
JavaDStream<RawUnionValue> staged = inputDStream.mapPartitions(function);
String intermediateId = getExecutableStageIntermediateId(transformNode);
context.pushDataset(intermediateId, new Dataset() {
@Override
public void cache(String storageLevel, Coder<?> coder) {
StorageLevel level = StorageLevel.fromString(storageLevel);
staged.persist(level);
}
@Override
public void action() {
// Empty function to force computation of RDD.
staged.foreachRDD(TranslationUtils.emptyVoidFunction());
}
@Override
public void setName(String name) {
// ignore
}
});
// Pop dataset to mark DStream as used
context.popDataset(intermediateId);
for (String outputId : outputs.values()) {
JavaDStream<WindowedValue<OutputT>> outStream = staged.flatMap(new SparkExecutableStageExtractionFunction<>(outputMap.get(outputId)));
context.pushDataset(outputId, new UnboundedDataset<>(outStream, streamSources));
}
// Add sink to ensure stage is executed
if (outputs.isEmpty()) {
JavaDStream<WindowedValue<OutputT>> outStream = staged.flatMap((rawUnionValue) -> Collections.emptyIterator());
context.pushDataset(String.format("EmptyOutputSink_%d", context.nextSinkId()), new UnboundedDataset<>(outStream, streamSources));
}
}
use of org.apache.beam.sdk.transforms.join.RawUnionValue in project beam by apache.
the class FlinkBatchPortablePipelineTranslator method pruneOutput.
private static void pruneOutput(DataSet<RawUnionValue> taggedDataset, BatchTranslationContext context, int unionTag, Coder<WindowedValue<?>> outputCoder, String collectionId) {
TypeInformation<WindowedValue<?>> outputType = new CoderTypeInformation<>(outputCoder, context.getPipelineOptions());
FlinkExecutableStagePruningFunction pruningFunction = new FlinkExecutableStagePruningFunction(unionTag, context.getPipelineOptions());
FlatMapOperator<RawUnionValue, WindowedValue<?>> pruningOperator = new FlatMapOperator<>(taggedDataset, outputType, pruningFunction, String.format("ExtractOutput[%s]", unionTag));
context.addDataSet(collectionId, pruningOperator);
}
use of org.apache.beam.sdk.transforms.join.RawUnionValue in project beam by apache.
the class FlinkStreamingPortablePipelineTranslator method translateExecutableStage.
private <InputT, OutputT> void translateExecutableStage(String id, RunnerApi.Pipeline pipeline, StreamingTranslationContext context) {
// TODO: Fail on splittable DoFns.
// TODO: Special-case single outputs to avoid multiplexing PCollections.
RunnerApi.Components components = pipeline.getComponents();
RunnerApi.PTransform transform = components.getTransformsOrThrow(id);
Map<String, String> outputs = transform.getOutputsMap();
final RunnerApi.ExecutableStagePayload stagePayload;
try {
stagePayload = RunnerApi.ExecutableStagePayload.parseFrom(transform.getSpec().getPayload());
} catch (IOException e) {
throw new RuntimeException(e);
}
String inputPCollectionId = stagePayload.getInput();
final TransformedSideInputs transformedSideInputs;
if (stagePayload.getSideInputsCount() > 0) {
transformedSideInputs = transformSideInputs(stagePayload, components, context);
} else {
transformedSideInputs = new TransformedSideInputs(Collections.emptyMap(), null);
}
Map<TupleTag<?>, OutputTag<WindowedValue<?>>> tagsToOutputTags = Maps.newLinkedHashMap();
Map<TupleTag<?>, Coder<WindowedValue<?>>> tagsToCoders = Maps.newLinkedHashMap();
// TODO: does it matter which output we designate as "main"
final TupleTag<OutputT> mainOutputTag = outputs.isEmpty() ? null : new TupleTag(outputs.keySet().iterator().next());
// associate output tags with ids, output manager uses these Integer ids to serialize state
BiMap<String, Integer> outputIndexMap = createOutputMap(outputs.keySet());
Map<String, Coder<WindowedValue<?>>> outputCoders = Maps.newHashMap();
Map<TupleTag<?>, Integer> tagsToIds = Maps.newHashMap();
Map<String, TupleTag<?>> collectionIdToTupleTag = Maps.newHashMap();
// order output names for deterministic mapping
for (String localOutputName : new TreeMap<>(outputIndexMap).keySet()) {
String collectionId = outputs.get(localOutputName);
Coder<WindowedValue<?>> windowCoder = (Coder) instantiateCoder(collectionId, components);
outputCoders.put(localOutputName, windowCoder);
TupleTag<?> tupleTag = new TupleTag<>(localOutputName);
CoderTypeInformation<WindowedValue<?>> typeInformation = new CoderTypeInformation(windowCoder, context.getPipelineOptions());
tagsToOutputTags.put(tupleTag, new OutputTag<>(localOutputName, typeInformation));
tagsToCoders.put(tupleTag, windowCoder);
tagsToIds.put(tupleTag, outputIndexMap.get(localOutputName));
collectionIdToTupleTag.put(collectionId, tupleTag);
}
final SingleOutputStreamOperator<WindowedValue<OutputT>> outputStream;
DataStream<WindowedValue<InputT>> inputDataStream = context.getDataStreamOrThrow(inputPCollectionId);
CoderTypeInformation<WindowedValue<OutputT>> outputTypeInformation = !outputs.isEmpty() ? new CoderTypeInformation(outputCoders.get(mainOutputTag.getId()), context.getPipelineOptions()) : null;
ArrayList<TupleTag<?>> additionalOutputTags = Lists.newArrayList();
for (TupleTag<?> tupleTag : tagsToCoders.keySet()) {
if (!mainOutputTag.getId().equals(tupleTag.getId())) {
additionalOutputTags.add(tupleTag);
}
}
final Coder<WindowedValue<InputT>> windowedInputCoder = instantiateCoder(inputPCollectionId, components);
final boolean stateful = stagePayload.getUserStatesCount() > 0 || stagePayload.getTimersCount() > 0;
final boolean hasSdfProcessFn = stagePayload.getComponents().getTransformsMap().values().stream().anyMatch(pTransform -> pTransform.getSpec().getUrn().equals(PTransformTranslation.SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN));
Coder keyCoder = null;
KeySelector<WindowedValue<InputT>, ?> keySelector = null;
if (stateful || hasSdfProcessFn) {
// Stateful/SDF stages are only allowed of KV input.
Coder valueCoder = ((WindowedValue.FullWindowedValueCoder) windowedInputCoder).getValueCoder();
if (!(valueCoder instanceof KvCoder)) {
throw new IllegalStateException(String.format(Locale.ENGLISH, "The element coder for stateful DoFn '%s' must be KvCoder but is: %s", inputPCollectionId, valueCoder.getClass().getSimpleName()));
}
if (stateful) {
keyCoder = ((KvCoder) valueCoder).getKeyCoder();
keySelector = new KvToByteBufferKeySelector(keyCoder, new SerializablePipelineOptions(context.getPipelineOptions()));
} else {
// as the key.
if (!(((KvCoder) valueCoder).getKeyCoder() instanceof KvCoder)) {
throw new IllegalStateException(String.format(Locale.ENGLISH, "The element coder for splittable DoFn '%s' must be KVCoder(KvCoder, DoubleCoder) but is: %s", inputPCollectionId, valueCoder.getClass().getSimpleName()));
}
keyCoder = ((KvCoder) ((KvCoder) valueCoder).getKeyCoder()).getKeyCoder();
keySelector = new SdfByteBufferKeySelector(keyCoder, new SerializablePipelineOptions(context.getPipelineOptions()));
}
inputDataStream = inputDataStream.keyBy(keySelector);
}
DoFnOperator.MultiOutputOutputManagerFactory<OutputT> outputManagerFactory = new DoFnOperator.MultiOutputOutputManagerFactory<>(mainOutputTag, tagsToOutputTags, tagsToCoders, tagsToIds, new SerializablePipelineOptions(context.getPipelineOptions()));
DoFnOperator<InputT, OutputT> doFnOperator = new ExecutableStageDoFnOperator<>(transform.getUniqueName(), windowedInputCoder, Collections.emptyMap(), mainOutputTag, additionalOutputTags, outputManagerFactory, transformedSideInputs.unionTagToView, new ArrayList<>(transformedSideInputs.unionTagToView.values()), getSideInputIdToPCollectionViewMap(stagePayload, components), context.getPipelineOptions(), stagePayload, context.getJobInfo(), FlinkExecutableStageContextFactory.getInstance(), collectionIdToTupleTag, getWindowingStrategy(inputPCollectionId, components), keyCoder, keySelector);
final String operatorName = generateNameFromStagePayload(stagePayload);
if (transformedSideInputs.unionTagToView.isEmpty()) {
outputStream = inputDataStream.transform(operatorName, outputTypeInformation, doFnOperator);
} else {
DataStream<RawUnionValue> sideInputStream = transformedSideInputs.unionedSideInputs.broadcast();
if (stateful || hasSdfProcessFn) {
// We have to manually construct the two-input transform because we're not
// allowed to have only one input keyed, normally. Since Flink 1.5.0 it's
// possible to use the Broadcast State Pattern which provides a more elegant
// way to process keyed main input with broadcast state, but it's not feasible
// here because it breaks the DoFnOperator abstraction.
TwoInputTransformation<WindowedValue<KV<?, InputT>>, RawUnionValue, WindowedValue<OutputT>> rawFlinkTransform = new TwoInputTransformation(inputDataStream.getTransformation(), sideInputStream.getTransformation(), transform.getUniqueName(), doFnOperator, outputTypeInformation, inputDataStream.getParallelism());
rawFlinkTransform.setStateKeyType(((KeyedStream) inputDataStream).getKeyType());
rawFlinkTransform.setStateKeySelectors(((KeyedStream) inputDataStream).getKeySelector(), null);
outputStream = new SingleOutputStreamOperator(inputDataStream.getExecutionEnvironment(), // we have to cheat around the ctor being protected
rawFlinkTransform) {
};
} else {
outputStream = inputDataStream.connect(sideInputStream).transform(operatorName, outputTypeInformation, doFnOperator);
}
}
// Assign a unique but consistent id to re-map operator state
outputStream.uid(transform.getUniqueName());
if (mainOutputTag != null) {
context.addDataStream(outputs.get(mainOutputTag.getId()), outputStream);
}
for (TupleTag<?> tupleTag : additionalOutputTags) {
context.addDataStream(outputs.get(tupleTag.getId()), outputStream.getSideOutput(tagsToOutputTags.get(tupleTag)));
}
}
use of org.apache.beam.sdk.transforms.join.RawUnionValue in project beam by apache.
the class FlinkStreamingTransformTranslators method transformSideInputs.
private static Tuple2<Map<Integer, PCollectionView<?>>, DataStream<RawUnionValue>> transformSideInputs(Collection<PCollectionView<?>> sideInputs, FlinkStreamingTranslationContext context) {
// collect all side inputs
Map<TupleTag<?>, Integer> tagToIntMapping = new HashMap<>();
Map<Integer, PCollectionView<?>> intToViewMapping = new HashMap<>();
int count = 0;
for (PCollectionView<?> sideInput : sideInputs) {
TupleTag<?> tag = sideInput.getTagInternal();
intToViewMapping.put(count, sideInput);
tagToIntMapping.put(tag, count);
count++;
}
List<Coder<?>> inputCoders = new ArrayList<>();
for (PCollectionView<?> sideInput : sideInputs) {
DataStream<Object> sideInputStream = context.getInputDataStream(sideInput);
TypeInformation<Object> tpe = sideInputStream.getType();
if (!(tpe instanceof CoderTypeInformation)) {
throw new IllegalStateException("Input Stream TypeInformation is no CoderTypeInformation.");
}
Coder<?> coder = ((CoderTypeInformation) tpe).getCoder();
inputCoders.add(coder);
}
UnionCoder unionCoder = UnionCoder.of(inputCoders);
CoderTypeInformation<RawUnionValue> unionTypeInformation = new CoderTypeInformation<>(unionCoder, context.getPipelineOptions());
// transform each side input to RawUnionValue and union them
DataStream<RawUnionValue> sideInputUnion = null;
for (PCollectionView<?> sideInput : sideInputs) {
TupleTag<?> tag = sideInput.getTagInternal();
final int intTag = tagToIntMapping.get(tag);
DataStream<Object> sideInputStream = context.getInputDataStream(sideInput);
DataStream<RawUnionValue> unionValueStream = sideInputStream.map(new ToRawUnion<>(intTag, context.getPipelineOptions())).returns(unionTypeInformation);
if (sideInputUnion == null) {
sideInputUnion = unionValueStream;
} else {
sideInputUnion = sideInputUnion.union(unionValueStream);
}
}
if (sideInputUnion == null) {
throw new IllegalStateException("No unioned side inputs, this indicates a bug.");
}
return new Tuple2<>(intToViewMapping, sideInputUnion);
}
use of org.apache.beam.sdk.transforms.join.RawUnionValue in project beam by apache.
the class DoFnOperatorTest method sideInputCheckpointing.
void sideInputCheckpointing(TestHarnessFactory<TwoInputStreamOperatorTestHarness<WindowedValue<String>, RawUnionValue, WindowedValue<String>>> harnessFactory) throws Exception {
TwoInputStreamOperatorTestHarness<WindowedValue<String>, RawUnionValue, WindowedValue<String>> testHarness = harnessFactory.create();
testHarness.open();
IntervalWindow firstWindow = new IntervalWindow(new Instant(0), new Instant(100));
IntervalWindow secondWindow = new IntervalWindow(new Instant(0), new Instant(500));
// push in some side inputs for both windows
testHarness.processElement2(new StreamRecord<>(new RawUnionValue(1, valuesInWindow(PCollectionViewTesting.materializeValuesFor(view1.getPipeline().getOptions(), View.asIterable(), "hello", "ciao"), new Instant(0), firstWindow))));
testHarness.processElement2(new StreamRecord<>(new RawUnionValue(2, valuesInWindow(PCollectionViewTesting.materializeValuesFor(view2.getPipeline().getOptions(), View.asIterable(), "foo", "bar"), new Instant(0), secondWindow))));
// snapshot state, throw away the operator, then restore and verify that we still match
// main-input elements to the side-inputs that we sent earlier
OperatorSubtaskState snapshot = testHarness.snapshot(0, 0);
testHarness = harnessFactory.create();
testHarness.initializeState(snapshot);
testHarness.open();
// push in main-input elements
WindowedValue<String> helloElement = valueInWindow("Hello", new Instant(0), firstWindow);
WindowedValue<String> worldElement = valueInWindow("World", new Instant(1000), firstWindow);
testHarness.processElement1(new StreamRecord<>(helloElement));
testHarness.processElement1(new StreamRecord<>(worldElement));
assertThat(stripStreamRecordFromWindowedValue(testHarness.getOutput()), contains(helloElement, worldElement));
testHarness.close();
}
Aggregations