use of org.apache.beam.model.pipeline.v1.RunnerApi.Coder in project beam by apache.
the class RemoteExecutionTest method testExecutionWithSideInputCaching.
@Test
public void testExecutionWithSideInputCaching() throws Exception {
Pipeline p = Pipeline.create();
addExperiment(p.getOptions().as(ExperimentalOptions.class), "beam_fn_api");
// TODO(BEAM-10097): Remove experiment once all portable runners support this view type
addExperiment(p.getOptions().as(ExperimentalOptions.class), "use_runner_v2");
launchSdkHarness(p.getOptions());
PCollection<String> input = p.apply("impulse", Impulse.create()).apply("create", ParDo.of(new DoFn<byte[], String>() {
@ProcessElement
public void process(ProcessContext ctxt) {
ctxt.output("zero");
ctxt.output("one");
ctxt.output("two");
}
})).setCoder(StringUtf8Coder.of());
PCollectionView<Iterable<String>> iterableView = input.apply("createIterableSideInput", View.asIterable());
PCollectionView<Map<String, Iterable<String>>> multimapView = input.apply(WithKeys.of("key")).apply("createMultimapSideInput", View.asMultimap());
input.apply("readSideInput", ParDo.of(new DoFn<String, KV<String, String>>() {
@ProcessElement
public void processElement(ProcessContext context) {
for (String value : context.sideInput(iterableView)) {
context.output(KV.of(context.element(), value));
}
for (Map.Entry<String, Iterable<String>> entry : context.sideInput(multimapView).entrySet()) {
for (String value : entry.getValue()) {
context.output(KV.of(context.element(), entry.getKey() + ":" + value));
}
}
}
}).withSideInputs(iterableView, multimapView)).setCoder(KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of())).apply("gbk", GroupByKey.create());
RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
FusedPipeline fused = GreedyPipelineFuser.fuse(pipelineProto);
Optional<ExecutableStage> optionalStage = Iterables.tryFind(fused.getFusedStages(), (ExecutableStage stage) -> !stage.getSideInputs().isEmpty());
checkState(optionalStage.isPresent(), "Expected a stage with side inputs.");
ExecutableStage stage = optionalStage.get();
ExecutableProcessBundleDescriptor descriptor = ProcessBundleDescriptors.fromExecutableStage("test_stage", stage, dataServer.getApiServiceDescriptor(), stateServer.getApiServiceDescriptor());
BundleProcessor processor = controlClient.getProcessor(descriptor.getProcessBundleDescriptor(), descriptor.getRemoteInputDestinations(), stateDelegator);
Map<String, Coder> remoteOutputCoders = descriptor.getRemoteOutputCoders();
Map<String, Collection<WindowedValue<?>>> outputValues = new HashMap<>();
Map<String, RemoteOutputReceiver<?>> outputReceivers = new HashMap<>();
for (Entry<String, Coder> remoteOutputCoder : remoteOutputCoders.entrySet()) {
List<WindowedValue<?>> outputContents = Collections.synchronizedList(new ArrayList<>());
outputValues.put(remoteOutputCoder.getKey(), outputContents);
outputReceivers.put(remoteOutputCoder.getKey(), RemoteOutputReceiver.of((Coder<WindowedValue<?>>) remoteOutputCoder.getValue(), outputContents::add));
}
StoringStateRequestHandler stateRequestHandler = new StoringStateRequestHandler(StateRequestHandlers.forSideInputHandlerFactory(descriptor.getSideInputSpecs(), new SideInputHandlerFactory() {
@Override
public <V, W extends BoundedWindow> IterableSideInputHandler<V, W> forIterableSideInput(String pTransformId, String sideInputId, Coder<V> elementCoder, Coder<W> windowCoder) {
return new IterableSideInputHandler<V, W>() {
@Override
public Iterable<V> get(W window) {
return (Iterable) Arrays.asList("A", "B", "C");
}
@Override
public Coder<V> elementCoder() {
return elementCoder;
}
};
}
@Override
public <K, V, W extends BoundedWindow> MultimapSideInputHandler<K, V, W> forMultimapSideInput(String pTransformId, String sideInputId, KvCoder<K, V> elementCoder, Coder<W> windowCoder) {
return new MultimapSideInputHandler<K, V, W>() {
@Override
public Iterable<K> get(W window) {
return (Iterable) Arrays.asList("key1", "key2");
}
@Override
public Iterable<V> get(K key, W window) {
if ("key1".equals(key)) {
return (Iterable) Arrays.asList("H", "I", "J");
} else if ("key2".equals(key)) {
return (Iterable) Arrays.asList("M", "N", "O");
}
return Collections.emptyList();
}
@Override
public Coder<K> keyCoder() {
return elementCoder.getKeyCoder();
}
@Override
public Coder<V> valueCoder() {
return elementCoder.getValueCoder();
}
};
}
}));
String transformId = Iterables.get(stage.getSideInputs(), 0).transform().getId();
stateRequestHandler.addCacheToken(BeamFnApi.ProcessBundleRequest.CacheToken.newBuilder().setSideInput(BeamFnApi.ProcessBundleRequest.CacheToken.SideInput.newBuilder().setSideInputId(iterableView.getTagInternal().getId()).setTransformId(transformId).build()).setToken(ByteString.copyFromUtf8("IterableSideInputToken")).build());
stateRequestHandler.addCacheToken(BeamFnApi.ProcessBundleRequest.CacheToken.newBuilder().setSideInput(BeamFnApi.ProcessBundleRequest.CacheToken.SideInput.newBuilder().setSideInputId(multimapView.getTagInternal().getId()).setTransformId(transformId).build()).setToken(ByteString.copyFromUtf8("MulitmapSideInputToken")).build());
BundleProgressHandler progressHandler = BundleProgressHandler.ignored();
try (RemoteBundle bundle = processor.newBundle(outputReceivers, stateRequestHandler, progressHandler)) {
Iterables.getOnlyElement(bundle.getInputReceivers().values()).accept(valueInGlobalWindow("X"));
}
try (RemoteBundle bundle = processor.newBundle(outputReceivers, stateRequestHandler, progressHandler)) {
Iterables.getOnlyElement(bundle.getInputReceivers().values()).accept(valueInGlobalWindow("Y"));
}
for (Collection<WindowedValue<?>> windowedValues : outputValues.values()) {
assertThat(windowedValues, containsInAnyOrder(valueInGlobalWindow(KV.of("X", "A")), valueInGlobalWindow(KV.of("X", "B")), valueInGlobalWindow(KV.of("X", "C")), valueInGlobalWindow(KV.of("X", "key1:H")), valueInGlobalWindow(KV.of("X", "key1:I")), valueInGlobalWindow(KV.of("X", "key1:J")), valueInGlobalWindow(KV.of("X", "key2:M")), valueInGlobalWindow(KV.of("X", "key2:N")), valueInGlobalWindow(KV.of("X", "key2:O")), valueInGlobalWindow(KV.of("Y", "A")), valueInGlobalWindow(KV.of("Y", "B")), valueInGlobalWindow(KV.of("Y", "C")), valueInGlobalWindow(KV.of("Y", "key1:H")), valueInGlobalWindow(KV.of("Y", "key1:I")), valueInGlobalWindow(KV.of("Y", "key1:J")), valueInGlobalWindow(KV.of("Y", "key2:M")), valueInGlobalWindow(KV.of("Y", "key2:N")), valueInGlobalWindow(KV.of("Y", "key2:O"))));
}
// Expect the following requests for the first bundle:
// * one to read iterable side input
// * one to read keys from multimap side input
// * one to read key1 iterable from multimap side input
// * one to read key2 iterable from multimap side input
assertEquals(4, stateRequestHandler.receivedRequests.size());
assertEquals(stateRequestHandler.receivedRequests.get(0).getStateKey().getIterableSideInput(), BeamFnApi.StateKey.IterableSideInput.newBuilder().setSideInputId(iterableView.getTagInternal().getId()).setTransformId(transformId).build());
assertEquals(stateRequestHandler.receivedRequests.get(1).getStateKey().getMultimapKeysSideInput(), BeamFnApi.StateKey.MultimapKeysSideInput.newBuilder().setSideInputId(multimapView.getTagInternal().getId()).setTransformId(transformId).build());
assertEquals(stateRequestHandler.receivedRequests.get(2).getStateKey().getMultimapSideInput(), BeamFnApi.StateKey.MultimapSideInput.newBuilder().setSideInputId(multimapView.getTagInternal().getId()).setTransformId(transformId).setKey(encode("key1")).build());
assertEquals(stateRequestHandler.receivedRequests.get(3).getStateKey().getMultimapSideInput(), BeamFnApi.StateKey.MultimapSideInput.newBuilder().setSideInputId(multimapView.getTagInternal().getId()).setTransformId(transformId).setKey(encode("key2")).build());
}
use of org.apache.beam.model.pipeline.v1.RunnerApi.Coder in project beam by apache.
the class RemoteExecutionTest method testExecutionWithTimer.
@Test
public void testExecutionWithTimer() throws Exception {
launchSdkHarness(PipelineOptionsFactory.create());
Pipeline p = Pipeline.create();
p.apply("impulse", Impulse.create()).apply("create", ParDo.of(new DoFn<byte[], KV<String, String>>() {
@ProcessElement
public void process(ProcessContext ctxt) {
}
})).setCoder(KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of())).apply("timer", ParDo.of(new DoFn<KV<String, String>, KV<String, String>>() {
@TimerId("event")
private final TimerSpec eventTimerSpec = TimerSpecs.timer(TimeDomain.EVENT_TIME);
@TimerId("processing")
private final TimerSpec processingTimerSpec = TimerSpecs.timer(TimeDomain.PROCESSING_TIME);
@ProcessElement
public void processElement(ProcessContext context, @TimerId("event") Timer eventTimeTimer, @TimerId("processing") Timer processingTimeTimer) {
context.output(KV.of("main" + context.element().getKey(), ""));
eventTimeTimer.withOutputTimestamp(context.timestamp()).set(context.timestamp().plus(Duration.millis(1L)));
processingTimeTimer.offset(Duration.millis(2L));
processingTimeTimer.setRelative();
}
@OnTimer("event")
public void eventTimer(OnTimerContext context, @Key String key, @TimerId("event") Timer eventTimeTimer, @TimerId("processing") Timer processingTimeTimer) {
context.output(KV.of("event", key));
eventTimeTimer.withOutputTimestamp(context.timestamp()).set(context.fireTimestamp().plus(Duration.millis(11L)));
processingTimeTimer.offset(Duration.millis(12L));
processingTimeTimer.setRelative();
}
@OnTimer("processing")
public void processingTimer(OnTimerContext context, @Key String key, @TimerId("event") Timer eventTimeTimer, @TimerId("processing") Timer processingTimeTimer) {
context.output(KV.of("processing", key));
eventTimeTimer.withOutputTimestamp(context.timestamp()).set(context.fireTimestamp().plus(Duration.millis(21L)));
processingTimeTimer.offset(Duration.millis(22L));
processingTimeTimer.setRelative();
}
@OnWindowExpiration
public void onWindowExpiration(@Key String key, OutputReceiver<KV<String, String>> outputReceiver) {
outputReceiver.output(KV.of("onWindowExpiration", key));
}
})).apply("gbk", GroupByKey.create());
RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
FusedPipeline fused = GreedyPipelineFuser.fuse(pipelineProto);
Optional<ExecutableStage> optionalStage = Iterables.tryFind(fused.getFusedStages(), (ExecutableStage stage) -> !stage.getTimers().isEmpty());
checkState(optionalStage.isPresent(), "Expected a stage with timers.");
ExecutableStage stage = optionalStage.get();
ExecutableProcessBundleDescriptor descriptor = ProcessBundleDescriptors.fromExecutableStage("test_stage", stage, dataServer.getApiServiceDescriptor(), stateServer.getApiServiceDescriptor());
BundleProcessor processor = controlClient.getProcessor(descriptor.getProcessBundleDescriptor(), descriptor.getRemoteInputDestinations(), stateDelegator, descriptor.getTimerSpecs());
Map<String, Collection<WindowedValue<?>>> outputValues = new HashMap<>();
Map<String, RemoteOutputReceiver<?>> outputReceivers = new HashMap<>();
for (Entry<String, Coder> remoteOutputCoder : descriptor.getRemoteOutputCoders().entrySet()) {
List<WindowedValue<?>> outputContents = Collections.synchronizedList(new ArrayList<>());
outputValues.put(remoteOutputCoder.getKey(), outputContents);
outputReceivers.put(remoteOutputCoder.getKey(), RemoteOutputReceiver.of((Coder<WindowedValue<?>>) remoteOutputCoder.getValue(), outputContents::add));
}
Map<KV<String, String>, Collection<org.apache.beam.runners.core.construction.Timer<?>>> timerValues = new HashMap<>();
Map<KV<String, String>, RemoteOutputReceiver<org.apache.beam.runners.core.construction.Timer<?>>> timerReceivers = new HashMap<>();
for (Map.Entry<String, Map<String, ProcessBundleDescriptors.TimerSpec>> transformTimerSpecs : descriptor.getTimerSpecs().entrySet()) {
for (ProcessBundleDescriptors.TimerSpec timerSpec : transformTimerSpecs.getValue().values()) {
KV<String, String> key = KV.of(timerSpec.transformId(), timerSpec.timerId());
List<org.apache.beam.runners.core.construction.Timer<?>> outputContents = Collections.synchronizedList(new ArrayList<>());
timerValues.put(key, outputContents);
timerReceivers.put(key, RemoteOutputReceiver.of((Coder<org.apache.beam.runners.core.construction.Timer<?>>) timerSpec.coder(), outputContents::add));
}
}
ProcessBundleDescriptors.TimerSpec eventTimerSpec = null;
ProcessBundleDescriptors.TimerSpec processingTimerSpec = null;
ProcessBundleDescriptors.TimerSpec onWindowExpirationSpec = null;
for (Map<String, ProcessBundleDescriptors.TimerSpec> timerSpecs : descriptor.getTimerSpecs().values()) {
for (ProcessBundleDescriptors.TimerSpec timerSpec : timerSpecs.values()) {
if ("onWindowExpiration0".equals(timerSpec.timerId())) {
onWindowExpirationSpec = timerSpec;
} else if (TimeDomain.EVENT_TIME.equals(timerSpec.getTimerSpec().getTimeDomain())) {
eventTimerSpec = timerSpec;
} else if (TimeDomain.PROCESSING_TIME.equals(timerSpec.getTimerSpec().getTimeDomain())) {
processingTimerSpec = timerSpec;
} else {
fail(String.format("Unknown timer specification %s", timerSpec));
}
}
}
// Set the current system time to a fixed value to get stable values for processing time timer
// output.
DateTimeUtils.setCurrentMillisFixed(BoundedWindow.TIMESTAMP_MIN_VALUE.getMillis() + 10000L);
try {
try (RemoteBundle bundle = processor.newBundle(outputReceivers, timerReceivers, StateRequestHandler.unsupported(), BundleProgressHandler.ignored(), null, null)) {
Iterables.getOnlyElement(bundle.getInputReceivers().values()).accept(valueInGlobalWindow(KV.of("X", "X")));
bundle.getTimerReceivers().get(KV.of(eventTimerSpec.transformId(), eventTimerSpec.timerId())).accept(timerForTest("Y", 1000L, 100L));
bundle.getTimerReceivers().get(KV.of(processingTimerSpec.transformId(), processingTimerSpec.timerId())).accept(timerForTest("Z", 2000L, 200L));
bundle.getTimerReceivers().get(KV.of(onWindowExpirationSpec.transformId(), onWindowExpirationSpec.timerId())).accept(timerForTest("key", 5001L, 5000L));
}
String mainOutputTransform = Iterables.getOnlyElement(descriptor.getRemoteOutputCoders().keySet());
assertThat(outputValues.get(mainOutputTransform), containsInAnyOrder(valueInGlobalWindow(KV.of("mainX", "")), WindowedValue.timestampedValueInGlobalWindow(KV.of("event", "Y"), BoundedWindow.TIMESTAMP_MIN_VALUE.plus(Duration.millis(100L))), WindowedValue.timestampedValueInGlobalWindow(KV.of("processing", "Z"), BoundedWindow.TIMESTAMP_MIN_VALUE.plus(Duration.millis(200L))), WindowedValue.timestampedValueInGlobalWindow(KV.of("onWindowExpiration", "key"), BoundedWindow.TIMESTAMP_MIN_VALUE.plus(Duration.millis(5000L)))));
assertThat(timerValues.get(KV.of(eventTimerSpec.transformId(), eventTimerSpec.timerId())), containsInAnyOrder(timerForTest("X", 1L, 0L), timerForTest("Y", 1011L, 100L), timerForTest("Z", 2021L, 200L)));
assertThat(timerValues.get(KV.of(processingTimerSpec.transformId(), processingTimerSpec.timerId())), containsInAnyOrder(timerForTest("X", 10002L, 0L), timerForTest("Y", 10012L, 100L), timerForTest("Z", 10022L, 200L)));
} finally {
DateTimeUtils.setCurrentMillisSystem();
}
}
use of org.apache.beam.model.pipeline.v1.RunnerApi.Coder in project beam by apache.
the class ParDoBoundMultiTranslator method doTranslate.
// static for serializing anonymous functions
private static <InT, OutT> void doTranslate(ParDo.MultiOutput<InT, OutT> transform, TransformHierarchy.Node node, TranslationContext ctx) {
final PCollection<? extends InT> input = ctx.getInput(transform);
final Map<TupleTag<?>, Coder<?>> outputCoders = ctx.getCurrentTransform().getOutputs().entrySet().stream().filter(e -> e.getValue() instanceof PCollection).collect(Collectors.toMap(e -> e.getKey(), e -> ((PCollection<?>) e.getValue()).getCoder()));
final Coder<?> keyCoder = StateUtils.isStateful(transform.getFn()) ? ((KvCoder<?, ?>) input.getCoder()).getKeyCoder() : null;
if (DoFnSignatures.isSplittable(transform.getFn())) {
throw new UnsupportedOperationException("Splittable DoFn is not currently supported");
}
if (DoFnSignatures.requiresTimeSortedInput(transform.getFn())) {
throw new UnsupportedOperationException("@RequiresTimeSortedInput annotation is not currently supported");
}
final MessageStream<OpMessage<InT>> inputStream = ctx.getMessageStream(input);
final List<MessageStream<OpMessage<InT>>> sideInputStreams = transform.getSideInputs().values().stream().map(ctx::<InT>getViewStream).collect(Collectors.toList());
final ArrayList<Map.Entry<TupleTag<?>, PCollection<?>>> outputs = new ArrayList<>(node.getOutputs().entrySet());
final Map<TupleTag<?>, Integer> tagToIndexMap = new HashMap<>();
final Map<Integer, PCollection<?>> indexToPCollectionMap = new HashMap<>();
for (int index = 0; index < outputs.size(); ++index) {
final Map.Entry<TupleTag<?>, PCollection<?>> taggedOutput = outputs.get(index);
tagToIndexMap.put(taggedOutput.getKey(), index);
if (!(taggedOutput.getValue() instanceof PCollection)) {
throw new IllegalArgumentException("Expected side output to be PCollection, but was: " + taggedOutput.getValue());
}
final PCollection<?> sideOutputCollection = taggedOutput.getValue();
indexToPCollectionMap.put(index, sideOutputCollection);
}
final HashMap<String, PCollectionView<?>> idToPValueMap = new HashMap<>();
for (PCollectionView<?> view : transform.getSideInputs().values()) {
idToPValueMap.put(ctx.getViewId(view), view);
}
DoFnSchemaInformation doFnSchemaInformation;
doFnSchemaInformation = ParDoTranslation.getSchemaInformation(ctx.getCurrentTransform());
Map<String, PCollectionView<?>> sideInputMapping = ParDoTranslation.getSideInputMapping(ctx.getCurrentTransform());
final DoFnOp<InT, OutT, RawUnionValue> op = new DoFnOp<>(transform.getMainOutputTag(), transform.getFn(), keyCoder, (Coder<InT>) input.getCoder(), null, outputCoders, transform.getSideInputs().values(), transform.getAdditionalOutputTags().getAll(), input.getWindowingStrategy(), idToPValueMap, new DoFnOp.MultiOutputManagerFactory(tagToIndexMap), ctx.getTransformFullName(), ctx.getTransformId(), input.isBounded(), false, null, null, Collections.emptyMap(), doFnSchemaInformation, sideInputMapping);
final MessageStream<OpMessage<InT>> mergedStreams;
if (sideInputStreams.isEmpty()) {
mergedStreams = inputStream;
} else {
MessageStream<OpMessage<InT>> mergedSideInputStreams = MessageStream.mergeAll(sideInputStreams).flatMap(new SideInputWatermarkFn());
mergedStreams = inputStream.merge(Collections.singletonList(mergedSideInputStreams));
}
final MessageStream<OpMessage<RawUnionValue>> taggedOutputStream = mergedStreams.flatMapAsync(OpAdapter.adapt(op));
for (int outputIndex : tagToIndexMap.values()) {
@SuppressWarnings("unchecked") final MessageStream<OpMessage<OutT>> outputStream = taggedOutputStream.filter(message -> message.getType() != OpMessage.Type.ELEMENT || message.getElement().getValue().getUnionTag() == outputIndex).flatMapAsync(OpAdapter.adapt(new RawUnionValueToValue()));
ctx.registerMessageStream(indexToPCollectionMap.get(outputIndex), outputStream);
}
}
use of org.apache.beam.model.pipeline.v1.RunnerApi.Coder in project beam by apache.
the class ParDoBoundMultiTranslator method doTranslatePortable.
// static for serializing anonymous functions
private static <InT, OutT> void doTranslatePortable(PipelineNode.PTransformNode transform, QueryablePipeline pipeline, PortableTranslationContext ctx) {
Map<String, String> outputs = transform.getTransform().getOutputsMap();
final RunnerApi.ExecutableStagePayload stagePayload;
try {
stagePayload = RunnerApi.ExecutableStagePayload.parseFrom(transform.getTransform().getSpec().getPayload());
} catch (IOException e) {
throw new RuntimeException(e);
}
String inputId = stagePayload.getInput();
final MessageStream<OpMessage<InT>> inputStream = ctx.getMessageStreamById(inputId);
// Analyze side inputs
final List<MessageStream<OpMessage<Iterable<?>>>> sideInputStreams = new ArrayList<>();
final Map<SideInputId, PCollectionView<?>> sideInputMapping = new HashMap<>();
final Map<String, PCollectionView<?>> idToViewMapping = new HashMap<>();
final RunnerApi.Components components = stagePayload.getComponents();
for (SideInputId sideInputId : stagePayload.getSideInputsList()) {
final String sideInputCollectionId = components.getTransformsOrThrow(sideInputId.getTransformId()).getInputsOrThrow(sideInputId.getLocalName());
final WindowingStrategy<?, BoundedWindow> windowingStrategy = WindowUtils.getWindowStrategy(sideInputCollectionId, components);
final WindowedValue.WindowedValueCoder<?> coder = (WindowedValue.WindowedValueCoder) instantiateCoder(sideInputCollectionId, components);
// Create a runner-side view
final PCollectionView<?> view = createPCollectionView(sideInputId, coder, windowingStrategy);
// Use GBK to aggregate the side inputs and then broadcast it out
final MessageStream<OpMessage<Iterable<?>>> broadcastSideInput = groupAndBroadcastSideInput(sideInputId, sideInputCollectionId, components.getPcollectionsOrThrow(sideInputCollectionId), (WindowingStrategy) windowingStrategy, coder, ctx);
sideInputStreams.add(broadcastSideInput);
sideInputMapping.put(sideInputId, view);
idToViewMapping.put(getSideInputUniqueId(sideInputId), view);
}
final Map<TupleTag<?>, Integer> tagToIndexMap = new HashMap<>();
final Map<Integer, String> indexToIdMap = new HashMap<>();
final Map<String, TupleTag<?>> idToTupleTagMap = new HashMap<>();
// first output as the main output
final TupleTag<OutT> mainOutputTag = outputs.isEmpty() ? null : new TupleTag(outputs.keySet().iterator().next());
AtomicInteger index = new AtomicInteger(0);
outputs.keySet().iterator().forEachRemaining(outputName -> {
TupleTag<?> tupleTag = new TupleTag<>(outputName);
tagToIndexMap.put(tupleTag, index.get());
String collectionId = outputs.get(outputName);
indexToIdMap.put(index.get(), collectionId);
idToTupleTagMap.put(collectionId, tupleTag);
index.incrementAndGet();
});
WindowedValue.WindowedValueCoder<InT> windowedInputCoder = WindowUtils.instantiateWindowedCoder(inputId, pipeline.getComponents());
// TODO: support schema and side inputs for portable runner
// Note: transform.getTransform() is an ExecutableStage, not ParDo, so we need to extract
// these info from its components.
final DoFnSchemaInformation doFnSchemaInformation = null;
final RunnerApi.PCollection input = pipeline.getComponents().getPcollectionsOrThrow(inputId);
final PCollection.IsBounded isBounded = SamzaPipelineTranslatorUtils.isBounded(input);
final Coder<?> keyCoder = StateUtils.isStateful(stagePayload) ? ((KvCoder) ((WindowedValue.FullWindowedValueCoder) windowedInputCoder).getValueCoder()).getKeyCoder() : null;
final DoFnOp<InT, OutT, RawUnionValue> op = new DoFnOp<>(mainOutputTag, new NoOpDoFn<>(), keyCoder, // input coder not in use
windowedInputCoder.getValueCoder(), windowedInputCoder, // output coders not in use
Collections.emptyMap(), new ArrayList<>(sideInputMapping.values()), // used by java runner only
new ArrayList<>(idToTupleTagMap.values()), WindowUtils.getWindowStrategy(inputId, stagePayload.getComponents()), idToViewMapping, new DoFnOp.MultiOutputManagerFactory(tagToIndexMap), ctx.getTransformFullName(), ctx.getTransformId(), isBounded, true, stagePayload, ctx.getJobInfo(), idToTupleTagMap, doFnSchemaInformation, sideInputMapping);
final MessageStream<OpMessage<InT>> mergedStreams;
if (sideInputStreams.isEmpty()) {
mergedStreams = inputStream;
} else {
MessageStream<OpMessage<InT>> mergedSideInputStreams = MessageStream.mergeAll(sideInputStreams).flatMap(new SideInputWatermarkFn());
mergedStreams = inputStream.merge(Collections.singletonList(mergedSideInputStreams));
}
final MessageStream<OpMessage<RawUnionValue>> taggedOutputStream = mergedStreams.flatMapAsync(OpAdapter.adapt(op));
for (int outputIndex : tagToIndexMap.values()) {
@SuppressWarnings("unchecked") final MessageStream<OpMessage<OutT>> outputStream = taggedOutputStream.filter(message -> message.getType() != OpMessage.Type.ELEMENT || message.getElement().getValue().getUnionTag() == outputIndex).flatMapAsync(OpAdapter.adapt(new RawUnionValueToValue()));
ctx.registerMessageStream(indexToIdMap.get(outputIndex), outputStream);
}
}
use of org.apache.beam.model.pipeline.v1.RunnerApi.Coder in project beam by apache.
the class SamzaTestStreamTranslator method translate.
@Override
public void translate(TestStream<T> testStream, TransformHierarchy.Node node, TranslationContext ctx) {
final PCollection<T> output = ctx.getOutput(testStream);
final String outputId = ctx.getIdForPValue(output);
final Coder<T> valueCoder = testStream.getValueCoder();
final TestStream.TestStreamCoder<T> testStreamCoder = TestStream.TestStreamCoder.of(valueCoder);
// encode testStream as a string
final String encodedTestStream;
try {
encodedTestStream = CoderUtils.encodeToBase64(testStreamCoder, testStream);
} catch (CoderException e) {
throw new RuntimeException("Could not encode TestStream.", e);
}
// the decoder for encodedTestStream
SerializableFunction<String, TestStream<T>> testStreamDecoder = string -> {
try {
return CoderUtils.decodeFromBase64(TestStream.TestStreamCoder.of(valueCoder), string);
} catch (CoderException e) {
throw new RuntimeException("Could not decode TestStream.", e);
}
};
ctx.registerInputMessageStream(output, createInputDescriptor(outputId, encodedTestStream, testStreamDecoder));
}
Aggregations