use of org.apache.beam.vendor.guava.v26_0_jre.com.google.common.graph.MutableNetwork in project beam by apache.
the class CreateRegisterFnOperationFunctionTest method testAllSdkGraph.
@Test
public void testAllSdkGraph() {
Node sdkPortionNode = TestNode.create("SdkPortion");
@SuppressWarnings({ "unchecked", "rawtypes" }) ArgumentCaptor<MutableNetwork<Node, Edge>> networkCapture = ArgumentCaptor.forClass((Class) MutableNetwork.class);
when(registerFnOperationFunction.apply(networkCapture.capture())).thenReturn(sdkPortionNode);
// Read -out-> ParDo
Node readNode = createReadNode("Read", Nodes.ExecutionLocation.SDK_HARNESS);
Edge readNodeEdge = DefaultEdge.create();
Node readNodeOut = createInstructionOutputNode("Read.out");
Edge readNodeOutEdge = DefaultEdge.create();
Node parDoNode = createParDoNode("ParDo", Nodes.ExecutionLocation.SDK_HARNESS);
Edge parDoNodeEdge = DefaultEdge.create();
Node parDoNodeOut = createInstructionOutputNode("ParDo.out");
MutableNetwork<Node, Edge> network = createEmptyNetwork();
network.addNode(readNode);
network.addNode(readNodeOut);
network.addNode(parDoNode);
network.addNode(parDoNodeOut);
network.addEdge(readNode, readNodeOut, readNodeEdge);
network.addEdge(readNodeOut, parDoNode, readNodeOutEdge);
network.addEdge(parDoNode, parDoNodeOut, parDoNodeEdge);
MutableNetwork<Node, Edge> expectedNetwork = createEmptyNetwork();
expectedNetwork.addNode(sdkPortionNode);
MutableNetwork<Node, Edge> appliedNetwork = createRegisterFnOperation.apply(Graphs.copyOf(network));
assertNetworkMaintainsBipartiteStructure(appliedNetwork);
assertNetworkMaintainsBipartiteStructure(networkCapture.getValue());
assertEquals(String.format("Expected network %s but got network %s", expectedNetwork, appliedNetwork), expectedNetwork, appliedNetwork);
assertEquals(String.format("Expected network %s but got network %s", network, networkCapture.getValue()), network, networkCapture.getValue());
}
use of org.apache.beam.vendor.guava.v26_0_jre.com.google.common.graph.MutableNetwork in project beam by apache.
the class StreamingDataflowWorker method process.
private void process(final SdkWorkerHarness worker, final ComputationState computationState, final Instant inputDataWatermark, @Nullable final Instant outputDataWatermark, @Nullable final Instant synchronizedProcessingTime, final Work work) {
final Windmill.WorkItem workItem = work.getWorkItem();
final String computationId = computationState.getComputationId();
final ByteString key = workItem.getKey();
work.setState(State.PROCESSING);
{
StringBuilder workIdBuilder = new StringBuilder(33);
workIdBuilder.append(Long.toHexString(workItem.getShardingKey()));
workIdBuilder.append('-');
workIdBuilder.append(Long.toHexString(workItem.getWorkToken()));
DataflowWorkerLoggingMDC.setWorkId(workIdBuilder.toString());
}
DataflowWorkerLoggingMDC.setStageName(computationId);
LOG.debug("Starting processing for {}:\n{}", computationId, work);
Windmill.WorkItemCommitRequest.Builder outputBuilder = initializeOutputBuilder(key, workItem);
// Before any processing starts, call any pending OnCommit callbacks. Nothing that requires
// cleanup should be done before this, since we might exit early here.
callFinalizeCallbacks(workItem);
if (workItem.getSourceState().getOnlyFinalize()) {
outputBuilder.setSourceStateUpdates(Windmill.SourceState.newBuilder().setOnlyFinalize(true));
work.setState(State.COMMIT_QUEUED);
commitQueue.put(new Commit(outputBuilder.build(), computationState, work));
return;
}
long processingStartTimeNanos = System.nanoTime();
final MapTask mapTask = computationState.getMapTask();
StageInfo stageInfo = stageInfoMap.computeIfAbsent(mapTask.getStageName(), s -> new StageInfo(s, mapTask.getSystemName(), this));
ExecutionState executionState = null;
try {
executionState = computationState.getExecutionStateQueue(worker).poll();
if (executionState == null) {
MutableNetwork<Node, Edge> mapTaskNetwork = mapTaskToNetwork.apply(mapTask);
if (LOG.isDebugEnabled()) {
LOG.debug("Network as Graphviz .dot: {}", Networks.toDot(mapTaskNetwork));
}
ParallelInstructionNode readNode = (ParallelInstructionNode) Iterables.find(mapTaskNetwork.nodes(), node -> node instanceof ParallelInstructionNode && ((ParallelInstructionNode) node).getParallelInstruction().getRead() != null);
InstructionOutputNode readOutputNode = (InstructionOutputNode) Iterables.getOnlyElement(mapTaskNetwork.successors(readNode));
DataflowExecutionContext.DataflowExecutionStateTracker executionStateTracker = new DataflowExecutionContext.DataflowExecutionStateTracker(ExecutionStateSampler.instance(), stageInfo.executionStateRegistry.getState(NameContext.forStage(mapTask.getStageName()), "other", null, ScopedProfiler.INSTANCE.emptyScope()), stageInfo.deltaCounters, options, computationId);
StreamingModeExecutionContext context = new StreamingModeExecutionContext(pendingDeltaCounters, computationId, readerCache, !computationState.getTransformUserNameToStateFamily().isEmpty() ? computationState.getTransformUserNameToStateFamily() : stateNameMap, stateCache.forComputation(computationId), stageInfo.metricsContainerRegistry, executionStateTracker, stageInfo.executionStateRegistry, maxSinkBytes);
DataflowMapTaskExecutor mapTaskExecutor = mapTaskExecutorFactory.create(worker.getControlClientHandler(), worker.getGrpcDataFnServer(), sdkHarnessRegistry.beamFnDataApiServiceDescriptor(), worker.getGrpcStateFnServer(), mapTaskNetwork, options, mapTask.getStageName(), readerRegistry, sinkRegistry, context, pendingDeltaCounters, idGenerator);
ReadOperation readOperation = mapTaskExecutor.getReadOperation();
// Disable progress updates since its results are unused for streaming
// and involves starting a thread.
readOperation.setProgressUpdatePeriodMs(ReadOperation.DONT_UPDATE_PERIODICALLY);
Preconditions.checkState(mapTaskExecutor.supportsRestart(), "Streaming runner requires all operations support restart.");
Coder<?> readCoder;
readCoder = CloudObjects.coderFromCloudObject(CloudObject.fromSpec(readOutputNode.getInstructionOutput().getCodec()));
Coder<?> keyCoder = extractKeyCoder(readCoder);
// If using a custom source, count bytes read for autoscaling.
if (CustomSources.class.getName().equals(readNode.getParallelInstruction().getRead().getSource().getSpec().get("@type"))) {
NameContext nameContext = NameContext.create(mapTask.getStageName(), readNode.getParallelInstruction().getOriginalName(), readNode.getParallelInstruction().getSystemName(), readNode.getParallelInstruction().getName());
readOperation.receivers[0].addOutputCounter(new OutputObjectAndByteCounter(new IntrinsicMapTaskExecutorFactory.ElementByteSizeObservableCoder<>(readCoder), mapTaskExecutor.getOutputCounters(), nameContext).setSamplingPeriod(100).countBytes("dataflow_input_size-" + mapTask.getSystemName()));
}
executionState = new ExecutionState(mapTaskExecutor, context, keyCoder, executionStateTracker);
}
WindmillStateReader stateReader = new WindmillStateReader(metricTrackingWindmillServer, computationId, key, workItem.getShardingKey(), workItem.getWorkToken());
StateFetcher localStateFetcher = stateFetcher.byteTrackingView();
// If the read output KVs, then we can decode Windmill's byte key into a userland
// key object and provide it to the execution context for use with per-key state.
// Otherwise, we pass null.
//
// The coder type that will be present is:
// WindowedValueCoder(TimerOrElementCoder(KvCoder))
@Nullable Coder<?> keyCoder = executionState.getKeyCoder();
@Nullable Object executionKey = keyCoder == null ? null : keyCoder.decode(key.newInput(), Coder.Context.OUTER);
if (workItem.hasHotKeyInfo()) {
Windmill.HotKeyInfo hotKeyInfo = workItem.getHotKeyInfo();
Duration hotKeyAge = Duration.millis(hotKeyInfo.getHotKeyAgeUsec() / 1000);
// The MapTask instruction is ordered by dependencies, such that the first element is
// always going to be the shuffle task.
String stepName = computationState.getMapTask().getInstructions().get(0).getName();
if (options.isHotKeyLoggingEnabled() && keyCoder != null) {
hotKeyLogger.logHotKeyDetection(stepName, hotKeyAge, executionKey);
} else {
hotKeyLogger.logHotKeyDetection(stepName, hotKeyAge);
}
}
executionState.getContext().start(executionKey, workItem, inputDataWatermark, outputDataWatermark, synchronizedProcessingTime, stateReader, localStateFetcher, outputBuilder);
// Blocks while executing work.
executionState.getWorkExecutor().execute();
Iterables.addAll(this.pendingMonitoringInfos, executionState.getWorkExecutor().extractMetricUpdates());
commitCallbacks.putAll(executionState.getContext().flushState());
// Release the execution state for another thread to use.
computationState.getExecutionStateQueue(worker).offer(executionState);
executionState = null;
// Add the output to the commit queue.
work.setState(State.COMMIT_QUEUED);
WorkItemCommitRequest commitRequest = outputBuilder.build();
int byteLimit = maxWorkItemCommitBytes;
int commitSize = commitRequest.getSerializedSize();
int estimatedCommitSize = commitSize < 0 ? Integer.MAX_VALUE : commitSize;
// Detect overflow of integer serialized size or if the byte limit was exceeded.
windmillMaxObservedWorkItemCommitBytes.addValue(estimatedCommitSize);
if (commitSize < 0 || commitSize > byteLimit) {
KeyCommitTooLargeException e = KeyCommitTooLargeException.causedBy(computationId, byteLimit, commitRequest);
reportFailure(computationId, workItem, e);
LOG.error(e.toString());
// Drop the current request in favor of a new, minimal one requesting truncation.
// Messages, timers, counters, and other commit content will not be used by the service
// so we're purposefully dropping them here
commitRequest = buildWorkItemTruncationRequest(key, workItem, estimatedCommitSize);
}
commitQueue.put(new Commit(commitRequest, computationState, work));
// Compute shuffle and state byte statistics these will be flushed asynchronously.
long stateBytesWritten = outputBuilder.clearOutputMessages().build().getSerializedSize();
long shuffleBytesRead = 0;
for (Windmill.InputMessageBundle bundle : workItem.getMessageBundlesList()) {
for (Windmill.Message message : bundle.getMessagesList()) {
shuffleBytesRead += message.getSerializedSize();
}
}
long stateBytesRead = stateReader.getBytesRead() + localStateFetcher.getBytesRead();
windmillShuffleBytesRead.addValue(shuffleBytesRead);
windmillStateBytesRead.addValue(stateBytesRead);
windmillStateBytesWritten.addValue(stateBytesWritten);
LOG.debug("Processing done for work token: {}", workItem.getWorkToken());
} catch (Throwable t) {
if (executionState != null) {
try {
executionState.getContext().invalidateCache();
executionState.getWorkExecutor().close();
} catch (Exception e) {
LOG.warn("Failed to close map task executor: ", e);
} finally {
// Release references to potentially large objects early.
executionState = null;
}
}
t = t instanceof UserCodeException ? t.getCause() : t;
boolean retryLocally = false;
if (KeyTokenInvalidException.isKeyTokenInvalidException(t)) {
LOG.debug("Execution of work for computation '{}' on key '{}' failed due to token expiration. " + "Work will not be retried locally.", computationId, key.toStringUtf8());
} else {
LastExceptionDataProvider.reportException(t);
LOG.debug("Failed work: {}", work);
Duration elapsedTimeSinceStart = new Duration(Instant.now(), work.getStartTime());
if (!reportFailure(computationId, workItem, t)) {
LOG.error("Execution of work for computation '{}' on key '{}' failed with uncaught exception, " + "and Windmill indicated not to retry locally.", computationId, key.toStringUtf8(), t);
} else if (isOutOfMemoryError(t)) {
File heapDump = memoryMonitor.tryToDumpHeap();
LOG.error("Execution of work for computation '{}' for key '{}' failed with out-of-memory. " + "Work will not be retried locally. Heap dump {}.", computationId, key.toStringUtf8(), heapDump == null ? "not written" : ("written to '" + heapDump + "'"), t);
} else if (elapsedTimeSinceStart.isLongerThan(MAX_LOCAL_PROCESSING_RETRY_DURATION)) {
LOG.error("Execution of work for computation '{}' for key '{}' failed with uncaught exception, " + "and it will not be retried locally because the elapsed time since start {} " + "exceeds {}.", computationId, key.toStringUtf8(), elapsedTimeSinceStart, MAX_LOCAL_PROCESSING_RETRY_DURATION, t);
} else {
LOG.error("Execution of work for computation '{}' on key '{}' failed with uncaught exception. " + "Work will be retried locally.", computationId, key.toStringUtf8(), t);
retryLocally = true;
}
}
if (retryLocally) {
// Try again after some delay and at the end of the queue to avoid a tight loop.
sleep(retryLocallyDelayMs);
workUnitExecutor.forceExecute(work, work.getWorkItem().getSerializedSize());
} else {
// Consider the item invalid. It will eventually be retried by Windmill if it still needs to
// be processed.
computationState.completeWork(ShardedKey.create(key, workItem.getShardingKey()), workItem.getWorkToken());
}
} finally {
// Update total processing time counters. Updating in finally clause ensures that
// work items causing exceptions are also accounted in time spent.
long processingTimeMsecs = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - processingStartTimeNanos);
stageInfo.totalProcessingMsecs.addValue(processingTimeMsecs);
// either here or in DFE.
if (work.getWorkItem().hasTimers()) {
stageInfo.timerProcessingMsecs.addValue(processingTimeMsecs);
}
DataflowWorkerLoggingMDC.setWorkId(null);
DataflowWorkerLoggingMDC.setStageName(null);
}
}
use of org.apache.beam.vendor.guava.v26_0_jre.com.google.common.graph.MutableNetwork in project beam by apache.
the class CreateRegisterFnOperationFunctionTest method testRunnerAndSdkToRunnerAndSdkGraph.
@Test
public void testRunnerAndSdkToRunnerAndSdkGraph() {
// RunnerSource --\ /--> RunnerParDo
// out
// CustomSource --/ \--> SdkParDo
//
// Should produce:
// PortB --> out --\
// RunnerSource --> out --> RunnerParDo
// \--> PortA
// PortA --> out --\
// CustomSource --> out --> SdkParDo
// \--> PortB
Node firstSdkPortion = TestNode.create("FirstSdkPortion");
Node secondSdkPortion = TestNode.create("SecondSdkPortion");
@SuppressWarnings({ "unchecked", "rawtypes" }) ArgumentCaptor<MutableNetwork<Node, Edge>> networkCapture = ArgumentCaptor.forClass((Class) MutableNetwork.class);
when(registerFnOperationFunction.apply(networkCapture.capture())).thenReturn(firstSdkPortion, secondSdkPortion);
Node firstPort = TestNode.create("FirstPort");
Node secondPort = TestNode.create("SecondPort");
when(portSupplier.get()).thenReturn(firstPort, secondPort);
Node runnerReadNode = createReadNode("RunnerRead", Nodes.ExecutionLocation.RUNNER_HARNESS);
Edge runnerReadNodeEdge = DefaultEdge.create();
Node sdkReadNode = createReadNode("SdkRead", Nodes.ExecutionLocation.SDK_HARNESS);
Edge sdkReadNodeEdge = DefaultEdge.create();
Node readNodeOut = createInstructionOutputNode("Read.out");
Edge readNodeOutToRunnerEdge = DefaultEdge.create();
Edge readNodeOutToSdkEdge = DefaultEdge.create();
Node runnerParDoNode = createParDoNode("RunnerParDo", Nodes.ExecutionLocation.RUNNER_HARNESS);
Edge runnerParDoNodeEdge = DefaultEdge.create();
Node runnerParDoNodeOut = createInstructionOutputNode("RunnerParDo.out");
Node sdkParDoNode = createParDoNode("SdkParDo", Nodes.ExecutionLocation.SDK_HARNESS);
Edge sdkParDoNodeEdge = DefaultEdge.create();
Node sdkParDoNodeOut = createInstructionOutputNode("SdkParDo.out");
// Read -out-> RunnerParDo -out-> SdkParDo
MutableNetwork<Node, Edge> network = createEmptyNetwork();
network.addNode(sdkReadNode);
network.addNode(runnerReadNode);
network.addNode(readNodeOut);
network.addNode(runnerParDoNode);
network.addNode(runnerParDoNodeOut);
network.addNode(sdkParDoNodeOut);
network.addNode(sdkParDoNodeOut);
network.addEdge(sdkReadNode, readNodeOut, sdkReadNodeEdge);
network.addEdge(runnerReadNode, readNodeOut, runnerReadNodeEdge);
network.addEdge(readNodeOut, runnerParDoNode, readNodeOutToRunnerEdge);
network.addEdge(readNodeOut, sdkParDoNode, readNodeOutToSdkEdge);
network.addEdge(runnerParDoNode, runnerParDoNodeOut, runnerParDoNodeEdge);
network.addEdge(sdkParDoNode, sdkParDoNodeOut, sdkParDoNodeEdge);
MutableNetwork<Node, Edge> appliedNetwork = createRegisterFnOperation.apply(Graphs.copyOf(network));
assertNetworkMaintainsBipartiteStructure(appliedNetwork);
// Node wiring is indeterministic, must be detected from generated graph.
Node sdkPortionA;
Node sdkPortionB;
if (appliedNetwork.inDegree(firstSdkPortion) == 0) {
sdkPortionA = firstSdkPortion;
sdkPortionB = secondSdkPortion;
} else {
sdkPortionA = secondSdkPortion;
sdkPortionB = firstSdkPortion;
}
Node portA = Iterables.getOnlyElement(appliedNetwork.successors(sdkPortionA));
Node portB = Iterables.getOnlyElement(appliedNetwork.predecessors(sdkPortionB));
// On each rewire between runner and SDK, we use a new output node
Node newOutA = Iterables.getOnlyElement(appliedNetwork.successors(portA));
Node newOutB = Iterables.getOnlyElement(appliedNetwork.predecessors(portB));
// sdkPortionA -> portA -newOutA-> runnerParDoNode -> runnerParDoNodeOut
// runnerReadNode -newOutB-/
// \--> portB -> sdkPortionB
assertThat(appliedNetwork.nodes(), containsInAnyOrder(runnerReadNode, firstSdkPortion, secondSdkPortion, portA, newOutA, portB, newOutB, runnerParDoNode, runnerParDoNodeOut));
assertThat(appliedNetwork.successors(runnerReadNode), containsInAnyOrder(newOutB));
assertThat(appliedNetwork.successors(newOutB), containsInAnyOrder(runnerParDoNode, portB));
assertThat(appliedNetwork.successors(portB), containsInAnyOrder(sdkPortionB));
assertThat(appliedNetwork.successors(sdkPortionA), containsInAnyOrder(portA));
assertThat(appliedNetwork.successors(portA), containsInAnyOrder(newOutA));
assertThat(appliedNetwork.successors(newOutA), containsInAnyOrder(runnerParDoNode));
assertThat(appliedNetwork.successors(runnerParDoNode), containsInAnyOrder(runnerParDoNodeOut));
assertThat(appliedNetwork.edgesConnecting(sdkPortionA, portA), everyItem(Matchers.<Edges.Edge>instanceOf(HappensBeforeEdge.class)));
assertThat(appliedNetwork.edgesConnecting(portB, sdkPortionB), everyItem(Matchers.<Edges.Edge>instanceOf(HappensBeforeEdge.class)));
// Argument captor call order can be indeterministic
List<MutableNetwork<Node, Edge>> sdkSubnetworks = networkCapture.getAllValues();
MutableNetwork<Node, Edge> sdkSubnetworkA;
MutableNetwork<Node, Edge> sdkSubnetworkB;
if (sdkSubnetworks.get(0).nodes().contains(sdkReadNode)) {
sdkSubnetworkA = sdkSubnetworks.get(0);
sdkSubnetworkB = sdkSubnetworks.get(1);
} else {
sdkSubnetworkA = sdkSubnetworks.get(1);
sdkSubnetworkB = sdkSubnetworks.get(0);
}
assertNetworkMaintainsBipartiteStructure(sdkSubnetworkA);
assertNetworkMaintainsBipartiteStructure(sdkSubnetworkB);
// /-> portA
// sdkReadNode -sdkNewOutA-> sdkParDoNode -> sdkParDoNodeOut
Node sdkNewOutA = Iterables.getOnlyElement(sdkSubnetworkA.predecessors(portA));
assertThat(sdkSubnetworkA.nodes(), containsInAnyOrder(sdkReadNode, portA, sdkNewOutA, sdkParDoNode, sdkParDoNodeOut));
assertThat(sdkSubnetworkA.successors(sdkReadNode), containsInAnyOrder(sdkNewOutA));
assertThat(sdkSubnetworkA.successors(sdkNewOutA), containsInAnyOrder(portA, sdkParDoNode));
assertThat(sdkSubnetworkA.successors(sdkParDoNode), containsInAnyOrder(sdkParDoNodeOut));
// portB -sdkNewOutB-> sdkParDoNode -> sdkParDoNodeOut
Node sdkNewOutB = Iterables.getOnlyElement(sdkSubnetworkB.successors(portB));
assertThat(sdkSubnetworkB.nodes(), containsInAnyOrder(portB, sdkNewOutB, sdkParDoNode, sdkParDoNodeOut));
assertThat(sdkSubnetworkB.successors(portB), containsInAnyOrder(sdkNewOutB));
assertThat(sdkSubnetworkB.successors(sdkNewOutB), containsInAnyOrder(sdkParDoNode));
assertThat(sdkSubnetworkB.successors(sdkParDoNode), containsInAnyOrder(sdkParDoNodeOut));
}
use of org.apache.beam.vendor.guava.v26_0_jre.com.google.common.graph.MutableNetwork in project beam by apache.
the class CreateRegisterFnOperationFunctionTest method testRunnerToSdkToRunnerGraph.
@Test
public void testRunnerToSdkToRunnerGraph() {
Node sdkPortion = TestNode.create("SdkPortion");
@SuppressWarnings({ "unchecked", "rawtypes" }) ArgumentCaptor<MutableNetwork<Node, Edge>> networkCapture = ArgumentCaptor.forClass((Class) MutableNetwork.class);
when(registerFnOperationFunction.apply(networkCapture.capture())).thenReturn(sdkPortion);
Node firstPort = TestNode.create("FirstPort");
Node secondPort = TestNode.create("SecondPort");
when(portSupplier.get()).thenReturn(firstPort, secondPort);
Node readNode = createReadNode("Read", Nodes.ExecutionLocation.RUNNER_HARNESS);
Edge readNodeEdge = DefaultEdge.create();
Node readNodeOut = createInstructionOutputNode("Read.out");
Edge readNodeOutEdge = DefaultEdge.create();
Node sdkParDoNode = createParDoNode("SdkParDo", Nodes.ExecutionLocation.SDK_HARNESS);
Edge sdkParDoNodeEdge = DefaultEdge.create();
Node sdkParDoNodeOut = createInstructionOutputNode("SdkParDo.out");
Edge sdkParDoNodeOutEdge = DefaultEdge.create();
Node runnerParDoNode = createParDoNode("RunnerParDo", Nodes.ExecutionLocation.RUNNER_HARNESS);
Edge runnerParDoNodeEdge = DefaultEdge.create();
Node runnerParDoNodeOut = createInstructionOutputNode("RunnerParDo.out");
// Read -out-> SdkParDo -out-> RunnerParDo
MutableNetwork<Node, Edge> network = createEmptyNetwork();
network.addNode(readNode);
network.addNode(readNodeOut);
network.addNode(sdkParDoNodeOut);
network.addNode(sdkParDoNodeOut);
network.addNode(runnerParDoNode);
network.addNode(runnerParDoNodeOut);
network.addEdge(readNode, readNodeOut, readNodeEdge);
network.addEdge(readNodeOut, sdkParDoNode, readNodeOutEdge);
network.addEdge(sdkParDoNode, sdkParDoNodeOut, sdkParDoNodeEdge);
network.addEdge(sdkParDoNodeOut, runnerParDoNode, sdkParDoNodeOutEdge);
network.addEdge(runnerParDoNode, runnerParDoNodeOut, runnerParDoNodeEdge);
MutableNetwork<Node, Edge> appliedNetwork = createRegisterFnOperation.apply(Graphs.copyOf(network));
assertNetworkMaintainsBipartiteStructure(appliedNetwork);
// On each rewire between runner and SDK and vice versa, we use a new output node
Node newOutA = Iterables.getOnlyElement(appliedNetwork.predecessors(firstPort));
Node newOutB = Iterables.getOnlyElement(appliedNetwork.successors(secondPort));
// readNode -newOutA-> firstPort --> sdkPortion --> secondPort -newOutB-> runnerParDoNode
assertThat(appliedNetwork.nodes(), containsInAnyOrder(readNode, newOutA, firstPort, sdkPortion, secondPort, newOutB, runnerParDoNode, runnerParDoNodeOut));
assertThat(appliedNetwork.successors(readNode), containsInAnyOrder(newOutA));
assertThat(appliedNetwork.successors(newOutA), containsInAnyOrder(firstPort));
assertThat(appliedNetwork.successors(firstPort), containsInAnyOrder(sdkPortion));
assertThat(appliedNetwork.successors(sdkPortion), containsInAnyOrder(secondPort));
assertThat(appliedNetwork.successors(secondPort), containsInAnyOrder(newOutB));
assertThat(appliedNetwork.successors(newOutB), containsInAnyOrder(runnerParDoNode));
assertThat(appliedNetwork.successors(runnerParDoNode), containsInAnyOrder(runnerParDoNodeOut));
assertThat(appliedNetwork.edgesConnecting(firstPort, sdkPortion), everyItem(Matchers.<Edges.Edge>instanceOf(HappensBeforeEdge.class)));
assertThat(appliedNetwork.edgesConnecting(sdkPortion, secondPort), everyItem(Matchers.<Edges.Edge>instanceOf(HappensBeforeEdge.class)));
MutableNetwork<Node, Edge> sdkSubnetwork = networkCapture.getValue();
assertNetworkMaintainsBipartiteStructure(sdkSubnetwork);
Node sdkNewOutA = Iterables.getOnlyElement(sdkSubnetwork.successors(firstPort));
Node sdkNewOutB = Iterables.getOnlyElement(sdkSubnetwork.predecessors(secondPort));
// firstPort -sdkNewOutA-> sdkParDoNode -sdkNewOutB-> secondPort
assertThat(sdkSubnetwork.nodes(), containsInAnyOrder(firstPort, sdkNewOutA, sdkParDoNode, sdkNewOutB, secondPort));
assertThat(sdkSubnetwork.successors(firstPort), containsInAnyOrder(sdkNewOutA));
assertThat(sdkSubnetwork.successors(sdkNewOutA), containsInAnyOrder(sdkParDoNode));
assertThat(sdkSubnetwork.successors(sdkParDoNode), containsInAnyOrder(sdkNewOutB));
assertThat(sdkSubnetwork.successors(sdkNewOutB), containsInAnyOrder(secondPort));
}
use of org.apache.beam.vendor.guava.v26_0_jre.com.google.common.graph.MutableNetwork in project beam by apache.
the class CreateRegisterFnOperationFunctionTest method testSdkToRunnerToSdkGraph.
@Test
public void testSdkToRunnerToSdkGraph() {
Node firstSdkPortion = TestNode.create("FirstSdkPortion");
Node secondSdkPortion = TestNode.create("SecondSdkPortion");
@SuppressWarnings({ "unchecked", "rawtypes" }) ArgumentCaptor<MutableNetwork<Node, Edge>> networkCapture = ArgumentCaptor.forClass((Class) MutableNetwork.class);
when(registerFnOperationFunction.apply(networkCapture.capture())).thenReturn(firstSdkPortion, secondSdkPortion);
Node firstPort = TestNode.create("FirstPort");
Node secondPort = TestNode.create("SecondPort");
when(portSupplier.get()).thenReturn(firstPort, secondPort);
Node readNode = createReadNode("Read", Nodes.ExecutionLocation.SDK_HARNESS);
Edge readNodeEdge = DefaultEdge.create();
Node readNodeOut = createInstructionOutputNode("Read.out");
Edge readNodeOutEdge = DefaultEdge.create();
Node runnerParDoNode = createParDoNode("RunnerParDo", Nodes.ExecutionLocation.RUNNER_HARNESS);
Edge runnerParDoNodeEdge = DefaultEdge.create();
Node runnerParDoNodeOut = createInstructionOutputNode("RunnerParDo.out");
Edge runnerParDoNodeOutEdge = DefaultEdge.create();
Node sdkParDoNode = createParDoNode("SdkParDo", Nodes.ExecutionLocation.SDK_HARNESS);
Edge sdkParDoNodeEdge = DefaultEdge.create();
Node sdkParDoNodeOut = createInstructionOutputNode("SdkParDo.out");
// Read -out-> RunnerParDo -out-> SdkParDo
MutableNetwork<Node, Edge> network = createEmptyNetwork();
network.addNode(readNode);
network.addNode(readNodeOut);
network.addNode(runnerParDoNode);
network.addNode(runnerParDoNodeOut);
network.addNode(sdkParDoNodeOut);
network.addNode(sdkParDoNodeOut);
network.addEdge(readNode, readNodeOut, readNodeEdge);
network.addEdge(readNodeOut, runnerParDoNode, readNodeOutEdge);
network.addEdge(runnerParDoNode, runnerParDoNodeOut, runnerParDoNodeEdge);
network.addEdge(runnerParDoNodeOut, sdkParDoNode, runnerParDoNodeOutEdge);
network.addEdge(sdkParDoNode, sdkParDoNodeOut, sdkParDoNodeEdge);
MutableNetwork<Node, Edge> appliedNetwork = createRegisterFnOperation.apply(Graphs.copyOf(network));
assertNetworkMaintainsBipartiteStructure(appliedNetwork);
// On each rewire between runner and SDK, we use a new output node
Node newOutA = Iterables.getOnlyElement(appliedNetwork.successors(firstPort));
Node newOutB = Iterables.getOnlyElement(appliedNetwork.predecessors(secondPort));
// firstSdkPortion -> firstPort -newOutA-> RunnerParDo -newOutB-> secondPort -> secondSdkPortion
assertThat(appliedNetwork.nodes(), containsInAnyOrder(firstSdkPortion, firstPort, newOutA, runnerParDoNode, newOutB, secondPort, secondSdkPortion));
assertThat(appliedNetwork.successors(firstSdkPortion), containsInAnyOrder(firstPort));
assertThat(appliedNetwork.successors(firstPort), containsInAnyOrder(newOutA));
assertThat(appliedNetwork.successors(newOutA), containsInAnyOrder(runnerParDoNode));
assertThat(appliedNetwork.successors(runnerParDoNode), containsInAnyOrder(newOutB));
assertThat(appliedNetwork.successors(newOutB), containsInAnyOrder(secondPort));
assertThat(appliedNetwork.successors(secondPort), containsInAnyOrder(secondSdkPortion));
assertThat(appliedNetwork.edgesConnecting(firstSdkPortion, firstPort), everyItem(Matchers.<Edges.Edge>instanceOf(HappensBeforeEdge.class)));
assertThat(appliedNetwork.edgesConnecting(secondPort, secondSdkPortion), everyItem(Matchers.<Edges.Edge>instanceOf(HappensBeforeEdge.class)));
// The order of the calls to create the SDK subnetworks is indeterminate
List<MutableNetwork<Node, Edge>> sdkSubnetworks = networkCapture.getAllValues();
MutableNetwork<Node, Edge> firstSdkSubnetwork;
MutableNetwork<Node, Edge> secondSdkSubnetwork;
if (sdkSubnetworks.get(0).nodes().contains(readNode)) {
firstSdkSubnetwork = sdkSubnetworks.get(0);
secondSdkSubnetwork = sdkSubnetworks.get(1);
} else {
firstSdkSubnetwork = sdkSubnetworks.get(1);
secondSdkSubnetwork = sdkSubnetworks.get(0);
}
assertNetworkMaintainsBipartiteStructure(firstSdkSubnetwork);
assertNetworkMaintainsBipartiteStructure(secondSdkSubnetwork);
Node sdkNewOutA = Iterables.getOnlyElement(firstSdkSubnetwork.predecessors(firstPort));
// readNode -sdkNewOutA-> firstPort
assertThat(firstSdkSubnetwork.nodes(), containsInAnyOrder(readNode, sdkNewOutA, firstPort));
assertThat(firstSdkSubnetwork.successors(readNode), containsInAnyOrder(sdkNewOutA));
assertThat(firstSdkSubnetwork.successors(sdkNewOutA), containsInAnyOrder(firstPort));
Node sdkNewOutB = Iterables.getOnlyElement(secondSdkSubnetwork.successors(secondPort));
// secondPort -sdkNewOutB-> sdkParDoNode -> sdkParDoNodeOut
assertThat(secondSdkSubnetwork.nodes(), containsInAnyOrder(secondPort, sdkNewOutB, sdkParDoNode, sdkParDoNodeOut));
assertThat(secondSdkSubnetwork.successors(secondPort), containsInAnyOrder(sdkNewOutB));
assertThat(secondSdkSubnetwork.successors(sdkNewOutB), containsInAnyOrder(sdkParDoNode));
assertThat(secondSdkSubnetwork.successors(sdkParDoNode), containsInAnyOrder(sdkParDoNodeOut));
}
Aggregations