use of org.apache.beam.runners.dataflow.worker.graph.Nodes.ExecutionLocation in project beam by apache.
the class DeduceFlattenLocationsFunctionTest method assertSingleFlattenLocationDeduction.
/**
* For testing deducing the location of a single flatten. This function checks that a flatten with
* the given aggregated locations for predecessors and successors deduces to the expected {@code
* ExecutionLocation}.
*/
private static void assertSingleFlattenLocationDeduction(ExecutionLocation predecessorLocations, ExecutionLocation successorLocations, ExecutionLocation expectedLocation) throws Exception {
MutableNetwork<Node, Edge> network = createSingleFlattenNetwork(predecessorLocations, successorLocations);
network = new DeduceFlattenLocationsFunction().apply(network);
ExecutionLocation flattenLocation = getExecutionLocationOf("flatten", network);
assertEquals(expectedLocation, flattenLocation);
}
use of org.apache.beam.runners.dataflow.worker.graph.Nodes.ExecutionLocation in project beam by apache.
the class DeduceFlattenLocationsFunctionTest method testDeductionOfChainedFlattens.
/**
* Test that when multiple flattens with PCollections are connected, they are deduced.
*/
@Test
public void testDeductionOfChainedFlattens() throws Exception {
// sdk_node1 --> out --\
// sdk_node2 --> out --> flatten1 --> out ----\ /-> sdk_node3 --> out
// flatten3 --> out
// runner_node1 --> out --> flatten2 --> out -/ \-> runner_node3 --> out
// runner_node2 --> out --/
MutableNetwork<Node, Edge> network = createEmptyNetwork();
Node sdkNode1 = createSdkNode("sdk_node1");
Node sdkNode1Output = createPCollection("sdk_node1.out");
Node sdkNode2 = createSdkNode("sdk_node2");
Node sdkNode2Output = createPCollection("sdk_node2.out");
Node sdkNode3 = createSdkNode("sdk_node3");
Node sdkNode3Output = createPCollection("sdk_node3.out");
Node runnerNode1 = createRunnerNode("runner_node1");
Node runnerNode1Output = createPCollection("runner_node1.out");
Node runnerNode2 = createRunnerNode("runner_node2");
Node runnerNode2Output = createPCollection("runner_node2.out");
Node runnerNode3 = createRunnerNode("runner_node3");
Node runnerNode3Output = createPCollection("runner_node3.out");
Node flatten1 = createFlatten("flatten1");
Node flatten1Output = createPCollection("flatten1.out");
Node flatten2 = createFlatten("flatten2");
Node flatten2Output = createPCollection("flatten2.out");
Node flatten3 = createFlatten("flatten3");
Node flatten3Output = createPCollection("flatten3.out");
network.addNode(sdkNode1);
network.addNode(sdkNode2);
network.addNode(sdkNode3);
network.addNode(runnerNode1);
network.addNode(runnerNode2);
network.addNode(runnerNode3);
network.addNode(flatten1);
network.addNode(flatten1Output);
network.addNode(flatten2);
network.addNode(flatten2Output);
network.addNode(flatten3);
network.addNode(flatten3Output);
network.addEdge(sdkNode1, sdkNode1Output, DefaultEdge.create());
network.addEdge(sdkNode2, sdkNode2Output, DefaultEdge.create());
network.addEdge(runnerNode1, runnerNode1Output, DefaultEdge.create());
network.addEdge(runnerNode2, runnerNode2Output, DefaultEdge.create());
network.addEdge(sdkNode1Output, flatten1, DefaultEdge.create());
network.addEdge(sdkNode2Output, flatten1, DefaultEdge.create());
network.addEdge(runnerNode1Output, flatten2, DefaultEdge.create());
network.addEdge(runnerNode2Output, flatten2, DefaultEdge.create());
network.addEdge(flatten1, flatten1Output, DefaultEdge.create());
network.addEdge(flatten2, flatten2Output, DefaultEdge.create());
network.addEdge(flatten1Output, flatten3, DefaultEdge.create());
network.addEdge(flatten2Output, flatten3, DefaultEdge.create());
network.addEdge(flatten3, flatten3Output, DefaultEdge.create());
network.addEdge(flatten3Output, sdkNode3, DefaultEdge.create());
network.addEdge(flatten3Output, runnerNode3, DefaultEdge.create());
network.addEdge(sdkNode3, sdkNode3Output, DefaultEdge.create());
network.addEdge(runnerNode3, runnerNode3Output, DefaultEdge.create());
network = new DeduceFlattenLocationsFunction().apply(network);
ExecutionLocation flatten1Location = getExecutionLocationOf("flatten1", network);
assertEquals(flatten1Location, ExecutionLocation.SDK_HARNESS);
ExecutionLocation flatten2Location = getExecutionLocationOf("flatten2", network);
assertEquals(flatten2Location, ExecutionLocation.RUNNER_HARNESS);
ExecutionLocation flatten3Location = getExecutionLocationOf("flatten3", network);
assertEquals(flatten3Location, ExecutionLocation.AMBIGUOUS);
}
use of org.apache.beam.runners.dataflow.worker.graph.Nodes.ExecutionLocation in project beam by apache.
the class DeduceNodeLocationsFunction method apply.
@Override
public MutableNetwork<Node, Edge> apply(MutableNetwork<Node, Edge> network) {
// Replace deducible nodes with identical node except with location deduced.
Networks.replaceDirectedNetworkNodes(network, (Node node) -> {
if (!isDeducible(node)) {
return node;
}
ParallelInstructionNode castNode = ((ParallelInstructionNode) node);
ExecutionLocation location;
if (executesInSdkHarness(castNode.getParallelInstruction())) {
location = ExecutionLocation.SDK_HARNESS;
} else {
location = ExecutionLocation.RUNNER_HARNESS;
}
return ParallelInstructionNode.create(castNode.getParallelInstruction(), location);
});
return network;
}
use of org.apache.beam.runners.dataflow.worker.graph.Nodes.ExecutionLocation in project beam by apache.
the class DeduceFlattenLocationsFunction method getConnectedNodeLocations.
/**
* A function which retrieves the aggregated location of a node's connecting nodes in one
* direction, either checking the target node's successors or predecessors. This is done by
* checking all the connected node's locations. For nodes that do not have locations embedded in
* the actual node (they may have unknown location or might not even be {@link
* ParallelInstructionNode}s) the location can be deduced by recursively checking that node's
* predecessors. To prevent a large amount of needless recursion a map is used for memoization;
* The results of this function will be stored in the map so that they can be retrieved later if
* needed without having to perform the recursions again.
*/
private AggregatedLocation getConnectedNodeLocations(Node node, MutableNetwork<Node, Edge> network, Map<Node, AggregatedLocation> connectedLocationsMap, SearchDirection direction) {
// First check the map
if (connectedLocationsMap.containsKey(node)) {
return connectedLocationsMap.get(node);
}
boolean hasSdkConnections = false;
boolean hasRunnerConnections = false;
Set<Node> connectedNodes;
if (direction == SearchDirection.SUCCESSORS) {
connectedNodes = network.successors(node);
} else {
connectedNodes = network.predecessors(node);
}
// work recurse this function to the unknown node.
for (Node connectedNode : connectedNodes) {
if (connectedNode instanceof ParallelInstructionNode && ((ParallelInstructionNode) connectedNode).getExecutionLocation() != ExecutionLocation.UNKNOWN) {
ExecutionLocation executionLocation = ((ParallelInstructionNode) connectedNode).getExecutionLocation();
switch(executionLocation) {
case SDK_HARNESS:
hasSdkConnections = true;
break;
case RUNNER_HARNESS:
hasRunnerConnections = true;
break;
case AMBIGUOUS:
hasSdkConnections = true;
hasRunnerConnections = true;
break;
default:
throw new IllegalStateException("Unknown case " + executionLocation);
}
} else {
AggregatedLocation connectedLocation = getConnectedNodeLocations(connectedNode, network, connectedLocationsMap, direction);
switch(connectedLocation) {
case SDK_HARNESS:
hasSdkConnections = true;
break;
case RUNNER_HARNESS:
hasRunnerConnections = true;
break;
case BOTH:
hasSdkConnections = true;
hasRunnerConnections = true;
break;
case NEITHER:
break;
default:
throw new IllegalStateException("Unknown case " + connectedLocation);
}
}
// need to continue checking.
if (hasSdkConnections && hasRunnerConnections) {
break;
}
}
// Return aggregated locations for this node's connections and store it in the map.
AggregatedLocation aggregatedLocation;
if (hasSdkConnections && hasRunnerConnections) {
aggregatedLocation = AggregatedLocation.BOTH;
} else if (hasSdkConnections) {
aggregatedLocation = AggregatedLocation.SDK_HARNESS;
} else if (hasRunnerConnections) {
aggregatedLocation = AggregatedLocation.RUNNER_HARNESS;
} else {
aggregatedLocation = AggregatedLocation.NEITHER;
}
connectedLocationsMap.put(node, aggregatedLocation);
return aggregatedLocation;
}
use of org.apache.beam.runners.dataflow.worker.graph.Nodes.ExecutionLocation in project beam by apache.
the class DeduceFlattenLocationsFunction method apply.
/**
* Deduces an {@link ExecutionLocation} for each flatten by first checking the locations of all
* the predecessors and successors to each node. These locations are aggregated to a single result
* representing all successors/predecessors. Once the aggregated location for both successors and
* predecessors are found they are used to determine the execution location of the flatten node
* itself and the flattens are replaced by copies that include the updated {@link
* ExecutionLocation}.
*/
@Override
public MutableNetwork<Node, Edge> apply(MutableNetwork<Node, Edge> network) {
Map<Node, AggregatedLocation> predecessorLocationsMap = new HashMap<>();
Map<Node, AggregatedLocation> successorLocationsMap = new HashMap<>();
Map<Node, ExecutionLocation> deducedLocationsMap = new HashMap<>();
ImmutableList<Node> flattens = ImmutableList.copyOf(Iterables.filter(network.nodes(), IsFlatten.INSTANCE));
// Find all predecessor and successor locations for every flatten.
for (Node flatten : flattens) {
AggregatedLocation predecessorLocations = getPredecessorLocations(flatten, network, predecessorLocationsMap);
AggregatedLocation successorLocations = getSuccessorLocations(flatten, network, successorLocationsMap);
deducedLocationsMap.put(flatten, DEDUCTION_TABLE.get(predecessorLocations, successorLocations));
}
// Actually set the locations of the flattens permanently.
Networks.replaceDirectedNetworkNodes(network, (Node node) -> {
if (!deducedLocationsMap.containsKey(node)) {
return node;
}
ParallelInstructionNode castNode = ((ParallelInstructionNode) node);
ExecutionLocation deducedLocation = deducedLocationsMap.get(node);
return ParallelInstructionNode.create(castNode.getParallelInstruction(), deducedLocation);
});
return network;
}
Aggregations