Search in sources :

Example 16 with ExecutableStage

use of org.apache.beam.runners.core.construction.graph.ExecutableStage in project beam by apache.

the class RemoteExecutionTest method testSplit.

@Test(timeout = 60000L)
public void testSplit() throws Exception {
    launchSdkHarness(PipelineOptionsFactory.create());
    Pipeline p = Pipeline.create();
    p.apply("impulse", Impulse.create()).apply("create", ParDo.of(new DoFn<byte[], String>() {

        @ProcessElement
        public void process(ProcessContext ctxt) {
            ctxt.output("zero");
            ctxt.output(WaitingTillSplitRestrictionTracker.WAIT_TILL_SPLIT);
            ctxt.output("two");
        }
    })).apply("forceSplit", ParDo.of(new DoFn<String, String>() {

        @GetInitialRestriction
        public String getInitialRestriction(@Element String element) {
            return element;
        }

        @NewTracker
        public WaitingTillSplitRestrictionTracker newTracker(@Restriction String restriction) {
            return new WaitingTillSplitRestrictionTracker(restriction);
        }

        @ProcessElement
        public void process(RestrictionTracker<String, Void> tracker, ProcessContext context) {
            while (tracker.tryClaim(null)) {
            }
            context.output(tracker.currentRestriction());
        }
    })).apply("addKeys", WithKeys.of("foo")).setCoder(KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of())).apply("gbk", GroupByKey.create());
    RunnerApi.Pipeline pipeline = PipelineTranslation.toProto(p);
    // Expand any splittable DoFns within the graph to enable sizing and splitting of bundles.
    RunnerApi.Pipeline pipelineWithSdfExpanded = ProtoOverrides.updateTransform(PTransformTranslation.PAR_DO_TRANSFORM_URN, pipeline, SplittableParDoExpander.createSizedReplacement());
    FusedPipeline fused = GreedyPipelineFuser.fuse(pipelineWithSdfExpanded);
    // Find the fused stage with the SDF ProcessSizedElementAndRestriction transform
    Optional<ExecutableStage> optionalStage = Iterables.tryFind(fused.getFusedStages(), (ExecutableStage stage) -> Iterables.filter(stage.getTransforms(), (PTransformNode node) -> PTransformTranslation.SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN.equals(node.getTransform().getSpec().getUrn())).iterator().hasNext());
    checkState(optionalStage.isPresent(), "Expected a stage with SDF ProcessSizedElementAndRestriction.");
    ExecutableStage stage = optionalStage.get();
    ExecutableProcessBundleDescriptor descriptor = ProcessBundleDescriptors.fromExecutableStage("my_stage", stage, dataServer.getApiServiceDescriptor());
    BundleProcessor processor = controlClient.getProcessor(descriptor.getProcessBundleDescriptor(), descriptor.getRemoteInputDestinations());
    Map<String, ? super Coder<WindowedValue<?>>> remoteOutputCoders = descriptor.getRemoteOutputCoders();
    Map<String, Collection<? super WindowedValue<?>>> outputValues = new HashMap<>();
    Map<String, RemoteOutputReceiver<?>> outputReceivers = new HashMap<>();
    for (Entry<String, ? super Coder<WindowedValue<?>>> remoteOutputCoder : remoteOutputCoders.entrySet()) {
        List<? super WindowedValue<?>> outputContents = Collections.synchronizedList(new ArrayList<>());
        outputValues.put(remoteOutputCoder.getKey(), outputContents);
        outputReceivers.put(remoteOutputCoder.getKey(), RemoteOutputReceiver.of((Coder) remoteOutputCoder.getValue(), (FnDataReceiver<? super WindowedValue<?>>) outputContents::add));
    }
    List<ProcessBundleSplitResponse> splitResponses = new ArrayList<>();
    List<ProcessBundleResponse> checkpointResponses = new ArrayList<>();
    List<String> requestsFinalization = new ArrayList<>();
    ScheduledExecutorService executor = Executors.newSingleThreadScheduledExecutor();
    ScheduledFuture<Object> future;
    // Execute the remote bundle.
    try (RemoteBundle bundle = processor.newBundle(outputReceivers, Collections.emptyMap(), StateRequestHandler.unsupported(), BundleProgressHandler.ignored(), splitResponses::add, checkpointResponses::add, requestsFinalization::add)) {
        Iterables.getOnlyElement(bundle.getInputReceivers().values()).accept(valueInGlobalWindow(sdfSizedElementAndRestrictionForTest(WaitingTillSplitRestrictionTracker.WAIT_TILL_SPLIT)));
        // Keep sending splits until the bundle terminates.
        future = (ScheduledFuture) executor.scheduleWithFixedDelay(() -> bundle.split(0.5), 0L, 100L, TimeUnit.MILLISECONDS);
    }
    future.cancel(false);
    executor.shutdown();
    assertTrue(requestsFinalization.isEmpty());
    assertTrue(checkpointResponses.isEmpty());
    // We only validate the last split response since it is the only one that could possibly
    // contain the SDF split, all others will be a reduction in the ChannelSplit range.
    assertFalse(splitResponses.isEmpty());
    ProcessBundleSplitResponse splitResponse = splitResponses.get(splitResponses.size() - 1);
    ChannelSplit channelSplit = Iterables.getOnlyElement(splitResponse.getChannelSplitsList());
    // There is only one outcome for the final split that can happen since the SDF is blocking the
    // bundle from completing and hence needed to be split.
    assertEquals(-1L, channelSplit.getLastPrimaryElement());
    assertEquals(1L, channelSplit.getFirstResidualElement());
    assertEquals(1, splitResponse.getPrimaryRootsCount());
    assertEquals(1, splitResponse.getResidualRootsCount());
    assertThat(Iterables.getOnlyElement(outputValues.values()), containsInAnyOrder(valueInGlobalWindow(KV.of("foo", WaitingTillSplitRestrictionTracker.PRIMARY))));
}
Also used : ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) HashMap(java.util.HashMap) PTransformNode(org.apache.beam.runners.core.construction.graph.PipelineNode.PTransformNode) ArrayList(java.util.ArrayList) ByteString(org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString) RunnerApi(org.apache.beam.model.pipeline.v1.RunnerApi) BundleProcessor(org.apache.beam.runners.fnexecution.control.SdkHarnessClient.BundleProcessor) WindowedValue(org.apache.beam.sdk.util.WindowedValue) ExecutableStage(org.apache.beam.runners.core.construction.graph.ExecutableStage) ExecutableProcessBundleDescriptor(org.apache.beam.runners.fnexecution.control.ProcessBundleDescriptors.ExecutableProcessBundleDescriptor) ProcessBundleResponse(org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleResponse) KvCoder(org.apache.beam.sdk.coders.KvCoder) Coder(org.apache.beam.sdk.coders.Coder) StringUtf8Coder(org.apache.beam.sdk.coders.StringUtf8Coder) BigEndianLongCoder(org.apache.beam.sdk.coders.BigEndianLongCoder) ScheduledExecutorService(java.util.concurrent.ScheduledExecutorService) FnDataReceiver(org.apache.beam.sdk.fn.data.FnDataReceiver) FusedPipeline(org.apache.beam.runners.core.construction.graph.FusedPipeline) Pipeline(org.apache.beam.sdk.Pipeline) FusedPipeline(org.apache.beam.runners.core.construction.graph.FusedPipeline) ProcessBundleSplitResponse(org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleSplitResponse) DoFn(org.apache.beam.sdk.transforms.DoFn) Collection(java.util.Collection) PCollection(org.apache.beam.sdk.values.PCollection) ChannelSplit(org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleSplitResponse.ChannelSplit) Test(org.junit.Test)

