use of org.apache.beam.runners.fnexecution.state.StateRequestHandler in project beam by apache.
the class SdkHarnessClientTest method verifyCacheTokensAreUsedInNewBundleRequest.
@Test
public void verifyCacheTokensAreUsedInNewBundleRequest() throws InterruptedException {
when(fnApiControlClient.handle(any(BeamFnApi.InstructionRequest.class))).thenReturn(CompletableFuture.<InstructionResponse>completedFuture(InstructionResponse.newBuilder().build()));
ProcessBundleDescriptor descriptor1 = ProcessBundleDescriptor.newBuilder().setId("descriptor1").build();
List<RemoteInputDestination> remoteInputs = Collections.singletonList(RemoteInputDestination.of(FullWindowedValueCoder.of(VarIntCoder.of(), GlobalWindow.Coder.INSTANCE), SDK_GRPC_READ_TRANSFORM));
BundleProcessor processor1 = sdkHarnessClient.getProcessor(descriptor1, remoteInputs);
when(dataService.send(any(), any())).thenReturn(mock(CloseableFnDataReceiver.class));
StateRequestHandler stateRequestHandler = Mockito.mock(StateRequestHandler.class);
List<BeamFnApi.ProcessBundleRequest.CacheToken> cacheTokens = Collections.singletonList(BeamFnApi.ProcessBundleRequest.CacheToken.newBuilder().getDefaultInstanceForType());
when(stateRequestHandler.getCacheTokens()).thenReturn(cacheTokens);
processor1.newBundle(ImmutableMap.of(SDK_GRPC_WRITE_TRANSFORM, mock(RemoteOutputReceiver.class)), stateRequestHandler, BundleProgressHandler.ignored());
// Retrieve the requests made to the FnApiControlClient
ArgumentCaptor<BeamFnApi.InstructionRequest> reqCaptor = ArgumentCaptor.forClass(BeamFnApi.InstructionRequest.class);
Mockito.verify(fnApiControlClient, Mockito.times(1)).handle(reqCaptor.capture());
List<BeamFnApi.InstructionRequest> requests = reqCaptor.getAllValues();
// Verify that the cache tokens are included in the ProcessBundleRequest
assertThat(requests.get(0).getRequestCase(), is(BeamFnApi.InstructionRequest.RequestCase.PROCESS_BUNDLE));
assertThat(requests.get(0).getProcessBundle().getCacheTokensList(), is(cacheTokens));
}
use of org.apache.beam.runners.fnexecution.state.StateRequestHandler in project beam by apache.
the class SdkHarnessClientTest method handleCleanupWithStateWhenInputSenderFails.
@Test
public void handleCleanupWithStateWhenInputSenderFails() throws Exception {
Exception testException = new Exception();
InboundDataClient mockOutputReceiver = mock(InboundDataClient.class);
CloseableFnDataReceiver mockInputSender = mock(CloseableFnDataReceiver.class);
StateDelegator mockStateDelegator = mock(StateDelegator.class);
StateDelegator.Registration mockStateRegistration = mock(StateDelegator.Registration.class);
when(mockStateDelegator.registerForProcessBundleInstructionId(any(), any())).thenReturn(mockStateRegistration);
StateRequestHandler mockStateHandler = mock(StateRequestHandler.class);
when(mockStateHandler.getCacheTokens()).thenReturn(Collections.emptyList());
BundleProgressHandler mockProgressHandler = mock(BundleProgressHandler.class);
CompletableFuture<InstructionResponse> processBundleResponseFuture = new CompletableFuture<>();
when(fnApiControlClient.handle(any(BeamFnApi.InstructionRequest.class))).thenReturn(processBundleResponseFuture);
FullWindowedValueCoder<String> coder = FullWindowedValueCoder.of(StringUtf8Coder.of(), Coder.INSTANCE);
BundleProcessor processor = sdkHarnessClient.getProcessor(descriptor, Collections.singletonList(RemoteInputDestination.of((FullWindowedValueCoder) coder, SDK_GRPC_READ_TRANSFORM)), mockStateDelegator);
when(dataService.receive(any(), any(), any())).thenReturn(mockOutputReceiver);
when(dataService.send(any(), eq(coder))).thenReturn(mockInputSender);
doThrow(testException).when(mockInputSender).close();
RemoteOutputReceiver mockRemoteOutputReceiver = mock(RemoteOutputReceiver.class);
try {
try (RemoteBundle activeBundle = processor.newBundle(ImmutableMap.of(SDK_GRPC_WRITE_TRANSFORM, mockRemoteOutputReceiver), mockStateHandler, mockProgressHandler)) {
// We shouldn't be required to complete the process bundle response future.
}
fail("Exception expected");
} catch (Exception e) {
assertEquals(testException, e);
verify(mockStateRegistration).abort();
verify(mockOutputReceiver).cancel();
verifyNoMoreInteractions(mockStateRegistration, mockOutputReceiver);
}
}
use of org.apache.beam.runners.fnexecution.state.StateRequestHandler in project beam by apache.
the class DefaultJobBundleFactoryTest method loadBalancesBundles.
@Test
public void loadBalancesBundles() throws Exception {
PortablePipelineOptions portableOptions = PipelineOptionsFactory.as(PortablePipelineOptions.class);
portableOptions.setSdkWorkerParallelism(2);
portableOptions.setLoadBalanceBundles(true);
Struct pipelineOptions = PipelineOptionsTranslation.toProto(portableOptions);
try (DefaultJobBundleFactory bundleFactory = new DefaultJobBundleFactory(JobInfo.create("testJob", "testJob", "token", pipelineOptions), envFactoryProviderMap, stageIdGenerator, serverInfo)) {
OutputReceiverFactory orf = mock(OutputReceiverFactory.class);
StateRequestHandler srh = mock(StateRequestHandler.class);
when(srh.getCacheTokens()).thenReturn(Collections.emptyList());
StageBundleFactory sbf = bundleFactory.forStage(getExecutableStage(environment));
RemoteBundle b1 = sbf.getBundle(orf, srh, BundleProgressHandler.ignored());
verify(envFactory, Mockito.times(1)).createEnvironment(eq(environment), any());
final RemoteBundle b2 = sbf.getBundle(orf, srh, BundleProgressHandler.ignored());
verify(envFactory, Mockito.times(2)).createEnvironment(eq(environment), any());
AtomicBoolean b2Closing = new AtomicBoolean(false);
ScheduledExecutorService executor = Executors.newSingleThreadScheduledExecutor();
ScheduledFuture<Optional<Exception>> closingFuture = executor.schedule(() -> {
try {
b2Closing.compareAndSet(false, true);
b2.close();
return Optional.empty();
} catch (Exception e) {
return Optional.of(e);
}
}, 100, TimeUnit.MILLISECONDS);
assertThat(b2Closing.get(), equalTo(false));
// This call should block until closingFuture has finished closing b2 (100ms)
RemoteBundle b3 = sbf.getBundle(orf, srh, BundleProgressHandler.ignored());
// ensure the previous call waited for close
assertThat(b2Closing.get(), equalTo(true));
// Join closingFuture and check if an exception occurred
Optional<Exception> closingException = closingFuture.get();
if (closingException.isPresent()) {
throw new AssertionError("Exception occurred while closing b2", closingException.get());
}
verify(envFactory, Mockito.times(2)).createEnvironment(eq(environment), any());
b3.close();
b1.close();
executor.shutdown();
}
}
use of org.apache.beam.runners.fnexecution.state.StateRequestHandler in project beam by apache.
the class SparkExecutableStageFunction method getStateRequestHandler.
private StateRequestHandler getStateRequestHandler(ExecutableStage executableStage, ProcessBundleDescriptors.ExecutableProcessBundleDescriptor processBundleDescriptor) {
EnumMap<TypeCase, StateRequestHandler> handlerMap = new EnumMap<>(StateKey.TypeCase.class);
final StateRequestHandler sideInputHandler;
StateRequestHandlers.SideInputHandlerFactory sideInputHandlerFactory = BatchSideInputHandlerFactory.forStage(executableStage, new BatchSideInputHandlerFactory.SideInputGetter() {
@Override
public <T> List<T> getSideInput(String pCollectionId) {
Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<SideInputT>> tuple2 = sideInputs.get(pCollectionId);
Broadcast<List<byte[]>> broadcast = tuple2._1;
WindowedValueCoder<SideInputT> coder = tuple2._2;
return (List<T>) broadcast.value().stream().map(bytes -> CoderHelpers.fromByteArray(bytes, coder)).collect(Collectors.toList());
}
});
try {
sideInputHandler = StateRequestHandlers.forSideInputHandlerFactory(ProcessBundleDescriptors.getSideInputs(executableStage), sideInputHandlerFactory);
} catch (IOException e) {
throw new RuntimeException("Failed to setup state handler", e);
}
if (bagUserStateHandlerFactory == null) {
bagUserStateHandlerFactory = new InMemoryBagUserStateFactory();
}
final StateRequestHandler userStateHandler;
if (executableStage.getUserStates().size() > 0) {
// Need to discard the old key's state
bagUserStateHandlerFactory.resetForNewKey();
userStateHandler = StateRequestHandlers.forBagUserStateHandlerFactory(processBundleDescriptor, bagUserStateHandlerFactory);
} else {
userStateHandler = StateRequestHandler.unsupported();
}
handlerMap.put(StateKey.TypeCase.ITERABLE_SIDE_INPUT, sideInputHandler);
handlerMap.put(StateKey.TypeCase.MULTIMAP_SIDE_INPUT, sideInputHandler);
handlerMap.put(StateKey.TypeCase.MULTIMAP_KEYS_SIDE_INPUT, sideInputHandler);
handlerMap.put(StateKey.TypeCase.BAG_USER_STATE, userStateHandler);
return StateRequestHandlers.delegateBasedUponType(handlerMap);
}
use of org.apache.beam.runners.fnexecution.state.StateRequestHandler in project beam by apache.
the class ExecutableStageDoFnOperatorTest method outputsAreTaggedCorrectly.
@Test
public void outputsAreTaggedCorrectly() throws Exception {
WindowedValue.ValueOnlyWindowedValueCoder<Integer> coder = WindowedValue.getValueOnlyCoder(VarIntCoder.of());
TupleTag<Integer> mainOutput = new TupleTag<>("main-output");
TupleTag<Integer> additionalOutput1 = new TupleTag<>("output-1");
TupleTag<Integer> additionalOutput2 = new TupleTag<>("output-2");
ImmutableMap<TupleTag<?>, OutputTag<?>> tagsToOutputTags = ImmutableMap.<TupleTag<?>, OutputTag<?>>builder().put(additionalOutput1, new OutputTag<WindowedValue<String>>(additionalOutput1.getId()) {
}).put(additionalOutput2, new OutputTag<WindowedValue<String>>(additionalOutput2.getId()) {
}).build();
ImmutableMap<TupleTag<?>, Coder<WindowedValue<?>>> tagsToCoders = ImmutableMap.<TupleTag<?>, Coder<WindowedValue<?>>>builder().put(mainOutput, (Coder) coder).put(additionalOutput1, coder).put(additionalOutput2, coder).build();
ImmutableMap<TupleTag<?>, Integer> tagsToIds = ImmutableMap.<TupleTag<?>, Integer>builder().put(mainOutput, 0).put(additionalOutput1, 1).put(additionalOutput2, 2).build();
DoFnOperator.MultiOutputOutputManagerFactory<Integer> outputManagerFactory = new DoFnOperator.MultiOutputOutputManagerFactory(mainOutput, tagsToOutputTags, tagsToCoders, tagsToIds, new SerializablePipelineOptions(FlinkPipelineOptions.defaults()));
WindowedValue<Integer> zero = WindowedValue.valueInGlobalWindow(0);
WindowedValue<Integer> three = WindowedValue.valueInGlobalWindow(3);
WindowedValue<Integer> four = WindowedValue.valueInGlobalWindow(4);
WindowedValue<Integer> five = WindowedValue.valueInGlobalWindow(5);
// We use a real StageBundleFactory here in order to exercise the output receiver factory.
StageBundleFactory stageBundleFactory = new StageBundleFactory() {
private boolean onceEmitted;
@Override
public RemoteBundle getBundle(OutputReceiverFactory receiverFactory, TimerReceiverFactory timerReceiverFactory, StateRequestHandler stateRequestHandler, BundleProgressHandler progressHandler, BundleFinalizationHandler finalizationHandler, BundleCheckpointHandler checkpointHandler) {
return new RemoteBundle() {
@Override
public String getId() {
return "bundle-id";
}
@Override
public Map<String, FnDataReceiver> getInputReceivers() {
return ImmutableMap.of("input", input -> {
/* Ignore input*/
});
}
@Override
public Map<KV<String, String>, FnDataReceiver<Timer>> getTimerReceivers() {
return Collections.emptyMap();
}
@Override
public void requestProgress() {
throw new UnsupportedOperationException();
}
@Override
public void split(double fractionOfRemainder) {
throw new UnsupportedOperationException();
}
@Override
public void close() throws Exception {
if (onceEmitted) {
return;
}
// Emit all values to the runner when the bundle is closed.
receiverFactory.create(mainOutput.getId()).accept(three);
receiverFactory.create(additionalOutput1.getId()).accept(four);
receiverFactory.create(additionalOutput2.getId()).accept(five);
onceEmitted = true;
}
};
}
@Override
public ProcessBundleDescriptors.ExecutableProcessBundleDescriptor getProcessBundleDescriptor() {
return processBundleDescriptor;
}
@Override
public InstructionRequestHandler getInstructionRequestHandler() {
return null;
}
@Override
public void close() {
}
};
// Wire the stage bundle factory into our context.
when(stageContext.getStageBundleFactory(any())).thenReturn(stageBundleFactory);
ExecutableStageDoFnOperator<Integer, Integer> operator = getOperator(mainOutput, ImmutableList.of(additionalOutput1, additionalOutput2), outputManagerFactory);
OneInputStreamOperatorTestHarness<WindowedValue<Integer>, WindowedValue<Integer>> testHarness = new OneInputStreamOperatorTestHarness<>(operator);
long watermark = testHarness.getCurrentWatermark() + 1;
testHarness.open();
testHarness.processElement(new StreamRecord<>(zero));
testHarness.processWatermark(watermark);
watermark++;
testHarness.processWatermark(watermark);
assertEquals(watermark, testHarness.getCurrentWatermark());
// watermark hold until bundle complete
assertEquals(0, testHarness.getOutput().size());
// triggers finish bundle
testHarness.close();
assertThat(stripStreamRecordFromWindowedValue(testHarness.getOutput()), contains(three));
assertThat(testHarness.getSideOutput(tagsToOutputTags.get(additionalOutput1)), contains(new StreamRecord<>(four)));
assertThat(testHarness.getSideOutput(tagsToOutputTags.get(additionalOutput2)), contains(new StreamRecord<>(five)));
}
Aggregations