use of org.apache.beam.model.pipeline.v1.RunnerApi.Pipeline in project beam by apache.
the class SdkComponentsTest method registerTransformAfterChildren.
@Test
public void registerTransformAfterChildren() throws IOException {
Create.Values<Long> create = Create.of(1L, 2L, 3L);
GenerateSequence createChild = GenerateSequence.from(0);
PCollection<Long> pt = pipeline.apply(create);
String userName = "my_transform";
String childUserName = "my_transform/my_nesting";
AppliedPTransform<?, ?, ?> transform = AppliedPTransform.of(userName, PValues.expandInput(pipeline.begin()), PValues.expandOutput(pt), create, ResourceHints.create(), pipeline);
AppliedPTransform<?, ?, ?> childTransform = AppliedPTransform.of(childUserName, PValues.expandInput(pipeline.begin()), PValues.expandOutput(pt), createChild, ResourceHints.create(), pipeline);
String childId = components.registerPTransform(childTransform, Collections.emptyList());
String parentId = components.registerPTransform(transform, Collections.singletonList(childTransform));
Components components = this.components.toComponents();
assertThat(components.getTransformsOrThrow(parentId).getSubtransforms(0), equalTo(childId));
assertThat(components.getTransformsOrThrow(childId).getSubtransformsCount(), equalTo(0));
}
use of org.apache.beam.model.pipeline.v1.RunnerApi.Pipeline in project beam by apache.
the class ProtoOverridesTest method replacesMultiple.
@Test
public void replacesMultiple() {
RunnerApi.Pipeline p = Pipeline.newBuilder().addAllRootTransformIds(ImmutableList.of("first", "second")).setComponents(Components.newBuilder().putTransforms("first", PTransform.newBuilder().setSpec(FunctionSpec.newBuilder().setUrn("beam:first")).build()).putTransforms("second", PTransform.newBuilder().setSpec(FunctionSpec.newBuilder().setUrn("beam:repeated")).build()).putTransforms("third", PTransform.newBuilder().setSpec(FunctionSpec.newBuilder().setUrn("beam:repeated")).build()).putPcollections("intermediatePc", PCollection.newBuilder().setUniqueName("intermediate").build()).putCoders("coder", Coder.newBuilder().setSpec(FunctionSpec.getDefaultInstance()).build())).build();
ByteString newPayload = ByteString.copyFrom("foo-bar-baz".getBytes(StandardCharsets.UTF_8));
Pipeline updated = ProtoOverrides.updateTransform("beam:repeated", p, (transformId, existingComponents) -> {
String subtransform = String.format("%s_sub", transformId);
return MessageWithComponents.newBuilder().setPtransform(PTransform.newBuilder().setSpec(FunctionSpec.newBuilder().setUrn("beam:repeated:replacement").setPayload(newPayload)).addSubtransforms(subtransform)).setComponents(Components.newBuilder().putTransforms(subtransform, PTransform.newBuilder().setUniqueName(subtransform).build())).build();
});
PTransform updatedSecond = updated.getComponents().getTransformsOrThrow("second");
PTransform updatedThird = updated.getComponents().getTransformsOrThrow("third");
assertThat(updatedSecond, not(equalTo(p.getComponents().getTransformsOrThrow("second"))));
assertThat(updatedThird, not(equalTo(p.getComponents().getTransformsOrThrow("third"))));
assertThat(updatedSecond.getSubtransformsList(), contains("second_sub"));
assertThat(updatedSecond.getSpec().getPayload(), equalTo(newPayload));
assertThat(updatedThird.getSubtransformsList(), contains("third_sub"));
assertThat(updatedThird.getSpec().getPayload(), equalTo(newPayload));
assertThat(updated.getComponents().getTransformsMap(), hasKey("second_sub"));
assertThat(updated.getComponents().getTransformsMap(), hasKey("third_sub"));
assertThat(updated.getComponents().getTransformsOrThrow("second_sub").getUniqueName(), equalTo("second_sub"));
assertThat(updated.getComponents().getTransformsOrThrow("third_sub").getUniqueName(), equalTo("third_sub"));
}
use of org.apache.beam.model.pipeline.v1.RunnerApi.Pipeline in project beam by apache.
the class QueryablePipelineTest method forTransformsWithSubgraph.
@Test
public void forTransformsWithSubgraph() {
Components components = Components.newBuilder().putTransforms("root", PTransform.newBuilder().putOutputs("output", "output.out").build()).putPcollections("output.out", RunnerApi.PCollection.newBuilder().setUniqueName("output.out").build()).putTransforms("consumer", PTransform.newBuilder().putInputs("input", "output.out").build()).putTransforms("ignored", PTransform.newBuilder().putInputs("input", "output.out").build()).build();
QueryablePipeline pipeline = QueryablePipeline.forTransforms(ImmutableSet.of("root", "consumer"), components);
assertThat(pipeline.getRootTransforms(), contains(PipelineNode.pTransform("root", components.getTransformsOrThrow("root"))));
Set<PTransformNode> consumers = pipeline.getPerElementConsumers(PipelineNode.pCollection("output.out", components.getPcollectionsOrThrow("output.out")));
assertThat(consumers, contains(PipelineNode.pTransform("consumer", components.getTransformsOrThrow("consumer"))));
}
use of org.apache.beam.model.pipeline.v1.RunnerApi.Pipeline in project beam by apache.
the class FusedPipelineTest method testToProto.
@Test
public void testToProto() {
Pipeline p = Pipeline.create();
p.apply("impulse", Impulse.create()).apply("map", MapElements.into(TypeDescriptors.integers()).via(bytes -> bytes.length)).apply("key", WithKeys.of("foo")).apply("gbk", GroupByKey.create()).apply("values", Values.create());
RunnerApi.Pipeline protoPipeline = PipelineTranslation.toProto(p);
checkState(protoPipeline.getRootTransformIdsList().containsAll(ImmutableList.of("impulse", "map", "key", "gbk", "values")), "Unexpected Root Transform IDs %s", protoPipeline.getRootTransformIdsList());
FusedPipeline fused = GreedyPipelineFuser.fuse(protoPipeline);
checkState(fused.getRunnerExecutedTransforms().size() == 2, "Unexpected number of runner transforms %s", fused.getRunnerExecutedTransforms());
checkState(fused.getFusedStages().size() == 2, "Unexpected number of fused stages %s", fused.getFusedStages());
RunnerApi.Pipeline fusedPipelineProto = fused.toPipeline();
assertThat("Root Transforms should all be present in the Pipeline Components", fusedPipelineProto.getComponents().getTransformsMap().keySet(), hasItems(fusedPipelineProto.getRootTransformIdsList().toArray(new String[0])));
assertThat("Should contain Impulse, GroupByKey, and two Environment Stages", fusedPipelineProto.getRootTransformIdsCount(), equalTo(4));
assertThat(fusedPipelineProto.getRootTransformIdsList(), hasItems("impulse", "gbk"));
assertRootsInTopologicalOrder(fusedPipelineProto);
// Since MapElements, WithKeys, and Values are all composites of a ParDo, we do prefix matching
// instead of looking at the inside of their expansions
assertThat("Fused transforms should be present in the components", fusedPipelineProto.getComponents().getTransformsMap(), allOf(hasKey(startsWith("map")), hasKey(startsWith("key")), hasKey(startsWith("values"))));
assertThat("Fused transforms shouldn't be present in the root IDs", fusedPipelineProto.getRootTransformIdsList(), not(hasItems(startsWith("map"), startsWith("key"), startsWith("values"))));
// The other components should be those of the original pipeline.
assertThat(fusedPipelineProto.getComponents().getCodersMap(), equalTo(protoPipeline.getComponents().getCodersMap()));
assertThat(fusedPipelineProto.getComponents().getWindowingStrategiesMap(), equalTo(protoPipeline.getComponents().getWindowingStrategiesMap()));
assertThat(fusedPipelineProto.getComponents().getEnvironmentsMap(), equalTo(protoPipeline.getComponents().getEnvironmentsMap()));
assertThat(fusedPipelineProto.getComponents().getPcollectionsMap(), equalTo(protoPipeline.getComponents().getPcollectionsMap()));
}
use of org.apache.beam.model.pipeline.v1.RunnerApi.Pipeline in project beam by apache.
the class RemoteExecutionTest method testExecutionWithUserStateCaching.
@Test
public void testExecutionWithUserStateCaching() throws Exception {
Pipeline p = Pipeline.create();
launchSdkHarness(p.getOptions());
final String stateId = "foo";
final String stateId2 = "bar";
p.apply("impulse", Impulse.create()).apply("create", ParDo.of(new DoFn<byte[], KV<String, String>>() {
@ProcessElement
public void process(ProcessContext ctxt) {
}
})).setCoder(KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of())).apply("userState", ParDo.of(new DoFn<KV<String, String>, KV<String, String>>() {
@StateId(stateId)
private final StateSpec<BagState<String>> bufferState = StateSpecs.bag(StringUtf8Coder.of());
@StateId(stateId2)
private final StateSpec<BagState<String>> bufferState2 = StateSpecs.bag(StringUtf8Coder.of());
@ProcessElement
public void processElement(@Element KV<String, String> element, @StateId(stateId) BagState<String> state, @StateId(stateId2) BagState<String> state2, OutputReceiver<KV<String, String>> r) {
for (String value : state.read()) {
r.output(KV.of(element.getKey(), value));
}
ReadableState<Boolean> isEmpty = state2.isEmpty();
if (isEmpty.read()) {
r.output(KV.of(element.getKey(), "Empty"));
} else {
state2.clear();
}
}
})).apply("gbk", GroupByKey.create());
RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
FusedPipeline fused = GreedyPipelineFuser.fuse(pipelineProto);
Optional<ExecutableStage> optionalStage = Iterables.tryFind(fused.getFusedStages(), (ExecutableStage stage) -> !stage.getUserStates().isEmpty());
checkState(optionalStage.isPresent(), "Expected a stage with user state.");
ExecutableStage stage = optionalStage.get();
ExecutableProcessBundleDescriptor descriptor = ProcessBundleDescriptors.fromExecutableStage("test_stage", stage, dataServer.getApiServiceDescriptor(), stateServer.getApiServiceDescriptor());
BundleProcessor processor = controlClient.getProcessor(descriptor.getProcessBundleDescriptor(), descriptor.getRemoteInputDestinations(), stateDelegator);
Map<String, Coder> remoteOutputCoders = descriptor.getRemoteOutputCoders();
Map<String, Collection<WindowedValue<?>>> outputValues = new HashMap<>();
Map<String, RemoteOutputReceiver<?>> outputReceivers = new HashMap<>();
for (Entry<String, Coder> remoteOutputCoder : remoteOutputCoders.entrySet()) {
List<WindowedValue<?>> outputContents = Collections.synchronizedList(new ArrayList<>());
outputValues.put(remoteOutputCoder.getKey(), outputContents);
outputReceivers.put(remoteOutputCoder.getKey(), RemoteOutputReceiver.of((Coder<WindowedValue<?>>) remoteOutputCoder.getValue(), outputContents::add));
}
Map<String, List<ByteString>> userStateData = ImmutableMap.of(stateId, new ArrayList(Arrays.asList(ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "A", Coder.Context.NESTED)), ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "B", Coder.Context.NESTED)), ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "C", Coder.Context.NESTED)))), stateId2, new ArrayList(Arrays.asList(ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "D", Coder.Context.NESTED)))));
StoringStateRequestHandler stateRequestHandler = new StoringStateRequestHandler(StateRequestHandlers.forBagUserStateHandlerFactory(descriptor, new BagUserStateHandlerFactory<ByteString, Object, BoundedWindow>() {
@Override
public BagUserStateHandler<ByteString, Object, BoundedWindow> forUserState(String pTransformId, String userStateId, Coder<ByteString> keyCoder, Coder<Object> valueCoder, Coder<BoundedWindow> windowCoder) {
return new BagUserStateHandler<ByteString, Object, BoundedWindow>() {
@Override
public Iterable<Object> get(ByteString key, BoundedWindow window) {
return (Iterable) userStateData.get(userStateId);
}
@Override
public void append(ByteString key, BoundedWindow window, Iterator<Object> values) {
Iterators.addAll(userStateData.get(userStateId), (Iterator) values);
}
@Override
public void clear(ByteString key, BoundedWindow window) {
userStateData.get(userStateId).clear();
}
};
}
}));
try (RemoteBundle bundle = processor.newBundle(outputReceivers, stateRequestHandler, BundleProgressHandler.ignored())) {
Iterables.getOnlyElement(bundle.getInputReceivers().values()).accept(valueInGlobalWindow(KV.of("X", "Y")));
}
try (RemoteBundle bundle2 = processor.newBundle(outputReceivers, stateRequestHandler, BundleProgressHandler.ignored())) {
Iterables.getOnlyElement(bundle2.getInputReceivers().values()).accept(valueInGlobalWindow(KV.of("X", "Z")));
}
for (Collection<WindowedValue<?>> windowedValues : outputValues.values()) {
assertThat(windowedValues, containsInAnyOrder(valueInGlobalWindow(KV.of("X", "A")), valueInGlobalWindow(KV.of("X", "B")), valueInGlobalWindow(KV.of("X", "C")), valueInGlobalWindow(KV.of("X", "A")), valueInGlobalWindow(KV.of("X", "B")), valueInGlobalWindow(KV.of("X", "C")), valueInGlobalWindow(KV.of("X", "Empty"))));
}
assertThat(userStateData.get(stateId), IsIterableContainingInOrder.contains(ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "A", Coder.Context.NESTED)), ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "B", Coder.Context.NESTED)), ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "C", Coder.Context.NESTED))));
assertThat(userStateData.get(stateId2), IsEmptyIterable.emptyIterable());
// 3 Requests expected: state read, state2 read, and state2 clear
assertEquals(3, stateRequestHandler.getRequestCount());
ByteString.Output out = ByteString.newOutput();
StringUtf8Coder.of().encode("X", out);
assertEquals(stateId, stateRequestHandler.receivedRequests.get(0).getStateKey().getBagUserState().getUserStateId());
assertEquals(stateRequestHandler.receivedRequests.get(0).getStateKey().getBagUserState().getKey(), out.toByteString());
assertTrue(stateRequestHandler.receivedRequests.get(0).hasGet());
assertEquals(stateId2, stateRequestHandler.receivedRequests.get(1).getStateKey().getBagUserState().getUserStateId());
assertEquals(stateRequestHandler.receivedRequests.get(1).getStateKey().getBagUserState().getKey(), out.toByteString());
assertTrue(stateRequestHandler.receivedRequests.get(1).hasGet());
assertEquals(stateId2, stateRequestHandler.receivedRequests.get(2).getStateKey().getBagUserState().getUserStateId());
assertEquals(stateRequestHandler.receivedRequests.get(2).getStateKey().getBagUserState().getKey(), out.toByteString());
assertTrue(stateRequestHandler.receivedRequests.get(2).hasClear());
}
Aggregations