use of org.apache.beam.runners.dataflow.worker.graph.Nodes.InstructionOutputNode in project beam by apache.
the class CreateRegisterFnOperationFunction method rewireAcrossSdkRunnerPortNode.
/**
* Rewires the given set of predecessors and successors across a gRPC port surrounded by output
* nodes. Edges to the remaining successors are copied over to the new output node that is placed
* before the port node. For example:
*
* <pre><code>
* predecessors --> outputNode --> successors
* \--> existingSuccessors
* </pre></code> becomes:
*
* <pre><code>
*
* outputNode -------------------------------\
* \ \
* |-> existingSuccessors \
* / \
* predecessors --> newPredecessorOutputNode --> portNode --> portOutputNode --> successors}.
* </code></pre>
*/
private Node rewireAcrossSdkRunnerPortNode(MutableNetwork<Node, Edge> network, InstructionOutputNode outputNode, Set<Node> predecessors, Set<Node> successors) {
InstructionOutputNode newPredecessorOutputNode = InstructionOutputNode.create(outputNode.getInstructionOutput(), outputNode.getPcollectionId());
InstructionOutputNode portOutputNode = InstructionOutputNode.create(outputNode.getInstructionOutput(), outputNode.getPcollectionId());
Node portNode = portSupplier.get();
network.addNode(newPredecessorOutputNode);
network.addNode(portNode);
for (Node predecessor : predecessors) {
for (Edge edge : ImmutableList.copyOf(network.edgesConnecting(predecessor, outputNode))) {
network.removeEdge(edge);
network.addEdge(predecessor, newPredecessorOutputNode, edge);
}
}
// Maintain edges for existing successors.
List<Node> existingSuccessors = ImmutableList.copyOf(Sets.difference(network.successors(outputNode), successors));
for (Node existingSuccessor : existingSuccessors) {
List<Edge> existingSuccessorEdges = ImmutableList.copyOf(network.edgesConnecting(outputNode, existingSuccessor));
for (Edge existingSuccessorEdge : existingSuccessorEdges) {
network.addEdge(newPredecessorOutputNode, existingSuccessor, existingSuccessorEdge.clone());
}
}
// Rewire the requested successors over the port node.
network.addEdge(newPredecessorOutputNode, portNode, MultiOutputInfoEdge.create(new MultiOutputInfo().setTag(idGenerator.getId())));
network.addEdge(portNode, portOutputNode, MultiOutputInfoEdge.create(new MultiOutputInfo().setTag(idGenerator.getId())));
for (Node successor : successors) {
for (Edge edge : ImmutableList.copyOf(network.edgesConnecting(outputNode, successor))) {
network.addEdge(portOutputNode, successor, edge.clone());
}
}
return portNode;
}
use of org.apache.beam.runners.dataflow.worker.graph.Nodes.InstructionOutputNode in project beam by apache.
the class CreateExecutableStageNodeFunction method apply.
@Override
public Node apply(MutableNetwork<Node, Edge> input) {
for (Node node : input.nodes()) {
if (node instanceof RemoteGrpcPortNode || node instanceof ParallelInstructionNode || node instanceof InstructionOutputNode) {
continue;
}
throw new IllegalArgumentException(String.format("Network contains unknown type of node: %s", input));
}
// Fix all non output nodes to have named edges.
for (Node node : input.nodes()) {
if (node instanceof InstructionOutputNode) {
continue;
}
for (Node successor : input.successors(node)) {
for (Edge edge : input.edgesConnecting(node, successor)) {
if (edge instanceof DefaultEdge) {
input.removeEdge(edge);
input.addEdge(node, successor, MultiOutputInfoEdge.create(new MultiOutputInfo().setTag(idGenerator.getId())));
}
}
}
}
RunnerApi.Components.Builder componentsBuilder = RunnerApi.Components.newBuilder();
componentsBuilder.mergeFrom(this.pipeline.getComponents());
// Default to use the Java environment if pipeline doesn't have environment specified.
if (pipeline.getComponents().getEnvironmentsMap().isEmpty()) {
String envId = Environments.JAVA_SDK_HARNESS_ENVIRONMENT.getUrn() + idGenerator.getId();
componentsBuilder.putEnvironments(envId, Environments.JAVA_SDK_HARNESS_ENVIRONMENT);
}
// By default, use GlobalWindow for all languages.
// For java, if there is a IntervalWindowCoder, then use FixedWindow instead.
// TODO: should get real WindowingStategy from pipeline proto.
String globalWindowingStrategyId = "generatedGlobalWindowingStrategy" + idGenerator.getId();
String intervalWindowEncodingWindowingStrategyId = "generatedIntervalWindowEncodingWindowingStrategy" + idGenerator.getId();
SdkComponents sdkComponents = SdkComponents.create(pipeline.getComponents(), null);
try {
registerWindowingStrategy(globalWindowingStrategyId, WindowingStrategy.globalDefault(), componentsBuilder, sdkComponents);
registerWindowingStrategy(intervalWindowEncodingWindowingStrategyId, WindowingStrategy.of(FixedWindows.of(Duration.standardSeconds(1))), componentsBuilder, sdkComponents);
} catch (IOException exc) {
throw new RuntimeException("Could not convert default windowing stratey to proto", exc);
}
Map<Node, String> nodesToPCollections = new HashMap<>();
ImmutableMap.Builder<String, NameContext> ptransformIdToNameContexts = ImmutableMap.builder();
ImmutableMap.Builder<String, Iterable<SideInputInfo>> ptransformIdToSideInputInfos = ImmutableMap.builder();
ImmutableMap.Builder<String, Iterable<PCollectionView<?>>> ptransformIdToPCollectionViews = ImmutableMap.builder();
// A field of ExecutableStage which includes the PCollection goes to worker side.
Set<PCollectionNode> executableStageOutputs = new HashSet<>();
// A field of ExecutableStage which includes the PCollection goes to runner side.
Set<PCollectionNode> executableStageInputs = new HashSet<>();
for (InstructionOutputNode node : Iterables.filter(input.nodes(), InstructionOutputNode.class)) {
InstructionOutput instructionOutput = node.getInstructionOutput();
String coderId = "generatedCoder" + idGenerator.getId();
String windowingStrategyId;
try (ByteString.Output output = ByteString.newOutput()) {
try {
Coder<?> javaCoder = CloudObjects.coderFromCloudObject(CloudObject.fromSpec(instructionOutput.getCodec()));
Coder<?> elementCoder = ((WindowedValueCoder<?>) javaCoder).getValueCoder();
sdkComponents.registerCoder(elementCoder);
RunnerApi.Coder coderProto = CoderTranslation.toProto(elementCoder, sdkComponents);
componentsBuilder.putCoders(coderId, coderProto);
// For now, Dataflow runner harness only deal with FixedWindow.
if (javaCoder instanceof FullWindowedValueCoder) {
FullWindowedValueCoder<?> windowedValueCoder = (FullWindowedValueCoder<?>) javaCoder;
Coder<?> windowCoder = windowedValueCoder.getWindowCoder();
if (windowCoder instanceof IntervalWindowCoder) {
windowingStrategyId = intervalWindowEncodingWindowingStrategyId;
} else if (windowCoder instanceof GlobalWindow.Coder) {
windowingStrategyId = globalWindowingStrategyId;
} else {
throw new UnsupportedOperationException(String.format("Dataflow portable runner harness doesn't support windowing with %s", windowCoder));
}
} else {
throw new UnsupportedOperationException("Dataflow portable runner harness only supports FullWindowedValueCoder");
}
} catch (IOException e) {
throw new IllegalArgumentException(String.format("Unable to encode coder %s for output %s", instructionOutput.getCodec(), instructionOutput), e);
} catch (Exception e) {
// Coder probably wasn't a java coder
OBJECT_MAPPER.writeValue(output, instructionOutput.getCodec());
componentsBuilder.putCoders(coderId, RunnerApi.Coder.newBuilder().setSpec(RunnerApi.FunctionSpec.newBuilder().setPayload(output.toByteString())).build());
// For non-java coder, hope it's GlobalWindows by default.
// TODO(BEAM-6231): Actually discover the right windowing strategy.
windowingStrategyId = globalWindowingStrategyId;
}
} catch (IOException e) {
throw new IllegalArgumentException(String.format("Unable to encode coder %s for output %s", instructionOutput.getCodec(), instructionOutput), e);
}
// TODO(BEAM-6275): Set correct IsBounded on generated PCollections
String pcollectionId = node.getPcollectionId();
RunnerApi.PCollection pCollection = RunnerApi.PCollection.newBuilder().setCoderId(coderId).setWindowingStrategyId(windowingStrategyId).setIsBounded(RunnerApi.IsBounded.Enum.BOUNDED).build();
nodesToPCollections.put(node, pcollectionId);
componentsBuilder.putPcollections(pcollectionId, pCollection);
// is set
if (isExecutableStageOutputPCollection(input, node)) {
executableStageOutputs.add(PipelineNode.pCollection(pcollectionId, pCollection));
}
if (isExecutableStageInputPCollection(input, node)) {
executableStageInputs.add(PipelineNode.pCollection(pcollectionId, pCollection));
}
}
componentsBuilder.putAllCoders(sdkComponents.toComponents().getCodersMap());
Set<PTransformNode> executableStageTransforms = new HashSet<>();
Set<TimerReference> executableStageTimers = new HashSet<>();
List<UserStateId> userStateIds = new ArrayList<>();
Set<SideInputReference> executableStageSideInputs = new HashSet<>();
for (ParallelInstructionNode node : Iterables.filter(input.nodes(), ParallelInstructionNode.class)) {
ImmutableMap.Builder<String, PCollectionNode> sideInputIds = ImmutableMap.builder();
ParallelInstruction parallelInstruction = node.getParallelInstruction();
String ptransformId = "generatedPtransform" + idGenerator.getId();
ptransformIdToNameContexts.put(ptransformId, NameContext.create(null, parallelInstruction.getOriginalName(), parallelInstruction.getSystemName(), parallelInstruction.getName()));
RunnerApi.PTransform.Builder pTransform = RunnerApi.PTransform.newBuilder();
RunnerApi.FunctionSpec.Builder transformSpec = RunnerApi.FunctionSpec.newBuilder();
List<String> timerIds = new ArrayList<>();
if (parallelInstruction.getParDo() != null) {
ParDoInstruction parDoInstruction = parallelInstruction.getParDo();
CloudObject userFnSpec = CloudObject.fromSpec(parDoInstruction.getUserFn());
String userFnClassName = userFnSpec.getClassName();
if (userFnClassName.equals("CombineValuesFn") || userFnClassName.equals("KeyedCombineFn")) {
transformSpec = transformCombineValuesFnToFunctionSpec(userFnSpec);
ptransformIdToPCollectionViews.put(ptransformId, Collections.emptyList());
} else {
String parDoPTransformId = getString(userFnSpec, PropertyNames.SERIALIZED_FN);
RunnerApi.PTransform parDoPTransform = pipeline.getComponents().getTransformsOrDefault(parDoPTransformId, null);
// TODO: only the non-null branch should exist; for migration ease only
if (parDoPTransform != null) {
checkArgument(parDoPTransform.getSpec().getUrn().equals(PTransformTranslation.PAR_DO_TRANSFORM_URN), "Found transform \"%s\" for ParallelDo instruction, " + " but that transform had unexpected URN \"%s\" (expected \"%s\")", parDoPTransformId, parDoPTransform.getSpec().getUrn(), PTransformTranslation.PAR_DO_TRANSFORM_URN);
RunnerApi.ParDoPayload parDoPayload;
try {
parDoPayload = RunnerApi.ParDoPayload.parseFrom(parDoPTransform.getSpec().getPayload());
} catch (InvalidProtocolBufferException exc) {
throw new RuntimeException("ParDo did not have a ParDoPayload", exc);
}
// user timers and user state.
for (Map.Entry<String, RunnerApi.TimerFamilySpec> entry : parDoPayload.getTimerFamilySpecsMap().entrySet()) {
timerIds.add(entry.getKey());
}
for (Map.Entry<String, RunnerApi.StateSpec> entry : parDoPayload.getStateSpecsMap().entrySet()) {
UserStateId.Builder builder = UserStateId.newBuilder();
builder.setTransformId(parDoPTransformId);
builder.setLocalName(entry.getKey());
userStateIds.add(builder.build());
}
// To facilitate the creation of Set executableStageSideInputs.
for (String sideInputTag : parDoPayload.getSideInputsMap().keySet()) {
String sideInputPCollectionId = parDoPTransform.getInputsOrThrow(sideInputTag);
RunnerApi.PCollection sideInputPCollection = pipeline.getComponents().getPcollectionsOrThrow(sideInputPCollectionId);
pTransform.putInputs(sideInputTag, sideInputPCollectionId);
PCollectionNode pCollectionNode = PipelineNode.pCollection(sideInputPCollectionId, sideInputPCollection);
sideInputIds.put(sideInputTag, pCollectionNode);
}
// To facilitate the creation of Map(ptransformId -> pCollectionView), which is
// required by constructing an ExecutableStageNode.
ImmutableList.Builder<PCollectionView<?>> pcollectionViews = ImmutableList.builder();
for (Map.Entry<String, RunnerApi.SideInput> sideInputEntry : parDoPayload.getSideInputsMap().entrySet()) {
pcollectionViews.add(RegisterNodeFunction.transformSideInputForRunner(pipeline, parDoPTransform, sideInputEntry.getKey(), sideInputEntry.getValue()));
}
ptransformIdToPCollectionViews.put(ptransformId, pcollectionViews.build());
transformSpec.setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN).setPayload(parDoPayload.toByteString());
} else {
// legacy path - bytes are the FunctionSpec's payload field, basically, and
// SDKs expect it in the PTransform's payload field
byte[] userFnBytes = getBytes(userFnSpec, PropertyNames.SERIALIZED_FN);
transformSpec.setUrn(ParDoTranslation.CUSTOM_JAVA_DO_FN_URN).setPayload(ByteString.copyFrom(userFnBytes));
}
if (parDoInstruction.getSideInputs() != null) {
ptransformIdToSideInputInfos.put(ptransformId, forSideInputInfos(parDoInstruction.getSideInputs(), true));
}
}
} else if (parallelInstruction.getRead() != null) {
ReadInstruction readInstruction = parallelInstruction.getRead();
CloudObject sourceSpec = CloudObject.fromSpec(CloudSourceUtils.flattenBaseSpecs(readInstruction.getSource()).getSpec());
// TODO: Need to plumb through the SDK specific function spec.
transformSpec.setUrn(JAVA_SOURCE_URN);
try {
byte[] serializedSource = Base64.getDecoder().decode(getString(sourceSpec, SERIALIZED_SOURCE));
ByteString sourceByteString = ByteString.copyFrom(serializedSource);
transformSpec.setPayload(sourceByteString);
} catch (Exception e) {
throw new IllegalArgumentException(String.format("Unable to process Read %s", parallelInstruction), e);
}
} else if (parallelInstruction.getFlatten() != null) {
transformSpec.setUrn(PTransformTranslation.FLATTEN_TRANSFORM_URN);
} else {
throw new IllegalArgumentException(String.format("Unknown type of ParallelInstruction %s", parallelInstruction));
}
// predecessor in a ParDo. This PCollection is called the "main input".
for (Node predecessorOutput : input.predecessors(node)) {
pTransform.putInputs("generatedInput" + idGenerator.getId(), nodesToPCollections.get(predecessorOutput));
}
for (Edge edge : input.outEdges(node)) {
Node nodeOutput = input.incidentNodes(edge).target();
MultiOutputInfoEdge edge2 = (MultiOutputInfoEdge) edge;
pTransform.putOutputs(edge2.getMultiOutputInfo().getTag(), nodesToPCollections.get(nodeOutput));
}
pTransform.setSpec(transformSpec);
PTransformNode pTransformNode = PipelineNode.pTransform(ptransformId, pTransform.build());
executableStageTransforms.add(pTransformNode);
for (String timerId : timerIds) {
executableStageTimers.add(TimerReference.of(pTransformNode, timerId));
}
ImmutableMap<String, PCollectionNode> sideInputIdToPCollectionNodes = sideInputIds.build();
for (String sideInputTag : sideInputIdToPCollectionNodes.keySet()) {
SideInputReference sideInputReference = SideInputReference.of(pTransformNode, sideInputTag, sideInputIdToPCollectionNodes.get(sideInputTag));
executableStageSideInputs.add(sideInputReference);
}
executableStageTransforms.add(pTransformNode);
}
if (executableStageInputs.size() != 1) {
throw new UnsupportedOperationException("ExecutableStage only support one input PCollection");
}
PCollectionNode executableInput = executableStageInputs.iterator().next();
RunnerApi.Components executableStageComponents = componentsBuilder.build();
// Get Environment from ptransform, otherwise, use JAVA_SDK_HARNESS_ENVIRONMENT as default.
Environment executableStageEnv = getEnvironmentFromPTransform(executableStageComponents, executableStageTransforms);
if (executableStageEnv == null) {
executableStageEnv = Environments.JAVA_SDK_HARNESS_ENVIRONMENT;
}
Set<UserStateReference> executableStageUserStateReference = new HashSet<>();
for (UserStateId userStateId : userStateIds) {
executableStageUserStateReference.add(UserStateReference.fromUserStateId(userStateId, executableStageComponents));
}
ExecutableStage executableStage = ImmutableExecutableStage.ofFullComponents(executableStageComponents, executableStageEnv, executableInput, executableStageSideInputs, executableStageUserStateReference, executableStageTimers, executableStageTransforms, executableStageOutputs, DEFAULT_WIRE_CODER_SETTINGS);
return ExecutableStageNode.create(executableStage, ptransformIdToNameContexts.build(), ptransformIdToSideInputInfos.build(), ptransformIdToPCollectionViews.build());
}
use of org.apache.beam.runners.dataflow.worker.graph.Nodes.InstructionOutputNode in project beam by apache.
the class CreateRegisterFnOperationFunction method apply.
@Override
public MutableNetwork<Node, Edge> apply(MutableNetwork<Node, Edge> network) {
// Record all SDK nodes, and all root nodes.
Set<Node> runnerRootNodes = new HashSet<>();
Set<Node> sdkNodes = new HashSet<>();
Set<Node> sdkRootNodes = new HashSet<>();
for (ParallelInstructionNode node : Iterables.filter(network.nodes(), ParallelInstructionNode.class)) {
if (executesInSdkHarness(node)) {
sdkNodes.add(node);
if (network.inDegree(node) == 0) {
sdkRootNodes.add(node);
}
} else if (network.inDegree(node) == 0) {
runnerRootNodes.add(node);
}
}
// If nothing executes within the SDK harness, return the original network.
if (sdkNodes.isEmpty()) {
return network;
}
// Represents the set of nodes which represent gRPC boundaries from the Runner to the SDK.
Set<Node> runnerToSdkBoundaries = new HashSet<>();
// Represents the set of nodes which represent gRPC boundaries from the SDK to the Runner.
Set<Node> sdkToRunnerBoundaries = new HashSet<>();
ImmutableNetwork<Node, Edge> originalNetwork = ImmutableNetwork.copyOf(network);
// flow from runner to SDK and SDK to runner per original output node.
for (InstructionOutputNode outputNode : Iterables.filter(originalNetwork.nodes(), InstructionOutputNode.class)) {
// Categorize all predecessor instructions
Set<Node> predecessorRunnerInstructions = new HashSet<>();
Set<Node> predecessorSdkInstructions = new HashSet<>();
for (Node predecessorInstruction : originalNetwork.predecessors(outputNode)) {
if (sdkNodes.contains(predecessorInstruction)) {
predecessorSdkInstructions.add(predecessorInstruction);
} else {
predecessorRunnerInstructions.add(predecessorInstruction);
}
}
// Categorize all successor instructions
Set<Node> successorRunnerInstructions = new HashSet<>();
Set<Node> successorSdkInstructions = new HashSet<>();
for (Node successorInstruction : originalNetwork.successors(outputNode)) {
if (sdkNodes.contains(successorInstruction)) {
successorSdkInstructions.add(successorInstruction);
} else {
successorRunnerInstructions.add(successorInstruction);
}
}
// nodes connected across a gRPC node. Also add the gRPC node as an SDK root.
if (!predecessorRunnerInstructions.isEmpty() && !successorSdkInstructions.isEmpty()) {
runnerToSdkBoundaries.add(rewireAcrossSdkRunnerPortNode(network, outputNode, predecessorRunnerInstructions, successorSdkInstructions));
}
// nodes connected across a gRPC node.
if (!predecessorSdkInstructions.isEmpty() && !successorRunnerInstructions.isEmpty()) {
sdkToRunnerBoundaries.add(rewireAcrossSdkRunnerPortNode(network, outputNode, predecessorSdkInstructions, successorRunnerInstructions));
}
// through the new output node.
if (network.inDegree(outputNode) == 0) {
network.removeNode(outputNode);
}
}
// Create the subnetworks that represent potentially multiple fused SDK portions and a single
// fused Runner portion replacing the SDK portion that is embedded within the Runner portion
// with a RegisterFnOperation, adding edges to maintain proper happens before relationships.
Set<Node> allRunnerNodes = Networks.reachableNodes(network, Sets.union(runnerRootNodes, sdkToRunnerBoundaries), runnerToSdkBoundaries);
if (this.useExecutableStageBundleExecution) {
// When using shared library, there is no grpc node in runner graph.
allRunnerNodes = Sets.difference(allRunnerNodes, Sets.union(runnerToSdkBoundaries, sdkToRunnerBoundaries));
}
MutableNetwork<Node, Edge> runnerNetwork = Graphs.inducedSubgraph(network, allRunnerNodes);
// using poison paths.
for (Node sdkRoot : Sets.union(sdkRootNodes, runnerToSdkBoundaries)) {
Set<Node> sdkSubnetworkNodes = Networks.reachableNodes(network, ImmutableSet.of(sdkRoot), sdkToRunnerBoundaries);
MutableNetwork<Node, Edge> sdkNetwork = Graphs.inducedSubgraph(network, sdkSubnetworkNodes);
Node registerFnNode = registerFnOperationFunction.apply(sdkNetwork);
runnerNetwork.addNode(registerFnNode);
// a successor.
if (this.useExecutableStageBundleExecution) {
// should be linked directly to 2 OutputInstruction nodes.
for (Node predecessor : Sets.intersection(sdkSubnetworkNodes, runnerToSdkBoundaries)) {
predecessor = network.predecessors(predecessor).iterator().next();
runnerNetwork.addEdge(predecessor, registerFnNode, HappensBeforeEdge.create());
}
for (Node successor : Sets.intersection(sdkSubnetworkNodes, sdkToRunnerBoundaries)) {
successor = network.successors(successor).iterator().next();
runnerNetwork.addEdge(registerFnNode, successor, HappensBeforeEdge.create());
}
} else {
for (Node predecessor : Sets.intersection(sdkSubnetworkNodes, runnerToSdkBoundaries)) {
runnerNetwork.addEdge(predecessor, registerFnNode, HappensBeforeEdge.create());
}
for (Node successor : Sets.intersection(sdkSubnetworkNodes, sdkToRunnerBoundaries)) {
runnerNetwork.addEdge(registerFnNode, successor, HappensBeforeEdge.create());
}
}
}
return runnerNetwork;
}
use of org.apache.beam.runners.dataflow.worker.graph.Nodes.InstructionOutputNode in project beam by apache.
the class InsertFetchAndFilterStreamingSideInputNodes method forNetwork.
public MutableNetwork<Node, Edge> forNetwork(MutableNetwork<Node, Edge> network) {
if (pipeline == null) {
return network;
}
RehydratedComponents rehydratedComponents = RehydratedComponents.forComponents(pipeline.getComponents());
for (ParallelInstructionNode node : ImmutableList.copyOf(Iterables.filter(network.nodes(), ParallelInstructionNode.class))) {
// to worry about it.
if (node.getParallelInstruction().getParDo() == null || !ExecutionLocation.SDK_HARNESS.equals(node.getExecutionLocation())) {
continue;
}
ParDoInstruction parDoInstruction = node.getParallelInstruction().getParDo();
CloudObject userFnSpec = CloudObject.fromSpec(parDoInstruction.getUserFn());
String parDoPTransformId = getString(userFnSpec, PropertyNames.SERIALIZED_FN);
// Skip ParDoInstruction nodes that contain payloads without side inputs.
String userFnClassName = userFnSpec.getClassName();
if ("CombineValuesFn".equals(userFnClassName) || "KeyedCombineFn".equals(userFnClassName)) {
// These nodes have CombinePayloads which have no side inputs.
continue;
}
RunnerApi.PTransform parDoPTransform = pipeline.getComponents().getTransformsOrDefault(parDoPTransformId, null);
// TODO: only the non-null branch should exist; for migration ease only
if (parDoPTransform == null) {
continue;
}
RunnerApi.ParDoPayload parDoPayload;
try {
parDoPayload = RunnerApi.ParDoPayload.parseFrom(parDoPTransform.getSpec().getPayload());
} catch (InvalidProtocolBufferException exc) {
throw new RuntimeException("ParDo did not have a ParDoPayload", exc);
}
// Skip any ParDo that doesn't have a side input.
if (parDoPayload.getSideInputsMap().isEmpty()) {
continue;
}
String mainInputPCollectionLocalName = Iterables.getOnlyElement(Sets.difference(parDoPTransform.getInputsMap().keySet(), parDoPayload.getSideInputsMap().keySet()));
RunnerApi.WindowingStrategy windowingStrategyProto = pipeline.getComponents().getWindowingStrategiesOrThrow(pipeline.getComponents().getPcollectionsOrThrow(parDoPTransform.getInputsOrThrow(mainInputPCollectionLocalName)).getWindowingStrategyId());
WindowingStrategy windowingStrategy;
try {
windowingStrategy = WindowingStrategyTranslation.fromProto(windowingStrategyProto, rehydratedComponents);
} catch (InvalidProtocolBufferException e) {
throw new IllegalStateException(String.format("Unable to decode windowing strategy %s.", windowingStrategyProto), e);
}
// Gather all the side input window mapping fns which we need to request the SDK to map
ImmutableMap.Builder<PCollectionView<?>, RunnerApi.FunctionSpec> pCollectionViewsToWindowMapingsFns = ImmutableMap.builder();
parDoPayload.getSideInputsMap().forEach((sideInputTag, sideInput) -> pCollectionViewsToWindowMapingsFns.put(RegisterNodeFunction.transformSideInputForRunner(pipeline, parDoPTransform, sideInputTag, sideInput), sideInput.getWindowMappingFn()));
Node streamingSideInputWindowHandlerNode = FetchAndFilterStreamingSideInputsNode.create(windowingStrategy, pCollectionViewsToWindowMapingsFns.build(), NameContext.create(null, node.getParallelInstruction().getOriginalName(), node.getParallelInstruction().getSystemName(), node.getParallelInstruction().getName()));
// Rewire the graph such that streaming side inputs ParDos are preceded by a
// node which filters any side inputs that aren't ready and fetches any ready side inputs.
Edge mainInput = Iterables.getOnlyElement(network.inEdges(node));
InstructionOutputNode predecessor = (InstructionOutputNode) network.incidentNodes(mainInput).source();
InstructionOutputNode predecessorCopy = InstructionOutputNode.create(predecessor.getInstructionOutput(), predecessor.getPcollectionId());
network.removeEdge(mainInput);
network.addNode(streamingSideInputWindowHandlerNode);
network.addNode(predecessorCopy);
network.addEdge(predecessor, streamingSideInputWindowHandlerNode, mainInput.clone());
network.addEdge(streamingSideInputWindowHandlerNode, predecessorCopy, mainInput.clone());
network.addEdge(predecessorCopy, node, mainInput.clone());
}
return network;
}
use of org.apache.beam.runners.dataflow.worker.graph.Nodes.InstructionOutputNode in project beam by apache.
the class MapTaskToNetworkFunction method apply.
@Override
public MutableNetwork<Node, Edge> apply(MapTask mapTask) {
List<ParallelInstruction> parallelInstructions = Apiary.listOrEmpty(mapTask.getInstructions());
MutableNetwork<Node, Edge> network = NetworkBuilder.directed().allowsSelfLoops(false).allowsParallelEdges(true).expectedNodeCount(parallelInstructions.size() * 2).build();
// Add all the instruction nodes and output nodes
ParallelInstructionNode[] instructionNodes = new ParallelInstructionNode[parallelInstructions.size()];
InstructionOutputNode[][] outputNodes = new InstructionOutputNode[parallelInstructions.size()][];
for (int i = 0; i < parallelInstructions.size(); ++i) {
// InstructionOutputNode's are the source of truth on instruction outputs.
// Clear the instruction's outputs to reduce chance for confusion.
List<InstructionOutput> outputs = Apiary.listOrEmpty(parallelInstructions.get(i).getOutputs());
outputNodes[i] = new InstructionOutputNode[outputs.size()];
JsonFactory factory = MoreObjects.firstNonNull(mapTask.getFactory(), Transport.getJsonFactory());
ParallelInstruction parallelInstruction = clone(factory, parallelInstructions.get(i)).setOutputs(null);
ParallelInstructionNode instructionNode = ParallelInstructionNode.create(parallelInstruction, Nodes.ExecutionLocation.UNKNOWN);
instructionNodes[i] = instructionNode;
network.addNode(instructionNode);
// Connect the instruction node output to the output PCollection node
for (int j = 0; j < outputs.size(); ++j) {
InstructionOutput instructionOutput = outputs.get(j);
InstructionOutputNode outputNode = InstructionOutputNode.create(instructionOutput, "generatedPcollection" + this.idGenerator.getId());
network.addNode(outputNode);
if (parallelInstruction.getParDo() != null) {
network.addEdge(instructionNode, outputNode, MultiOutputInfoEdge.create(parallelInstruction.getParDo().getMultiOutputInfos().get(j)));
} else {
network.addEdge(instructionNode, outputNode, DefaultEdge.create());
}
outputNodes[i][j] = outputNode;
}
}
// Connect PCollections as inputs to instructions
for (ParallelInstructionNode instructionNode : instructionNodes) {
ParallelInstruction parallelInstruction = instructionNode.getParallelInstruction();
if (parallelInstruction.getFlatten() != null) {
for (InstructionInput input : Apiary.listOrEmpty(parallelInstruction.getFlatten().getInputs())) {
attachInput(input, network, instructionNode, outputNodes);
}
} else if (parallelInstruction.getParDo() != null) {
attachInput(parallelInstruction.getParDo().getInput(), network, instructionNode, outputNodes);
} else if (parallelInstruction.getPartialGroupByKey() != null) {
attachInput(parallelInstruction.getPartialGroupByKey().getInput(), network, instructionNode, outputNodes);
} else if (parallelInstruction.getRead() != null) {
// Reads have no inputs so nothing to do
} else if (parallelInstruction.getWrite() != null) {
attachInput(parallelInstruction.getWrite().getInput(), network, instructionNode, outputNodes);
} else {
throw new IllegalArgumentException(String.format("Unknown type of instruction %s for map task %s", parallelInstruction, mapTask));
}
}
return network;
}
Aggregations