Search in sources :

Example 6 with SideInput

use of org.apache.beam.model.pipeline.v1.RunnerApi.SideInput in project beam by apache.

the class InsertFetchAndFilterStreamingSideInputNodes method forNetwork.

public MutableNetwork<Node, Edge> forNetwork(MutableNetwork<Node, Edge> network) {
    if (pipeline == null) {
        return network;
    }
    RehydratedComponents rehydratedComponents = RehydratedComponents.forComponents(pipeline.getComponents());
    for (ParallelInstructionNode node : ImmutableList.copyOf(Iterables.filter(network.nodes(), ParallelInstructionNode.class))) {
        // to worry about it.
        if (node.getParallelInstruction().getParDo() == null || !ExecutionLocation.SDK_HARNESS.equals(node.getExecutionLocation())) {
            continue;
        }
        ParDoInstruction parDoInstruction = node.getParallelInstruction().getParDo();
        CloudObject userFnSpec = CloudObject.fromSpec(parDoInstruction.getUserFn());
        String parDoPTransformId = getString(userFnSpec, PropertyNames.SERIALIZED_FN);
        // Skip ParDoInstruction nodes that contain payloads without side inputs.
        String userFnClassName = userFnSpec.getClassName();
        if ("CombineValuesFn".equals(userFnClassName) || "KeyedCombineFn".equals(userFnClassName)) {
            // These nodes have CombinePayloads which have no side inputs.
            continue;
        }
        RunnerApi.PTransform parDoPTransform = pipeline.getComponents().getTransformsOrDefault(parDoPTransformId, null);
        // TODO: only the non-null branch should exist; for migration ease only
        if (parDoPTransform == null) {
            continue;
        }
        RunnerApi.ParDoPayload parDoPayload;
        try {
            parDoPayload = RunnerApi.ParDoPayload.parseFrom(parDoPTransform.getSpec().getPayload());
        } catch (InvalidProtocolBufferException exc) {
            throw new RuntimeException("ParDo did not have a ParDoPayload", exc);
        }
        // Skip any ParDo that doesn't have a side input.
        if (parDoPayload.getSideInputsMap().isEmpty()) {
            continue;
        }
        String mainInputPCollectionLocalName = Iterables.getOnlyElement(Sets.difference(parDoPTransform.getInputsMap().keySet(), parDoPayload.getSideInputsMap().keySet()));
        RunnerApi.WindowingStrategy windowingStrategyProto = pipeline.getComponents().getWindowingStrategiesOrThrow(pipeline.getComponents().getPcollectionsOrThrow(parDoPTransform.getInputsOrThrow(mainInputPCollectionLocalName)).getWindowingStrategyId());
        WindowingStrategy windowingStrategy;
        try {
            windowingStrategy = WindowingStrategyTranslation.fromProto(windowingStrategyProto, rehydratedComponents);
        } catch (InvalidProtocolBufferException e) {
            throw new IllegalStateException(String.format("Unable to decode windowing strategy %s.", windowingStrategyProto), e);
        }
        // Gather all the side input window mapping fns which we need to request the SDK to map
        ImmutableMap.Builder<PCollectionView<?>, RunnerApi.FunctionSpec> pCollectionViewsToWindowMapingsFns = ImmutableMap.builder();
        parDoPayload.getSideInputsMap().forEach((sideInputTag, sideInput) -> pCollectionViewsToWindowMapingsFns.put(RegisterNodeFunction.transformSideInputForRunner(pipeline, parDoPTransform, sideInputTag, sideInput), sideInput.getWindowMappingFn()));
        Node streamingSideInputWindowHandlerNode = FetchAndFilterStreamingSideInputsNode.create(windowingStrategy, pCollectionViewsToWindowMapingsFns.build(), NameContext.create(null, node.getParallelInstruction().getOriginalName(), node.getParallelInstruction().getSystemName(), node.getParallelInstruction().getName()));
        // Rewire the graph such that streaming side inputs ParDos are preceded by a
        // node which filters any side inputs that aren't ready and fetches any ready side inputs.
        Edge mainInput = Iterables.getOnlyElement(network.inEdges(node));
        InstructionOutputNode predecessor = (InstructionOutputNode) network.incidentNodes(mainInput).source();
        InstructionOutputNode predecessorCopy = InstructionOutputNode.create(predecessor.getInstructionOutput(), predecessor.getPcollectionId());
        network.removeEdge(mainInput);
        network.addNode(streamingSideInputWindowHandlerNode);
        network.addNode(predecessorCopy);
        network.addEdge(predecessor, streamingSideInputWindowHandlerNode, mainInput.clone());
        network.addEdge(streamingSideInputWindowHandlerNode, predecessorCopy, mainInput.clone());
        network.addEdge(predecessorCopy, node, mainInput.clone());
    }
    return network;
}
Also used : FetchAndFilterStreamingSideInputsNode(org.apache.beam.runners.dataflow.worker.graph.Nodes.FetchAndFilterStreamingSideInputsNode) InstructionOutputNode(org.apache.beam.runners.dataflow.worker.graph.Nodes.InstructionOutputNode) ParallelInstructionNode(org.apache.beam.runners.dataflow.worker.graph.Nodes.ParallelInstructionNode) Node(org.apache.beam.runners.dataflow.worker.graph.Nodes.Node) ParallelInstructionNode(org.apache.beam.runners.dataflow.worker.graph.Nodes.ParallelInstructionNode) InvalidProtocolBufferException(org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.InvalidProtocolBufferException) Structs.getString(org.apache.beam.runners.dataflow.util.Structs.getString) WindowingStrategy(org.apache.beam.sdk.values.WindowingStrategy) ImmutableMap(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap) ParDoInstruction(com.google.api.services.dataflow.model.ParDoInstruction) RunnerApi(org.apache.beam.model.pipeline.v1.RunnerApi) InstructionOutputNode(org.apache.beam.runners.dataflow.worker.graph.Nodes.InstructionOutputNode) PCollectionView(org.apache.beam.sdk.values.PCollectionView) CloudObject(org.apache.beam.runners.dataflow.util.CloudObject) RehydratedComponents(org.apache.beam.runners.core.construction.RehydratedComponents) Edge(org.apache.beam.runners.dataflow.worker.graph.Edges.Edge)

