Search in sources :

Example 6 with InstructionOutputNode

use of org.apache.beam.runners.dataflow.worker.graph.Nodes.InstructionOutputNode in project beam by apache.

the class CreateRegisterFnOperationFunction method rewireAcrossSdkRunnerPortNode.

/**
 * Rewires the given set of predecessors and successors across a gRPC port surrounded by output
 * nodes. Edges to the remaining successors are copied over to the new output node that is placed
 * before the port node. For example:
 *
 * <pre><code>
 * predecessors --> outputNode --> successors
 *                            \--> existingSuccessors
 * </pre></code> becomes:
 *
 * <pre><code>
 *
 *                                outputNode -------------------------------\
 *                                          \                                \
 *                                           |-> existingSuccessors           \
 *                                          /                                  \
 * predecessors --> newPredecessorOutputNode --> portNode --> portOutputNode --> successors}.
 * </code></pre>
 */
private Node rewireAcrossSdkRunnerPortNode(MutableNetwork<Node, Edge> network, InstructionOutputNode outputNode, Set<Node> predecessors, Set<Node> successors) {
    InstructionOutputNode newPredecessorOutputNode = InstructionOutputNode.create(outputNode.getInstructionOutput(), outputNode.getPcollectionId());
    InstructionOutputNode portOutputNode = InstructionOutputNode.create(outputNode.getInstructionOutput(), outputNode.getPcollectionId());
    Node portNode = portSupplier.get();
    network.addNode(newPredecessorOutputNode);
    network.addNode(portNode);
    for (Node predecessor : predecessors) {
        for (Edge edge : ImmutableList.copyOf(network.edgesConnecting(predecessor, outputNode))) {
            network.removeEdge(edge);
            network.addEdge(predecessor, newPredecessorOutputNode, edge);
        }
    }
    // Maintain edges for existing successors.
    List<Node> existingSuccessors = ImmutableList.copyOf(Sets.difference(network.successors(outputNode), successors));
    for (Node existingSuccessor : existingSuccessors) {
        List<Edge> existingSuccessorEdges = ImmutableList.copyOf(network.edgesConnecting(outputNode, existingSuccessor));
        for (Edge existingSuccessorEdge : existingSuccessorEdges) {
            network.addEdge(newPredecessorOutputNode, existingSuccessor, existingSuccessorEdge.clone());
        }
    }
    // Rewire the requested successors over the port node.
    network.addEdge(newPredecessorOutputNode, portNode, MultiOutputInfoEdge.create(new MultiOutputInfo().setTag(idGenerator.getId())));
    network.addEdge(portNode, portOutputNode, MultiOutputInfoEdge.create(new MultiOutputInfo().setTag(idGenerator.getId())));
    for (Node successor : successors) {
        for (Edge edge : ImmutableList.copyOf(network.edgesConnecting(outputNode, successor))) {
            network.addEdge(portOutputNode, successor, edge.clone());
        }
    }
    return portNode;
}
Also used : InstructionOutputNode(org.apache.beam.runners.dataflow.worker.graph.Nodes.InstructionOutputNode) MultiOutputInfo(com.google.api.services.dataflow.model.MultiOutputInfo) Node(org.apache.beam.runners.dataflow.worker.graph.Nodes.Node) InstructionOutputNode(org.apache.beam.runners.dataflow.worker.graph.Nodes.InstructionOutputNode) ParallelInstructionNode(org.apache.beam.runners.dataflow.worker.graph.Nodes.ParallelInstructionNode) HappensBeforeEdge(org.apache.beam.runners.dataflow.worker.graph.Edges.HappensBeforeEdge) Edge(org.apache.beam.runners.dataflow.worker.graph.Edges.Edge) MultiOutputInfoEdge(org.apache.beam.runners.dataflow.worker.graph.Edges.MultiOutputInfoEdge)

Example 7 with InstructionOutputNode

use of org.apache.beam.runners.dataflow.worker.graph.Nodes.InstructionOutputNode in project beam by apache.

the class CreateExecutableStageNodeFunction method apply.