Example 17 with ExecutableStage

use of org.apache.beam.runners.core.construction.graph.ExecutableStage in project beam by apache.

the class RemoteExecutionTest method testExecutionWithUserState.

@Test
public void testExecutionWithUserState() throws Exception {
    launchSdkHarness(PipelineOptionsFactory.create());
    Pipeline p = Pipeline.create();
    final String stateId = "foo";
    final String stateId2 = "foo2";
    p.apply("impulse", Impulse.create()).apply("create", ParDo.of(new DoFn<byte[], KV<String, String>>() {

        @ProcessElement
        public void process(ProcessContext ctxt) {
        }
    })).setCoder(KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of())).apply("userState", ParDo.of(new DoFn<KV<String, String>, KV<String, String>>() {

        @StateId(stateId)
        private final StateSpec<BagState<String>> bufferState = StateSpecs.bag(StringUtf8Coder.of());

        @StateId(stateId2)
        private final StateSpec<BagState<String>> bufferState2 = StateSpecs.bag(StringUtf8Coder.of());

        @ProcessElement
        public void processElement(@Element KV<String, String> element, @StateId(stateId) BagState<String> state, @StateId(stateId2) BagState<String> state2, OutputReceiver<KV<String, String>> r) {
            for (String value : state.read()) {
                r.output(KV.of(element.getKey(), value));
            }
            state.add(element.getValue());
            state2.clear();
        }
    })).apply("gbk", GroupByKey.create());
    RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
    FusedPipeline fused = GreedyPipelineFuser.fuse(pipelineProto);
    Optional<ExecutableStage> optionalStage = Iterables.tryFind(fused.getFusedStages(), (ExecutableStage stage) -> !stage.getUserStates().isEmpty());
    checkState(optionalStage.isPresent(), "Expected a stage with user state.");
    ExecutableStage stage = optionalStage.get();
    ExecutableProcessBundleDescriptor descriptor = ProcessBundleDescriptors.fromExecutableStage("test_stage", stage, dataServer.getApiServiceDescriptor(), stateServer.getApiServiceDescriptor());
    BundleProcessor processor = controlClient.getProcessor(descriptor.getProcessBundleDescriptor(), descriptor.getRemoteInputDestinations(), stateDelegator);
    Map<String, Coder> remoteOutputCoders = descriptor.getRemoteOutputCoders();
    Map<String, Collection<WindowedValue<?>>> outputValues = new HashMap<>();
    Map<String, RemoteOutputReceiver<?>> outputReceivers = new HashMap<>();
    for (Entry<String, Coder> remoteOutputCoder : remoteOutputCoders.entrySet()) {
        List<WindowedValue<?>> outputContents = Collections.synchronizedList(new ArrayList<>());
        outputValues.put(remoteOutputCoder.getKey(), outputContents);
        outputReceivers.put(remoteOutputCoder.getKey(), RemoteOutputReceiver.of((Coder<WindowedValue<?>>) remoteOutputCoder.getValue(), outputContents::add));
    }
    Map<String, List<ByteString>> userStateData = ImmutableMap.of(stateId, new ArrayList(Arrays.asList(ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "A", Coder.Context.NESTED)), ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "B", Coder.Context.NESTED)), ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "C", Coder.Context.NESTED)))), stateId2, new ArrayList(Arrays.asList(ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "D", Coder.Context.NESTED)))));
    StateRequestHandler stateRequestHandler = StateRequestHandlers.forBagUserStateHandlerFactory(descriptor, new BagUserStateHandlerFactory<ByteString, Object, BoundedWindow>() {

        @Override
        public BagUserStateHandler<ByteString, Object, BoundedWindow> forUserState(String pTransformId, String userStateId, Coder<ByteString> keyCoder, Coder<Object> valueCoder, Coder<BoundedWindow> windowCoder) {
            return new BagUserStateHandler<ByteString, Object, BoundedWindow>() {

                @Override
                public Iterable<Object> get(ByteString key, BoundedWindow window) {
                    return (Iterable) userStateData.get(userStateId);
                }

                @Override
                public void append(ByteString key, BoundedWindow window, Iterator<Object> values) {
                    Iterators.addAll(userStateData.get(userStateId), (Iterator) values);
                }

                @Override
                public void clear(ByteString key, BoundedWindow window) {
                    userStateData.get(userStateId).clear();
                }
            };
        }
    });
    try (RemoteBundle bundle = processor.newBundle(outputReceivers, stateRequestHandler, BundleProgressHandler.ignored())) {
        Iterables.getOnlyElement(bundle.getInputReceivers().values()).accept(valueInGlobalWindow(KV.of("X", "Y")));
    }
    for (Collection<WindowedValue<?>> windowedValues : outputValues.values()) {
        assertThat(windowedValues, containsInAnyOrder(valueInGlobalWindow(KV.of("X", "A")), valueInGlobalWindow(KV.of("X", "B")), valueInGlobalWindow(KV.of("X", "C"))));
    }
    assertThat(userStateData.get(stateId), IsIterableContainingInOrder.contains(ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "A", Coder.Context.NESTED)), ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "B", Coder.Context.NESTED)), ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "C", Coder.Context.NESTED)), ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "Y", Coder.Context.NESTED))));
    assertThat(userStateData.get(stateId2), IsEmptyIterable.emptyIterable());
}
Also used : ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) HashMap(java.util.HashMap) ByteString(org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString) ArrayList(java.util.ArrayList) ByteString(org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString) RunnerApi(org.apache.beam.model.pipeline.v1.RunnerApi) WindowedValue(org.apache.beam.sdk.util.WindowedValue) BoundedWindow(org.apache.beam.sdk.transforms.windowing.BoundedWindow) ArrayList(java.util.ArrayList) PCollectionList(org.apache.beam.sdk.values.PCollectionList) List(java.util.List) ExecutableStage(org.apache.beam.runners.core.construction.graph.ExecutableStage) ExecutableProcessBundleDescriptor(org.apache.beam.runners.fnexecution.control.ProcessBundleDescriptors.ExecutableProcessBundleDescriptor) BagUserStateHandler(org.apache.beam.runners.fnexecution.state.StateRequestHandlers.BagUserStateHandler) Collection(java.util.Collection) PCollection(org.apache.beam.sdk.values.PCollection) StateRequestHandler(org.apache.beam.runners.fnexecution.state.StateRequestHandler) IsEmptyIterable(org.hamcrest.collection.IsEmptyIterable) BundleProcessor(org.apache.beam.runners.fnexecution.control.SdkHarnessClient.BundleProcessor) Iterator(java.util.Iterator) BagState(org.apache.beam.sdk.state.BagState) KvCoder(org.apache.beam.sdk.coders.KvCoder) Coder(org.apache.beam.sdk.coders.Coder) StringUtf8Coder(org.apache.beam.sdk.coders.StringUtf8Coder) BigEndianLongCoder(org.apache.beam.sdk.coders.BigEndianLongCoder) FusedPipeline(org.apache.beam.runners.core.construction.graph.FusedPipeline) KV(org.apache.beam.sdk.values.KV) Pipeline(org.apache.beam.sdk.Pipeline) FusedPipeline(org.apache.beam.runners.core.construction.graph.FusedPipeline) Test(org.junit.Test)