Example 7 with SideInput

use of org.apache.beam.model.pipeline.v1.RunnerApi.SideInput in project beam by apache.

the class RegisterNodeFunction method transformSideInputForRunner.

/**
 * Returns an artificial PCollectionView that can be used to fulfill API requirements of a {@link
 * SideInputReader} when used inside the Dataflow runner harness.
 *
 * <p>Generates length prefixed coder variants suitable to be used within the Dataflow Runner
 * harness so that encoding and decoding values matches the length prefixing that occurred when
 * materializing the side input.
 */
public static final PCollectionView<?> transformSideInputForRunner(RunnerApi.Pipeline pipeline, RunnerApi.PTransform parDoPTransform, String sideInputTag, RunnerApi.SideInput sideInput) {
    checkArgument(Materializations.MULTIMAP_MATERIALIZATION_URN.equals(sideInput.getAccessPattern().getUrn()), "This handler is only capable of dealing with %s materializations " + "but was asked to handle %s for PCollectionView with tag %s.", Materializations.MULTIMAP_MATERIALIZATION_URN, sideInput.getAccessPattern().getUrn(), sideInputTag);
    String sideInputPCollectionId = parDoPTransform.getInputsOrThrow(sideInputTag);
    RunnerApi.PCollection sideInputPCollection = pipeline.getComponents().getPcollectionsOrThrow(sideInputPCollectionId);
    try {
        FullWindowedValueCoder<KV<Object, Object>> runnerSideInputCoder = (FullWindowedValueCoder) WireCoders.instantiateRunnerWireCoder(PipelineNode.pCollection(sideInputPCollectionId, sideInputPCollection), pipeline.getComponents());
        return DataflowPortabilityPCollectionView.with(new TupleTag<>(sideInputTag), runnerSideInputCoder);
    } catch (IOException e) {
        throw new IllegalStateException("Unable to translate proto to coder", e);
    }
}
Also used : RunnerApi(org.apache.beam.model.pipeline.v1.RunnerApi) FullWindowedValueCoder(org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder) Structs.getString(org.apache.beam.runners.dataflow.util.Structs.getString) ByteString(org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString) KV(org.apache.beam.sdk.values.KV) IOException(java.io.IOException)

Example 8 with SideInput

use of org.apache.beam.model.pipeline.v1.RunnerApi.SideInput in project beam by apache.

the class RegisterNodeFunction method apply.

