Search in sources :

Example 1 with ExecutionLocation

use of org.apache.beam.runners.dataflow.worker.graph.Nodes.ExecutionLocation in project beam by apache.

the class DeduceFlattenLocationsFunctionTest method assertSingleFlattenLocationDeduction.

/**
 * For testing deducing the location of a single flatten. This function checks that a flatten with
 * the given aggregated locations for predecessors and successors deduces to the expected {@code
 * ExecutionLocation}.
 */
private static void assertSingleFlattenLocationDeduction(ExecutionLocation predecessorLocations, ExecutionLocation successorLocations, ExecutionLocation expectedLocation) throws Exception {
    MutableNetwork<Node, Edge> network = createSingleFlattenNetwork(predecessorLocations, successorLocations);
    network = new DeduceFlattenLocationsFunction().apply(network);
    ExecutionLocation flattenLocation = getExecutionLocationOf("flatten", network);
    assertEquals(expectedLocation, flattenLocation);
}
Also used : Node(org.apache.beam.runners.dataflow.worker.graph.Nodes.Node) InstructionOutputNode(org.apache.beam.runners.dataflow.worker.graph.Nodes.InstructionOutputNode) ParallelInstructionNode(org.apache.beam.runners.dataflow.worker.graph.Nodes.ParallelInstructionNode) ExecutionLocation(org.apache.beam.runners.dataflow.worker.graph.Nodes.ExecutionLocation) Edge(org.apache.beam.runners.dataflow.worker.graph.Edges.Edge) DefaultEdge(org.apache.beam.runners.dataflow.worker.graph.Edges.DefaultEdge)

Example 2 with ExecutionLocation

use of org.apache.beam.runners.dataflow.worker.graph.Nodes.ExecutionLocation in project beam by apache.

the class DeduceFlattenLocationsFunctionTest method testDeductionOfChainedFlattens.

/**
 * Test that when multiple flattens with PCollections are connected, they are deduced.
 */
@Test
public void testDeductionOfChainedFlattens() throws Exception {
    // sdk_node1 --> out --\
    // sdk_node2 --> out --> flatten1 --> out ----\                /-> sdk_node3 --> out
    // flatten3 --> out
    // runner_node1 --> out --> flatten2 --> out -/                \-> runner_node3 --> out
    // runner_node2 --> out --/
    MutableNetwork<Node, Edge> network = createEmptyNetwork();
    Node sdkNode1 = createSdkNode("sdk_node1");
    Node sdkNode1Output = createPCollection("sdk_node1.out");
    Node sdkNode2 = createSdkNode("sdk_node2");
    Node sdkNode2Output = createPCollection("sdk_node2.out");
    Node sdkNode3 = createSdkNode("sdk_node3");
    Node sdkNode3Output = createPCollection("sdk_node3.out");
    Node runnerNode1 = createRunnerNode("runner_node1");
    Node runnerNode1Output = createPCollection("runner_node1.out");
    Node runnerNode2 = createRunnerNode("runner_node2");
    Node runnerNode2Output = createPCollection("runner_node2.out");
    Node runnerNode3 = createRunnerNode("runner_node3");
    Node runnerNode3Output = createPCollection("runner_node3.out");
    Node flatten1 = createFlatten("flatten1");
    Node flatten1Output = createPCollection("flatten1.out");
    Node flatten2 = createFlatten("flatten2");
    Node flatten2Output = createPCollection("flatten2.out");
    Node flatten3 = createFlatten("flatten3");
    Node flatten3Output = createPCollection("flatten3.out");
    network.addNode(sdkNode1);
    network.addNode(sdkNode2);
    network.addNode(sdkNode3);
    network.addNode(runnerNode1);
    network.addNode(runnerNode2);
    network.addNode(runnerNode3);
    network.addNode(flatten1);
    network.addNode(flatten1Output);
    network.addNode(flatten2);
    network.addNode(flatten2Output);
    network.addNode(flatten3);
    network.addNode(flatten3Output);
    network.addEdge(sdkNode1, sdkNode1Output, DefaultEdge.create());
    network.addEdge(sdkNode2, sdkNode2Output, DefaultEdge.create());
    network.addEdge(runnerNode1, runnerNode1Output, DefaultEdge.create());
    network.addEdge(runnerNode2, runnerNode2Output, DefaultEdge.create());
    network.addEdge(sdkNode1Output, flatten1, DefaultEdge.create());
    network.addEdge(sdkNode2Output, flatten1, DefaultEdge.create());
    network.addEdge(runnerNode1Output, flatten2, DefaultEdge.create());
    network.addEdge(runnerNode2Output, flatten2, DefaultEdge.create());
    network.addEdge(flatten1, flatten1Output, DefaultEdge.create());
    network.addEdge(flatten2, flatten2Output, DefaultEdge.create());
    network.addEdge(flatten1Output, flatten3, DefaultEdge.create());
    network.addEdge(flatten2Output, flatten3, DefaultEdge.create());
    network.addEdge(flatten3, flatten3Output, DefaultEdge.create());
    network.addEdge(flatten3Output, sdkNode3, DefaultEdge.create());
    network.addEdge(flatten3Output, runnerNode3, DefaultEdge.create());
    network.addEdge(sdkNode3, sdkNode3Output, DefaultEdge.create());
    network.addEdge(runnerNode3, runnerNode3Output, DefaultEdge.create());
    network = new DeduceFlattenLocationsFunction().apply(network);
    ExecutionLocation flatten1Location = getExecutionLocationOf("flatten1", network);
    assertEquals(flatten1Location, ExecutionLocation.SDK_HARNESS);
    ExecutionLocation flatten2Location = getExecutionLocationOf("flatten2", network);
    assertEquals(flatten2Location, ExecutionLocation.RUNNER_HARNESS);
    ExecutionLocation flatten3Location = getExecutionLocationOf("flatten3", network);
    assertEquals(flatten3Location, ExecutionLocation.AMBIGUOUS);
}
Also used : Node(org.apache.beam.runners.dataflow.worker.graph.Nodes.Node) InstructionOutputNode(org.apache.beam.runners.dataflow.worker.graph.Nodes.InstructionOutputNode) ParallelInstructionNode(org.apache.beam.runners.dataflow.worker.graph.Nodes.ParallelInstructionNode) ExecutionLocation(org.apache.beam.runners.dataflow.worker.graph.Nodes.ExecutionLocation) Edge(org.apache.beam.runners.dataflow.worker.graph.Edges.Edge) DefaultEdge(org.apache.beam.runners.dataflow.worker.graph.Edges.DefaultEdge) Test(org.junit.Test)