@Override
public Node apply(MutableNetwork<Node, Edge> input) {
    for (Node node : input.nodes()) {
        if (node instanceof RemoteGrpcPortNode || node instanceof ParallelInstructionNode || node instanceof InstructionOutputNode) {
            continue;
        }
        throw new IllegalArgumentException(String.format("Network contains unknown type of node: %s", input));
    }
    // Fix all non output nodes to have named edges.
    for (Node node : input.nodes()) {
        if (node instanceof InstructionOutputNode) {
            continue;
        }
        for (Node successor : input.successors(node)) {
            for (Edge edge : input.edgesConnecting(node, successor)) {
                if (edge instanceof DefaultEdge) {
                    input.removeEdge(edge);
                    input.addEdge(node, successor, MultiOutputInfoEdge.create(new MultiOutputInfo().setTag(idGenerator.getId())));
                }
            }
        }
    }
    RunnerApi.Components.Builder componentsBuilder = RunnerApi.Components.newBuilder();
    componentsBuilder.mergeFrom(this.pipeline.getComponents());
    // Default to use the Java environment if pipeline doesn't have environment specified.
    if (pipeline.getComponents().getEnvironmentsMap().isEmpty()) {
        String envId = Environments.JAVA_SDK_HARNESS_ENVIRONMENT.getUrn() + idGenerator.getId();
        componentsBuilder.putEnvironments(envId, Environments.JAVA_SDK_HARNESS_ENVIRONMENT);
    }
    // By default, use GlobalWindow for all languages.
    // For java, if there is a IntervalWindowCoder, then use FixedWindow instead.
    // TODO: should get real WindowingStategy from pipeline proto.
    String globalWindowingStrategyId = "generatedGlobalWindowingStrategy" + idGenerator.getId();
    String intervalWindowEncodingWindowingStrategyId = "generatedIntervalWindowEncodingWindowingStrategy" + idGenerator.getId();
    SdkComponents sdkComponents = SdkComponents.create(pipeline.getComponents(), null);
    try {
        registerWindowingStrategy(globalWindowingStrategyId, WindowingStrategy.globalDefault(), componentsBuilder, sdkComponents);
        registerWindowingStrategy(intervalWindowEncodingWindowingStrategyId, WindowingStrategy.of(FixedWindows.of(Duration.standardSeconds(1))), componentsBuilder, sdkComponents);
    } catch (IOException exc) {
        throw new RuntimeException("Could not convert default windowing stratey to proto", exc);
    }
    Map<Node, String> nodesToPCollections = new HashMap<>();
    ImmutableMap.Builder<String, NameContext> ptransformIdToNameContexts = ImmutableMap.builder();
    ImmutableMap.Builder<String, Iterable<SideInputInfo>> ptransformIdToSideInputInfos = ImmutableMap.builder();
    ImmutableMap.Builder<String, Iterable<PCollectionView<?>>> ptransformIdToPCollectionViews = ImmutableMap.builder();
    // A field of ExecutableStage which includes the PCollection goes to worker side.
    Set<PCollectionNode> executableStageOutputs = new HashSet<>();
    // A field of ExecutableStage which includes the PCollection goes to runner side.
    Set<PCollectionNode> executableStageInputs = new HashSet<>();
    for (InstructionOutputNode node : Iterables.filter(input.nodes(), InstructionOutputNode.class)) {
        InstructionOutput instructionOutput = node.getInstructionOutput();
        String coderId = "generatedCoder" + idGenerator.getId();
        String windowingStrategyId;
        try (ByteString.Output output = ByteString.newOutput()) {
            try {
                Coder<?> javaCoder = CloudObjects.coderFromCloudObject(CloudObject.fromSpec(instructionOutput.getCodec()));
                Coder<?> elementCoder = ((WindowedValueCoder<?>) javaCoder).getValueCoder();
                sdkComponents.registerCoder(elementCoder);
                RunnerApi.Coder coderProto = CoderTranslation.toProto(elementCoder, sdkComponents);
                componentsBuilder.putCoders(coderId, coderProto);
                // For now, Dataflow runner harness only deal with FixedWindow.
                if (javaCoder instanceof FullWindowedValueCoder) {
                    FullWindowedValueCoder<?> windowedValueCoder = (FullWindowedValueCoder<?>) javaCoder;
                    Coder<?> windowCoder = windowedValueCoder.getWindowCoder();
                    if (windowCoder instanceof IntervalWindowCoder) {
                        windowingStrategyId = intervalWindowEncodingWindowingStrategyId;
                    } else if (windowCoder instanceof GlobalWindow.Coder) {
                        windowingStrategyId = globalWindowingStrategyId;
                    } else {
                        throw new UnsupportedOperationException(String.format("Dataflow portable runner harness doesn't support windowing with %s", windowCoder));
                    }
                } else {
                    throw new UnsupportedOperationException("Dataflow portable runner harness only supports FullWindowedValueCoder");
                }
            } catch (IOException e) {
                throw new IllegalArgumentException(String.format("Unable to encode coder %s for output %s", instructionOutput.getCodec(), instructionOutput), e);
            } catch (Exception e) {
                // Coder probably wasn't a java coder
                OBJECT_MAPPER.writeValue(output, instructionOutput.getCodec());
                componentsBuilder.putCoders(coderId, RunnerApi.Coder.newBuilder().setSpec(RunnerApi.FunctionSpec.newBuilder().setPayload(output.toByteString())).build());
                // For non-java coder, hope it's GlobalWindows by default.
                // TODO(BEAM-6231): Actually discover the right windowing strategy.
                windowingStrategyId = globalWindowingStrategyId;
            }
        } catch (IOException e) {
            throw new IllegalArgumentException(String.format("Unable to encode coder %s for output %s", instructionOutput.getCodec(), instructionOutput), e);
        }
        // TODO(BEAM-6275): Set correct IsBounded on generated PCollections
        String pcollectionId = node.getPcollectionId();
        RunnerApi.PCollection pCollection = RunnerApi.PCollection.newBuilder().setCoderId(coderId).setWindowingStrategyId(windowingStrategyId).setIsBounded(RunnerApi.IsBounded.Enum.BOUNDED).build();
        nodesToPCollections.put(node, pcollectionId);
        componentsBuilder.putPcollections(pcollectionId, pCollection);
        // is set
        if (isExecutableStageOutputPCollection(input, node)) {
            executableStageOutputs.add(PipelineNode.pCollection(pcollectionId, pCollection));
        }
        if (isExecutableStageInputPCollection(input, node)) {
            executableStageInputs.add(PipelineNode.pCollection(pcollectionId, pCollection));
        }
    }
    componentsBuilder.putAllCoders(sdkComponents.toComponents().getCodersMap());
    Set<PTransformNode> executableStageTransforms = new HashSet<>();
    Set<TimerReference> executableStageTimers = new HashSet<>();
    List<UserStateId> userStateIds = new ArrayList<>();
    Set<SideInputReference> executableStageSideInputs = new HashSet<>();
    for (ParallelInstructionNode node : Iterables.filter(input.nodes(), ParallelInstructionNode.class)) {
        ImmutableMap.Builder<String, PCollectionNode> sideInputIds = ImmutableMap.builder();
        ParallelInstruction parallelInstruction = node.getParallelInstruction();
        String ptransformId = "generatedPtransform" + idGenerator.getId();
        ptransformIdToNameContexts.put(ptransformId, NameContext.create(null, parallelInstruction.getOriginalName(), parallelInstruction.getSystemName(), parallelInstruction.getName()));
        RunnerApi.PTransform.Builder pTransform = RunnerApi.PTransform.newBuilder();
        RunnerApi.FunctionSpec.Builder transformSpec = RunnerApi.FunctionSpec.newBuilder();
        List<String> timerIds = new ArrayList<>();
        if (parallelInstruction.getParDo() != null) {
            ParDoInstruction parDoInstruction = parallelInstruction.getParDo();
            CloudObject userFnSpec = CloudObject.fromSpec(parDoInstruction.getUserFn());
            String userFnClassName = userFnSpec.getClassName();
            if (userFnClassName.equals("CombineValuesFn") || userFnClassName.equals("KeyedCombineFn")) {
                transformSpec = transformCombineValuesFnToFunctionSpec(userFnSpec);
                ptransformIdToPCollectionViews.put(ptransformId, Collections.emptyList());
            } else {
                String parDoPTransformId = getString(userFnSpec, PropertyNames.SERIALIZED_FN);
                RunnerApi.PTransform parDoPTransform = pipeline.getComponents().getTransformsOrDefault(parDoPTransformId, null);
                // TODO: only the non-null branch should exist; for migration ease only
                if (parDoPTransform != null) {
                    checkArgument(parDoPTransform.getSpec().getUrn().equals(PTransformTranslation.PAR_DO_TRANSFORM_URN), "Found transform \"%s\" for ParallelDo instruction, " + " but that transform had unexpected URN \"%s\" (expected \"%s\")", parDoPTransformId, parDoPTransform.getSpec().getUrn(), PTransformTranslation.PAR_DO_TRANSFORM_URN);
                    RunnerApi.ParDoPayload parDoPayload;
                    try {
                        parDoPayload = RunnerApi.ParDoPayload.parseFrom(parDoPTransform.getSpec().getPayload());
                    } catch (InvalidProtocolBufferException exc) {
                        throw new RuntimeException("ParDo did not have a ParDoPayload", exc);
                    }
                    // user timers and user state.
                    for (Map.Entry<String, RunnerApi.TimerFamilySpec> entry : parDoPayload.getTimerFamilySpecsMap().entrySet()) {
                        timerIds.add(entry.getKey());
                    }
                    for (Map.Entry<String, RunnerApi.StateSpec> entry : parDoPayload.getStateSpecsMap().entrySet()) {
                        UserStateId.Builder builder = UserStateId.newBuilder();
                        builder.setTransformId(parDoPTransformId);
                        builder.setLocalName(entry.getKey());
                        userStateIds.add(builder.build());
                    }
                    // To facilitate the creation of Set executableStageSideInputs.
                    for (String sideInputTag : parDoPayload.getSideInputsMap().keySet()) {
                        String sideInputPCollectionId = parDoPTransform.getInputsOrThrow(sideInputTag);
                        RunnerApi.PCollection sideInputPCollection = pipeline.getComponents().getPcollectionsOrThrow(sideInputPCollectionId);
                        pTransform.putInputs(sideInputTag, sideInputPCollectionId);
                        PCollectionNode pCollectionNode = PipelineNode.pCollection(sideInputPCollectionId, sideInputPCollection);
                        sideInputIds.put(sideInputTag, pCollectionNode);
                    }
                    // To facilitate the creation of Map(ptransformId -> pCollectionView), which is
                    // required by constructing an ExecutableStageNode.
                    ImmutableList.Builder<PCollectionView<?>> pcollectionViews = ImmutableList.builder();
                    for (Map.Entry<String, RunnerApi.SideInput> sideInputEntry : parDoPayload.getSideInputsMap().entrySet()) {
                        pcollectionViews.add(RegisterNodeFunction.transformSideInputForRunner(pipeline, parDoPTransform, sideInputEntry.getKey(), sideInputEntry.getValue()));
                    }
                    ptransformIdToPCollectionViews.put(ptransformId, pcollectionViews.build());
                    transformSpec.setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN).setPayload(parDoPayload.toByteString());
                } else {
                    // legacy path - bytes are the FunctionSpec's payload field, basically, and
                    // SDKs expect it in the PTransform's payload field
                    byte[] userFnBytes = getBytes(userFnSpec, PropertyNames.SERIALIZED_FN);
                    transformSpec.setUrn(ParDoTranslation.CUSTOM_JAVA_DO_FN_URN).setPayload(ByteString.copyFrom(userFnBytes));
                }
                if (parDoInstruction.getSideInputs() != null) {
                    ptransformIdToSideInputInfos.put(ptransformId, forSideInputInfos(parDoInstruction.getSideInputs(), true));
                }
            }
        } else if (parallelInstruction.getRead() != null) {
            ReadInstruction readInstruction = parallelInstruction.getRead();
            CloudObject sourceSpec = CloudObject.fromSpec(CloudSourceUtils.flattenBaseSpecs(readInstruction.getSource()).getSpec());
            // TODO: Need to plumb through the SDK specific function spec.
            transformSpec.setUrn(JAVA_SOURCE_URN);
            try {
                byte[] serializedSource = Base64.getDecoder().decode(getString(sourceSpec, SERIALIZED_SOURCE));
                ByteString sourceByteString = ByteString.copyFrom(serializedSource);
                transformSpec.setPayload(sourceByteString);
            } catch (Exception e) {
                throw new IllegalArgumentException(String.format("Unable to process Read %s", parallelInstruction), e);
            }
        } else if (parallelInstruction.getFlatten() != null) {
            transformSpec.setUrn(PTransformTranslation.FLATTEN_TRANSFORM_URN);
        } else {
            throw new IllegalArgumentException(String.format("Unknown type of ParallelInstruction %s", parallelInstruction));
        }
        // predecessor in a ParDo. This PCollection is called the "main input".
        for (Node predecessorOutput : input.predecessors(node)) {
            pTransform.putInputs("generatedInput" + idGenerator.getId(), nodesToPCollections.get(predecessorOutput));
        }
        for (Edge edge : input.outEdges(node)) {
            Node nodeOutput = input.incidentNodes(edge).target();
            MultiOutputInfoEdge edge2 = (MultiOutputInfoEdge) edge;
            pTransform.putOutputs(edge2.getMultiOutputInfo().getTag(), nodesToPCollections.get(nodeOutput));
        }
        pTransform.setSpec(transformSpec);
        PTransformNode pTransformNode = PipelineNode.pTransform(ptransformId, pTransform.build());
        executableStageTransforms.add(pTransformNode);
        for (String timerId : timerIds) {
            executableStageTimers.add(TimerReference.of(pTransformNode, timerId));
        }
        ImmutableMap<String, PCollectionNode> sideInputIdToPCollectionNodes = sideInputIds.build();
        for (String sideInputTag : sideInputIdToPCollectionNodes.keySet()) {
            SideInputReference sideInputReference = SideInputReference.of(pTransformNode, sideInputTag, sideInputIdToPCollectionNodes.get(sideInputTag));
            executableStageSideInputs.add(sideInputReference);
        }
        executableStageTransforms.add(pTransformNode);
    }
    if (executableStageInputs.size() != 1) {
        throw new UnsupportedOperationException("ExecutableStage only support one input PCollection");
    }
    PCollectionNode executableInput = executableStageInputs.iterator().next();
    RunnerApi.Components executableStageComponents = componentsBuilder.build();
    // Get Environment from ptransform, otherwise, use JAVA_SDK_HARNESS_ENVIRONMENT as default.
    Environment executableStageEnv = getEnvironmentFromPTransform(executableStageComponents, executableStageTransforms);
    if (executableStageEnv == null) {
        executableStageEnv = Environments.JAVA_SDK_HARNESS_ENVIRONMENT;
    }
    Set<UserStateReference> executableStageUserStateReference = new HashSet<>();
    for (UserStateId userStateId : userStateIds) {
        executableStageUserStateReference.add(UserStateReference.fromUserStateId(userStateId, executableStageComponents));
    }
    ExecutableStage executableStage = ImmutableExecutableStage.ofFullComponents(executableStageComponents, executableStageEnv, executableInput, executableStageSideInputs, executableStageUserStateReference, executableStageTimers, executableStageTransforms, executableStageOutputs, DEFAULT_WIRE_CODER_SETTINGS);
    return ExecutableStageNode.create(executableStage, ptransformIdToNameContexts.build(), ptransformIdToSideInputInfos.build(), ptransformIdToPCollectionViews.build());
}
Also used : HashMap(java.util.HashMap) MultiOutputInfo(com.google.api.services.dataflow.model.MultiOutputInfo) ByteString(org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString) PCollectionNode(org.apache.beam.runners.core.construction.graph.PipelineNode.PCollectionNode) InstructionOutputNode(org.apache.beam.runners.dataflow.worker.graph.Nodes.InstructionOutputNode) PTransformNode(org.apache.beam.runners.core.construction.graph.PipelineNode.PTransformNode) ParallelInstructionNode(org.apache.beam.runners.dataflow.worker.graph.Nodes.ParallelInstructionNode) Node(org.apache.beam.runners.dataflow.worker.graph.Nodes.Node) ExecutableStageNode(org.apache.beam.runners.dataflow.worker.graph.Nodes.ExecutableStageNode) PipelineNode(org.apache.beam.runners.core.construction.graph.PipelineNode) RemoteGrpcPortNode(org.apache.beam.runners.dataflow.worker.graph.Nodes.RemoteGrpcPortNode) InstructionOutput(com.google.api.services.dataflow.model.InstructionOutput) PTransformNode(org.apache.beam.runners.core.construction.graph.PipelineNode.PTransformNode) ArrayList(java.util.ArrayList) Structs.getString(org.apache.beam.runners.dataflow.util.Structs.getString) ByteString(org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString) RehydratedComponents(org.apache.beam.runners.core.construction.RehydratedComponents) SdkComponents(org.apache.beam.runners.core.construction.SdkComponents) RunnerApi(org.apache.beam.model.pipeline.v1.RunnerApi) SideInputReference(org.apache.beam.runners.core.construction.graph.SideInputReference) ImmutableExecutableStage(org.apache.beam.runners.core.construction.graph.ImmutableExecutableStage) ExecutableStage(org.apache.beam.runners.core.construction.graph.ExecutableStage) HashSet(java.util.HashSet) DefaultEdge(org.apache.beam.runners.dataflow.worker.graph.Edges.DefaultEdge) MultiOutputInfoEdge(org.apache.beam.runners.dataflow.worker.graph.Edges.MultiOutputInfoEdge) ImmutableMap(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap) ParDoInstruction(com.google.api.services.dataflow.model.ParDoInstruction) PCollectionView(org.apache.beam.sdk.values.PCollectionView) Environment(org.apache.beam.model.pipeline.v1.RunnerApi.Environment) GlobalWindow(org.apache.beam.sdk.transforms.windowing.GlobalWindow) ImmutableMap(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap) Map(java.util.Map) HashMap(java.util.HashMap) IntervalWindowCoder(org.apache.beam.sdk.transforms.windowing.IntervalWindow.IntervalWindowCoder) RemoteGrpcPortNode(org.apache.beam.runners.dataflow.worker.graph.Nodes.RemoteGrpcPortNode) TimerReference(org.apache.beam.runners.core.construction.graph.TimerReference) ImmutableList(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList) ParallelInstructionNode(org.apache.beam.runners.dataflow.worker.graph.Nodes.ParallelInstructionNode) SdkComponents(org.apache.beam.runners.core.construction.SdkComponents) ReadInstruction(com.google.api.services.dataflow.model.ReadInstruction) InstructionOutputNode(org.apache.beam.runners.dataflow.worker.graph.Nodes.InstructionOutputNode) FullWindowedValueCoder(org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder) NameContext(org.apache.beam.runners.dataflow.worker.counters.NameContext) InvalidProtocolBufferException(org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.InvalidProtocolBufferException) IOException(java.io.IOException) PCollectionNode(org.apache.beam.runners.core.construction.graph.PipelineNode.PCollectionNode) InvalidProtocolBufferException(org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.InvalidProtocolBufferException) IOException(java.io.IOException) ParallelInstruction(com.google.api.services.dataflow.model.ParallelInstruction) WindowedValueCoder(org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder) FullWindowedValueCoder(org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder) CloudObject(org.apache.beam.runners.dataflow.util.CloudObject) UserStateId(org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload.UserStateId) UserStateReference(org.apache.beam.runners.core.construction.graph.UserStateReference) Edge(org.apache.beam.runners.dataflow.worker.graph.Edges.Edge) MultiOutputInfoEdge(org.apache.beam.runners.dataflow.worker.graph.Edges.MultiOutputInfoEdge) DefaultEdge(org.apache.beam.runners.dataflow.worker.graph.Edges.DefaultEdge)

