use of org.apache.beam.runners.dataflow.worker.windmill.Windmill in project beam by apache.
the class StreamingDataflowWorker method refreshActiveWork.
* Sends a GetData request to Windmill for all sufficiently old active work.
* <p>This informs Windmill that processing is ongoing and the work should not be retried. The age
* threshold is determined by {@link
* StreamingDataflowWorkerOptions#getActiveWorkRefreshPeriodMillis}.
private void refreshActiveWork() {
Map<String, List<Windmill.KeyedGetDataRequest>> active = new HashMap<>();
Instant refreshDeadline =;
for (Map.Entry<String, ComputationState> entry : computationMap.entrySet()) {
active.put(entry.getKey(), entry.getValue().getKeysToRefresh(refreshDeadline));
the class StreamingDataflowWorker method process.
private void process(final SdkWorkerHarness worker, final ComputationState computationState, final Instant inputDataWatermark, @Nullable final Instant outputDataWatermark, @Nullable final Instant synchronizedProcessingTime, final Work work) {
final Windmill.WorkItem workItem = work.getWorkItem();
final String computationId = computationState.getComputationId();
final ByteString key = workItem.getKey();
StringBuilder workIdBuilder = new StringBuilder(33);
LOG.debug("Starting processing for {}:\n{}", computationId, work);
Windmill.WorkItemCommitRequest.Builder outputBuilder = initializeOutputBuilder(key, workItem);
// Before any processing starts, call any pending OnCommit callbacks. Nothing that requires
// cleanup should be done before this, since we might exit early here.
if (workItem.getSourceState().getOnlyFinalize()) {
commitQueue.put(new Commit(, computationState, work));
long processingStartTimeNanos = System.nanoTime();
final MapTask mapTask = computationState.getMapTask();
StageInfo stageInfo = stageInfoMap.computeIfAbsent(mapTask.getStageName(), s -> new StageInfo(s, mapTask.getSystemName(), this));
ExecutionState executionState = null;
try {
executionState = computationState.getExecutionStateQueue(worker).poll();
if (executionState == null) {
MutableNetwork<Node, Edge> mapTaskNetwork = mapTaskToNetwork.apply(mapTask);
if (LOG.isDebugEnabled()) {
LOG.debug("Network as Graphviz .dot: {}", Networks.toDot(mapTaskNetwork));
ParallelInstructionNode readNode = (ParallelInstructionNode) Iterables.find(mapTaskNetwork.nodes(), node -> node instanceof ParallelInstructionNode && ((ParallelInstructionNode) node).getParallelInstruction().getRead() != null);
InstructionOutputNode readOutputNode = (InstructionOutputNode) Iterables.getOnlyElement(mapTaskNetwork.successors(readNode));
DataflowExecutionContext.DataflowExecutionStateTracker executionStateTracker = new DataflowExecutionContext.DataflowExecutionStateTracker(ExecutionStateSampler.instance(), stageInfo.executionStateRegistry.getState(NameContext.forStage(mapTask.getStageName()), "other", null, ScopedProfiler.INSTANCE.emptyScope()), stageInfo.deltaCounters, options, computationId);
StreamingModeExecutionContext context = new StreamingModeExecutionContext(pendingDeltaCounters, computationId, readerCache, !computationState.getTransformUserNameToStateFamily().isEmpty() ? computationState.getTransformUserNameToStateFamily() : stateNameMap, stateCache.forComputation(computationId), stageInfo.metricsContainerRegistry, executionStateTracker, stageInfo.executionStateRegistry, maxSinkBytes);
DataflowMapTaskExecutor mapTaskExecutor = mapTaskExecutorFactory.create(worker.getControlClientHandler(), worker.getGrpcDataFnServer(), sdkHarnessRegistry.beamFnDataApiServiceDescriptor(), worker.getGrpcStateFnServer(), mapTaskNetwork, options, mapTask.getStageName(), readerRegistry, sinkRegistry, context, pendingDeltaCounters, idGenerator);
ReadOperation readOperation = mapTaskExecutor.getReadOperation();
// Disable progress updates since its results are unused for streaming
// and involves starting a thread.
Preconditions.checkState(mapTaskExecutor.supportsRestart(), "Streaming runner requires all operations support restart.");
Coder<?> readCoder;
readCoder = CloudObjects.coderFromCloudObject(CloudObject.fromSpec(readOutputNode.getInstructionOutput().getCodec()));
Coder<?> keyCoder = extractKeyCoder(readCoder);
// If using a custom source, count bytes read for autoscaling.
if (CustomSources.class.getName().equals(readNode.getParallelInstruction().getRead().getSource().getSpec().get("@type"))) {
NameContext nameContext = NameContext.create(mapTask.getStageName(), readNode.getParallelInstruction().getOriginalName(), readNode.getParallelInstruction().getSystemName(), readNode.getParallelInstruction().getName());
readOperation.receivers[0].addOutputCounter(new OutputObjectAndByteCounter(new IntrinsicMapTaskExecutorFactory.ElementByteSizeObservableCoder<>(readCoder), mapTaskExecutor.getOutputCounters(), nameContext).setSamplingPeriod(100).countBytes("dataflow_input_size-" + mapTask.getSystemName()));
executionState = new ExecutionState(mapTaskExecutor, context, keyCoder, executionStateTracker);
WindmillStateReader stateReader = new WindmillStateReader(metricTrackingWindmillServer, computationId, key, workItem.getShardingKey(), workItem.getWorkToken());
StateFetcher localStateFetcher = stateFetcher.byteTrackingView();
// If the read output KVs, then we can decode Windmill's byte key into a userland
// key object and provide it to the execution context for use with per-key state.
// Otherwise, we pass null.
// The coder type that will be present is:
// WindowedValueCoder(TimerOrElementCoder(KvCoder))
@Nullable Coder<?> keyCoder = executionState.getKeyCoder();
@Nullable Object executionKey = keyCoder == null ? null : keyCoder.decode(key.newInput(), Coder.Context.OUTER);
if (workItem.hasHotKeyInfo()) {
Windmill.HotKeyInfo hotKeyInfo = workItem.getHotKeyInfo();
Duration hotKeyAge = Duration.millis(hotKeyInfo.getHotKeyAgeUsec() / 1000);
// The MapTask instruction is ordered by dependencies, such that the first element is
// always going to be the shuffle task.
String stepName = computationState.getMapTask().getInstructions().get(0).getName();
if (options.isHotKeyLoggingEnabled() && keyCoder != null) {
hotKeyLogger.logHotKeyDetection(stepName, hotKeyAge, executionKey);
} else {
hotKeyLogger.logHotKeyDetection(stepName, hotKeyAge);
executionState.getContext().start(executionKey, workItem, inputDataWatermark, outputDataWatermark, synchronizedProcessingTime, stateReader, localStateFetcher, outputBuilder);
// Blocks while executing work.
Iterables.addAll(this.pendingMonitoringInfos, executionState.getWorkExecutor().extractMetricUpdates());
// Release the execution state for another thread to use.
executionState = null;
// Add the output to the commit queue.
WorkItemCommitRequest commitRequest =;
int byteLimit = maxWorkItemCommitBytes;
int commitSize = commitRequest.getSerializedSize();
int estimatedCommitSize = commitSize < 0 ? Integer.MAX_VALUE : commitSize;
// Detect overflow of integer serialized size or if the byte limit was exceeded.
if (commitSize < 0 || commitSize > byteLimit) {
KeyCommitTooLargeException e = KeyCommitTooLargeException.causedBy(computationId, byteLimit, commitRequest);
reportFailure(computationId, workItem, e);
// Drop the current request in favor of a new, minimal one requesting truncation.
// Messages, timers, counters, and other commit content will not be used by the service
// so we're purposefully dropping them here
commitRequest = buildWorkItemTruncationRequest(key, workItem, estimatedCommitSize);
commitQueue.put(new Commit(commitRequest, computationState, work));
// Compute shuffle and state byte statistics these will be flushed asynchronously.
long stateBytesWritten = outputBuilder.clearOutputMessages().build().getSerializedSize();
long shuffleBytesRead = 0;
for (Windmill.InputMessageBundle bundle : workItem.getMessageBundlesList()) {
for (Windmill.Message message : bundle.getMessagesList()) {
shuffleBytesRead += message.getSerializedSize();
long stateBytesRead = stateReader.getBytesRead() + localStateFetcher.getBytesRead();
LOG.debug("Processing done for work token: {}", workItem.getWorkToken());
} catch (Throwable t) {
if (executionState != null) {
try {
} catch (Exception e) {
LOG.warn("Failed to close map task executor: ", e);
} finally {
// Release references to potentially large objects early.
executionState = null;
t = t instanceof UserCodeException ? t.getCause() : t;
boolean retryLocally = false;
if (KeyTokenInvalidException.isKeyTokenInvalidException(t)) {
LOG.debug("Execution of work for computation '{}' on key '{}' failed due to token expiration. " + "Work will not be retried locally.", computationId, key.toStringUtf8());
} else {
LOG.debug("Failed work: {}", work);
Duration elapsedTimeSinceStart = new Duration(, work.getStartTime());
if (!reportFailure(computationId, workItem, t)) {
LOG.error("Execution of work for computation '{}' on key '{}' failed with uncaught exception, " + "and Windmill indicated not to retry locally.", computationId, key.toStringUtf8(), t);
} else if (isOutOfMemoryError(t)) {
File heapDump = memoryMonitor.tryToDumpHeap();
LOG.error("Execution of work for computation '{}' for key '{}' failed with out-of-memory. " + "Work will not be retried locally. Heap dump {}.", computationId, key.toStringUtf8(), heapDump == null ? "not written" : ("written to '" + heapDump + "'"), t);
} else if (elapsedTimeSinceStart.isLongerThan(MAX_LOCAL_PROCESSING_RETRY_DURATION)) {
LOG.error("Execution of work for computation '{}' for key '{}' failed with uncaught exception, " + "and it will not be retried locally because the elapsed time since start {} " + "exceeds {}.", computationId, key.toStringUtf8(), elapsedTimeSinceStart, MAX_LOCAL_PROCESSING_RETRY_DURATION, t);
} else {
LOG.error("Execution of work for computation '{}' on key '{}' failed with uncaught exception. " + "Work will be retried locally.", computationId, key.toStringUtf8(), t);
retryLocally = true;
if (retryLocally) {
// Try again after some delay and at the end of the queue to avoid a tight loop.
workUnitExecutor.forceExecute(work, work.getWorkItem().getSerializedSize());
} else {
// Consider the item invalid. It will eventually be retried by Windmill if it still needs to
// be processed.
computationState.completeWork(ShardedKey.create(key, workItem.getShardingKey()), workItem.getWorkToken());
} finally {
// Update total processing time counters. Updating in finally clause ensures that
// work items causing exceptions are also accounted in time spent.
long processingTimeMsecs = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - processingStartTimeNanos);
// either here or in DFE.
if (work.getWorkItem().hasTimers()) {
the class StreamingDataflowWorkerTest method testLimitOnOutputBundleSize.
public void testLimitOnOutputBundleSize() throws Exception {
// This verifies that ReadOperation, StreamingModeExecutionContext, and windmill sinks
// coordinate to limit size of an output bundle.
List<Integer> finalizeTracker = Lists.newArrayList();
// 100K input messages.
final int numMessagesInCustomSourceShard = 100000;
// x10k => 1GB total output size.
final int inflatedSizePerMessage = 10000;
FakeWindmillServer server = new FakeWindmillServer(errorCollector);
StreamingDataflowWorker worker = makeWorker(makeUnboundedSourcePipeline(numMessagesInCustomSourceShard, new InflateDoFn(inflatedSizePerMessage)), createTestingPipelineOptions(server), false);
// Test new key.
server.addWorkToOffer(buildInput("work {" + " computation_id: \"computation\"" + " input_data_watermark: 0" + " work {" + " key: \"0000000000000001\"" + " sharding_key: 1" + " work_token: 1" + " cache_token: 1" + " }" + "}", null));
// Matcher to ensure that commit size is within 10% of max bundle size.
Matcher<Integer> isWithinBundleSizeLimits = both(greaterThan(StreamingDataflowWorker.MAX_SINK_BYTES * 9 / 10)).and(lessThan(StreamingDataflowWorker.MAX_SINK_BYTES * 11 / 10));
Map<Long, Windmill.WorkItemCommitRequest> result = server.waitForAndGetCommits(1);
Windmill.WorkItemCommitRequest commit = result.get(1L);
assertThat(commit.getSerializedSize(), isWithinBundleSizeLimits);
// Try another bundle
server.addWorkToOffer(buildInput("work {" + " computation_id: \"computation\"" + " input_data_watermark: 0" + " work {" + " key: \"0000000000000001\"" + " sharding_key: 1" + " work_token: 2" + " cache_token: 1" + " }" + "}", null));
result = server.waitForAndGetCommits(1);
commit = result.get(2L);
assertThat(commit.getSerializedSize(), isWithinBundleSizeLimits);