Example 18 with ExecutableStage

use of org.apache.beam.runners.core.construction.graph.ExecutableStage in project beam by apache.

the class RemoteExecutionTest method testExecutionWithSideInputCaching.

@Test
public void testExecutionWithSideInputCaching() throws Exception {
    Pipeline p = Pipeline.create();
    addExperiment(p.getOptions().as(ExperimentalOptions.class), "beam_fn_api");
    // TODO(BEAM-10097): Remove experiment once all portable runners support this view type
    addExperiment(p.getOptions().as(ExperimentalOptions.class), "use_runner_v2");
    launchSdkHarness(p.getOptions());
    PCollection<String> input = p.apply("impulse", Impulse.create()).apply("create", ParDo.of(new DoFn<byte[], String>() {

        @ProcessElement
        public void process(ProcessContext ctxt) {
            ctxt.output("zero");
            ctxt.output("one");
            ctxt.output("two");
        }
    })).setCoder(StringUtf8Coder.of());
    PCollectionView<Iterable<String>> iterableView = input.apply("createIterableSideInput", View.asIterable());
    PCollectionView<Map<String, Iterable<String>>> multimapView = input.apply(WithKeys.of("key")).apply("createMultimapSideInput", View.asMultimap());
    input.apply("readSideInput", ParDo.of(new DoFn<String, KV<String, String>>() {

        @ProcessElement
        public void processElement(ProcessContext context) {
            for (String value : context.sideInput(iterableView)) {
                context.output(KV.of(context.element(), value));
            }
            for (Map.Entry<String, Iterable<String>> entry : context.sideInput(multimapView).entrySet()) {
                for (String value : entry.getValue()) {
                    context.output(KV.of(context.element(), entry.getKey() + ":" + value));
                }
            }
        }
    }).withSideInputs(iterableView, multimapView)).setCoder(KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of())).apply("gbk", GroupByKey.create());
    RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
    FusedPipeline fused = GreedyPipelineFuser.fuse(pipelineProto);
    Optional<ExecutableStage> optionalStage = Iterables.tryFind(fused.getFusedStages(), (ExecutableStage stage) -> !stage.getSideInputs().isEmpty());
    checkState(optionalStage.isPresent(), "Expected a stage with side inputs.");
    ExecutableStage stage = optionalStage.get();
    ExecutableProcessBundleDescriptor descriptor = ProcessBundleDescriptors.fromExecutableStage("test_stage", stage, dataServer.getApiServiceDescriptor(), stateServer.getApiServiceDescriptor());
    BundleProcessor processor = controlClient.getProcessor(descriptor.getProcessBundleDescriptor(), descriptor.getRemoteInputDestinations(), stateDelegator);
    Map<String, Coder> remoteOutputCoders = descriptor.getRemoteOutputCoders();
    Map<String, Collection<WindowedValue<?>>> outputValues = new HashMap<>();
    Map<String, RemoteOutputReceiver<?>> outputReceivers = new HashMap<>();
    for (Entry<String, Coder> remoteOutputCoder : remoteOutputCoders.entrySet()) {
        List<WindowedValue<?>> outputContents = Collections.synchronizedList(new ArrayList<>());
        outputValues.put(remoteOutputCoder.getKey(), outputContents);
        outputReceivers.put(remoteOutputCoder.getKey(), RemoteOutputReceiver.of((Coder<WindowedValue<?>>) remoteOutputCoder.getValue(), outputContents::add));
    }
    StoringStateRequestHandler stateRequestHandler = new StoringStateRequestHandler(StateRequestHandlers.forSideInputHandlerFactory(descriptor.getSideInputSpecs(), new SideInputHandlerFactory() {

        @Override
        public <V, W extends BoundedWindow> IterableSideInputHandler<V, W> forIterableSideInput(String pTransformId, String sideInputId, Coder<V> elementCoder, Coder<W> windowCoder) {
            return new IterableSideInputHandler<V, W>() {

                @Override
                public Iterable<V> get(W window) {
                    return (Iterable) Arrays.asList("A", "B", "C");
                }

                @Override
                public Coder<V> elementCoder() {
                    return elementCoder;
                }
            };
        }

        @Override
        public <K, V, W extends BoundedWindow> MultimapSideInputHandler<K, V, W> forMultimapSideInput(String pTransformId, String sideInputId, KvCoder<K, V> elementCoder, Coder<W> windowCoder) {
            return new MultimapSideInputHandler<K, V, W>() {

                @Override
                public Iterable<K> get(W window) {
                    return (Iterable) Arrays.asList("key1", "key2");
                }

                @Override
                public Iterable<V> get(K key, W window) {
                    if ("key1".equals(key)) {
                        return (Iterable) Arrays.asList("H", "I", "J");
                    } else if ("key2".equals(key)) {
                        return (Iterable) Arrays.asList("M", "N", "O");
                    }
                    return Collections.emptyList();
                }

                @Override
                public Coder<K> keyCoder() {
                    return elementCoder.getKeyCoder();
                }

                @Override
                public Coder<V> valueCoder() {
                    return elementCoder.getValueCoder();
                }
            };
        }
    }));
    String transformId = Iterables.get(stage.getSideInputs(), 0).transform().getId();
    stateRequestHandler.addCacheToken(BeamFnApi.ProcessBundleRequest.CacheToken.newBuilder().setSideInput(BeamFnApi.ProcessBundleRequest.CacheToken.SideInput.newBuilder().setSideInputId(iterableView.getTagInternal().getId()).setTransformId(transformId).build()).setToken(ByteString.copyFromUtf8("IterableSideInputToken")).build());
    stateRequestHandler.addCacheToken(BeamFnApi.ProcessBundleRequest.CacheToken.newBuilder().setSideInput(BeamFnApi.ProcessBundleRequest.CacheToken.SideInput.newBuilder().setSideInputId(multimapView.getTagInternal().getId()).setTransformId(transformId).build()).setToken(ByteString.copyFromUtf8("MulitmapSideInputToken")).build());
    BundleProgressHandler progressHandler = BundleProgressHandler.ignored();
    try (RemoteBundle bundle = processor.newBundle(outputReceivers, stateRequestHandler, progressHandler)) {
        Iterables.getOnlyElement(bundle.getInputReceivers().values()).accept(valueInGlobalWindow("X"));
    }
    try (RemoteBundle bundle = processor.newBundle(outputReceivers, stateRequestHandler, progressHandler)) {
        Iterables.getOnlyElement(bundle.getInputReceivers().values()).accept(valueInGlobalWindow("Y"));
    }
    for (Collection<WindowedValue<?>> windowedValues : outputValues.values()) {
        assertThat(windowedValues, containsInAnyOrder(valueInGlobalWindow(KV.of("X", "A")), valueInGlobalWindow(KV.of("X", "B")), valueInGlobalWindow(KV.of("X", "C")), valueInGlobalWindow(KV.of("X", "key1:H")), valueInGlobalWindow(KV.of("X", "key1:I")), valueInGlobalWindow(KV.of("X", "key1:J")), valueInGlobalWindow(KV.of("X", "key2:M")), valueInGlobalWindow(KV.of("X", "key2:N")), valueInGlobalWindow(KV.of("X", "key2:O")), valueInGlobalWindow(KV.of("Y", "A")), valueInGlobalWindow(KV.of("Y", "B")), valueInGlobalWindow(KV.of("Y", "C")), valueInGlobalWindow(KV.of("Y", "key1:H")), valueInGlobalWindow(KV.of("Y", "key1:I")), valueInGlobalWindow(KV.of("Y", "key1:J")), valueInGlobalWindow(KV.of("Y", "key2:M")), valueInGlobalWindow(KV.of("Y", "key2:N")), valueInGlobalWindow(KV.of("Y", "key2:O"))));
    }
    // Expect the following requests for the first bundle:
    // * one to read iterable side input
    // * one to read keys from multimap side input
    // * one to read key1 iterable from multimap side input
    // * one to read key2 iterable from multimap side input
    assertEquals(4, stateRequestHandler.receivedRequests.size());
    assertEquals(stateRequestHandler.receivedRequests.get(0).getStateKey().getIterableSideInput(), BeamFnApi.StateKey.IterableSideInput.newBuilder().setSideInputId(iterableView.getTagInternal().getId()).setTransformId(transformId).build());
    assertEquals(stateRequestHandler.receivedRequests.get(1).getStateKey().getMultimapKeysSideInput(), BeamFnApi.StateKey.MultimapKeysSideInput.newBuilder().setSideInputId(multimapView.getTagInternal().getId()).setTransformId(transformId).build());
    assertEquals(stateRequestHandler.receivedRequests.get(2).getStateKey().getMultimapSideInput(), BeamFnApi.StateKey.MultimapSideInput.newBuilder().setSideInputId(multimapView.getTagInternal().getId()).setTransformId(transformId).setKey(encode("key1")).build());
    assertEquals(stateRequestHandler.receivedRequests.get(3).getStateKey().getMultimapSideInput(), BeamFnApi.StateKey.MultimapSideInput.newBuilder().setSideInputId(multimapView.getTagInternal().getId()).setTransformId(transformId).setKey(encode("key2")).build());
}
Also used : IsEmptyIterable(org.hamcrest.collection.IsEmptyIterable) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) HashMap(java.util.HashMap) ExperimentalOptions(org.apache.beam.sdk.options.ExperimentalOptions) ByteString(org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString) RunnerApi(org.apache.beam.model.pipeline.v1.RunnerApi) Entry(java.util.Map.Entry) BundleProcessor(org.apache.beam.runners.fnexecution.control.SdkHarnessClient.BundleProcessor) WindowedValue(org.apache.beam.sdk.util.WindowedValue) KV(org.apache.beam.sdk.values.KV) SideInputHandlerFactory(org.apache.beam.runners.fnexecution.state.StateRequestHandlers.SideInputHandlerFactory) BoundedWindow(org.apache.beam.sdk.transforms.windowing.BoundedWindow) ExecutableStage(org.apache.beam.runners.core.construction.graph.ExecutableStage) ExecutableProcessBundleDescriptor(org.apache.beam.runners.fnexecution.control.ProcessBundleDescriptors.ExecutableProcessBundleDescriptor) KvCoder(org.apache.beam.sdk.coders.KvCoder) Coder(org.apache.beam.sdk.coders.Coder) StringUtf8Coder(org.apache.beam.sdk.coders.StringUtf8Coder) BigEndianLongCoder(org.apache.beam.sdk.coders.BigEndianLongCoder) FusedPipeline(org.apache.beam.runners.core.construction.graph.FusedPipeline) KvCoder(org.apache.beam.sdk.coders.KvCoder) KV(org.apache.beam.sdk.values.KV) Pipeline(org.apache.beam.sdk.Pipeline) FusedPipeline(org.apache.beam.runners.core.construction.graph.FusedPipeline) IterableSideInputHandler(org.apache.beam.runners.fnexecution.state.StateRequestHandlers.IterableSideInputHandler) Collection(java.util.Collection) PCollection(org.apache.beam.sdk.values.PCollection) MultimapSideInputHandler(org.apache.beam.runners.fnexecution.state.StateRequestHandlers.MultimapSideInputHandler) ImmutableMap(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap) Map(java.util.Map) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) HashMap(java.util.HashMap) ConcurrentMap(java.util.concurrent.ConcurrentMap) Test(org.junit.Test)