@Override
public Node apply(MutableNetwork<Node, Edge> input) {
    for (Node node : input.nodes()) {
        if (node instanceof RemoteGrpcPortNode || node instanceof ParallelInstructionNode || node instanceof InstructionOutputNode) {
            continue;
        }
        throw new IllegalArgumentException(String.format("Network contains unknown type of node: %s", input));
    }
    // Fix all non output nodes to have named edges.
    for (Node node : input.nodes()) {
        if (node instanceof InstructionOutputNode) {
            continue;
        }
        for (Node successor : input.successors(node)) {
            for (Edge edge : input.edgesConnecting(node, successor)) {
                if (edge instanceof DefaultEdge) {
                    input.removeEdge(edge);
                    input.addEdge(node, successor, MultiOutputInfoEdge.create(new MultiOutputInfo().setTag(idGenerator.getId())));
                }
            }
        }
    }
    // We start off by replacing all edges within the graph with edges that have the named
    // outputs from the predecessor step. For ParallelInstruction Source nodes and RemoteGrpcPort
    // nodes this is a generated port id. All ParDoInstructions will have already
    ProcessBundleDescriptor.Builder processBundleDescriptor = ProcessBundleDescriptor.newBuilder().setId(idGenerator.getId()).setStateApiServiceDescriptor(stateApiServiceDescriptor);
    // For intermediate PCollections we fabricate, we make a bogus WindowingStrategy
    // TODO: create a correct windowing strategy, including coders and environment
    SdkComponents sdkComponents = SdkComponents.create(pipeline.getComponents(), null);
    // Default to use the Java environment if pipeline doesn't have environment specified.
    if (pipeline.getComponents().getEnvironmentsMap().isEmpty()) {
        sdkComponents.registerEnvironment(Environments.JAVA_SDK_HARNESS_ENVIRONMENT);
    }
    String fakeWindowingStrategyId = "fakeWindowingStrategy" + idGenerator.getId();
    try {
        RunnerApi.MessageWithComponents fakeWindowingStrategyProto = WindowingStrategyTranslation.toMessageProto(WindowingStrategy.globalDefault(), sdkComponents);
        processBundleDescriptor.putWindowingStrategies(fakeWindowingStrategyId, fakeWindowingStrategyProto.getWindowingStrategy()).putAllCoders(fakeWindowingStrategyProto.getComponents().getCodersMap()).putAllEnvironments(fakeWindowingStrategyProto.getComponents().getEnvironmentsMap());
    } catch (IOException exc) {
        throw new RuntimeException("Could not convert default windowing stratey to proto", exc);
    }
    Map<Node, String> nodesToPCollections = new HashMap<>();
    ImmutableMap.Builder<String, NameContext> ptransformIdToNameContexts = ImmutableMap.builder();
    ImmutableMap.Builder<String, Iterable<SideInputInfo>> ptransformIdToSideInputInfos = ImmutableMap.builder();
    ImmutableMap.Builder<String, Iterable<PCollectionView<?>>> ptransformIdToPCollectionViews = ImmutableMap.builder();
    ImmutableMap.Builder<String, NameContext> pcollectionIdToNameContexts = ImmutableMap.builder();
    ImmutableMap.Builder<InstructionOutputNode, String> instructionOutputNodeToCoderIdBuilder = ImmutableMap.builder();
    // 2. Generate new PCollectionId and register it with ProcessBundleDescriptor.
    for (InstructionOutputNode node : Iterables.filter(input.nodes(), InstructionOutputNode.class)) {
        InstructionOutput instructionOutput = node.getInstructionOutput();
        String coderId = "generatedCoder" + idGenerator.getId();
        instructionOutputNodeToCoderIdBuilder.put(node, coderId);
        try (ByteString.Output output = ByteString.newOutput()) {
            try {
                Coder<?> javaCoder = CloudObjects.coderFromCloudObject(CloudObject.fromSpec(instructionOutput.getCodec()));
                sdkComponents.registerCoder(javaCoder);
                RunnerApi.Coder coderProto = CoderTranslation.toProto(javaCoder, sdkComponents);
                processBundleDescriptor.putCoders(coderId, coderProto);
            } catch (IOException e) {
                throw new IllegalArgumentException(String.format("Unable to encode coder %s for output %s", instructionOutput.getCodec(), instructionOutput), e);
            } catch (Exception e) {
                // Coder probably wasn't a java coder
                OBJECT_MAPPER.writeValue(output, instructionOutput.getCodec());
                processBundleDescriptor.putCoders(coderId, RunnerApi.Coder.newBuilder().setSpec(RunnerApi.FunctionSpec.newBuilder().setPayload(output.toByteString())).build());
            }
        } catch (IOException e) {
            throw new IllegalArgumentException(String.format("Unable to encode coder %s for output %s", instructionOutput.getCodec(), instructionOutput), e);
        }
        // Generate new PCollection ID and map it to relevant node.
        // Will later be used to fill PTransform inputs/outputs information.
        String pcollectionId = "generatedPcollection" + idGenerator.getId();
        processBundleDescriptor.putPcollections(pcollectionId, RunnerApi.PCollection.newBuilder().setCoderId(coderId).setWindowingStrategyId(fakeWindowingStrategyId).build());
        nodesToPCollections.put(node, pcollectionId);
        pcollectionIdToNameContexts.put(pcollectionId, NameContext.create(null, instructionOutput.getOriginalName(), instructionOutput.getSystemName(), instructionOutput.getName()));
    }
    processBundleDescriptor.putAllCoders(sdkComponents.toComponents().getCodersMap());
    Map<InstructionOutputNode, String> instructionOutputNodeToCoderIdMap = instructionOutputNodeToCoderIdBuilder.build();
    for (ParallelInstructionNode node : Iterables.filter(input.nodes(), ParallelInstructionNode.class)) {
        ParallelInstruction parallelInstruction = node.getParallelInstruction();
        String ptransformId = "generatedPtransform" + idGenerator.getId();
        ptransformIdToNameContexts.put(ptransformId, NameContext.create(null, parallelInstruction.getOriginalName(), parallelInstruction.getSystemName(), parallelInstruction.getName()));
        RunnerApi.PTransform.Builder pTransform = RunnerApi.PTransform.newBuilder();
        RunnerApi.FunctionSpec.Builder transformSpec = RunnerApi.FunctionSpec.newBuilder();
        if (parallelInstruction.getParDo() != null) {
            ParDoInstruction parDoInstruction = parallelInstruction.getParDo();
            CloudObject userFnSpec = CloudObject.fromSpec(parDoInstruction.getUserFn());
            String userFnClassName = userFnSpec.getClassName();
            if ("CombineValuesFn".equals(userFnClassName) || "KeyedCombineFn".equals(userFnClassName)) {
                transformSpec = transformCombineValuesFnToFunctionSpec(userFnSpec);
                ptransformIdToPCollectionViews.put(ptransformId, Collections.emptyList());
            } else {
                String parDoPTransformId = getString(userFnSpec, PropertyNames.SERIALIZED_FN);
                RunnerApi.PTransform parDoPTransform = pipeline.getComponents().getTransformsOrDefault(parDoPTransformId, null);
                // TODO: only the non-null branch should exist; for migration ease only
                if (parDoPTransform != null) {
                    checkArgument(parDoPTransform.getSpec().getUrn().equals(PTransformTranslation.PAR_DO_TRANSFORM_URN), "Found transform \"%s\" for ParallelDo instruction, " + " but that transform had unexpected URN \"%s\" (expected \"%s\")", parDoPTransformId, parDoPTransform.getSpec().getUrn(), PTransformTranslation.PAR_DO_TRANSFORM_URN);
                    RunnerApi.ParDoPayload parDoPayload;
                    try {
                        parDoPayload = RunnerApi.ParDoPayload.parseFrom(parDoPTransform.getSpec().getPayload());
                    } catch (InvalidProtocolBufferException exc) {
                        throw new RuntimeException("ParDo did not have a ParDoPayload", exc);
                    }
                    ImmutableList.Builder<PCollectionView<?>> pcollectionViews = ImmutableList.builder();
                    for (Map.Entry<String, SideInput> sideInputEntry : parDoPayload.getSideInputsMap().entrySet()) {
                        pcollectionViews.add(transformSideInputForRunner(pipeline, parDoPTransform, sideInputEntry.getKey(), sideInputEntry.getValue()));
                        transformSideInputForSdk(pipeline, parDoPTransform, sideInputEntry.getKey(), processBundleDescriptor, pTransform);
                    }
                    ptransformIdToPCollectionViews.put(ptransformId, pcollectionViews.build());
                    transformSpec.setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN).setPayload(parDoPayload.toByteString());
                } else {
                    // legacy path - bytes are the FunctionSpec's payload field, basically, and
                    // SDKs expect it in the PTransform's payload field
                    byte[] userFnBytes = getBytes(userFnSpec, PropertyNames.SERIALIZED_FN);
                    transformSpec.setUrn(ParDoTranslation.CUSTOM_JAVA_DO_FN_URN).setPayload(ByteString.copyFrom(userFnBytes));
                }
                // Add side input information for batch pipelines
                if (parDoInstruction.getSideInputs() != null) {
                    ptransformIdToSideInputInfos.put(ptransformId, forSideInputInfos(parDoInstruction.getSideInputs(), true));
                }
            }
        } else if (parallelInstruction.getRead() != null) {
            ReadInstruction readInstruction = parallelInstruction.getRead();
            CloudObject sourceSpec = CloudObject.fromSpec(CloudSourceUtils.flattenBaseSpecs(readInstruction.getSource()).getSpec());
            // TODO: Need to plumb through the SDK specific function spec.
            transformSpec.setUrn(JAVA_SOURCE_URN);
            try {
                byte[] serializedSource = Base64.getDecoder().decode(getString(sourceSpec, SERIALIZED_SOURCE));
                ByteString sourceByteString = ByteString.copyFrom(serializedSource);
                transformSpec.setPayload(sourceByteString);
            } catch (Exception e) {
                throw new IllegalArgumentException(String.format("Unable to process Read %s", parallelInstruction), e);
            }
        } else if (parallelInstruction.getFlatten() != null) {
            transformSpec.setUrn(PTransformTranslation.FLATTEN_TRANSFORM_URN);
        } else {
            throw new IllegalArgumentException(String.format("Unknown type of ParallelInstruction %s", parallelInstruction));
        }
        for (Node predecessorOutput : input.predecessors(node)) {
            pTransform.putInputs("generatedInput" + idGenerator.getId(), nodesToPCollections.get(predecessorOutput));
        }
        for (Edge edge : input.outEdges(node)) {
            Node nodeOutput = input.incidentNodes(edge).target();
            MultiOutputInfoEdge edge2 = (MultiOutputInfoEdge) edge;
            pTransform.putOutputs(edge2.getMultiOutputInfo().getTag(), nodesToPCollections.get(nodeOutput));
        }
        pTransform.setSpec(transformSpec);
        processBundleDescriptor.putTransforms(ptransformId, pTransform.build());
    }
    // Add the PTransforms representing the remote gRPC nodes
    for (RemoteGrpcPortNode node : Iterables.filter(input.nodes(), RemoteGrpcPortNode.class)) {
        RunnerApi.PTransform.Builder pTransform = RunnerApi.PTransform.newBuilder();
        Set<Node> predecessors = input.predecessors(node);
        Set<Node> successors = input.successors(node);
        if (predecessors.isEmpty() && !successors.isEmpty()) {
            Node instructionOutputNode = Iterables.getOnlyElement(successors);
            pTransform.putOutputs("generatedOutput" + idGenerator.getId(), nodesToPCollections.get(instructionOutputNode));
            pTransform.setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(DATA_INPUT_URN).setPayload(node.getRemoteGrpcPort().toBuilder().setCoderId(instructionOutputNodeToCoderIdMap.get(instructionOutputNode)).build().toByteString()).build());
        } else if (!predecessors.isEmpty() && successors.isEmpty()) {
            Node instructionOutputNode = Iterables.getOnlyElement(predecessors);
            pTransform.putInputs("generatedInput" + idGenerator.getId(), nodesToPCollections.get(instructionOutputNode));
            pTransform.setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(DATA_OUTPUT_URN).setPayload(node.getRemoteGrpcPort().toBuilder().setCoderId(instructionOutputNodeToCoderIdMap.get(instructionOutputNode)).build().toByteString()).build());
        } else {
            throw new IllegalStateException("Expected either one input OR one output " + "InstructionOutputNode for this RemoteGrpcPortNode");
        }
        processBundleDescriptor.putTransforms(node.getPrimitiveTransformId(), pTransform.build());
    }
    return RegisterRequestNode.create(RegisterRequest.newBuilder().addProcessBundleDescriptor(processBundleDescriptor).build(), ptransformIdToNameContexts.build(), ptransformIdToSideInputInfos.build(), ptransformIdToPCollectionViews.build(), pcollectionIdToNameContexts.build());
}
Also used : HashMap(java.util.HashMap) MultiOutputInfo(com.google.api.services.dataflow.model.MultiOutputInfo) ByteString(org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString) RegisterRequestNode(org.apache.beam.runners.dataflow.worker.graph.Nodes.RegisterRequestNode) InstructionOutputNode(org.apache.beam.runners.dataflow.worker.graph.Nodes.InstructionOutputNode) ParallelInstructionNode(org.apache.beam.runners.dataflow.worker.graph.Nodes.ParallelInstructionNode) Node(org.apache.beam.runners.dataflow.worker.graph.Nodes.Node) PipelineNode(org.apache.beam.runners.core.construction.graph.PipelineNode) RemoteGrpcPortNode(org.apache.beam.runners.dataflow.worker.graph.Nodes.RemoteGrpcPortNode) InstructionOutput(com.google.api.services.dataflow.model.InstructionOutput) Structs.getString(org.apache.beam.runners.dataflow.util.Structs.getString) ByteString(org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString) RunnerApi(org.apache.beam.model.pipeline.v1.RunnerApi) SideInput(org.apache.beam.model.pipeline.v1.RunnerApi.SideInput) DefaultEdge(org.apache.beam.runners.dataflow.worker.graph.Edges.DefaultEdge) MultiOutputInfoEdge(org.apache.beam.runners.dataflow.worker.graph.Edges.MultiOutputInfoEdge) ImmutableMap(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap) ParDoInstruction(com.google.api.services.dataflow.model.ParDoInstruction) DataflowPortabilityPCollectionView(org.apache.beam.runners.dataflow.worker.DataflowPortabilityPCollectionView) PCollectionView(org.apache.beam.sdk.values.PCollectionView) ImmutableMap(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap) Map(java.util.Map) HashMap(java.util.HashMap) RemoteGrpcPortNode(org.apache.beam.runners.dataflow.worker.graph.Nodes.RemoteGrpcPortNode) ProcessBundleDescriptor(org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleDescriptor) ImmutableList(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList) ParallelInstructionNode(org.apache.beam.runners.dataflow.worker.graph.Nodes.ParallelInstructionNode) SdkComponents(org.apache.beam.runners.core.construction.SdkComponents) ReadInstruction(com.google.api.services.dataflow.model.ReadInstruction) InstructionOutputNode(org.apache.beam.runners.dataflow.worker.graph.Nodes.InstructionOutputNode) NameContext(org.apache.beam.runners.dataflow.worker.counters.NameContext) InvalidProtocolBufferException(org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.InvalidProtocolBufferException) IOException(java.io.IOException) InvalidProtocolBufferException(org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.InvalidProtocolBufferException) IOException(java.io.IOException) ParallelInstruction(com.google.api.services.dataflow.model.ParallelInstruction) CloudObject(org.apache.beam.runners.dataflow.util.CloudObject) Edge(org.apache.beam.runners.dataflow.worker.graph.Edges.Edge) MultiOutputInfoEdge(org.apache.beam.runners.dataflow.worker.graph.Edges.MultiOutputInfoEdge) DefaultEdge(org.apache.beam.runners.dataflow.worker.graph.Edges.DefaultEdge)

