use of org.apache.beam.runners.core.StatefulDoFnRunner in project beam by apache.
the class ExecutableStageDoFnOperator method ensureStateDoFnRunner.
private DoFnRunner<InputT, OutputT> ensureStateDoFnRunner(SdkHarnessDoFnRunner<InputT, OutputT> sdkHarnessRunner, RunnerApi.ExecutableStagePayload payload, StepContext stepContext) {
if (!isStateful) {
return sdkHarnessRunner;
}
// Takes care of state cleanup via StatefulDoFnRunner
Coder windowCoder = windowingStrategy.getWindowFn().windowCoder();
CleanupTimer<InputT> cleanupTimer = new CleanupTimer<>(timerInternals, stateBackendLock, windowingStrategy, keyCoder, windowCoder, getKeyedStateBackend());
List<String> userStates = executableStage.getUserStates().stream().map(UserStateReference::localName).collect(Collectors.toList());
KeyedStateBackend<ByteBuffer> stateBackend = getKeyedStateBackend();
StateCleaner stateCleaner = new StateCleaner(userStates, windowCoder, stateBackend::getCurrentKey, timerInternals::hasPendingEventTimeTimers, cleanupTimer);
return new StatefulDoFnRunner<InputT, OutputT, BoundedWindow>(sdkHarnessRunner, getInputCoder(), stepContext, windowingStrategy, cleanupTimer, stateCleaner, requiresTimeSortedInput(payload, true)) {
@Override
public void processElement(WindowedValue<InputT> input) {
try (Locker locker = Locker.locked(stateBackendLock)) {
@SuppressWarnings({ "unchecked", "rawtypes" }) final ByteBuffer key = FlinkKeyUtils.encodeKey(((KV) input.getValue()).getKey(), (Coder) keyCoder);
getKeyedStateBackend().setCurrentKey(key);
super.processElement(input);
}
}
@Override
public void finishBundle() {
// Before cleaning up state, first finish bundle for all underlying DoFnRunners
super.finishBundle();
// execute cleanup after the bundle is complete
if (!stateCleaner.cleanupQueue.isEmpty()) {
try (Locker locker = Locker.locked(stateBackendLock)) {
stateCleaner.cleanupState(keyedStateInternals, stateBackend::setCurrentKey);
} catch (Exception e) {
throw new RuntimeException("Failed to cleanup state.", e);
}
}
}
};
}
use of org.apache.beam.runners.core.StatefulDoFnRunner in project beam by apache.
the class DoFnOperator method earlyBindStateIfNeeded.
private void earlyBindStateIfNeeded() throws IllegalArgumentException, IllegalAccessException {
if (keyCoder != null) {
if (doFn != null) {
DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass());
FlinkStateInternals.EarlyBinder earlyBinder = new FlinkStateInternals.EarlyBinder(getKeyedStateBackend(), serializedOptions);
for (DoFnSignature.StateDeclaration value : signature.stateDeclarations().values()) {
StateSpec<?> spec = (StateSpec<?>) signature.stateDeclarations().get(value.id()).field().get(doFn);
spec.bind(value.id(), earlyBinder);
}
if (doFnRunner instanceof StatefulDoFnRunner) {
((StatefulDoFnRunner<InputT, OutputT, BoundedWindow>) doFnRunner).getSystemStateTags().forEach(tag -> tag.getSpec().bind(tag.getId(), earlyBinder));
}
}
}
}
use of org.apache.beam.runners.core.StatefulDoFnRunner in project beam by apache.
the class ExecutableStageDoFnOperatorTest method testEnsureStateCleanupWithKeyedInput.
@Test
@SuppressWarnings("unchecked")
public void testEnsureStateCleanupWithKeyedInput() throws Exception {
TupleTag<Integer> mainOutput = new TupleTag<>("main-output");
DoFnOperator.MultiOutputOutputManagerFactory<Integer> outputManagerFactory = new DoFnOperator.MultiOutputOutputManagerFactory(mainOutput, VarIntCoder.of(), new SerializablePipelineOptions(FlinkPipelineOptions.defaults()));
VarIntCoder keyCoder = VarIntCoder.of();
ExecutableStageDoFnOperator<Integer, Integer> operator = getOperator(mainOutput, Collections.emptyList(), outputManagerFactory, WindowingStrategy.globalDefault(), keyCoder, WindowedValue.getFullCoder(keyCoder, GlobalWindow.Coder.INSTANCE));
KeyedOneInputStreamOperatorTestHarness<Integer, WindowedValue<Integer>, WindowedValue<Integer>> testHarness = new KeyedOneInputStreamOperatorTestHarness(operator, val -> val, new CoderTypeInformation<>(keyCoder, FlinkPipelineOptions.defaults()));
RemoteBundle bundle = Mockito.mock(RemoteBundle.class);
when(bundle.getInputReceivers()).thenReturn(ImmutableMap.<String, FnDataReceiver<WindowedValue>>builder().put("input", Mockito.mock(FnDataReceiver.class)).build());
when(stageBundleFactory.getBundle(any(), any(), any(), any(), any(), any())).thenReturn(bundle);
testHarness.open();
Object doFnRunner = Whitebox.getInternalState(operator, "doFnRunner");
assertThat(doFnRunner, instanceOf(DoFnRunnerWithMetricsUpdate.class));
// There should be a StatefulDoFnRunner installed which takes care of clearing state
Object statefulDoFnRunner = Whitebox.getInternalState(doFnRunner, "delegate");
assertThat(statefulDoFnRunner, instanceOf(StatefulDoFnRunner.class));
}
use of org.apache.beam.runners.core.StatefulDoFnRunner in project beam by apache.
the class SamzaDoFnRunners method create.
/**
* Create DoFnRunner for java runner.
*/
public static <InT, FnOutT> DoFnRunner<InT, FnOutT> create(SamzaPipelineOptions pipelineOptions, DoFn<InT, FnOutT> doFn, WindowingStrategy<?, ?> windowingStrategy, String transformFullName, String transformId, Context context, TupleTag<FnOutT> mainOutputTag, SideInputHandler sideInputHandler, SamzaTimerInternalsFactory<?> timerInternalsFactory, Coder<?> keyCoder, DoFnRunners.OutputManager outputManager, Coder<InT> inputCoder, List<TupleTag<?>> sideOutputTags, Map<TupleTag<?>, Coder<?>> outputCoders, DoFnSchemaInformation doFnSchemaInformation, Map<String, PCollectionView<?>> sideInputMapping) {
final KeyedInternals keyedInternals;
final TimerInternals timerInternals;
final StateInternals stateInternals;
final DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass());
final SamzaStoreStateInternals.Factory<?> stateInternalsFactory = SamzaStoreStateInternals.createStateInternalsFactory(transformId, keyCoder, context.getTaskContext(), pipelineOptions, signature);
final SamzaExecutionContext executionContext = (SamzaExecutionContext) context.getApplicationContainerContext();
if (StateUtils.isStateful(doFn)) {
keyedInternals = new KeyedInternals(stateInternalsFactory, timerInternalsFactory);
stateInternals = keyedInternals.stateInternals();
timerInternals = keyedInternals.timerInternals();
} else {
keyedInternals = null;
stateInternals = stateInternalsFactory.stateInternalsForKey(null);
timerInternals = timerInternalsFactory.timerInternalsForKey(null);
}
final StepContext stepContext = createStepContext(stateInternals, timerInternals);
final DoFnRunner<InT, FnOutT> underlyingRunner = DoFnRunners.simpleRunner(pipelineOptions, doFn, sideInputHandler, outputManager, mainOutputTag, sideOutputTags, stepContext, inputCoder, outputCoders, windowingStrategy, doFnSchemaInformation, sideInputMapping);
final DoFnRunner<InT, FnOutT> doFnRunnerWithMetrics = pipelineOptions.getEnableMetrics() ? DoFnRunnerWithMetrics.wrap(underlyingRunner, executionContext.getMetricsContainer(), transformFullName) : underlyingRunner;
if (keyedInternals != null) {
final DoFnRunner<InT, FnOutT> statefulDoFnRunner = DoFnRunners.defaultStatefulDoFnRunner(doFn, inputCoder, doFnRunnerWithMetrics, stepContext, windowingStrategy, new StatefulDoFnRunner.TimeInternalsCleanupTimer(timerInternals, windowingStrategy), createStateCleaner(doFn, windowingStrategy, keyedInternals.stateInternals()));
return new DoFnRunnerWithKeyedInternals<>(statefulDoFnRunner, keyedInternals);
} else {
return doFnRunnerWithMetrics;
}
}
Aggregations