Example 8 with InstructionOutputNode

use of org.apache.beam.runners.dataflow.worker.graph.Nodes.InstructionOutputNode in project beam by apache.

the class CreateRegisterFnOperationFunction method apply.

@Override
public MutableNetwork<Node, Edge> apply(MutableNetwork<Node, Edge> network) {
    // Record all SDK nodes, and all root nodes.
    Set<Node> runnerRootNodes = new HashSet<>();
    Set<Node> sdkNodes = new HashSet<>();
    Set<Node> sdkRootNodes = new HashSet<>();
    for (ParallelInstructionNode node : Iterables.filter(network.nodes(), ParallelInstructionNode.class)) {
        if (executesInSdkHarness(node)) {
            sdkNodes.add(node);
            if (network.inDegree(node) == 0) {
                sdkRootNodes.add(node);
            }
        } else if (network.inDegree(node) == 0) {
            runnerRootNodes.add(node);
        }
    }
    // If nothing executes within the SDK harness, return the original network.
    if (sdkNodes.isEmpty()) {
        return network;
    }
    // Represents the set of nodes which represent gRPC boundaries from the Runner to the SDK.
    Set<Node> runnerToSdkBoundaries = new HashSet<>();
    // Represents the set of nodes which represent gRPC boundaries from the SDK to the Runner.
    Set<Node> sdkToRunnerBoundaries = new HashSet<>();
    ImmutableNetwork<Node, Edge> originalNetwork = ImmutableNetwork.copyOf(network);
    // flow from runner to SDK and SDK to runner per original output node.
    for (InstructionOutputNode outputNode : Iterables.filter(originalNetwork.nodes(), InstructionOutputNode.class)) {
        // Categorize all predecessor instructions
        Set<Node> predecessorRunnerInstructions = new HashSet<>();
        Set<Node> predecessorSdkInstructions = new HashSet<>();
        for (Node predecessorInstruction : originalNetwork.predecessors(outputNode)) {
            if (sdkNodes.contains(predecessorInstruction)) {
                predecessorSdkInstructions.add(predecessorInstruction);
            } else {
                predecessorRunnerInstructions.add(predecessorInstruction);
            }
        }
        // Categorize all successor instructions
        Set<Node> successorRunnerInstructions = new HashSet<>();
        Set<Node> successorSdkInstructions = new HashSet<>();
        for (Node successorInstruction : originalNetwork.successors(outputNode)) {
            if (sdkNodes.contains(successorInstruction)) {
                successorSdkInstructions.add(successorInstruction);
            } else {
                successorRunnerInstructions.add(successorInstruction);
            }
        }
        // nodes connected across a gRPC node. Also add the gRPC node as an SDK root.
        if (!predecessorRunnerInstructions.isEmpty() && !successorSdkInstructions.isEmpty()) {
            runnerToSdkBoundaries.add(rewireAcrossSdkRunnerPortNode(network, outputNode, predecessorRunnerInstructions, successorSdkInstructions));
        }
        // nodes connected across a gRPC node.
        if (!predecessorSdkInstructions.isEmpty() && !successorRunnerInstructions.isEmpty()) {
            sdkToRunnerBoundaries.add(rewireAcrossSdkRunnerPortNode(network, outputNode, predecessorSdkInstructions, successorRunnerInstructions));
        }
        // through the new output node.
        if (network.inDegree(outputNode) == 0) {
            network.removeNode(outputNode);
        }
    }
    // Create the subnetworks that represent potentially multiple fused SDK portions and a single
    // fused Runner portion replacing the SDK portion that is embedded within the Runner portion
    // with a RegisterFnOperation, adding edges to maintain proper happens before relationships.
    Set<Node> allRunnerNodes = Networks.reachableNodes(network, Sets.union(runnerRootNodes, sdkToRunnerBoundaries), runnerToSdkBoundaries);
    if (this.useExecutableStageBundleExecution) {
        // When using shared library, there is no grpc node in runner graph.
        allRunnerNodes = Sets.difference(allRunnerNodes, Sets.union(runnerToSdkBoundaries, sdkToRunnerBoundaries));
    }
    MutableNetwork<Node, Edge> runnerNetwork = Graphs.inducedSubgraph(network, allRunnerNodes);
    // using poison paths.
    for (Node sdkRoot : Sets.union(sdkRootNodes, runnerToSdkBoundaries)) {
        Set<Node> sdkSubnetworkNodes = Networks.reachableNodes(network, ImmutableSet.of(sdkRoot), sdkToRunnerBoundaries);
        MutableNetwork<Node, Edge> sdkNetwork = Graphs.inducedSubgraph(network, sdkSubnetworkNodes);
        Node registerFnNode = registerFnOperationFunction.apply(sdkNetwork);
        runnerNetwork.addNode(registerFnNode);
        // a successor.
        if (this.useExecutableStageBundleExecution) {
            // should be linked directly to 2 OutputInstruction nodes.
            for (Node predecessor : Sets.intersection(sdkSubnetworkNodes, runnerToSdkBoundaries)) {
                predecessor = network.predecessors(predecessor).iterator().next();
                runnerNetwork.addEdge(predecessor, registerFnNode, HappensBeforeEdge.create());
            }
            for (Node successor : Sets.intersection(sdkSubnetworkNodes, sdkToRunnerBoundaries)) {
                successor = network.successors(successor).iterator().next();
                runnerNetwork.addEdge(registerFnNode, successor, HappensBeforeEdge.create());
            }
        } else {
            for (Node predecessor : Sets.intersection(sdkSubnetworkNodes, runnerToSdkBoundaries)) {
                runnerNetwork.addEdge(predecessor, registerFnNode, HappensBeforeEdge.create());
            }
            for (Node successor : Sets.intersection(sdkSubnetworkNodes, sdkToRunnerBoundaries)) {
                runnerNetwork.addEdge(registerFnNode, successor, HappensBeforeEdge.create());
            }
        }
    }
    return runnerNetwork;
}
Also used : InstructionOutputNode(org.apache.beam.runners.dataflow.worker.graph.Nodes.InstructionOutputNode) Node(org.apache.beam.runners.dataflow.worker.graph.Nodes.Node) InstructionOutputNode(org.apache.beam.runners.dataflow.worker.graph.Nodes.InstructionOutputNode) ParallelInstructionNode(org.apache.beam.runners.dataflow.worker.graph.Nodes.ParallelInstructionNode) ParallelInstructionNode(org.apache.beam.runners.dataflow.worker.graph.Nodes.ParallelInstructionNode) HappensBeforeEdge(org.apache.beam.runners.dataflow.worker.graph.Edges.HappensBeforeEdge) Edge(org.apache.beam.runners.dataflow.worker.graph.Edges.Edge) MultiOutputInfoEdge(org.apache.beam.runners.dataflow.worker.graph.Edges.MultiOutputInfoEdge) HashSet(java.util.HashSet)