Example 9 with SideInput

use of org.apache.beam.model.pipeline.v1.RunnerApi.SideInput in project beam by apache.

the class ExecutableStageTest method testRoundTripToFromTransform.

@Test
public void testRoundTripToFromTransform() throws Exception {
    Environment env = org.apache.beam.runners.core.construction.Environments.createDockerEnvironment("foo");
    PTransform pt = PTransform.newBuilder().putInputs("input", "input.out").putInputs("side_input", "sideInput.in").putInputs("timer", "timer.out").putOutputs("output", "output.out").putOutputs("timer", "timer.out").setSpec(FunctionSpec.newBuilder().setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN).setPayload(ParDoPayload.newBuilder().setDoFn(FunctionSpec.newBuilder()).putSideInputs("side_input", SideInput.getDefaultInstance()).putStateSpecs("user_state", StateSpec.getDefaultInstance()).putTimerFamilySpecs("timer", TimerFamilySpec.getDefaultInstance()).build().toByteString())).setEnvironmentId("foo").build();
    PCollection input = PCollection.newBuilder().setUniqueName("input.out").build();
    PCollection sideInput = PCollection.newBuilder().setUniqueName("sideInput.in").build();
    PCollection timer = PCollection.newBuilder().setUniqueName("timer.out").build();
    PCollection output = PCollection.newBuilder().setUniqueName("output.out").build();
    Components components = Components.newBuilder().putTransforms("pt", pt).putPcollections("input.out", input).putPcollections("sideInput.in", sideInput).putPcollections("timer.out", timer).putPcollections("output.out", output).putEnvironments("foo", env).build();
    PTransformNode transformNode = PipelineNode.pTransform("pt", pt);
    SideInputReference sideInputRef = SideInputReference.of(transformNode, "side_input", PipelineNode.pCollection("sideInput.in", sideInput));
    UserStateReference userStateRef = UserStateReference.of(transformNode, "user_state", PipelineNode.pCollection("input.out", input));
    TimerReference timerRef = TimerReference.of(transformNode, "timer");
    ImmutableExecutableStage stage = ImmutableExecutableStage.of(components, env, PipelineNode.pCollection("input.out", input), Collections.singleton(sideInputRef), Collections.singleton(userStateRef), Collections.singleton(timerRef), Collections.singleton(PipelineNode.pTransform("pt", pt)), Collections.singleton(PipelineNode.pCollection("output.out", output)), DEFAULT_WIRE_CODER_SETTINGS);
    PTransform stagePTransform = stage.toPTransform("foo");
    assertThat(stagePTransform.getOutputsMap(), hasValue("output.out"));
    assertThat(stagePTransform.getOutputsCount(), equalTo(1));
    assertThat(stagePTransform.getInputsMap(), allOf(hasValue("input.out"), hasValue("sideInput.in")));
    assertThat(stagePTransform.getInputsCount(), equalTo(2));
    ExecutableStagePayload payload = ExecutableStagePayload.parseFrom(stagePTransform.getSpec().getPayload());
    assertThat(payload.getTransformsList(), contains("pt"));
    assertThat(ExecutableStage.fromPayload(payload), equalTo(stage));
}
Also used : Components(org.apache.beam.model.pipeline.v1.RunnerApi.Components) PCollection(org.apache.beam.model.pipeline.v1.RunnerApi.PCollection) ExecutableStagePayload(org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload) PTransformNode(org.apache.beam.runners.core.construction.graph.PipelineNode.PTransformNode) Environment(org.apache.beam.model.pipeline.v1.RunnerApi.Environment) PTransform(org.apache.beam.model.pipeline.v1.RunnerApi.PTransform) Test(org.junit.Test)