Example 3 with ExecutionLocation

use of org.apache.beam.runners.dataflow.worker.graph.Nodes.ExecutionLocation in project beam by apache.

the class DeduceNodeLocationsFunction method apply.

@Override
public MutableNetwork<Node, Edge> apply(MutableNetwork<Node, Edge> network) {
    // Replace deducible nodes with identical node except with location deduced.
    Networks.replaceDirectedNetworkNodes(network, (Node node) -> {
        if (!isDeducible(node)) {
            return node;
        }
        ParallelInstructionNode castNode = ((ParallelInstructionNode) node);
        ExecutionLocation location;
        if (executesInSdkHarness(castNode.getParallelInstruction())) {
            location = ExecutionLocation.SDK_HARNESS;
        } else {
            location = ExecutionLocation.RUNNER_HARNESS;
        }
        return ParallelInstructionNode.create(castNode.getParallelInstruction(), location);
    });
    return network;
}
Also used : Node(org.apache.beam.runners.dataflow.worker.graph.Nodes.Node) ParallelInstructionNode(org.apache.beam.runners.dataflow.worker.graph.Nodes.ParallelInstructionNode) ParallelInstructionNode(org.apache.beam.runners.dataflow.worker.graph.Nodes.ParallelInstructionNode) ExecutionLocation(org.apache.beam.runners.dataflow.worker.graph.Nodes.ExecutionLocation)

Example 4 with ExecutionLocation

use of org.apache.beam.runners.dataflow.worker.graph.Nodes.ExecutionLocation in project beam by apache.

the class DeduceFlattenLocationsFunction method getConnectedNodeLocations.

/**
 * A function which retrieves the aggregated location of a node's connecting nodes in one
 * direction, either checking the target node's successors or predecessors. This is done by
 * checking all the connected node's locations. For nodes that do not have locations embedded in
 * the actual node (they may have unknown location or might not even be {@link
 * ParallelInstructionNode}s) the location can be deduced by recursively checking that node's
 * predecessors. To prevent a large amount of needless recursion a map is used for memoization;
 * The results of this function will be stored in the map so that they can be retrieved later if
 * needed without having to perform the recursions again.
 */
