use of org.apache.beam.runners.fnexecution.state.StateRequestHandler in project beam by apache.
the class RemoteExecutionTest method testExecutionWithUserState.
@Test
public void testExecutionWithUserState() throws Exception {
launchSdkHarness(PipelineOptionsFactory.create());
Pipeline p = Pipeline.create();
final String stateId = "foo";
final String stateId2 = "foo2";
p.apply("impulse", Impulse.create()).apply("create", ParDo.of(new DoFn<byte[], KV<String, String>>() {
@ProcessElement
public void process(ProcessContext ctxt) {
}
})).setCoder(KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of())).apply("userState", ParDo.of(new DoFn<KV<String, String>, KV<String, String>>() {
@StateId(stateId)
private final StateSpec<BagState<String>> bufferState = StateSpecs.bag(StringUtf8Coder.of());
@StateId(stateId2)
private final StateSpec<BagState<String>> bufferState2 = StateSpecs.bag(StringUtf8Coder.of());
@ProcessElement
public void processElement(@Element KV<String, String> element, @StateId(stateId) BagState<String> state, @StateId(stateId2) BagState<String> state2, OutputReceiver<KV<String, String>> r) {
for (String value : state.read()) {
r.output(KV.of(element.getKey(), value));
}
state.add(element.getValue());
state2.clear();
}
})).apply("gbk", GroupByKey.create());
RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
FusedPipeline fused = GreedyPipelineFuser.fuse(pipelineProto);
Optional<ExecutableStage> optionalStage = Iterables.tryFind(fused.getFusedStages(), (ExecutableStage stage) -> !stage.getUserStates().isEmpty());
checkState(optionalStage.isPresent(), "Expected a stage with user state.");
ExecutableStage stage = optionalStage.get();
ExecutableProcessBundleDescriptor descriptor = ProcessBundleDescriptors.fromExecutableStage("test_stage", stage, dataServer.getApiServiceDescriptor(), stateServer.getApiServiceDescriptor());
BundleProcessor processor = controlClient.getProcessor(descriptor.getProcessBundleDescriptor(), descriptor.getRemoteInputDestinations(), stateDelegator);
Map<String, Coder> remoteOutputCoders = descriptor.getRemoteOutputCoders();
Map<String, Collection<WindowedValue<?>>> outputValues = new HashMap<>();
Map<String, RemoteOutputReceiver<?>> outputReceivers = new HashMap<>();
for (Entry<String, Coder> remoteOutputCoder : remoteOutputCoders.entrySet()) {
List<WindowedValue<?>> outputContents = Collections.synchronizedList(new ArrayList<>());
outputValues.put(remoteOutputCoder.getKey(), outputContents);
outputReceivers.put(remoteOutputCoder.getKey(), RemoteOutputReceiver.of((Coder<WindowedValue<?>>) remoteOutputCoder.getValue(), outputContents::add));
}
Map<String, List<ByteString>> userStateData = ImmutableMap.of(stateId, new ArrayList(Arrays.asList(ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "A", Coder.Context.NESTED)), ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "B", Coder.Context.NESTED)), ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "C", Coder.Context.NESTED)))), stateId2, new ArrayList(Arrays.asList(ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "D", Coder.Context.NESTED)))));
StateRequestHandler stateRequestHandler = StateRequestHandlers.forBagUserStateHandlerFactory(descriptor, new BagUserStateHandlerFactory<ByteString, Object, BoundedWindow>() {
@Override
public BagUserStateHandler<ByteString, Object, BoundedWindow> forUserState(String pTransformId, String userStateId, Coder<ByteString> keyCoder, Coder<Object> valueCoder, Coder<BoundedWindow> windowCoder) {
return new BagUserStateHandler<ByteString, Object, BoundedWindow>() {
@Override
public Iterable<Object> get(ByteString key, BoundedWindow window) {
return (Iterable) userStateData.get(userStateId);
}
@Override
public void append(ByteString key, BoundedWindow window, Iterator<Object> values) {
Iterators.addAll(userStateData.get(userStateId), (Iterator) values);
}
@Override
public void clear(ByteString key, BoundedWindow window) {
userStateData.get(userStateId).clear();
}
};
}
});
try (RemoteBundle bundle = processor.newBundle(outputReceivers, stateRequestHandler, BundleProgressHandler.ignored())) {
Iterables.getOnlyElement(bundle.getInputReceivers().values()).accept(valueInGlobalWindow(KV.of("X", "Y")));
}
for (Collection<WindowedValue<?>> windowedValues : outputValues.values()) {
assertThat(windowedValues, containsInAnyOrder(valueInGlobalWindow(KV.of("X", "A")), valueInGlobalWindow(KV.of("X", "B")), valueInGlobalWindow(KV.of("X", "C"))));
}
assertThat(userStateData.get(stateId), IsIterableContainingInOrder.contains(ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "A", Coder.Context.NESTED)), ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "B", Coder.Context.NESTED)), ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "C", Coder.Context.NESTED)), ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "Y", Coder.Context.NESTED))));
assertThat(userStateData.get(stateId2), IsEmptyIterable.emptyIterable());
}
use of org.apache.beam.runners.fnexecution.state.StateRequestHandler in project beam by apache.
the class SdkHarnessClientTest method handleCleanupWithStateWhenAwaitingOnClosingOutputReceivers.
@Test
public void handleCleanupWithStateWhenAwaitingOnClosingOutputReceivers() throws Exception {
Exception testException = new Exception();
InboundDataClient mockOutputReceiver = mock(InboundDataClient.class);
CloseableFnDataReceiver mockInputSender = mock(CloseableFnDataReceiver.class);
StateDelegator mockStateDelegator = mock(StateDelegator.class);
StateDelegator.Registration mockStateRegistration = mock(StateDelegator.Registration.class);
when(mockStateDelegator.registerForProcessBundleInstructionId(any(), any())).thenReturn(mockStateRegistration);
StateRequestHandler mockStateHandler = mock(StateRequestHandler.class);
when(mockStateHandler.getCacheTokens()).thenReturn(Collections.emptyList());
BundleProgressHandler mockProgressHandler = mock(BundleProgressHandler.class);
CompletableFuture<InstructionResponse> processBundleResponseFuture = new CompletableFuture<>();
when(fnApiControlClient.handle(any(BeamFnApi.InstructionRequest.class))).thenReturn(processBundleResponseFuture);
FullWindowedValueCoder<String> coder = FullWindowedValueCoder.of(StringUtf8Coder.of(), Coder.INSTANCE);
BundleProcessor processor = sdkHarnessClient.getProcessor(descriptor, Collections.singletonList(RemoteInputDestination.of((FullWindowedValueCoder) coder, SDK_GRPC_READ_TRANSFORM)), mockStateDelegator);
when(dataService.receive(any(), any(), any())).thenReturn(mockOutputReceiver);
when(dataService.send(any(), eq(coder))).thenReturn(mockInputSender);
doThrow(testException).when(mockOutputReceiver).awaitCompletion();
RemoteOutputReceiver mockRemoteOutputReceiver = mock(RemoteOutputReceiver.class);
try {
try (RemoteBundle activeBundle = processor.newBundle(ImmutableMap.of(SDK_GRPC_WRITE_TRANSFORM, mockRemoteOutputReceiver), mockStateHandler, mockProgressHandler)) {
// Correlating the ProcessBundleRequest and ProcessBundleResponse is owned by the underlying
// FnApiControlClient. The SdkHarnessClient owns just wrapping the request and unwrapping
// the response.
//
// Currently there are no fields so there's nothing to check. This test is formulated
// to match the pattern it should have if/when the response is meaningful.
BeamFnApi.ProcessBundleResponse response = BeamFnApi.ProcessBundleResponse.getDefaultInstance();
processBundleResponseFuture.complete(BeamFnApi.InstructionResponse.newBuilder().setProcessBundle(response).build());
}
fail("Exception expected");
} catch (Exception e) {
assertEquals(testException, e);
}
}
use of org.apache.beam.runners.fnexecution.state.StateRequestHandler in project beam by apache.
the class SdkHarnessClientTest method handleCleanupWithStateWhenProcessingBundleFails.
@Test
public void handleCleanupWithStateWhenProcessingBundleFails() throws Exception {
Exception testException = new Exception();
InboundDataClient mockOutputReceiver = mock(InboundDataClient.class);
CloseableFnDataReceiver mockInputSender = mock(CloseableFnDataReceiver.class);
StateDelegator mockStateDelegator = mock(StateDelegator.class);
StateDelegator.Registration mockStateRegistration = mock(StateDelegator.Registration.class);
when(mockStateDelegator.registerForProcessBundleInstructionId(any(), any())).thenReturn(mockStateRegistration);
StateRequestHandler mockStateHandler = mock(StateRequestHandler.class);
when(mockStateHandler.getCacheTokens()).thenReturn(Collections.emptyList());
BundleProgressHandler mockProgressHandler = mock(BundleProgressHandler.class);
CompletableFuture<InstructionResponse> processBundleResponseFuture = new CompletableFuture<>();
when(fnApiControlClient.handle(any(BeamFnApi.InstructionRequest.class))).thenReturn(processBundleResponseFuture);
FullWindowedValueCoder<String> coder = FullWindowedValueCoder.of(StringUtf8Coder.of(), Coder.INSTANCE);
BundleProcessor processor = sdkHarnessClient.getProcessor(descriptor, Collections.singletonList(RemoteInputDestination.of((FullWindowedValueCoder) coder, SDK_GRPC_READ_TRANSFORM)), mockStateDelegator);
when(dataService.receive(any(), any(), any())).thenReturn(mockOutputReceiver);
when(dataService.send(any(), eq(coder))).thenReturn(mockInputSender);
RemoteOutputReceiver mockRemoteOutputReceiver = mock(RemoteOutputReceiver.class);
try {
try (RemoteBundle activeBundle = processor.newBundle(ImmutableMap.of(SDK_GRPC_WRITE_TRANSFORM, mockRemoteOutputReceiver), mockStateHandler, mockProgressHandler)) {
processBundleResponseFuture.completeExceptionally(testException);
}
fail("Exception expected");
} catch (ExecutionException e) {
assertEquals(testException, e.getCause());
verify(mockStateRegistration).abort();
verify(mockOutputReceiver).cancel();
verifyNoMoreInteractions(mockStateRegistration, mockOutputReceiver);
}
}
use of org.apache.beam.runners.fnexecution.state.StateRequestHandler in project beam by apache.
the class SamzaStateRequestHandlers method of.
public static StateRequestHandler of(String transformId, TaskContext context, SamzaPipelineOptions pipelineOptions, ExecutableStage executableStage, StageBundleFactory stageBundleFactory, Map<RunnerApi.ExecutableStagePayload.SideInputId, PCollectionView<?>> sideInputIds, SideInputHandler sideInputHandler) {
final StateRequestHandler sideInputStateHandler = createSideInputStateHandler(executableStage, sideInputIds, sideInputHandler);
final StateRequestHandler userStateRequestHandler = createUserStateRequestHandler(transformId, executableStage, context, pipelineOptions, stageBundleFactory);
final EnumMap<BeamFnApi.StateKey.TypeCase, StateRequestHandler> handlerMap = new EnumMap<>(BeamFnApi.StateKey.TypeCase.class);
handlerMap.put(BeamFnApi.StateKey.TypeCase.ITERABLE_SIDE_INPUT, sideInputStateHandler);
handlerMap.put(BeamFnApi.StateKey.TypeCase.MULTIMAP_SIDE_INPUT, sideInputStateHandler);
handlerMap.put(BeamFnApi.StateKey.TypeCase.MULTIMAP_KEYS_SIDE_INPUT, sideInputStateHandler);
handlerMap.put(BeamFnApi.StateKey.TypeCase.BAG_USER_STATE, userStateRequestHandler);
return StateRequestHandlers.delegateBasedUponType(handlerMap);
}
use of org.apache.beam.runners.fnexecution.state.StateRequestHandler in project beam by apache.
the class SparkExecutableStageFunction method call.
@Override
public Iterator<RawUnionValue> call(Iterator<WindowedValue<InputT>> inputs) throws Exception {
SparkPipelineOptions options = pipelineOptions.get().as(SparkPipelineOptions.class);
// Register standard file systems.
FileSystems.setDefaultPipelineOptions(options);
// Otherwise, this may cause validation errors (e.g. ParDoTest)
if (!inputs.hasNext()) {
return Collections.emptyIterator();
}
try (ExecutableStageContext stageContext = contextFactory.get(jobInfo)) {
ExecutableStage executableStage = ExecutableStage.fromPayload(stagePayload);
try (StageBundleFactory stageBundleFactory = stageContext.getStageBundleFactory(executableStage)) {
ConcurrentLinkedQueue<RawUnionValue> collector = new ConcurrentLinkedQueue<>();
StateRequestHandler stateRequestHandler = getStateRequestHandler(executableStage, stageBundleFactory.getProcessBundleDescriptor());
if (executableStage.getTimers().size() == 0) {
ReceiverFactory receiverFactory = new ReceiverFactory(collector, outputMap);
processElements(stateRequestHandler, receiverFactory, null, stageBundleFactory, inputs);
return collector.iterator();
}
// Used with Batch, we know that all the data is available for this key. We can't use the
// timer manager from the context because it doesn't exist. So we create one and advance
// time to the end after processing all elements.
final InMemoryTimerInternals timerInternals = new InMemoryTimerInternals();
timerInternals.advanceProcessingTime(Instant.now());
timerInternals.advanceSynchronizedProcessingTime(Instant.now());
ReceiverFactory receiverFactory = new ReceiverFactory(collector, outputMap);
TimerReceiverFactory timerReceiverFactory = new TimerReceiverFactory(stageBundleFactory, (Timer<?> timer, TimerInternals.TimerData timerData) -> {
currentTimerKey = timer.getUserKey();
if (timer.getClearBit()) {
timerInternals.deleteTimer(timerData);
} else {
timerInternals.setTimer(timerData);
}
}, windowCoder);
// Process inputs.
processElements(stateRequestHandler, receiverFactory, timerReceiverFactory, stageBundleFactory, inputs);
// Finish any pending windows by advancing the input watermark to infinity.
timerInternals.advanceInputWatermark(BoundedWindow.TIMESTAMP_MAX_VALUE);
// Finally, advance the processing time to infinity to fire any timers.
timerInternals.advanceProcessingTime(BoundedWindow.TIMESTAMP_MAX_VALUE);
timerInternals.advanceSynchronizedProcessingTime(BoundedWindow.TIMESTAMP_MAX_VALUE);
// itself)
while (timerInternals.hasPendingTimers()) {
try (RemoteBundle bundle = stageBundleFactory.getBundle(receiverFactory, timerReceiverFactory, stateRequestHandler, getBundleProgressHandler())) {
PipelineTranslatorUtils.fireEligibleTimers(timerInternals, bundle.getTimerReceivers(), currentTimerKey);
}
}
return collector.iterator();
}
}
}
Aggregations