Example 19 with ExecutableStage

use of org.apache.beam.runners.core.construction.graph.ExecutableStage in project beam by apache.

the class RemoteExecutionTest method testExecutionWithTimer.

@Test
public void testExecutionWithTimer() throws Exception {
    launchSdkHarness(PipelineOptionsFactory.create());
    Pipeline p = Pipeline.create();
    p.apply("impulse", Impulse.create()).apply("create", ParDo.of(new DoFn<byte[], KV<String, String>>() {

        @ProcessElement
        public void process(ProcessContext ctxt) {
        }
    })).setCoder(KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of())).apply("timer", ParDo.of(new DoFn<KV<String, String>, KV<String, String>>() {

        @TimerId("event")
        private final TimerSpec eventTimerSpec = TimerSpecs.timer(TimeDomain.EVENT_TIME);

        @TimerId("processing")
        private final TimerSpec processingTimerSpec = TimerSpecs.timer(TimeDomain.PROCESSING_TIME);

        @ProcessElement
        public void processElement(ProcessContext context, @TimerId("event") Timer eventTimeTimer, @TimerId("processing") Timer processingTimeTimer) {
            context.output(KV.of("main" + context.element().getKey(), ""));
            eventTimeTimer.withOutputTimestamp(context.timestamp()).set(context.timestamp().plus(Duration.millis(1L)));
            processingTimeTimer.offset(Duration.millis(2L));
            processingTimeTimer.setRelative();
        }

        @OnTimer("event")
        public void eventTimer(OnTimerContext context, @Key String key, @TimerId("event") Timer eventTimeTimer, @TimerId("processing") Timer processingTimeTimer) {
            context.output(KV.of("event", key));
            eventTimeTimer.withOutputTimestamp(context.timestamp()).set(context.fireTimestamp().plus(Duration.millis(11L)));
            processingTimeTimer.offset(Duration.millis(12L));
            processingTimeTimer.setRelative();
        }

        @OnTimer("processing")
        public void processingTimer(OnTimerContext context, @Key String key, @TimerId("event") Timer eventTimeTimer, @TimerId("processing") Timer processingTimeTimer) {
            context.output(KV.of("processing", key));
            eventTimeTimer.withOutputTimestamp(context.timestamp()).set(context.fireTimestamp().plus(Duration.millis(21L)));
            processingTimeTimer.offset(Duration.millis(22L));
            processingTimeTimer.setRelative();
        }

        @OnWindowExpiration
        public void onWindowExpiration(@Key String key, OutputReceiver<KV<String, String>> outputReceiver) {
            outputReceiver.output(KV.of("onWindowExpiration", key));
        }
    })).apply("gbk", GroupByKey.create());
    RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
    FusedPipeline fused = GreedyPipelineFuser.fuse(pipelineProto);
    Optional<ExecutableStage> optionalStage = Iterables.tryFind(fused.getFusedStages(), (ExecutableStage stage) -> !stage.getTimers().isEmpty());
    checkState(optionalStage.isPresent(), "Expected a stage with timers.");
    ExecutableStage stage = optionalStage.get();
    ExecutableProcessBundleDescriptor descriptor = ProcessBundleDescriptors.fromExecutableStage("test_stage", stage, dataServer.getApiServiceDescriptor(), stateServer.getApiServiceDescriptor());
    BundleProcessor processor = controlClient.getProcessor(descriptor.getProcessBundleDescriptor(), descriptor.getRemoteInputDestinations(), stateDelegator, descriptor.getTimerSpecs());
    Map<String, Collection<WindowedValue<?>>> outputValues = new HashMap<>();
    Map<String, RemoteOutputReceiver<?>> outputReceivers = new HashMap<>();
    for (Entry<String, Coder> remoteOutputCoder : descriptor.getRemoteOutputCoders().entrySet()) {
        List<WindowedValue<?>> outputContents = Collections.synchronizedList(new ArrayList<>());
        outputValues.put(remoteOutputCoder.getKey(), outputContents);
        outputReceivers.put(remoteOutputCoder.getKey(), RemoteOutputReceiver.of((Coder<WindowedValue<?>>) remoteOutputCoder.getValue(), outputContents::add));
    }
    Map<KV<String, String>, Collection<org.apache.beam.runners.core.construction.Timer<?>>> timerValues = new HashMap<>();
    Map<KV<String, String>, RemoteOutputReceiver<org.apache.beam.runners.core.construction.Timer<?>>> timerReceivers = new HashMap<>();
    for (Map.Entry<String, Map<String, ProcessBundleDescriptors.TimerSpec>> transformTimerSpecs : descriptor.getTimerSpecs().entrySet()) {
        for (ProcessBundleDescriptors.TimerSpec timerSpec : transformTimerSpecs.getValue().values()) {
            KV<String, String> key = KV.of(timerSpec.transformId(), timerSpec.timerId());
            List<org.apache.beam.runners.core.construction.Timer<?>> outputContents = Collections.synchronizedList(new ArrayList<>());
            timerValues.put(key, outputContents);
            timerReceivers.put(key, RemoteOutputReceiver.of((Coder<org.apache.beam.runners.core.construction.Timer<?>>) timerSpec.coder(), outputContents::add));
        }
    }
    ProcessBundleDescriptors.TimerSpec eventTimerSpec = null;
    ProcessBundleDescriptors.TimerSpec processingTimerSpec = null;
    ProcessBundleDescriptors.TimerSpec onWindowExpirationSpec = null;
    for (Map<String, ProcessBundleDescriptors.TimerSpec> timerSpecs : descriptor.getTimerSpecs().values()) {
        for (ProcessBundleDescriptors.TimerSpec timerSpec : timerSpecs.values()) {
            if ("onWindowExpiration0".equals(timerSpec.timerId())) {
                onWindowExpirationSpec = timerSpec;
            } else if (TimeDomain.EVENT_TIME.equals(timerSpec.getTimerSpec().getTimeDomain())) {
                eventTimerSpec = timerSpec;
            } else if (TimeDomain.PROCESSING_TIME.equals(timerSpec.getTimerSpec().getTimeDomain())) {
                processingTimerSpec = timerSpec;
            } else {
                fail(String.format("Unknown timer specification %s", timerSpec));
            }
        }
    }
    // Set the current system time to a fixed value to get stable values for processing time timer
    // output.
    DateTimeUtils.setCurrentMillisFixed(BoundedWindow.TIMESTAMP_MIN_VALUE.getMillis() + 10000L);
    try {
        try (RemoteBundle bundle = processor.newBundle(outputReceivers, timerReceivers, StateRequestHandler.unsupported(), BundleProgressHandler.ignored(), null, null)) {
            Iterables.getOnlyElement(bundle.getInputReceivers().values()).accept(valueInGlobalWindow(KV.of("X", "X")));
            bundle.getTimerReceivers().get(KV.of(eventTimerSpec.transformId(), eventTimerSpec.timerId())).accept(timerForTest("Y", 1000L, 100L));
            bundle.getTimerReceivers().get(KV.of(processingTimerSpec.transformId(), processingTimerSpec.timerId())).accept(timerForTest("Z", 2000L, 200L));
            bundle.getTimerReceivers().get(KV.of(onWindowExpirationSpec.transformId(), onWindowExpirationSpec.timerId())).accept(timerForTest("key", 5001L, 5000L));
        }
        String mainOutputTransform = Iterables.getOnlyElement(descriptor.getRemoteOutputCoders().keySet());
        assertThat(outputValues.get(mainOutputTransform), containsInAnyOrder(valueInGlobalWindow(KV.of("mainX", "")), WindowedValue.timestampedValueInGlobalWindow(KV.of("event", "Y"), BoundedWindow.TIMESTAMP_MIN_VALUE.plus(Duration.millis(100L))), WindowedValue.timestampedValueInGlobalWindow(KV.of("processing", "Z"), BoundedWindow.TIMESTAMP_MIN_VALUE.plus(Duration.millis(200L))), WindowedValue.timestampedValueInGlobalWindow(KV.of("onWindowExpiration", "key"), BoundedWindow.TIMESTAMP_MIN_VALUE.plus(Duration.millis(5000L)))));
        assertThat(timerValues.get(KV.of(eventTimerSpec.transformId(), eventTimerSpec.timerId())), containsInAnyOrder(timerForTest("X", 1L, 0L), timerForTest("Y", 1011L, 100L), timerForTest("Z", 2021L, 200L)));
        assertThat(timerValues.get(KV.of(processingTimerSpec.transformId(), processingTimerSpec.timerId())), containsInAnyOrder(timerForTest("X", 10002L, 0L), timerForTest("Y", 10012L, 100L), timerForTest("Z", 10022L, 200L)));
    } finally {
        DateTimeUtils.setCurrentMillisSystem();
    }
}
Also used : ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) HashMap(java.util.HashMap) ByteString(org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString) RunnerApi(org.apache.beam.model.pipeline.v1.RunnerApi) BundleProcessor(org.apache.beam.runners.fnexecution.control.SdkHarnessClient.BundleProcessor) WindowedValue(org.apache.beam.sdk.util.WindowedValue) ExecutableStage(org.apache.beam.runners.core.construction.graph.ExecutableStage) TimerSpec(org.apache.beam.sdk.state.TimerSpec) ExecutableProcessBundleDescriptor(org.apache.beam.runners.fnexecution.control.ProcessBundleDescriptors.ExecutableProcessBundleDescriptor) KvCoder(org.apache.beam.sdk.coders.KvCoder) Coder(org.apache.beam.sdk.coders.Coder) StringUtf8Coder(org.apache.beam.sdk.coders.StringUtf8Coder) BigEndianLongCoder(org.apache.beam.sdk.coders.BigEndianLongCoder) FusedPipeline(org.apache.beam.runners.core.construction.graph.FusedPipeline) KV(org.apache.beam.sdk.values.KV) Pipeline(org.apache.beam.sdk.Pipeline) FusedPipeline(org.apache.beam.runners.core.construction.graph.FusedPipeline) Timer(org.apache.beam.sdk.state.Timer) Collection(java.util.Collection) PCollection(org.apache.beam.sdk.values.PCollection) ImmutableMap(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap) Map(java.util.Map) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) HashMap(java.util.HashMap) ConcurrentMap(java.util.concurrent.ConcurrentMap) Test(org.junit.Test)