private AggregatedLocation getConnectedNodeLocations(Node node, MutableNetwork<Node, Edge> network, Map<Node, AggregatedLocation> connectedLocationsMap, SearchDirection direction) {
    // First check the map
    if (connectedLocationsMap.containsKey(node)) {
        return connectedLocationsMap.get(node);
    }
    boolean hasSdkConnections = false;
    boolean hasRunnerConnections = false;
    Set<Node> connectedNodes;
    if (direction == SearchDirection.SUCCESSORS) {
        connectedNodes = network.successors(node);
    } else {
        connectedNodes = network.predecessors(node);
    }
    // work recurse this function to the unknown node.
    for (Node connectedNode : connectedNodes) {
        if (connectedNode instanceof ParallelInstructionNode && ((ParallelInstructionNode) connectedNode).getExecutionLocation() != ExecutionLocation.UNKNOWN) {
            ExecutionLocation executionLocation = ((ParallelInstructionNode) connectedNode).getExecutionLocation();
            switch(executionLocation) {
                case SDK_HARNESS:
                    hasSdkConnections = true;
                    break;
                case RUNNER_HARNESS:
                    hasRunnerConnections = true;
                    break;
                case AMBIGUOUS:
                    hasSdkConnections = true;
                    hasRunnerConnections = true;
                    break;
                default:
                    throw new IllegalStateException("Unknown case " + executionLocation);
            }
        } else {
            AggregatedLocation connectedLocation = getConnectedNodeLocations(connectedNode, network, connectedLocationsMap, direction);
            switch(connectedLocation) {
                case SDK_HARNESS:
                    hasSdkConnections = true;
                    break;
                case RUNNER_HARNESS:
                    hasRunnerConnections = true;
                    break;
                case BOTH:
                    hasSdkConnections = true;
                    hasRunnerConnections = true;
                    break;
                case NEITHER:
                    break;
                default:
                    throw new IllegalStateException("Unknown case " + connectedLocation);
            }
        }
        // need to continue checking.
        if (hasSdkConnections && hasRunnerConnections) {
            break;
        }
    }
    // Return aggregated locations for this node's connections and store it in the map.
    AggregatedLocation aggregatedLocation;
    if (hasSdkConnections && hasRunnerConnections) {
        aggregatedLocation = AggregatedLocation.BOTH;
    } else if (hasSdkConnections) {
        aggregatedLocation = AggregatedLocation.SDK_HARNESS;
    } else if (hasRunnerConnections) {
        aggregatedLocation = AggregatedLocation.RUNNER_HARNESS;
    } else {
        aggregatedLocation = AggregatedLocation.NEITHER;
    }
    connectedLocationsMap.put(node, aggregatedLocation);
    return aggregatedLocation;
}
Also used : Node(org.apache.beam.runners.dataflow.worker.graph.Nodes.Node) ParallelInstructionNode(org.apache.beam.runners.dataflow.worker.graph.Nodes.ParallelInstructionNode) ParallelInstructionNode(org.apache.beam.runners.dataflow.worker.graph.Nodes.ParallelInstructionNode) ExecutionLocation(org.apache.beam.runners.dataflow.worker.graph.Nodes.ExecutionLocation)

Example 5 with ExecutionLocation

use of org.apache.beam.runners.dataflow.worker.graph.Nodes.ExecutionLocation in project beam by apache.

the class DeduceFlattenLocationsFunction method apply.

/**
 * Deduces an {@link ExecutionLocation} for each flatten by first checking the locations of all
 * the predecessors and successors to each node. These locations are aggregated to a single result
 * representing all successors/predecessors. Once the aggregated location for both successors and
 * predecessors are found they are used to determine the execution location of the flatten node
 * itself and the flattens are replaced by copies that include the updated {@link
 * ExecutionLocation}.
 */
@Override
public MutableNetwork<Node, Edge> apply(MutableNetwork<Node, Edge> network) {
    Map<Node, AggregatedLocation> predecessorLocationsMap = new HashMap<>();
    Map<Node, AggregatedLocation> successorLocationsMap = new HashMap<>();
    Map<Node, ExecutionLocation> deducedLocationsMap = new HashMap<>();
    ImmutableList<Node> flattens = ImmutableList.copyOf(Iterables.filter(network.nodes(), IsFlatten.INSTANCE));
    // Find all predecessor and successor locations for every flatten.
    for (Node flatten : flattens) {
        AggregatedLocation predecessorLocations = getPredecessorLocations(flatten, network, predecessorLocationsMap);
        AggregatedLocation successorLocations = getSuccessorLocations(flatten, network, successorLocationsMap);
        deducedLocationsMap.put(flatten, DEDUCTION_TABLE.get(predecessorLocations, successorLocations));
    }
    // Actually set the locations of the flattens permanently.
    Networks.replaceDirectedNetworkNodes(network, (Node node) -> {
        if (!deducedLocationsMap.containsKey(node)) {
            return node;
        }
        ParallelInstructionNode castNode = ((ParallelInstructionNode) node);
        ExecutionLocation deducedLocation = deducedLocationsMap.get(node);
        return ParallelInstructionNode.create(castNode.getParallelInstruction(), deducedLocation);
    });
    return network;
}
Also used : HashMap(java.util.HashMap) Node(org.apache.beam.runners.dataflow.worker.graph.Nodes.Node) ParallelInstructionNode(org.apache.beam.runners.dataflow.worker.graph.Nodes.ParallelInstructionNode) ParallelInstructionNode(org.apache.beam.runners.dataflow.worker.graph.Nodes.ParallelInstructionNode) ExecutionLocation(org.apache.beam.runners.dataflow.worker.graph.Nodes.ExecutionLocation)

Aggregations

ExecutionLocation (org.apache.beam.runners.dataflow.worker.graph.Nodes.ExecutionLocation)6 ParallelInstructionNode (org.apache.beam.runners.dataflow.worker.graph.Nodes.ParallelInstructionNode)6 Node (org.apache.beam.runners.dataflow.worker.graph.Nodes.Node)5 DefaultEdge (org.apache.beam.runners.dataflow.worker.graph.Edges.DefaultEdge)2 Edge (org.apache.beam.runners.dataflow.worker.graph.Edges.Edge)2 InstructionOutputNode (org.apache.beam.runners.dataflow.worker.graph.Nodes.InstructionOutputNode)2 HashMap (java.util.HashMap)1 Test (org.junit.Test)1