use of org.apache.beam.model.pipeline.v1.RunnerApi.Coder in project beam by apache.
the class SparkStreamingPortablePipelineTranslator method translateExecutableStage.
private static <InputT, OutputT, SideInputT> void translateExecutableStage(PTransformNode transformNode, RunnerApi.Pipeline pipeline, SparkStreamingTranslationContext context) {
RunnerApi.ExecutableStagePayload stagePayload;
try {
stagePayload = RunnerApi.ExecutableStagePayload.parseFrom(transformNode.getTransform().getSpec().getPayload());
} catch (IOException e) {
throw new RuntimeException(e);
}
String inputPCollectionId = stagePayload.getInput();
UnboundedDataset<InputT> inputDataset = (UnboundedDataset<InputT>) context.popDataset(inputPCollectionId);
List<Integer> streamSources = inputDataset.getStreamSources();
JavaDStream<WindowedValue<InputT>> inputDStream = inputDataset.getDStream();
Map<String, String> outputs = transformNode.getTransform().getOutputsMap();
BiMap<String, Integer> outputMap = createOutputMap(outputs.values());
RunnerApi.Components components = pipeline.getComponents();
Coder windowCoder = getWindowingStrategy(inputPCollectionId, components).getWindowFn().windowCoder();
// TODO (BEAM-10712): handle side inputs.
if (stagePayload.getSideInputsCount() > 0) {
throw new UnsupportedOperationException("Side inputs to executable stage are currently unsupported.");
}
ImmutableMap<String, Tuple2<Broadcast<List<byte[]>>, WindowedValue.WindowedValueCoder<SideInputT>>> broadcastVariables = ImmutableMap.copyOf(new HashMap<>());
SparkExecutableStageFunction<InputT, SideInputT> function = new SparkExecutableStageFunction<>(context.getSerializableOptions(), stagePayload, context.jobInfo, outputMap, SparkExecutableStageContextFactory.getInstance(), broadcastVariables, MetricsAccumulator.getInstance(), windowCoder);
JavaDStream<RawUnionValue> staged = inputDStream.mapPartitions(function);
String intermediateId = getExecutableStageIntermediateId(transformNode);
context.pushDataset(intermediateId, new Dataset() {
@Override
public void cache(String storageLevel, Coder<?> coder) {
StorageLevel level = StorageLevel.fromString(storageLevel);
staged.persist(level);
}
@Override
public void action() {
// Empty function to force computation of RDD.
staged.foreachRDD(TranslationUtils.emptyVoidFunction());
}
@Override
public void setName(String name) {
// ignore
}
});
// Pop dataset to mark DStream as used
context.popDataset(intermediateId);
for (String outputId : outputs.values()) {
JavaDStream<WindowedValue<OutputT>> outStream = staged.flatMap(new SparkExecutableStageExtractionFunction<>(outputMap.get(outputId)));
context.pushDataset(outputId, new UnboundedDataset<>(outStream, streamSources));
}
// Add sink to ensure stage is executed
if (outputs.isEmpty()) {
JavaDStream<WindowedValue<OutputT>> outStream = staged.flatMap((rawUnionValue) -> Collections.emptyIterator());
context.pushDataset(String.format("EmptyOutputSink_%d", context.nextSinkId()), new UnboundedDataset<>(outStream, streamSources));
}
}
use of org.apache.beam.model.pipeline.v1.RunnerApi.Coder in project beam by apache.
the class ProtoOverridesTest method replacesMultiple.
@Test
public void replacesMultiple() {
RunnerApi.Pipeline p = Pipeline.newBuilder().addAllRootTransformIds(ImmutableList.of("first", "second")).setComponents(Components.newBuilder().putTransforms("first", PTransform.newBuilder().setSpec(FunctionSpec.newBuilder().setUrn("beam:first")).build()).putTransforms("second", PTransform.newBuilder().setSpec(FunctionSpec.newBuilder().setUrn("beam:repeated")).build()).putTransforms("third", PTransform.newBuilder().setSpec(FunctionSpec.newBuilder().setUrn("beam:repeated")).build()).putPcollections("intermediatePc", PCollection.newBuilder().setUniqueName("intermediate").build()).putCoders("coder", Coder.newBuilder().setSpec(FunctionSpec.getDefaultInstance()).build())).build();
ByteString newPayload = ByteString.copyFrom("foo-bar-baz".getBytes(StandardCharsets.UTF_8));
Pipeline updated = ProtoOverrides.updateTransform("beam:repeated", p, (transformId, existingComponents) -> {
String subtransform = String.format("%s_sub", transformId);
return MessageWithComponents.newBuilder().setPtransform(PTransform.newBuilder().setSpec(FunctionSpec.newBuilder().setUrn("beam:repeated:replacement").setPayload(newPayload)).addSubtransforms(subtransform)).setComponents(Components.newBuilder().putTransforms(subtransform, PTransform.newBuilder().setUniqueName(subtransform).build())).build();
});
PTransform updatedSecond = updated.getComponents().getTransformsOrThrow("second");
PTransform updatedThird = updated.getComponents().getTransformsOrThrow("third");
assertThat(updatedSecond, not(equalTo(p.getComponents().getTransformsOrThrow("second"))));
assertThat(updatedThird, not(equalTo(p.getComponents().getTransformsOrThrow("third"))));
assertThat(updatedSecond.getSubtransformsList(), contains("second_sub"));
assertThat(updatedSecond.getSpec().getPayload(), equalTo(newPayload));
assertThat(updatedThird.getSubtransformsList(), contains("third_sub"));
assertThat(updatedThird.getSpec().getPayload(), equalTo(newPayload));
assertThat(updated.getComponents().getTransformsMap(), hasKey("second_sub"));
assertThat(updated.getComponents().getTransformsMap(), hasKey("third_sub"));
assertThat(updated.getComponents().getTransformsOrThrow("second_sub").getUniqueName(), equalTo("second_sub"));
assertThat(updated.getComponents().getTransformsOrThrow("third_sub").getUniqueName(), equalTo("third_sub"));
}
use of org.apache.beam.model.pipeline.v1.RunnerApi.Coder in project beam by apache.
the class GreedyPipelineFuserTest method sideInputRootsNewStage.
/*
* impulseA -> .out -> read -> .out -> leftParDo -> .out
* \ -> rightParDo -> .out
* ------> sideInputParDo -> .out
* /
* impulseB -> .out -> side_read -> .out /
*
* becomes
* (impulseA.out) -> read -> (read.out)
* (read.out) -> leftParDo
* \
* -> rightParDo
* (read.out) -> sideInputParDo
* (impulseB.out) -> side_read
*/
@Test
public void sideInputRootsNewStage() {
Components components = Components.newBuilder().putCoders("coder", Coder.newBuilder().build()).putCoders("windowCoder", Coder.newBuilder().build()).putWindowingStrategies("ws", WindowingStrategy.newBuilder().setWindowCoderId("windowCoder").build()).putTransforms("mainImpulse", PTransform.newBuilder().setUniqueName("MainImpulse").putOutputs("output", "mainImpulse.out").setSpec(FunctionSpec.newBuilder().setUrn(PTransformTranslation.IMPULSE_TRANSFORM_URN)).build()).putPcollections("mainImpulse.out", pc("mainImpulse.out")).putTransforms("read", PTransform.newBuilder().setUniqueName("Read").putInputs("input", "mainImpulse.out").putOutputs("output", "read.out").setSpec(FunctionSpec.newBuilder().setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN).setPayload(ParDoPayload.newBuilder().setDoFn(FunctionSpec.newBuilder()).build().toByteString())).setEnvironmentId("py").build()).putPcollections("read.out", pc("read.out")).putTransforms("sideImpulse", PTransform.newBuilder().setUniqueName("SideImpulse").putOutputs("output", "sideImpulse.out").setSpec(FunctionSpec.newBuilder().setUrn(PTransformTranslation.IMPULSE_TRANSFORM_URN)).build()).putPcollections("sideImpulse.out", pc("sideImpulse.out")).putTransforms("sideRead", PTransform.newBuilder().setUniqueName("SideRead").putInputs("input", "sideImpulse.out").putOutputs("output", "sideRead.out").setSpec(FunctionSpec.newBuilder().setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN).setPayload(ParDoPayload.newBuilder().setDoFn(FunctionSpec.newBuilder()).build().toByteString())).setEnvironmentId("py").build()).putPcollections("sideRead.out", pc("sideRead.out")).putTransforms("leftParDo", PTransform.newBuilder().setUniqueName("LeftParDo").putInputs("main", "read.out").putOutputs("output", "leftParDo.out").setSpec(FunctionSpec.newBuilder().setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN).setPayload(ParDoPayload.newBuilder().setDoFn(FunctionSpec.newBuilder()).build().toByteString()).build()).setEnvironmentId("py").build()).putPcollections("leftParDo.out", pc("leftParDo.out")).putTransforms("rightParDo", PTransform.newBuilder().setUniqueName("RightParDo").putInputs("main", "read.out").putOutputs("output", "rightParDo.out").setSpec(FunctionSpec.newBuilder().setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN).setPayload(ParDoPayload.newBuilder().setDoFn(FunctionSpec.newBuilder()).build().toByteString()).build()).setEnvironmentId("py").build()).putPcollections("rightParDo.out", pc("rightParDo.out")).putTransforms("sideParDo", PTransform.newBuilder().setUniqueName("SideParDo").putInputs("main", "read.out").putInputs("side", "sideRead.out").putOutputs("output", "sideParDo.out").setSpec(FunctionSpec.newBuilder().setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN).setPayload(ParDoPayload.newBuilder().setDoFn(FunctionSpec.newBuilder()).putSideInputs("side", SideInput.getDefaultInstance()).build().toByteString()).build()).setEnvironmentId("py").build()).putPcollections("sideParDo.out", pc("sideParDo.out")).putEnvironments("py", Environments.createDockerEnvironment("py")).build();
FusedPipeline fused = GreedyPipelineFuser.fuse(Pipeline.newBuilder().setComponents(components).build());
assertThat(fused.getRunnerExecutedTransforms(), containsInAnyOrder(PipelineNode.pTransform("mainImpulse", components.getTransformsOrThrow("mainImpulse")), PipelineNode.pTransform("sideImpulse", components.getTransformsOrThrow("sideImpulse"))));
assertThat(fused.getFusedStages(), containsInAnyOrder(ExecutableStageMatcher.withInput("mainImpulse.out").withOutputs("read.out").withTransforms("read"), ExecutableStageMatcher.withInput("read.out").withNoOutputs().withTransforms("leftParDo", "rightParDo"), ExecutableStageMatcher.withInput("read.out").withSideInputs(RunnerApi.ExecutableStagePayload.SideInputId.newBuilder().setTransformId("sideParDo").setLocalName("side").build()).withNoOutputs().withTransforms("sideParDo"), ExecutableStageMatcher.withInput("sideImpulse.out").withOutputs("sideRead.out").withTransforms("sideRead")));
}
use of org.apache.beam.model.pipeline.v1.RunnerApi.Coder in project beam by apache.
the class RemoteExecutionTest method testExecutionWithUserStateCaching.
@Test
public void testExecutionWithUserStateCaching() throws Exception {
Pipeline p = Pipeline.create();
launchSdkHarness(p.getOptions());
final String stateId = "foo";
final String stateId2 = "bar";
p.apply("impulse", Impulse.create()).apply("create", ParDo.of(new DoFn<byte[], KV<String, String>>() {
@ProcessElement
public void process(ProcessContext ctxt) {
}
})).setCoder(KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of())).apply("userState", ParDo.of(new DoFn<KV<String, String>, KV<String, String>>() {
@StateId(stateId)
private final StateSpec<BagState<String>> bufferState = StateSpecs.bag(StringUtf8Coder.of());
@StateId(stateId2)
private final StateSpec<BagState<String>> bufferState2 = StateSpecs.bag(StringUtf8Coder.of());
@ProcessElement
public void processElement(@Element KV<String, String> element, @StateId(stateId) BagState<String> state, @StateId(stateId2) BagState<String> state2, OutputReceiver<KV<String, String>> r) {
for (String value : state.read()) {
r.output(KV.of(element.getKey(), value));
}
ReadableState<Boolean> isEmpty = state2.isEmpty();
if (isEmpty.read()) {
r.output(KV.of(element.getKey(), "Empty"));
} else {
state2.clear();
}
}
})).apply("gbk", GroupByKey.create());
RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
FusedPipeline fused = GreedyPipelineFuser.fuse(pipelineProto);
Optional<ExecutableStage> optionalStage = Iterables.tryFind(fused.getFusedStages(), (ExecutableStage stage) -> !stage.getUserStates().isEmpty());
checkState(optionalStage.isPresent(), "Expected a stage with user state.");
ExecutableStage stage = optionalStage.get();
ExecutableProcessBundleDescriptor descriptor = ProcessBundleDescriptors.fromExecutableStage("test_stage", stage, dataServer.getApiServiceDescriptor(), stateServer.getApiServiceDescriptor());
BundleProcessor processor = controlClient.getProcessor(descriptor.getProcessBundleDescriptor(), descriptor.getRemoteInputDestinations(), stateDelegator);
Map<String, Coder> remoteOutputCoders = descriptor.getRemoteOutputCoders();
Map<String, Collection<WindowedValue<?>>> outputValues = new HashMap<>();
Map<String, RemoteOutputReceiver<?>> outputReceivers = new HashMap<>();
for (Entry<String, Coder> remoteOutputCoder : remoteOutputCoders.entrySet()) {
List<WindowedValue<?>> outputContents = Collections.synchronizedList(new ArrayList<>());
outputValues.put(remoteOutputCoder.getKey(), outputContents);
outputReceivers.put(remoteOutputCoder.getKey(), RemoteOutputReceiver.of((Coder<WindowedValue<?>>) remoteOutputCoder.getValue(), outputContents::add));
}
Map<String, List<ByteString>> userStateData = ImmutableMap.of(stateId, new ArrayList(Arrays.asList(ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "A", Coder.Context.NESTED)), ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "B", Coder.Context.NESTED)), ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "C", Coder.Context.NESTED)))), stateId2, new ArrayList(Arrays.asList(ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "D", Coder.Context.NESTED)))));
StoringStateRequestHandler stateRequestHandler = new StoringStateRequestHandler(StateRequestHandlers.forBagUserStateHandlerFactory(descriptor, new BagUserStateHandlerFactory<ByteString, Object, BoundedWindow>() {
@Override
public BagUserStateHandler<ByteString, Object, BoundedWindow> forUserState(String pTransformId, String userStateId, Coder<ByteString> keyCoder, Coder<Object> valueCoder, Coder<BoundedWindow> windowCoder) {
return new BagUserStateHandler<ByteString, Object, BoundedWindow>() {
@Override
public Iterable<Object> get(ByteString key, BoundedWindow window) {
return (Iterable) userStateData.get(userStateId);
}
@Override
public void append(ByteString key, BoundedWindow window, Iterator<Object> values) {
Iterators.addAll(userStateData.get(userStateId), (Iterator) values);
}
@Override
public void clear(ByteString key, BoundedWindow window) {
userStateData.get(userStateId).clear();
}
};
}
}));
try (RemoteBundle bundle = processor.newBundle(outputReceivers, stateRequestHandler, BundleProgressHandler.ignored())) {
Iterables.getOnlyElement(bundle.getInputReceivers().values()).accept(valueInGlobalWindow(KV.of("X", "Y")));
}
try (RemoteBundle bundle2 = processor.newBundle(outputReceivers, stateRequestHandler, BundleProgressHandler.ignored())) {
Iterables.getOnlyElement(bundle2.getInputReceivers().values()).accept(valueInGlobalWindow(KV.of("X", "Z")));
}
for (Collection<WindowedValue<?>> windowedValues : outputValues.values()) {
assertThat(windowedValues, containsInAnyOrder(valueInGlobalWindow(KV.of("X", "A")), valueInGlobalWindow(KV.of("X", "B")), valueInGlobalWindow(KV.of("X", "C")), valueInGlobalWindow(KV.of("X", "A")), valueInGlobalWindow(KV.of("X", "B")), valueInGlobalWindow(KV.of("X", "C")), valueInGlobalWindow(KV.of("X", "Empty"))));
}
assertThat(userStateData.get(stateId), IsIterableContainingInOrder.contains(ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "A", Coder.Context.NESTED)), ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "B", Coder.Context.NESTED)), ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "C", Coder.Context.NESTED))));
assertThat(userStateData.get(stateId2), IsEmptyIterable.emptyIterable());
// 3 Requests expected: state read, state2 read, and state2 clear
assertEquals(3, stateRequestHandler.getRequestCount());
ByteString.Output out = ByteString.newOutput();
StringUtf8Coder.of().encode("X", out);
assertEquals(stateId, stateRequestHandler.receivedRequests.get(0).getStateKey().getBagUserState().getUserStateId());
assertEquals(stateRequestHandler.receivedRequests.get(0).getStateKey().getBagUserState().getKey(), out.toByteString());
assertTrue(stateRequestHandler.receivedRequests.get(0).hasGet());
assertEquals(stateId2, stateRequestHandler.receivedRequests.get(1).getStateKey().getBagUserState().getUserStateId());
assertEquals(stateRequestHandler.receivedRequests.get(1).getStateKey().getBagUserState().getKey(), out.toByteString());
assertTrue(stateRequestHandler.receivedRequests.get(1).hasGet());
assertEquals(stateId2, stateRequestHandler.receivedRequests.get(2).getStateKey().getBagUserState().getUserStateId());
assertEquals(stateRequestHandler.receivedRequests.get(2).getStateKey().getBagUserState().getKey(), out.toByteString());
assertTrue(stateRequestHandler.receivedRequests.get(2).hasClear());
}
use of org.apache.beam.model.pipeline.v1.RunnerApi.Coder in project beam by apache.
the class RemoteExecutionTest method testExecutionWithMultipleStages.
@Test
public void testExecutionWithMultipleStages() throws Exception {
launchSdkHarness(PipelineOptionsFactory.create());
Pipeline p = Pipeline.create();
Function<String, PCollection<String>> pCollectionGenerator = suffix -> p.apply("impulse" + suffix, Impulse.create()).apply("create" + suffix, ParDo.of(new DoFn<byte[], String>() {
@ProcessElement
public void process(ProcessContext c) {
try {
c.output(CoderUtils.decodeFromByteArray(StringUtf8Coder.of(), c.element()));
} catch (CoderException e) {
throw new RuntimeException(e);
}
}
})).setCoder(StringUtf8Coder.of()).apply(ParDo.of(new DoFn<String, String>() {
@ProcessElement
public void processElement(ProcessContext c) {
c.output("stream" + suffix + c.element());
}
}));
PCollection<String> input1 = pCollectionGenerator.apply("1");
PCollection<String> input2 = pCollectionGenerator.apply("2");
PCollection<String> outputMerged = PCollectionList.of(input1).and(input2).apply(Flatten.pCollections());
outputMerged.apply("createKV", ParDo.of(new DoFn<String, KV<String, String>>() {
@ProcessElement
public void process(ProcessContext c) {
c.output(KV.of(c.element(), ""));
}
})).setCoder(KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of())).apply("gbk", GroupByKey.create());
RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
FusedPipeline fused = GreedyPipelineFuser.fuse(pipelineProto);
Set<ExecutableStage> stages = fused.getFusedStages();
assertThat(stages.size(), equalTo(2));
List<WindowedValue<?>> outputValues = Collections.synchronizedList(new ArrayList<>());
for (ExecutableStage stage : stages) {
ExecutableProcessBundleDescriptor descriptor = ProcessBundleDescriptors.fromExecutableStage(stage.toString(), stage, dataServer.getApiServiceDescriptor(), stateServer.getApiServiceDescriptor());
BundleProcessor processor = controlClient.getProcessor(descriptor.getProcessBundleDescriptor(), descriptor.getRemoteInputDestinations(), stateDelegator);
Map<String, Coder> remoteOutputCoders = descriptor.getRemoteOutputCoders();
Map<String, RemoteOutputReceiver<?>> outputReceivers = new HashMap<>();
for (Entry<String, Coder> remoteOutputCoder : remoteOutputCoders.entrySet()) {
outputReceivers.putIfAbsent(remoteOutputCoder.getKey(), RemoteOutputReceiver.of((Coder<WindowedValue<?>>) remoteOutputCoder.getValue(), outputValues::add));
}
try (RemoteBundle bundle = processor.newBundle(outputReceivers, StateRequestHandler.unsupported(), BundleProgressHandler.ignored())) {
Iterables.getOnlyElement(bundle.getInputReceivers().values()).accept(valueInGlobalWindow(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "X")));
}
}
assertThat(outputValues, containsInAnyOrder(valueInGlobalWindow(KV.of("stream1X", "")), valueInGlobalWindow(KV.of("stream2X", ""))));
}
Aggregations