Example 10 with SideInput

use of org.apache.beam.model.pipeline.v1.RunnerApi.SideInput in project beam by apache.

the class FlinkStreamingPortablePipelineTranslator method transformSideInputs.

private TransformedSideInputs transformSideInputs(RunnerApi.ExecutableStagePayload stagePayload, RunnerApi.Components components, StreamingTranslationContext context) {
    LinkedHashMap<RunnerApi.ExecutableStagePayload.SideInputId, PCollectionView<?>> sideInputs = getSideInputIdToPCollectionViewMap(stagePayload, components);
    Map<TupleTag<?>, Integer> tagToIntMapping = new HashMap<>();
    Map<Integer, PCollectionView<?>> intToViewMapping = new HashMap<>();
    List<WindowedValueCoder<KV<Void, Object>>> kvCoders = new ArrayList<>();
    List<Coder<?>> viewCoders = new ArrayList<>();
    int count = 0;
    for (Map.Entry<RunnerApi.ExecutableStagePayload.SideInputId, PCollectionView<?>> sideInput : sideInputs.entrySet()) {
        TupleTag<?> tag = sideInput.getValue().getTagInternal();
        intToViewMapping.put(count, sideInput.getValue());
        tagToIntMapping.put(tag, count);
        count++;
        String collectionId = components.getTransformsOrThrow(sideInput.getKey().getTransformId()).getInputsOrThrow(sideInput.getKey().getLocalName());
        DataStream<Object> sideInputStream = context.getDataStreamOrThrow(collectionId);
        TypeInformation<Object> tpe = sideInputStream.getType();
        if (!(tpe instanceof CoderTypeInformation)) {
            throw new IllegalStateException("Input Stream TypeInformation is no CoderTypeInformation.");
        }
        WindowedValueCoder<Object> coder = (WindowedValueCoder) ((CoderTypeInformation) tpe).getCoder();
        Coder<KV<Void, Object>> kvCoder = KvCoder.of(VoidCoder.of(), coder.getValueCoder());
        kvCoders.add(coder.withValueCoder(kvCoder));
        // coder for materialized view matching GBK below
        WindowedValueCoder<KV<Void, Iterable<Object>>> viewCoder = coder.withValueCoder(KvCoder.of(VoidCoder.of(), IterableCoder.of(coder.getValueCoder())));
        viewCoders.add(viewCoder);
    }
    // second pass, now that we gathered the input coders
    UnionCoder unionCoder = UnionCoder.of(viewCoders);
    CoderTypeInformation<RawUnionValue> unionTypeInformation = new CoderTypeInformation<>(unionCoder, context.getPipelineOptions());
    // transform each side input to RawUnionValue and union them
    DataStream<RawUnionValue> sideInputUnion = null;
    for (Map.Entry<RunnerApi.ExecutableStagePayload.SideInputId, PCollectionView<?>> sideInput : sideInputs.entrySet()) {
        TupleTag<?> tag = sideInput.getValue().getTagInternal();
        final int intTag = tagToIntMapping.get(tag);
        RunnerApi.PTransform pTransform = components.getTransformsOrThrow(sideInput.getKey().getTransformId());
        String collectionId = pTransform.getInputsOrThrow(sideInput.getKey().getLocalName());
        DataStream<WindowedValue<?>> sideInputStream = context.getDataStreamOrThrow(collectionId);
        // insert GBK to materialize side input view
        String viewName = sideInput.getKey().getTransformId() + "-" + sideInput.getKey().getLocalName();
        WindowedValueCoder<KV<Void, Object>> kvCoder = kvCoders.get(intTag);
        DataStream<WindowedValue<KV<Void, Object>>> keyedSideInputStream = sideInputStream.map(new ToVoidKeyValue(context.getPipelineOptions()));
        SingleOutputStreamOperator<WindowedValue<KV<Void, Iterable<Object>>>> viewStream = addGBK(keyedSideInputStream, sideInput.getValue().getWindowingStrategyInternal(), kvCoder, viewName, context);
        // Assign a unique but consistent id to re-map operator state
        viewStream.uid(pTransform.getUniqueName() + "-" + sideInput.getKey().getLocalName());
        DataStream<RawUnionValue> unionValueStream = viewStream.map(new FlinkStreamingTransformTranslators.ToRawUnion<>(intTag, context.getPipelineOptions())).returns(unionTypeInformation);
        if (sideInputUnion == null) {
            sideInputUnion = unionValueStream;
        } else {
            sideInputUnion = sideInputUnion.union(unionValueStream);
        }
    }
    return new TransformedSideInputs(intToViewMapping, sideInputUnion);
}
Also used : LinkedHashMap(java.util.LinkedHashMap) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) TupleTag(org.apache.beam.sdk.values.TupleTag) RunnerApi(org.apache.beam.model.pipeline.v1.RunnerApi) WindowedValue(org.apache.beam.sdk.util.WindowedValue) CoderTypeInformation(org.apache.beam.runners.flink.translation.types.CoderTypeInformation) SingletonKeyedWorkItemCoder(org.apache.beam.runners.flink.translation.wrappers.streaming.SingletonKeyedWorkItemCoder) WindowedValueCoder(org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder) KvCoder(org.apache.beam.sdk.coders.KvCoder) PipelineTranslatorUtils.instantiateCoder(org.apache.beam.runners.fnexecution.translation.PipelineTranslatorUtils.instantiateCoder) IterableCoder(org.apache.beam.sdk.coders.IterableCoder) VoidCoder(org.apache.beam.sdk.coders.VoidCoder) UnionCoder(org.apache.beam.sdk.transforms.join.UnionCoder) Coder(org.apache.beam.sdk.coders.Coder) ByteArrayCoder(org.apache.beam.sdk.coders.ByteArrayCoder) UnionCoder(org.apache.beam.sdk.transforms.join.UnionCoder) RawUnionValue(org.apache.beam.sdk.transforms.join.RawUnionValue) KV(org.apache.beam.sdk.values.KV) RunnerPCollectionView(org.apache.beam.runners.core.construction.RunnerPCollectionView) PCollectionView(org.apache.beam.sdk.values.PCollectionView) WindowedValueCoder(org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder) ImmutableMap(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap) Map(java.util.Map) LinkedHashMap(java.util.LinkedHashMap) BiMap(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.BiMap) TreeMap(java.util.TreeMap) PipelineTranslatorUtils.createOutputMap(org.apache.beam.runners.fnexecution.translation.PipelineTranslatorUtils.createOutputMap) HashMap(java.util.HashMap)