Example 20 with ExecutableStage

use of org.apache.beam.runners.core.construction.graph.ExecutableStage in project beam by apache.

the class SingleEnvironmentInstanceJobBundleFactoryTest method closeShutsDownEnvironments.

@Test
public void closeShutsDownEnvironments() throws Exception {
    Pipeline p = Pipeline.create();
    ExperimentalOptions.addExperiment(p.getOptions().as(ExperimentalOptions.class), "beam_fn_api");
    p.apply("Create", Create.of(1, 2, 3));
    ExecutableStage stage = GreedyPipelineFuser.fuse(PipelineTranslation.toProto(p)).getFusedStages().stream().findFirst().get();
    RemoteEnvironment remoteEnv = mock(RemoteEnvironment.class);
    when(remoteEnv.getInstructionRequestHandler()).thenReturn(instructionRequestHandler);
    when(environmentFactory.createEnvironment(stage.getEnvironment(), GENERATED_ID)).thenReturn(remoteEnv);
    factory.forStage(stage);
    factory.close();
    verify(remoteEnv).close();
}
Also used : RemoteEnvironment(org.apache.beam.runners.fnexecution.environment.RemoteEnvironment) ExperimentalOptions(org.apache.beam.sdk.options.ExperimentalOptions) ExecutableStage(org.apache.beam.runners.core.construction.graph.ExecutableStage) Pipeline(org.apache.beam.sdk.Pipeline) Test(org.junit.Test)