Example 9 with InstructionOutputNode

use of org.apache.beam.runners.dataflow.worker.graph.Nodes.InstructionOutputNode in project beam by apache.

the class InsertFetchAndFilterStreamingSideInputNodes method forNetwork.

public MutableNetwork<Node, Edge> forNetwork(MutableNetwork<Node, Edge> network) {
    if (pipeline == null) {
        return network;
    }
    RehydratedComponents rehydratedComponents = RehydratedComponents.forComponents(pipeline.getComponents());
    for (ParallelInstructionNode node : ImmutableList.copyOf(Iterables.filter(network.nodes(), ParallelInstructionNode.class))) {
        // to worry about it.
        if (node.getParallelInstruction().getParDo() == null || !ExecutionLocation.SDK_HARNESS.equals(node.getExecutionLocation())) {
            continue;
        }
        ParDoInstruction parDoInstruction = node.getParallelInstruction().getParDo();
        CloudObject userFnSpec = CloudObject.fromSpec(parDoInstruction.getUserFn());
        String parDoPTransformId = getString(userFnSpec, PropertyNames.SERIALIZED_FN);
        // Skip ParDoInstruction nodes that contain payloads without side inputs.
        String userFnClassName = userFnSpec.getClassName();
        if ("CombineValuesFn".equals(userFnClassName) || "KeyedCombineFn".equals(userFnClassName)) {
            // These nodes have CombinePayloads which have no side inputs.
            continue;
        }
        RunnerApi.PTransform parDoPTransform = pipeline.getComponents().getTransformsOrDefault(parDoPTransformId, null);
        // TODO: only the non-null branch should exist; for migration ease only
        if (parDoPTransform == null) {
            continue;
        }
        RunnerApi.ParDoPayload parDoPayload;
        try {
            parDoPayload = RunnerApi.ParDoPayload.parseFrom(parDoPTransform.getSpec().getPayload());
        } catch (InvalidProtocolBufferException exc) {
            throw new RuntimeException("ParDo did not have a ParDoPayload", exc);
        }
        // Skip any ParDo that doesn't have a side input.
        if (parDoPayload.getSideInputsMap().isEmpty()) {
            continue;
        }
        String mainInputPCollectionLocalName = Iterables.getOnlyElement(Sets.difference(parDoPTransform.getInputsMap().keySet(), parDoPayload.getSideInputsMap().keySet()));
        RunnerApi.WindowingStrategy windowingStrategyProto = pipeline.getComponents().getWindowingStrategiesOrThrow(pipeline.getComponents().getPcollectionsOrThrow(parDoPTransform.getInputsOrThrow(mainInputPCollectionLocalName)).getWindowingStrategyId());
        WindowingStrategy windowingStrategy;
        try {
            windowingStrategy = WindowingStrategyTranslation.fromProto(windowingStrategyProto, rehydratedComponents);
        } catch (InvalidProtocolBufferException e) {
            throw new IllegalStateException(String.format("Unable to decode windowing strategy %s.", windowingStrategyProto), e);
        }
        // Gather all the side input window mapping fns which we need to request the SDK to map
        ImmutableMap.Builder<PCollectionView<?>, RunnerApi.FunctionSpec> pCollectionViewsToWindowMapingsFns = ImmutableMap.builder();
        parDoPayload.getSideInputsMap().forEach((sideInputTag, sideInput) -> pCollectionViewsToWindowMapingsFns.put(RegisterNodeFunction.transformSideInputForRunner(pipeline, parDoPTransform, sideInputTag, sideInput), sideInput.getWindowMappingFn()));
        Node streamingSideInputWindowHandlerNode = FetchAndFilterStreamingSideInputsNode.create(windowingStrategy, pCollectionViewsToWindowMapingsFns.build(), NameContext.create(null, node.getParallelInstruction().getOriginalName(), node.getParallelInstruction().getSystemName(), node.getParallelInstruction().getName()));
        // Rewire the graph such that streaming side inputs ParDos are preceded by a
        // node which filters any side inputs that aren't ready and fetches any ready side inputs.
        Edge mainInput = Iterables.getOnlyElement(network.inEdges(node));
        InstructionOutputNode predecessor = (InstructionOutputNode) network.incidentNodes(mainInput).source();
        InstructionOutputNode predecessorCopy = InstructionOutputNode.create(predecessor.getInstructionOutput(), predecessor.getPcollectionId());
        network.removeEdge(mainInput);
        network.addNode(streamingSideInputWindowHandlerNode);
        network.addNode(predecessorCopy);
        network.addEdge(predecessor, streamingSideInputWindowHandlerNode, mainInput.clone());
        network.addEdge(streamingSideInputWindowHandlerNode, predecessorCopy, mainInput.clone());
        network.addEdge(predecessorCopy, node, mainInput.clone());
    }
    return network;
}
Also used : FetchAndFilterStreamingSideInputsNode(org.apache.beam.runners.dataflow.worker.graph.Nodes.FetchAndFilterStreamingSideInputsNode) InstructionOutputNode(org.apache.beam.runners.dataflow.worker.graph.Nodes.InstructionOutputNode) ParallelInstructionNode(org.apache.beam.runners.dataflow.worker.graph.Nodes.ParallelInstructionNode) Node(org.apache.beam.runners.dataflow.worker.graph.Nodes.Node) ParallelInstructionNode(org.apache.beam.runners.dataflow.worker.graph.Nodes.ParallelInstructionNode) InvalidProtocolBufferException(org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.InvalidProtocolBufferException) Structs.getString(org.apache.beam.runners.dataflow.util.Structs.getString) WindowingStrategy(org.apache.beam.sdk.values.WindowingStrategy) ImmutableMap(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap) ParDoInstruction(com.google.api.services.dataflow.model.ParDoInstruction) RunnerApi(org.apache.beam.model.pipeline.v1.RunnerApi) InstructionOutputNode(org.apache.beam.runners.dataflow.worker.graph.Nodes.InstructionOutputNode) PCollectionView(org.apache.beam.sdk.values.PCollectionView) CloudObject(org.apache.beam.runners.dataflow.util.CloudObject) RehydratedComponents(org.apache.beam.runners.core.construction.RehydratedComponents) Edge(org.apache.beam.runners.dataflow.worker.graph.Edges.Edge)