Aggregations

RunnerApi (org.apache.beam.model.pipeline.v1.RunnerApi)7 PCollectionView (org.apache.beam.sdk.values.PCollectionView)7 Map (java.util.Map)5 HashMap (java.util.HashMap)4 PTransform (org.apache.beam.model.pipeline.v1.RunnerApi.PTransform)4 PTransformNode (org.apache.beam.runners.core.construction.graph.PipelineNode.PTransformNode)4 ByteString (org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString)4 ImmutableMap (org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap)4 IOException (java.io.IOException)3 Components (org.apache.beam.model.pipeline.v1.RunnerApi.Components)3 ExecutableStagePayload (org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload)3 SideInput (org.apache.beam.model.pipeline.v1.RunnerApi.SideInput)3 Structs.getString (org.apache.beam.runners.dataflow.util.Structs.getString)3 Test (org.junit.Test)3 ParDoInstruction (com.google.api.services.dataflow.model.ParDoInstruction)2 ArrayList (java.util.ArrayList)2 Environment (org.apache.beam.model.pipeline.v1.RunnerApi.Environment)2 PCollection (org.apache.beam.model.pipeline.v1.RunnerApi.PCollection)2 ParDoPayload (org.apache.beam.model.pipeline.v1.RunnerApi.ParDoPayload)2 PCollectionNode (org.apache.beam.runners.core.construction.graph.PipelineNode.PCollectionNode)2