use of org.apache.beam.runners.fnexecution.state.StateRequestHandler in project beam by apache.
the class SamzaDoFnRunners method createPortable.
/**
* Create DoFnRunner for portable runner.
*/
@SuppressWarnings("unchecked")
public static <InT, FnOutT> DoFnRunner<InT, FnOutT> createPortable(String transformId, String bundleStateId, Coder<WindowedValue<InT>> windowedValueCoder, ExecutableStage executableStage, Map<?, PCollectionView<?>> sideInputMapping, SideInputHandler sideInputHandler, SamzaStoreStateInternals.Factory<?> nonKeyedStateInternalsFactory, SamzaTimerInternalsFactory<?> timerInternalsFactory, SamzaPipelineOptions pipelineOptions, DoFnRunners.OutputManager outputManager, StageBundleFactory stageBundleFactory, TupleTag<FnOutT> mainOutputTag, Map<String, TupleTag<?>> idToTupleTagMap, Context context, String transformFullName) {
// storing events within a bundle in states
final BagState<WindowedValue<InT>> bundledEventsBag = nonKeyedStateInternalsFactory.stateInternalsForKey(null).state(StateNamespaces.global(), StateTags.bag(bundleStateId, windowedValueCoder));
final StateRequestHandler stateRequestHandler = SamzaStateRequestHandlers.of(transformId, context.getTaskContext(), pipelineOptions, executableStage, stageBundleFactory, (Map<RunnerApi.ExecutableStagePayload.SideInputId, PCollectionView<?>>) sideInputMapping, sideInputHandler);
final SamzaExecutionContext executionContext = (SamzaExecutionContext) context.getApplicationContainerContext();
final DoFnRunner<InT, FnOutT> underlyingRunner = new SdkHarnessDoFnRunner<>(timerInternalsFactory, WindowUtils.getWindowStrategy(executableStage.getInputPCollection().getId(), executableStage.getComponents()), outputManager, stageBundleFactory, idToTupleTagMap, bundledEventsBag, stateRequestHandler);
return pipelineOptions.getEnableMetrics() ? DoFnRunnerWithMetrics.wrap(underlyingRunner, executionContext.getMetricsContainer(), transformFullName) : underlyingRunner;
}
use of org.apache.beam.runners.fnexecution.state.StateRequestHandler in project beam by apache.
the class FlinkExecutableStageFunctionTest method outputsAreTaggedCorrectly.
@Test
public void outputsAreTaggedCorrectly() throws Exception {
WindowedValue<Integer> three = WindowedValue.valueInGlobalWindow(3);
WindowedValue<Integer> four = WindowedValue.valueInGlobalWindow(4);
WindowedValue<Integer> five = WindowedValue.valueInGlobalWindow(5);
Map<String, Integer> outputTagMap = ImmutableMap.of("one", 1, "two", 2, "three", 3);
// We use a real StageBundleFactory here in order to exercise the output receiver factory.
StageBundleFactory stageBundleFactory = new StageBundleFactory() {
private boolean once;
@Override
public RemoteBundle getBundle(OutputReceiverFactory receiverFactory, TimerReceiverFactory timerReceiverFactory, StateRequestHandler stateRequestHandler, BundleProgressHandler progressHandler, BundleFinalizationHandler finalizationHandler, BundleCheckpointHandler checkpointHandler) {
return new RemoteBundle() {
@Override
public String getId() {
return "bundle-id";
}
@Override
public Map<String, FnDataReceiver> getInputReceivers() {
return ImmutableMap.of("input", input -> {
/* Ignore input*/
});
}
@Override
public Map<KV<String, String>, FnDataReceiver<Timer>> getTimerReceivers() {
return Collections.emptyMap();
}
@Override
public void requestProgress() {
throw new UnsupportedOperationException();
}
@Override
public void split(double fractionOfRemainder) {
throw new UnsupportedOperationException();
}
@Override
public void close() throws Exception {
if (once) {
return;
}
// Emit all values to the runner when the bundle is closed.
receiverFactory.create("one").accept(three);
receiverFactory.create("two").accept(four);
receiverFactory.create("three").accept(five);
once = true;
}
};
}
@Override
public ProcessBundleDescriptors.ExecutableProcessBundleDescriptor getProcessBundleDescriptor() {
return processBundleDescriptor;
}
@Override
public InstructionRequestHandler getInstructionRequestHandler() {
return null;
}
@Override
public void close() throws Exception {
}
};
// Wire the stage bundle factory into our context.
when(stageContext.getStageBundleFactory(any())).thenReturn(stageBundleFactory);
FlinkExecutableStageFunction<Integer> function = getFunction(outputTagMap);
function.open(new Configuration());
if (isStateful) {
function.reduce(Collections.emptyList(), collector);
} else {
function.mapPartition(Collections.emptyList(), collector);
}
// Ensure that the tagged values sent to the collector have the correct union tags as specified
// in the output map.
verify(collector).collect(new RawUnionValue(1, three));
verify(collector).collect(new RawUnionValue(2, four));
verify(collector).collect(new RawUnionValue(3, five));
verifyNoMoreInteractions(collector);
}
use of org.apache.beam.runners.fnexecution.state.StateRequestHandler in project beam by apache.
the class ExecutableStageDoFnOperatorTest method testCacheTokenHandling.
@Test
public void testCacheTokenHandling() throws Exception {
InMemoryStateInternals test = InMemoryStateInternals.forKey("test");
KeyedStateBackend<ByteBuffer> stateBackend = FlinkStateInternalsTest.createStateBackend();
ExecutableStageDoFnOperator.BagUserStateFactory<Integer, GlobalWindow> bagUserStateFactory = new ExecutableStageDoFnOperator.BagUserStateFactory<>(test, stateBackend, NoopLock.get(), null);
ByteString key1 = ByteString.copyFrom("key1", Charsets.UTF_8);
ByteString key2 = ByteString.copyFrom("key2", Charsets.UTF_8);
Map<String, Map<String, ProcessBundleDescriptors.BagUserStateSpec>> userStateMapMock = Mockito.mock(Map.class);
Map<String, ProcessBundleDescriptors.BagUserStateSpec> transformMap = Mockito.mock(Map.class);
final String userState1 = "userstate1";
ProcessBundleDescriptors.BagUserStateSpec bagUserStateSpec1 = mockBagUserState(userState1);
when(transformMap.get(userState1)).thenReturn(bagUserStateSpec1);
final String userState2 = "userstate2";
ProcessBundleDescriptors.BagUserStateSpec bagUserStateSpec2 = mockBagUserState(userState2);
when(transformMap.get(userState2)).thenReturn(bagUserStateSpec2);
when(userStateMapMock.get(anyString())).thenReturn(transformMap);
when(processBundleDescriptor.getBagUserStateSpecs()).thenReturn(userStateMapMock);
StateRequestHandler stateRequestHandler = StateRequestHandlers.forBagUserStateHandlerFactory(processBundleDescriptor, bagUserStateFactory);
// User state the cache token is valid for the lifetime of the operator
final BeamFnApi.ProcessBundleRequest.CacheToken expectedCacheToken = Iterables.getOnlyElement(stateRequestHandler.getCacheTokens());
// Make a request to generate initial cache token
stateRequestHandler.handle(getRequest(key1, userState1));
BeamFnApi.ProcessBundleRequest.CacheToken returnedCacheToken = Iterables.getOnlyElement(stateRequestHandler.getCacheTokens());
assertThat(returnedCacheToken.hasUserState(), is(true));
assertThat(returnedCacheToken, is(expectedCacheToken));
List<RequestGenerator> generators = Arrays.asList(ExecutableStageDoFnOperatorTest::getRequest, ExecutableStageDoFnOperatorTest::getAppend, ExecutableStageDoFnOperatorTest::getClear);
for (RequestGenerator req : generators) {
// For every state read the tokens remains unchanged
stateRequestHandler.handle(req.makeRequest(key1, userState1));
assertThat(Iterables.getOnlyElement(stateRequestHandler.getCacheTokens()), is(expectedCacheToken));
// The token is still valid for another key in the same key range
stateRequestHandler.handle(req.makeRequest(key2, userState1));
assertThat(Iterables.getOnlyElement(stateRequestHandler.getCacheTokens()), is(expectedCacheToken));
// The token is still valid for another state cell in the same key range
stateRequestHandler.handle(req.makeRequest(key2, userState2));
assertThat(Iterables.getOnlyElement(stateRequestHandler.getCacheTokens()), is(expectedCacheToken));
}
}
use of org.apache.beam.runners.fnexecution.state.StateRequestHandler in project beam by apache.
the class DefaultJobBundleFactoryTest method expiresEnvironment.
@Test
public void expiresEnvironment() throws Exception {
ServerFactory serverFactory = ServerFactory.createDefault();
Environment environmentA = Environment.newBuilder().setUrn("env:urn:a").build();
EnvironmentFactory envFactoryA = mock(EnvironmentFactory.class);
when(envFactoryA.createEnvironment(eq(environmentA), any())).thenReturn(remoteEnvironment);
EnvironmentFactory.Provider environmentProviderFactoryA = mock(EnvironmentFactory.Provider.class);
when(environmentProviderFactoryA.createEnvironmentFactory(any(), any(), any(), any(), any(), any())).thenReturn(envFactoryA);
when(environmentProviderFactoryA.getServerFactory()).thenReturn(serverFactory);
Map<String, Provider> environmentFactoryProviderMap = ImmutableMap.of(environmentA.getUrn(), environmentProviderFactoryA);
PortablePipelineOptions portableOptions = PipelineOptionsFactory.as(PortablePipelineOptions.class);
portableOptions.setEnvironmentExpirationMillis(1);
Struct pipelineOptions = PipelineOptionsTranslation.toProto(portableOptions);
try (DefaultJobBundleFactory bundleFactory = new DefaultJobBundleFactory(JobInfo.create("testJob", "testJob", "token", pipelineOptions), environmentFactoryProviderMap, stageIdGenerator, serverInfo)) {
OutputReceiverFactory orf = mock(OutputReceiverFactory.class);
StateRequestHandler srh = mock(StateRequestHandler.class);
when(srh.getCacheTokens()).thenReturn(Collections.emptyList());
StageBundleFactory sbf = bundleFactory.forStage(getExecutableStage(environmentA));
// allow environment to expire
Thread.sleep(10);
sbf.getBundle(orf, srh, BundleProgressHandler.ignored()).close();
// allow environment to expire
Thread.sleep(10);
sbf.getBundle(orf, srh, BundleProgressHandler.ignored()).close();
}
verify(envFactoryA, Mockito.times(3)).createEnvironment(eq(environmentA), any());
verify(remoteEnvironment, Mockito.times(3)).close();
}
use of org.apache.beam.runners.fnexecution.state.StateRequestHandler in project beam by apache.
the class RemoteExecutionTest method testExecutionWithSideInput.
@Test
public void testExecutionWithSideInput() throws Exception {
launchSdkHarness(PipelineOptionsFactory.create());
Pipeline p = Pipeline.create();
addExperiment(p.getOptions().as(ExperimentalOptions.class), "beam_fn_api");
// TODO(BEAM-10097): Remove experiment once all portable runners support this view type
addExperiment(p.getOptions().as(ExperimentalOptions.class), "use_runner_v2");
PCollection<String> input = p.apply("impulse", Impulse.create()).apply("create", ParDo.of(new DoFn<byte[], String>() {
@ProcessElement
public void process(ProcessContext ctxt) {
ctxt.output("zero");
ctxt.output("one");
ctxt.output("two");
}
})).setCoder(StringUtf8Coder.of());
PCollectionView<Iterable<String>> iterableView = input.apply("createIterableSideInput", View.asIterable());
PCollectionView<Map<String, Iterable<String>>> multimapView = input.apply(WithKeys.of("key")).apply("createMultimapSideInput", View.asMultimap());
input.apply("readSideInput", ParDo.of(new DoFn<String, KV<String, String>>() {
@ProcessElement
public void processElement(ProcessContext context) {
for (String value : context.sideInput(iterableView)) {
context.output(KV.of(context.element(), value));
}
for (Map.Entry<String, Iterable<String>> entry : context.sideInput(multimapView).entrySet()) {
for (String value : entry.getValue()) {
context.output(KV.of(context.element(), entry.getKey() + ":" + value));
}
}
}
}).withSideInputs(iterableView, multimapView)).setCoder(KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of())).apply("gbk", GroupByKey.create());
RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
FusedPipeline fused = GreedyPipelineFuser.fuse(pipelineProto);
Optional<ExecutableStage> optionalStage = Iterables.tryFind(fused.getFusedStages(), (ExecutableStage stage) -> !stage.getSideInputs().isEmpty());
checkState(optionalStage.isPresent(), "Expected a stage with side inputs.");
ExecutableStage stage = optionalStage.get();
ExecutableProcessBundleDescriptor descriptor = ProcessBundleDescriptors.fromExecutableStage("test_stage", stage, dataServer.getApiServiceDescriptor(), stateServer.getApiServiceDescriptor());
BundleProcessor processor = controlClient.getProcessor(descriptor.getProcessBundleDescriptor(), descriptor.getRemoteInputDestinations(), stateDelegator);
Map<String, Coder> remoteOutputCoders = descriptor.getRemoteOutputCoders();
Map<String, Collection<WindowedValue<?>>> outputValues = new HashMap<>();
Map<String, RemoteOutputReceiver<?>> outputReceivers = new HashMap<>();
for (Entry<String, Coder> remoteOutputCoder : remoteOutputCoders.entrySet()) {
List<WindowedValue<?>> outputContents = Collections.synchronizedList(new ArrayList<>());
outputValues.put(remoteOutputCoder.getKey(), outputContents);
outputReceivers.put(remoteOutputCoder.getKey(), RemoteOutputReceiver.of((Coder<WindowedValue<?>>) remoteOutputCoder.getValue(), outputContents::add));
}
StateRequestHandler stateRequestHandler = StateRequestHandlers.forSideInputHandlerFactory(descriptor.getSideInputSpecs(), new SideInputHandlerFactory() {
@Override
public <V, W extends BoundedWindow> IterableSideInputHandler<V, W> forIterableSideInput(String pTransformId, String sideInputId, Coder<V> elementCoder, Coder<W> windowCoder) {
return new IterableSideInputHandler<V, W>() {
@Override
public Iterable<V> get(W window) {
return (Iterable) Arrays.asList("A", "B", "C");
}
@Override
public Coder<V> elementCoder() {
return elementCoder;
}
};
}
@Override
public <K, V, W extends BoundedWindow> MultimapSideInputHandler<K, V, W> forMultimapSideInput(String pTransformId, String sideInputId, KvCoder<K, V> elementCoder, Coder<W> windowCoder) {
return new MultimapSideInputHandler<K, V, W>() {
@Override
public Iterable<K> get(W window) {
return (Iterable) Arrays.asList("key1", "key2");
}
@Override
public Iterable<V> get(K key, W window) {
if ("key1".equals(key)) {
return (Iterable) Arrays.asList("H", "I", "J");
} else if ("key2".equals(key)) {
return (Iterable) Arrays.asList("M", "N", "O");
}
return Collections.emptyList();
}
@Override
public Coder<K> keyCoder() {
return elementCoder.getKeyCoder();
}
@Override
public Coder<V> valueCoder() {
return elementCoder.getValueCoder();
}
};
}
});
BundleProgressHandler progressHandler = BundleProgressHandler.ignored();
try (RemoteBundle bundle = processor.newBundle(outputReceivers, stateRequestHandler, progressHandler)) {
Iterables.getOnlyElement(bundle.getInputReceivers().values()).accept(valueInGlobalWindow("X"));
Iterables.getOnlyElement(bundle.getInputReceivers().values()).accept(valueInGlobalWindow("Y"));
}
for (Collection<WindowedValue<?>> windowedValues : outputValues.values()) {
assertThat(windowedValues, containsInAnyOrder(valueInGlobalWindow(KV.of("X", "A")), valueInGlobalWindow(KV.of("X", "B")), valueInGlobalWindow(KV.of("X", "C")), valueInGlobalWindow(KV.of("X", "key1:H")), valueInGlobalWindow(KV.of("X", "key1:I")), valueInGlobalWindow(KV.of("X", "key1:J")), valueInGlobalWindow(KV.of("X", "key2:M")), valueInGlobalWindow(KV.of("X", "key2:N")), valueInGlobalWindow(KV.of("X", "key2:O")), valueInGlobalWindow(KV.of("Y", "A")), valueInGlobalWindow(KV.of("Y", "B")), valueInGlobalWindow(KV.of("Y", "C")), valueInGlobalWindow(KV.of("Y", "key1:H")), valueInGlobalWindow(KV.of("Y", "key1:I")), valueInGlobalWindow(KV.of("Y", "key1:J")), valueInGlobalWindow(KV.of("Y", "key2:M")), valueInGlobalWindow(KV.of("Y", "key2:N")), valueInGlobalWindow(KV.of("Y", "key2:O"))));
}
}
Aggregations