Example 10 with InstructionOutputNode

use of org.apache.beam.runners.dataflow.worker.graph.Nodes.InstructionOutputNode in project beam by apache.

the class MapTaskToNetworkFunction method apply.

@Override
public MutableNetwork<Node, Edge> apply(MapTask mapTask) {
    List<ParallelInstruction> parallelInstructions = Apiary.listOrEmpty(mapTask.getInstructions());
    MutableNetwork<Node, Edge> network = NetworkBuilder.directed().allowsSelfLoops(false).allowsParallelEdges(true).expectedNodeCount(parallelInstructions.size() * 2).build();
    // Add all the instruction nodes and output nodes
    ParallelInstructionNode[] instructionNodes = new ParallelInstructionNode[parallelInstructions.size()];
    InstructionOutputNode[][] outputNodes = new InstructionOutputNode[parallelInstructions.size()][];
    for (int i = 0; i < parallelInstructions.size(); ++i) {
        // InstructionOutputNode's are the source of truth on instruction outputs.
        // Clear the instruction's outputs to reduce chance for confusion.
        List<InstructionOutput> outputs = Apiary.listOrEmpty(parallelInstructions.get(i).getOutputs());
        outputNodes[i] = new InstructionOutputNode[outputs.size()];
        JsonFactory factory = MoreObjects.firstNonNull(mapTask.getFactory(), Transport.getJsonFactory());
        ParallelInstruction parallelInstruction = clone(factory, parallelInstructions.get(i)).setOutputs(null);
        ParallelInstructionNode instructionNode = ParallelInstructionNode.create(parallelInstruction, Nodes.ExecutionLocation.UNKNOWN);
        instructionNodes[i] = instructionNode;
        network.addNode(instructionNode);
        // Connect the instruction node output to the output PCollection node
        for (int j = 0; j < outputs.size(); ++j) {
            InstructionOutput instructionOutput = outputs.get(j);
            InstructionOutputNode outputNode = InstructionOutputNode.create(instructionOutput, "generatedPcollection" + this.idGenerator.getId());
            network.addNode(outputNode);
            if (parallelInstruction.getParDo() != null) {
                network.addEdge(instructionNode, outputNode, MultiOutputInfoEdge.create(parallelInstruction.getParDo().getMultiOutputInfos().get(j)));
            } else {
                network.addEdge(instructionNode, outputNode, DefaultEdge.create());
            }
            outputNodes[i][j] = outputNode;
        }
    }
    // Connect PCollections as inputs to instructions
    for (ParallelInstructionNode instructionNode : instructionNodes) {
        ParallelInstruction parallelInstruction = instructionNode.getParallelInstruction();
        if (parallelInstruction.getFlatten() != null) {
            for (InstructionInput input : Apiary.listOrEmpty(parallelInstruction.getFlatten().getInputs())) {
                attachInput(input, network, instructionNode, outputNodes);
            }
        } else if (parallelInstruction.getParDo() != null) {
            attachInput(parallelInstruction.getParDo().getInput(), network, instructionNode, outputNodes);
        } else if (parallelInstruction.getPartialGroupByKey() != null) {
            attachInput(parallelInstruction.getPartialGroupByKey().getInput(), network, instructionNode, outputNodes);
        } else if (parallelInstruction.getRead() != null) {
        // Reads have no inputs so nothing to do
        } else if (parallelInstruction.getWrite() != null) {
            attachInput(parallelInstruction.getWrite().getInput(), network, instructionNode, outputNodes);
        } else {
            throw new IllegalArgumentException(String.format("Unknown type of instruction %s for map task %s", parallelInstruction, mapTask));
        }
    }
    return network;
}
Also used : Node(org.apache.beam.runners.dataflow.worker.graph.Nodes.Node) InstructionOutputNode(org.apache.beam.runners.dataflow.worker.graph.Nodes.InstructionOutputNode) ParallelInstructionNode(org.apache.beam.runners.dataflow.worker.graph.Nodes.ParallelInstructionNode) ParallelInstructionNode(org.apache.beam.runners.dataflow.worker.graph.Nodes.ParallelInstructionNode) InstructionOutput(com.google.api.services.dataflow.model.InstructionOutput) JsonFactory(com.google.api.client.json.JsonFactory) ParallelInstruction(com.google.api.services.dataflow.model.ParallelInstruction) InstructionOutputNode(org.apache.beam.runners.dataflow.worker.graph.Nodes.InstructionOutputNode) InstructionInput(com.google.api.services.dataflow.model.InstructionInput) Edge(org.apache.beam.runners.dataflow.worker.graph.Edges.Edge) MultiOutputInfoEdge(org.apache.beam.runners.dataflow.worker.graph.Edges.MultiOutputInfoEdge) DefaultEdge(org.apache.beam.runners.dataflow.worker.graph.Edges.DefaultEdge)