Aggregations

ExecutableStage (org.apache.beam.runners.core.construction.graph.ExecutableStage)22 Test (org.junit.Test)17 RunnerApi (org.apache.beam.model.pipeline.v1.RunnerApi)16 Pipeline (org.apache.beam.sdk.Pipeline)15 Coder (org.apache.beam.sdk.coders.Coder)14 HashMap (java.util.HashMap)12 FusedPipeline (org.apache.beam.runners.core.construction.graph.FusedPipeline)12 KvCoder (org.apache.beam.sdk.coders.KvCoder)12 StringUtf8Coder (org.apache.beam.sdk.coders.StringUtf8Coder)12 WindowedValue (org.apache.beam.sdk.util.WindowedValue)12 ByteString (org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString)11 Map (java.util.Map)10 ConcurrentHashMap (java.util.concurrent.ConcurrentHashMap)10 ExecutableProcessBundleDescriptor (org.apache.beam.runners.fnexecution.control.ProcessBundleDescriptors.ExecutableProcessBundleDescriptor)10 BundleProcessor (org.apache.beam.runners.fnexecution.control.SdkHarnessClient.BundleProcessor)10 BigEndianLongCoder (org.apache.beam.sdk.coders.BigEndianLongCoder)10 Collection (java.util.Collection)9 KV (org.apache.beam.sdk.values.KV)9 PCollection (org.apache.beam.sdk.values.PCollection)9 ArrayList (java.util.ArrayList)7