use of org.apache.beam.runners.dataflow.worker.graph.Nodes.ParallelInstructionNode in project beam by apache.
the class MapTaskToNetworkFunction method apply.
@Override
public MutableNetwork<Node, Edge> apply(MapTask mapTask) {
List<ParallelInstruction> parallelInstructions = Apiary.listOrEmpty(mapTask.getInstructions());
MutableNetwork<Node, Edge> network = NetworkBuilder.directed().allowsSelfLoops(false).allowsParallelEdges(true).expectedNodeCount(parallelInstructions.size() * 2).build();
// Add all the instruction nodes and output nodes
ParallelInstructionNode[] instructionNodes = new ParallelInstructionNode[parallelInstructions.size()];
InstructionOutputNode[][] outputNodes = new InstructionOutputNode[parallelInstructions.size()][];
for (int i = 0; i < parallelInstructions.size(); ++i) {
// InstructionOutputNode's are the source of truth on instruction outputs.
// Clear the instruction's outputs to reduce chance for confusion.
List<InstructionOutput> outputs = Apiary.listOrEmpty(parallelInstructions.get(i).getOutputs());
outputNodes[i] = new InstructionOutputNode[outputs.size()];
JsonFactory factory = MoreObjects.firstNonNull(mapTask.getFactory(), Transport.getJsonFactory());
ParallelInstruction parallelInstruction = clone(factory, parallelInstructions.get(i)).setOutputs(null);
ParallelInstructionNode instructionNode = ParallelInstructionNode.create(parallelInstruction, Nodes.ExecutionLocation.UNKNOWN);
instructionNodes[i] = instructionNode;
network.addNode(instructionNode);
// Connect the instruction node output to the output PCollection node
for (int j = 0; j < outputs.size(); ++j) {
InstructionOutput instructionOutput = outputs.get(j);
InstructionOutputNode outputNode = InstructionOutputNode.create(instructionOutput, "generatedPcollection" + this.idGenerator.getId());
network.addNode(outputNode);
if (parallelInstruction.getParDo() != null) {
network.addEdge(instructionNode, outputNode, MultiOutputInfoEdge.create(parallelInstruction.getParDo().getMultiOutputInfos().get(j)));
} else {
network.addEdge(instructionNode, outputNode, DefaultEdge.create());
}
outputNodes[i][j] = outputNode;
}
}
// Connect PCollections as inputs to instructions
for (ParallelInstructionNode instructionNode : instructionNodes) {
ParallelInstruction parallelInstruction = instructionNode.getParallelInstruction();
if (parallelInstruction.getFlatten() != null) {
for (InstructionInput input : Apiary.listOrEmpty(parallelInstruction.getFlatten().getInputs())) {
attachInput(input, network, instructionNode, outputNodes);
}
} else if (parallelInstruction.getParDo() != null) {
attachInput(parallelInstruction.getParDo().getInput(), network, instructionNode, outputNodes);
} else if (parallelInstruction.getPartialGroupByKey() != null) {
attachInput(parallelInstruction.getPartialGroupByKey().getInput(), network, instructionNode, outputNodes);
} else if (parallelInstruction.getRead() != null) {
// Reads have no inputs so nothing to do
} else if (parallelInstruction.getWrite() != null) {
attachInput(parallelInstruction.getWrite().getInput(), network, instructionNode, outputNodes);
} else {
throw new IllegalArgumentException(String.format("Unknown type of instruction %s for map task %s", parallelInstruction, mapTask));
}
}
return network;
}
use of org.apache.beam.runners.dataflow.worker.graph.Nodes.ParallelInstructionNode in project beam by apache.
the class DeduceFlattenLocationsFunction method getConnectedNodeLocations.
/**
* A function which retrieves the aggregated location of a node's connecting nodes in one
* direction, either checking the target node's successors or predecessors. This is done by
* checking all the connected node's locations. For nodes that do not have locations embedded in
* the actual node (they may have unknown location or might not even be {@link
* ParallelInstructionNode}s) the location can be deduced by recursively checking that node's
* predecessors. To prevent a large amount of needless recursion a map is used for memoization;
* The results of this function will be stored in the map so that they can be retrieved later if
* needed without having to perform the recursions again.
*/
private AggregatedLocation getConnectedNodeLocations(Node node, MutableNetwork<Node, Edge> network, Map<Node, AggregatedLocation> connectedLocationsMap, SearchDirection direction) {
// First check the map
if (connectedLocationsMap.containsKey(node)) {
return connectedLocationsMap.get(node);
}
boolean hasSdkConnections = false;
boolean hasRunnerConnections = false;
Set<Node> connectedNodes;
if (direction == SearchDirection.SUCCESSORS) {
connectedNodes = network.successors(node);
} else {
connectedNodes = network.predecessors(node);
}
// work recurse this function to the unknown node.
for (Node connectedNode : connectedNodes) {
if (connectedNode instanceof ParallelInstructionNode && ((ParallelInstructionNode) connectedNode).getExecutionLocation() != ExecutionLocation.UNKNOWN) {
ExecutionLocation executionLocation = ((ParallelInstructionNode) connectedNode).getExecutionLocation();
switch(executionLocation) {
case SDK_HARNESS:
hasSdkConnections = true;
break;
case RUNNER_HARNESS:
hasRunnerConnections = true;
break;
case AMBIGUOUS:
hasSdkConnections = true;
hasRunnerConnections = true;
break;
default:
throw new IllegalStateException("Unknown case " + executionLocation);
}
} else {
AggregatedLocation connectedLocation = getConnectedNodeLocations(connectedNode, network, connectedLocationsMap, direction);
switch(connectedLocation) {
case SDK_HARNESS:
hasSdkConnections = true;
break;
case RUNNER_HARNESS:
hasRunnerConnections = true;
break;
case BOTH:
hasSdkConnections = true;
hasRunnerConnections = true;
break;
case NEITHER:
break;
default:
throw new IllegalStateException("Unknown case " + connectedLocation);
}
}
// need to continue checking.
if (hasSdkConnections && hasRunnerConnections) {
break;
}
}
// Return aggregated locations for this node's connections and store it in the map.
AggregatedLocation aggregatedLocation;
if (hasSdkConnections && hasRunnerConnections) {
aggregatedLocation = AggregatedLocation.BOTH;
} else if (hasSdkConnections) {
aggregatedLocation = AggregatedLocation.SDK_HARNESS;
} else if (hasRunnerConnections) {
aggregatedLocation = AggregatedLocation.RUNNER_HARNESS;
} else {
aggregatedLocation = AggregatedLocation.NEITHER;
}
connectedLocationsMap.put(node, aggregatedLocation);
return aggregatedLocation;
}
use of org.apache.beam.runners.dataflow.worker.graph.Nodes.ParallelInstructionNode in project beam by apache.
the class DeduceFlattenLocationsFunction method apply.
/**
* Deduces an {@link ExecutionLocation} for each flatten by first checking the locations of all
* the predecessors and successors to each node. These locations are aggregated to a single result
* representing all successors/predecessors. Once the aggregated location for both successors and
* predecessors are found they are used to determine the execution location of the flatten node
* itself and the flattens are replaced by copies that include the updated {@link
* ExecutionLocation}.
*/
@Override
public MutableNetwork<Node, Edge> apply(MutableNetwork<Node, Edge> network) {
Map<Node, AggregatedLocation> predecessorLocationsMap = new HashMap<>();
Map<Node, AggregatedLocation> successorLocationsMap = new HashMap<>();
Map<Node, ExecutionLocation> deducedLocationsMap = new HashMap<>();
ImmutableList<Node> flattens = ImmutableList.copyOf(Iterables.filter(network.nodes(), IsFlatten.INSTANCE));
// Find all predecessor and successor locations for every flatten.
for (Node flatten : flattens) {
AggregatedLocation predecessorLocations = getPredecessorLocations(flatten, network, predecessorLocationsMap);
AggregatedLocation successorLocations = getSuccessorLocations(flatten, network, successorLocationsMap);
deducedLocationsMap.put(flatten, DEDUCTION_TABLE.get(predecessorLocations, successorLocations));
}
// Actually set the locations of the flattens permanently.
Networks.replaceDirectedNetworkNodes(network, (Node node) -> {
if (!deducedLocationsMap.containsKey(node)) {
return node;
}
ParallelInstructionNode castNode = ((ParallelInstructionNode) node);
ExecutionLocation deducedLocation = deducedLocationsMap.get(node);
return ParallelInstructionNode.create(castNode.getParallelInstruction(), deducedLocation);
});
return network;
}
use of org.apache.beam.runners.dataflow.worker.graph.Nodes.ParallelInstructionNode in project beam by apache.
the class LengthPrefixUnknownCoders method andReplaceForParallelInstructionNode.
/**
* Replace unknown coders on the given {@link ParallelInstructionNode} with {@link
* org.apache.beam.sdk.coders.LengthPrefixCoder LengthPrefixCoder<T>} where {@code T} is a
* {@link org.apache.beam.sdk.coders.ByteArrayCoder}.
*/
private static Function<Node, Node> andReplaceForParallelInstructionNode() {
return new TypeSafeNodeFunction<ParallelInstructionNode>(ParallelInstructionNode.class) {
@Override
public Node typedApply(ParallelInstructionNode input) {
ParallelInstruction instruction = input.getParallelInstruction();
Nodes.ExecutionLocation location = input.getExecutionLocation();
try {
instruction = forParallelInstruction(instruction, true);
} catch (Exception e) {
throw new RuntimeException(String.format("Failed to replace unknown coder with " + "LengthPrefixCoder for : {%s}", input.getParallelInstruction()), e);
}
return ParallelInstructionNode.create(instruction, location);
}
};
}
use of org.apache.beam.runners.dataflow.worker.graph.Nodes.ParallelInstructionNode in project beam by apache.
the class RegisterNodeFunction method apply.
@Override
public Node apply(MutableNetwork<Node, Edge> input) {
for (Node node : input.nodes()) {
if (node instanceof RemoteGrpcPortNode || node instanceof ParallelInstructionNode || node instanceof InstructionOutputNode) {
continue;
}
throw new IllegalArgumentException(String.format("Network contains unknown type of node: %s", input));
}
// Fix all non output nodes to have named edges.
for (Node node : input.nodes()) {
if (node instanceof InstructionOutputNode) {
continue;
}
for (Node successor : input.successors(node)) {
for (Edge edge : input.edgesConnecting(node, successor)) {
if (edge instanceof DefaultEdge) {
input.removeEdge(edge);
input.addEdge(node, successor, MultiOutputInfoEdge.create(new MultiOutputInfo().setTag(idGenerator.getId())));
}
}
}
}
// We start off by replacing all edges within the graph with edges that have the named
// outputs from the predecessor step. For ParallelInstruction Source nodes and RemoteGrpcPort
// nodes this is a generated port id. All ParDoInstructions will have already
ProcessBundleDescriptor.Builder processBundleDescriptor = ProcessBundleDescriptor.newBuilder().setId(idGenerator.getId()).setStateApiServiceDescriptor(stateApiServiceDescriptor);
// For intermediate PCollections we fabricate, we make a bogus WindowingStrategy
// TODO: create a correct windowing strategy, including coders and environment
SdkComponents sdkComponents = SdkComponents.create(pipeline.getComponents(), null);
// Default to use the Java environment if pipeline doesn't have environment specified.
if (pipeline.getComponents().getEnvironmentsMap().isEmpty()) {
sdkComponents.registerEnvironment(Environments.JAVA_SDK_HARNESS_ENVIRONMENT);
}
String fakeWindowingStrategyId = "fakeWindowingStrategy" + idGenerator.getId();
try {
RunnerApi.MessageWithComponents fakeWindowingStrategyProto = WindowingStrategyTranslation.toMessageProto(WindowingStrategy.globalDefault(), sdkComponents);
processBundleDescriptor.putWindowingStrategies(fakeWindowingStrategyId, fakeWindowingStrategyProto.getWindowingStrategy()).putAllCoders(fakeWindowingStrategyProto.getComponents().getCodersMap()).putAllEnvironments(fakeWindowingStrategyProto.getComponents().getEnvironmentsMap());
} catch (IOException exc) {
throw new RuntimeException("Could not convert default windowing stratey to proto", exc);
}
Map<Node, String> nodesToPCollections = new HashMap<>();
ImmutableMap.Builder<String, NameContext> ptransformIdToNameContexts = ImmutableMap.builder();
ImmutableMap.Builder<String, Iterable<SideInputInfo>> ptransformIdToSideInputInfos = ImmutableMap.builder();
ImmutableMap.Builder<String, Iterable<PCollectionView<?>>> ptransformIdToPCollectionViews = ImmutableMap.builder();
ImmutableMap.Builder<String, NameContext> pcollectionIdToNameContexts = ImmutableMap.builder();
ImmutableMap.Builder<InstructionOutputNode, String> instructionOutputNodeToCoderIdBuilder = ImmutableMap.builder();
// 2. Generate new PCollectionId and register it with ProcessBundleDescriptor.
for (InstructionOutputNode node : Iterables.filter(input.nodes(), InstructionOutputNode.class)) {
InstructionOutput instructionOutput = node.getInstructionOutput();
String coderId = "generatedCoder" + idGenerator.getId();
instructionOutputNodeToCoderIdBuilder.put(node, coderId);
try (ByteString.Output output = ByteString.newOutput()) {
try {
Coder<?> javaCoder = CloudObjects.coderFromCloudObject(CloudObject.fromSpec(instructionOutput.getCodec()));
sdkComponents.registerCoder(javaCoder);
RunnerApi.Coder coderProto = CoderTranslation.toProto(javaCoder, sdkComponents);
processBundleDescriptor.putCoders(coderId, coderProto);
} catch (IOException e) {
throw new IllegalArgumentException(String.format("Unable to encode coder %s for output %s", instructionOutput.getCodec(), instructionOutput), e);
} catch (Exception e) {
// Coder probably wasn't a java coder
OBJECT_MAPPER.writeValue(output, instructionOutput.getCodec());
processBundleDescriptor.putCoders(coderId, RunnerApi.Coder.newBuilder().setSpec(RunnerApi.FunctionSpec.newBuilder().setPayload(output.toByteString())).build());
}
} catch (IOException e) {
throw new IllegalArgumentException(String.format("Unable to encode coder %s for output %s", instructionOutput.getCodec(), instructionOutput), e);
}
// Generate new PCollection ID and map it to relevant node.
// Will later be used to fill PTransform inputs/outputs information.
String pcollectionId = "generatedPcollection" + idGenerator.getId();
processBundleDescriptor.putPcollections(pcollectionId, RunnerApi.PCollection.newBuilder().setCoderId(coderId).setWindowingStrategyId(fakeWindowingStrategyId).build());
nodesToPCollections.put(node, pcollectionId);
pcollectionIdToNameContexts.put(pcollectionId, NameContext.create(null, instructionOutput.getOriginalName(), instructionOutput.getSystemName(), instructionOutput.getName()));
}
processBundleDescriptor.putAllCoders(sdkComponents.toComponents().getCodersMap());
Map<InstructionOutputNode, String> instructionOutputNodeToCoderIdMap = instructionOutputNodeToCoderIdBuilder.build();
for (ParallelInstructionNode node : Iterables.filter(input.nodes(), ParallelInstructionNode.class)) {
ParallelInstruction parallelInstruction = node.getParallelInstruction();
String ptransformId = "generatedPtransform" + idGenerator.getId();
ptransformIdToNameContexts.put(ptransformId, NameContext.create(null, parallelInstruction.getOriginalName(), parallelInstruction.getSystemName(), parallelInstruction.getName()));
RunnerApi.PTransform.Builder pTransform = RunnerApi.PTransform.newBuilder();
RunnerApi.FunctionSpec.Builder transformSpec = RunnerApi.FunctionSpec.newBuilder();
if (parallelInstruction.getParDo() != null) {
ParDoInstruction parDoInstruction = parallelInstruction.getParDo();
CloudObject userFnSpec = CloudObject.fromSpec(parDoInstruction.getUserFn());
String userFnClassName = userFnSpec.getClassName();
if ("CombineValuesFn".equals(userFnClassName) || "KeyedCombineFn".equals(userFnClassName)) {
transformSpec = transformCombineValuesFnToFunctionSpec(userFnSpec);
ptransformIdToPCollectionViews.put(ptransformId, Collections.emptyList());
} else {
String parDoPTransformId = getString(userFnSpec, PropertyNames.SERIALIZED_FN);
RunnerApi.PTransform parDoPTransform = pipeline.getComponents().getTransformsOrDefault(parDoPTransformId, null);
// TODO: only the non-null branch should exist; for migration ease only
if (parDoPTransform != null) {
checkArgument(parDoPTransform.getSpec().getUrn().equals(PTransformTranslation.PAR_DO_TRANSFORM_URN), "Found transform \"%s\" for ParallelDo instruction, " + " but that transform had unexpected URN \"%s\" (expected \"%s\")", parDoPTransformId, parDoPTransform.getSpec().getUrn(), PTransformTranslation.PAR_DO_TRANSFORM_URN);
RunnerApi.ParDoPayload parDoPayload;
try {
parDoPayload = RunnerApi.ParDoPayload.parseFrom(parDoPTransform.getSpec().getPayload());
} catch (InvalidProtocolBufferException exc) {
throw new RuntimeException("ParDo did not have a ParDoPayload", exc);
}
ImmutableList.Builder<PCollectionView<?>> pcollectionViews = ImmutableList.builder();
for (Map.Entry<String, SideInput> sideInputEntry : parDoPayload.getSideInputsMap().entrySet()) {
pcollectionViews.add(transformSideInputForRunner(pipeline, parDoPTransform, sideInputEntry.getKey(), sideInputEntry.getValue()));
transformSideInputForSdk(pipeline, parDoPTransform, sideInputEntry.getKey(), processBundleDescriptor, pTransform);
}
ptransformIdToPCollectionViews.put(ptransformId, pcollectionViews.build());
transformSpec.setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN).setPayload(parDoPayload.toByteString());
} else {
// legacy path - bytes are the FunctionSpec's payload field, basically, and
// SDKs expect it in the PTransform's payload field
byte[] userFnBytes = getBytes(userFnSpec, PropertyNames.SERIALIZED_FN);
transformSpec.setUrn(ParDoTranslation.CUSTOM_JAVA_DO_FN_URN).setPayload(ByteString.copyFrom(userFnBytes));
}
// Add side input information for batch pipelines
if (parDoInstruction.getSideInputs() != null) {
ptransformIdToSideInputInfos.put(ptransformId, forSideInputInfos(parDoInstruction.getSideInputs(), true));
}
}
} else if (parallelInstruction.getRead() != null) {
ReadInstruction readInstruction = parallelInstruction.getRead();
CloudObject sourceSpec = CloudObject.fromSpec(CloudSourceUtils.flattenBaseSpecs(readInstruction.getSource()).getSpec());
// TODO: Need to plumb through the SDK specific function spec.
transformSpec.setUrn(JAVA_SOURCE_URN);
try {
byte[] serializedSource = Base64.getDecoder().decode(getString(sourceSpec, SERIALIZED_SOURCE));
ByteString sourceByteString = ByteString.copyFrom(serializedSource);
transformSpec.setPayload(sourceByteString);
} catch (Exception e) {
throw new IllegalArgumentException(String.format("Unable to process Read %s", parallelInstruction), e);
}
} else if (parallelInstruction.getFlatten() != null) {
transformSpec.setUrn(PTransformTranslation.FLATTEN_TRANSFORM_URN);
} else {
throw new IllegalArgumentException(String.format("Unknown type of ParallelInstruction %s", parallelInstruction));
}
for (Node predecessorOutput : input.predecessors(node)) {
pTransform.putInputs("generatedInput" + idGenerator.getId(), nodesToPCollections.get(predecessorOutput));
}
for (Edge edge : input.outEdges(node)) {
Node nodeOutput = input.incidentNodes(edge).target();
MultiOutputInfoEdge edge2 = (MultiOutputInfoEdge) edge;
pTransform.putOutputs(edge2.getMultiOutputInfo().getTag(), nodesToPCollections.get(nodeOutput));
}
pTransform.setSpec(transformSpec);
processBundleDescriptor.putTransforms(ptransformId, pTransform.build());
}
// Add the PTransforms representing the remote gRPC nodes
for (RemoteGrpcPortNode node : Iterables.filter(input.nodes(), RemoteGrpcPortNode.class)) {
RunnerApi.PTransform.Builder pTransform = RunnerApi.PTransform.newBuilder();
Set<Node> predecessors = input.predecessors(node);
Set<Node> successors = input.successors(node);
if (predecessors.isEmpty() && !successors.isEmpty()) {
Node instructionOutputNode = Iterables.getOnlyElement(successors);
pTransform.putOutputs("generatedOutput" + idGenerator.getId(), nodesToPCollections.get(instructionOutputNode));
pTransform.setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(DATA_INPUT_URN).setPayload(node.getRemoteGrpcPort().toBuilder().setCoderId(instructionOutputNodeToCoderIdMap.get(instructionOutputNode)).build().toByteString()).build());
} else if (!predecessors.isEmpty() && successors.isEmpty()) {
Node instructionOutputNode = Iterables.getOnlyElement(predecessors);
pTransform.putInputs("generatedInput" + idGenerator.getId(), nodesToPCollections.get(instructionOutputNode));
pTransform.setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(DATA_OUTPUT_URN).setPayload(node.getRemoteGrpcPort().toBuilder().setCoderId(instructionOutputNodeToCoderIdMap.get(instructionOutputNode)).build().toByteString()).build());
} else {
throw new IllegalStateException("Expected either one input OR one output " + "InstructionOutputNode for this RemoteGrpcPortNode");
}
processBundleDescriptor.putTransforms(node.getPrimitiveTransformId(), pTransform.build());
}
return RegisterRequestNode.create(RegisterRequest.newBuilder().addProcessBundleDescriptor(processBundleDescriptor).build(), ptransformIdToNameContexts.build(), ptransformIdToSideInputInfos.build(), ptransformIdToPCollectionViews.build(), pcollectionIdToNameContexts.build());
}
Aggregations