Aggregations

InstructionOutputNode (org.apache.beam.runners.dataflow.worker.graph.Nodes.InstructionOutputNode)21 Edge (org.apache.beam.runners.dataflow.worker.graph.Edges.Edge)19 Node (org.apache.beam.runners.dataflow.worker.graph.Nodes.Node)19 ParallelInstructionNode (org.apache.beam.runners.dataflow.worker.graph.Nodes.ParallelInstructionNode)19 DefaultEdge (org.apache.beam.runners.dataflow.worker.graph.Edges.DefaultEdge)14 InstructionOutput (com.google.api.services.dataflow.model.InstructionOutput)13 MultiOutputInfoEdge (org.apache.beam.runners.dataflow.worker.graph.Edges.MultiOutputInfoEdge)13 ParallelInstruction (com.google.api.services.dataflow.model.ParallelInstruction)11 Test (org.junit.Test)10 ReadInstruction (com.google.api.services.dataflow.model.ReadInstruction)9 MapTask (com.google.api.services.dataflow.model.MapTask)8 MultiOutputInfo (com.google.api.services.dataflow.model.MultiOutputInfo)5 ParDoInstruction (com.google.api.services.dataflow.model.ParDoInstruction)5 RunnerApi (org.apache.beam.model.pipeline.v1.RunnerApi)5 RemoteGrpcPortNode (org.apache.beam.runners.dataflow.worker.graph.Nodes.RemoteGrpcPortNode)5 CloudObject (org.apache.beam.runners.dataflow.util.CloudObject)4 ImmutableMap (org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap)4 WriteInstruction (com.google.api.services.dataflow.model.WriteInstruction)3 IOException (java.io.IOException)3 ArrayList (java.util.ArrayList)3