use of org.apache.beam.runners.fnexecution.control.SdkHarnessClient.BundleProcessor in project beam by apache.
the class RemoteExecutionTest method testExecutionWithUserStateCaching.
@Test
public void testExecutionWithUserStateCaching() throws Exception {
Pipeline p = Pipeline.create();
launchSdkHarness(p.getOptions());
final String stateId = "foo";
final String stateId2 = "bar";
p.apply("impulse", Impulse.create()).apply("create", ParDo.of(new DoFn<byte[], KV<String, String>>() {
@ProcessElement
public void process(ProcessContext ctxt) {
}
})).setCoder(KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of())).apply("userState", ParDo.of(new DoFn<KV<String, String>, KV<String, String>>() {
@StateId(stateId)
private final StateSpec<BagState<String>> bufferState = StateSpecs.bag(StringUtf8Coder.of());
@StateId(stateId2)
private final StateSpec<BagState<String>> bufferState2 = StateSpecs.bag(StringUtf8Coder.of());
@ProcessElement
public void processElement(@Element KV<String, String> element, @StateId(stateId) BagState<String> state, @StateId(stateId2) BagState<String> state2, OutputReceiver<KV<String, String>> r) {
for (String value : state.read()) {
r.output(KV.of(element.getKey(), value));
}
ReadableState<Boolean> isEmpty = state2.isEmpty();
if (isEmpty.read()) {
r.output(KV.of(element.getKey(), "Empty"));
} else {
state2.clear();
}
}
})).apply("gbk", GroupByKey.create());
RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
FusedPipeline fused = GreedyPipelineFuser.fuse(pipelineProto);
Optional<ExecutableStage> optionalStage = Iterables.tryFind(fused.getFusedStages(), (ExecutableStage stage) -> !stage.getUserStates().isEmpty());
checkState(optionalStage.isPresent(), "Expected a stage with user state.");
ExecutableStage stage = optionalStage.get();
ExecutableProcessBundleDescriptor descriptor = ProcessBundleDescriptors.fromExecutableStage("test_stage", stage, dataServer.getApiServiceDescriptor(), stateServer.getApiServiceDescriptor());
BundleProcessor processor = controlClient.getProcessor(descriptor.getProcessBundleDescriptor(), descriptor.getRemoteInputDestinations(), stateDelegator);
Map<String, Coder> remoteOutputCoders = descriptor.getRemoteOutputCoders();
Map<String, Collection<WindowedValue<?>>> outputValues = new HashMap<>();
Map<String, RemoteOutputReceiver<?>> outputReceivers = new HashMap<>();
for (Entry<String, Coder> remoteOutputCoder : remoteOutputCoders.entrySet()) {
List<WindowedValue<?>> outputContents = Collections.synchronizedList(new ArrayList<>());
outputValues.put(remoteOutputCoder.getKey(), outputContents);
outputReceivers.put(remoteOutputCoder.getKey(), RemoteOutputReceiver.of((Coder<WindowedValue<?>>) remoteOutputCoder.getValue(), outputContents::add));
}
Map<String, List<ByteString>> userStateData = ImmutableMap.of(stateId, new ArrayList(Arrays.asList(ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "A", Coder.Context.NESTED)), ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "B", Coder.Context.NESTED)), ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "C", Coder.Context.NESTED)))), stateId2, new ArrayList(Arrays.asList(ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "D", Coder.Context.NESTED)))));
StoringStateRequestHandler stateRequestHandler = new StoringStateRequestHandler(StateRequestHandlers.forBagUserStateHandlerFactory(descriptor, new BagUserStateHandlerFactory<ByteString, Object, BoundedWindow>() {
@Override
public BagUserStateHandler<ByteString, Object, BoundedWindow> forUserState(String pTransformId, String userStateId, Coder<ByteString> keyCoder, Coder<Object> valueCoder, Coder<BoundedWindow> windowCoder) {
return new BagUserStateHandler<ByteString, Object, BoundedWindow>() {
@Override
public Iterable<Object> get(ByteString key, BoundedWindow window) {
return (Iterable) userStateData.get(userStateId);
}
@Override
public void append(ByteString key, BoundedWindow window, Iterator<Object> values) {
Iterators.addAll(userStateData.get(userStateId), (Iterator) values);
}
@Override
public void clear(ByteString key, BoundedWindow window) {
userStateData.get(userStateId).clear();
}
};
}
}));
try (RemoteBundle bundle = processor.newBundle(outputReceivers, stateRequestHandler, BundleProgressHandler.ignored())) {
Iterables.getOnlyElement(bundle.getInputReceivers().values()).accept(valueInGlobalWindow(KV.of("X", "Y")));
}
try (RemoteBundle bundle2 = processor.newBundle(outputReceivers, stateRequestHandler, BundleProgressHandler.ignored())) {
Iterables.getOnlyElement(bundle2.getInputReceivers().values()).accept(valueInGlobalWindow(KV.of("X", "Z")));
}
for (Collection<WindowedValue<?>> windowedValues : outputValues.values()) {
assertThat(windowedValues, containsInAnyOrder(valueInGlobalWindow(KV.of("X", "A")), valueInGlobalWindow(KV.of("X", "B")), valueInGlobalWindow(KV.of("X", "C")), valueInGlobalWindow(KV.of("X", "A")), valueInGlobalWindow(KV.of("X", "B")), valueInGlobalWindow(KV.of("X", "C")), valueInGlobalWindow(KV.of("X", "Empty"))));
}
assertThat(userStateData.get(stateId), IsIterableContainingInOrder.contains(ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "A", Coder.Context.NESTED)), ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "B", Coder.Context.NESTED)), ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "C", Coder.Context.NESTED))));
assertThat(userStateData.get(stateId2), IsEmptyIterable.emptyIterable());
// 3 Requests expected: state read, state2 read, and state2 clear
assertEquals(3, stateRequestHandler.getRequestCount());
ByteString.Output out = ByteString.newOutput();
StringUtf8Coder.of().encode("X", out);
assertEquals(stateId, stateRequestHandler.receivedRequests.get(0).getStateKey().getBagUserState().getUserStateId());
assertEquals(stateRequestHandler.receivedRequests.get(0).getStateKey().getBagUserState().getKey(), out.toByteString());
assertTrue(stateRequestHandler.receivedRequests.get(0).hasGet());
assertEquals(stateId2, stateRequestHandler.receivedRequests.get(1).getStateKey().getBagUserState().getUserStateId());
assertEquals(stateRequestHandler.receivedRequests.get(1).getStateKey().getBagUserState().getKey(), out.toByteString());
assertTrue(stateRequestHandler.receivedRequests.get(1).hasGet());
assertEquals(stateId2, stateRequestHandler.receivedRequests.get(2).getStateKey().getBagUserState().getUserStateId());
assertEquals(stateRequestHandler.receivedRequests.get(2).getStateKey().getBagUserState().getKey(), out.toByteString());
assertTrue(stateRequestHandler.receivedRequests.get(2).hasClear());
}
use of org.apache.beam.runners.fnexecution.control.SdkHarnessClient.BundleProcessor in project beam by apache.
the class RemoteExecutionTest method testExecutionWithMultipleStages.
@Test
public void testExecutionWithMultipleStages() throws Exception {
launchSdkHarness(PipelineOptionsFactory.create());
Pipeline p = Pipeline.create();
Function<String, PCollection<String>> pCollectionGenerator = suffix -> p.apply("impulse" + suffix, Impulse.create()).apply("create" + suffix, ParDo.of(new DoFn<byte[], String>() {
@ProcessElement
public void process(ProcessContext c) {
try {
c.output(CoderUtils.decodeFromByteArray(StringUtf8Coder.of(), c.element()));
} catch (CoderException e) {
throw new RuntimeException(e);
}
}
})).setCoder(StringUtf8Coder.of()).apply(ParDo.of(new DoFn<String, String>() {
@ProcessElement
public void processElement(ProcessContext c) {
c.output("stream" + suffix + c.element());
}
}));
PCollection<String> input1 = pCollectionGenerator.apply("1");
PCollection<String> input2 = pCollectionGenerator.apply("2");
PCollection<String> outputMerged = PCollectionList.of(input1).and(input2).apply(Flatten.pCollections());
outputMerged.apply("createKV", ParDo.of(new DoFn<String, KV<String, String>>() {
@ProcessElement
public void process(ProcessContext c) {
c.output(KV.of(c.element(), ""));
}
})).setCoder(KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of())).apply("gbk", GroupByKey.create());
RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
FusedPipeline fused = GreedyPipelineFuser.fuse(pipelineProto);
Set<ExecutableStage> stages = fused.getFusedStages();
assertThat(stages.size(), equalTo(2));
List<WindowedValue<?>> outputValues = Collections.synchronizedList(new ArrayList<>());
for (ExecutableStage stage : stages) {
ExecutableProcessBundleDescriptor descriptor = ProcessBundleDescriptors.fromExecutableStage(stage.toString(), stage, dataServer.getApiServiceDescriptor(), stateServer.getApiServiceDescriptor());
BundleProcessor processor = controlClient.getProcessor(descriptor.getProcessBundleDescriptor(), descriptor.getRemoteInputDestinations(), stateDelegator);
Map<String, Coder> remoteOutputCoders = descriptor.getRemoteOutputCoders();
Map<String, RemoteOutputReceiver<?>> outputReceivers = new HashMap<>();
for (Entry<String, Coder> remoteOutputCoder : remoteOutputCoders.entrySet()) {
outputReceivers.putIfAbsent(remoteOutputCoder.getKey(), RemoteOutputReceiver.of((Coder<WindowedValue<?>>) remoteOutputCoder.getValue(), outputValues::add));
}
try (RemoteBundle bundle = processor.newBundle(outputReceivers, StateRequestHandler.unsupported(), BundleProgressHandler.ignored())) {
Iterables.getOnlyElement(bundle.getInputReceivers().values()).accept(valueInGlobalWindow(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "X")));
}
}
assertThat(outputValues, containsInAnyOrder(valueInGlobalWindow(KV.of("stream1X", "")), valueInGlobalWindow(KV.of("stream2X", ""))));
}
use of org.apache.beam.runners.fnexecution.control.SdkHarnessClient.BundleProcessor in project beam by apache.
the class SdkHarnessClientTest method testClosingActiveBundleMultipleTimesIsNoop.
@Test
public void testClosingActiveBundleMultipleTimesIsNoop() throws Exception {
CompletableFuture<InstructionResponse> processBundleResponseFuture = new CompletableFuture<>();
when(fnApiControlClient.handle(any(BeamFnApi.InstructionRequest.class))).thenReturn(processBundleResponseFuture);
FullWindowedValueCoder<String> coder = FullWindowedValueCoder.of(StringUtf8Coder.of(), Coder.INSTANCE);
BundleProcessor processor = sdkHarnessClient.getProcessor(descriptor, Collections.singletonList(RemoteInputDestination.of((FullWindowedValueCoder) coder, SDK_GRPC_READ_TRANSFORM)));
when(dataService.send(any(), eq(coder))).thenReturn(mock(CloseableFnDataReceiver.class));
RemoteBundle activeBundle = processor.newBundle(Collections.emptyMap(), BundleProgressHandler.ignored());
// Correlating the request and response is owned by the underlying
// FnApiControlClient. The SdkHarnessClient owns just wrapping the request and unwrapping
// the response.
//
// Currently there are no fields so there's nothing to check. This test is formulated
// to match the pattern it should have if/when the response is meaningful.
BeamFnApi.ProcessBundleResponse response = ProcessBundleResponse.getDefaultInstance();
processBundleResponseFuture.complete(BeamFnApi.InstructionResponse.newBuilder().setProcessBundle(response).build());
activeBundle.close();
activeBundle.close();
}
use of org.apache.beam.runners.fnexecution.control.SdkHarnessClient.BundleProcessor in project beam by apache.
the class SdkHarnessClientTest method verifyCacheTokensAreUsedInNewBundleRequest.
@Test
public void verifyCacheTokensAreUsedInNewBundleRequest() throws InterruptedException {
when(fnApiControlClient.handle(any(BeamFnApi.InstructionRequest.class))).thenReturn(CompletableFuture.<InstructionResponse>completedFuture(InstructionResponse.newBuilder().build()));
ProcessBundleDescriptor descriptor1 = ProcessBundleDescriptor.newBuilder().setId("descriptor1").build();
List<RemoteInputDestination> remoteInputs = Collections.singletonList(RemoteInputDestination.of(FullWindowedValueCoder.of(VarIntCoder.of(), GlobalWindow.Coder.INSTANCE), SDK_GRPC_READ_TRANSFORM));
BundleProcessor processor1 = sdkHarnessClient.getProcessor(descriptor1, remoteInputs);
when(dataService.send(any(), any())).thenReturn(mock(CloseableFnDataReceiver.class));
StateRequestHandler stateRequestHandler = Mockito.mock(StateRequestHandler.class);
List<BeamFnApi.ProcessBundleRequest.CacheToken> cacheTokens = Collections.singletonList(BeamFnApi.ProcessBundleRequest.CacheToken.newBuilder().getDefaultInstanceForType());
when(stateRequestHandler.getCacheTokens()).thenReturn(cacheTokens);
processor1.newBundle(ImmutableMap.of(SDK_GRPC_WRITE_TRANSFORM, mock(RemoteOutputReceiver.class)), stateRequestHandler, BundleProgressHandler.ignored());
// Retrieve the requests made to the FnApiControlClient
ArgumentCaptor<BeamFnApi.InstructionRequest> reqCaptor = ArgumentCaptor.forClass(BeamFnApi.InstructionRequest.class);
Mockito.verify(fnApiControlClient, Mockito.times(1)).handle(reqCaptor.capture());
List<BeamFnApi.InstructionRequest> requests = reqCaptor.getAllValues();
// Verify that the cache tokens are included in the ProcessBundleRequest
assertThat(requests.get(0).getRequestCase(), is(BeamFnApi.InstructionRequest.RequestCase.PROCESS_BUNDLE));
assertThat(requests.get(0).getProcessBundle().getCacheTokensList(), is(cacheTokens));
}
use of org.apache.beam.runners.fnexecution.control.SdkHarnessClient.BundleProcessor in project beam by apache.
the class SdkHarnessClientTest method handleCleanupWhenAwaitingOnClosingOutputReceivers.
@Test
public void handleCleanupWhenAwaitingOnClosingOutputReceivers() throws Exception {
Exception testException = new Exception();
InboundDataClient mockOutputReceiver = mock(InboundDataClient.class);
CloseableFnDataReceiver mockInputSender = mock(CloseableFnDataReceiver.class);
CompletableFuture<InstructionResponse> processBundleResponseFuture = new CompletableFuture<>();
when(fnApiControlClient.handle(any(BeamFnApi.InstructionRequest.class))).thenReturn(processBundleResponseFuture);
FullWindowedValueCoder<String> coder = FullWindowedValueCoder.of(StringUtf8Coder.of(), Coder.INSTANCE);
BundleProcessor processor = sdkHarnessClient.getProcessor(descriptor, Collections.singletonList(RemoteInputDestination.of((FullWindowedValueCoder) coder, SDK_GRPC_READ_TRANSFORM)));
when(dataService.receive(any(), any(), any())).thenReturn(mockOutputReceiver);
when(dataService.send(any(), eq(coder))).thenReturn(mockInputSender);
doThrow(testException).when(mockOutputReceiver).awaitCompletion();
RemoteOutputReceiver mockRemoteOutputReceiver = mock(RemoteOutputReceiver.class);
BundleProgressHandler mockProgressHandler = mock(BundleProgressHandler.class);
try {
try (RemoteBundle activeBundle = processor.newBundle(ImmutableMap.of(SDK_GRPC_WRITE_TRANSFORM, mockRemoteOutputReceiver), mockProgressHandler)) {
// Correlating the ProcessBundleRequest and ProcessBundleResponse is owned by the underlying
// FnApiControlClient. The SdkHarnessClient owns just wrapping the request and unwrapping
// the response.
//
// Currently there are no fields so there's nothing to check. This test is formulated
// to match the pattern it should have if/when the response is meaningful.
BeamFnApi.ProcessBundleResponse response = BeamFnApi.ProcessBundleResponse.getDefaultInstance();
processBundleResponseFuture.complete(BeamFnApi.InstructionResponse.newBuilder().setProcessBundle(response).build());
}
fail("Exception expected");
} catch (Exception e) {
assertEquals(testException, e);
}
}
Aggregations