Search in sources :

Example 1 with ExecutableStage

use of org.apache.beam.runners.core.construction.graph.ExecutableStage in project beam by apache.

the class RemoteExecutionTest method testExecutionWithUserStateCaching.

@Test
public void testExecutionWithUserStateCaching() throws Exception {
    Pipeline p = Pipeline.create();
    launchSdkHarness(p.getOptions());
    final String stateId = "foo";
    final String stateId2 = "bar";
    p.apply("impulse", Impulse.create()).apply("create", ParDo.of(new DoFn<byte[], KV<String, String>>() {

        @ProcessElement
        public void process(ProcessContext ctxt) {
        }
    })).setCoder(KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of())).apply("userState", ParDo.of(new DoFn<KV<String, String>, KV<String, String>>() {

        @StateId(stateId)
        private final StateSpec<BagState<String>> bufferState = StateSpecs.bag(StringUtf8Coder.of());

        @StateId(stateId2)
        private final StateSpec<BagState<String>> bufferState2 = StateSpecs.bag(StringUtf8Coder.of());

        @ProcessElement
        public void processElement(@Element KV<String, String> element, @StateId(stateId) BagState<String> state, @StateId(stateId2) BagState<String> state2, OutputReceiver<KV<String, String>> r) {
            for (String value : state.read()) {
                r.output(KV.of(element.getKey(), value));
            }
            ReadableState<Boolean> isEmpty = state2.isEmpty();
            if (isEmpty.read()) {
                r.output(KV.of(element.getKey(), "Empty"));
            } else {
                state2.clear();
            }
        }
    })).apply("gbk", GroupByKey.create());
    RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
    FusedPipeline fused = GreedyPipelineFuser.fuse(pipelineProto);
    Optional<ExecutableStage> optionalStage = Iterables.tryFind(fused.getFusedStages(), (ExecutableStage stage) -> !stage.getUserStates().isEmpty());
    checkState(optionalStage.isPresent(), "Expected a stage with user state.");
    ExecutableStage stage = optionalStage.get();
    ExecutableProcessBundleDescriptor descriptor = ProcessBundleDescriptors.fromExecutableStage("test_stage", stage, dataServer.getApiServiceDescriptor(), stateServer.getApiServiceDescriptor());
    BundleProcessor processor = controlClient.getProcessor(descriptor.getProcessBundleDescriptor(), descriptor.getRemoteInputDestinations(), stateDelegator);
    Map<String, Coder> remoteOutputCoders = descriptor.getRemoteOutputCoders();
    Map<String, Collection<WindowedValue<?>>> outputValues = new HashMap<>();
    Map<String, RemoteOutputReceiver<?>> outputReceivers = new HashMap<>();
    for (Entry<String, Coder> remoteOutputCoder : remoteOutputCoders.entrySet()) {
        List<WindowedValue<?>> outputContents = Collections.synchronizedList(new ArrayList<>());
        outputValues.put(remoteOutputCoder.getKey(), outputContents);
        outputReceivers.put(remoteOutputCoder.getKey(), RemoteOutputReceiver.of((Coder<WindowedValue<?>>) remoteOutputCoder.getValue(), outputContents::add));
    }
    Map<String, List<ByteString>> userStateData = ImmutableMap.of(stateId, new ArrayList(Arrays.asList(ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "A", Coder.Context.NESTED)), ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "B", Coder.Context.NESTED)), ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "C", Coder.Context.NESTED)))), stateId2, new ArrayList(Arrays.asList(ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "D", Coder.Context.NESTED)))));
    StoringStateRequestHandler stateRequestHandler = new StoringStateRequestHandler(StateRequestHandlers.forBagUserStateHandlerFactory(descriptor, new BagUserStateHandlerFactory<ByteString, Object, BoundedWindow>() {

        @Override
        public BagUserStateHandler<ByteString, Object, BoundedWindow> forUserState(String pTransformId, String userStateId, Coder<ByteString> keyCoder, Coder<Object> valueCoder, Coder<BoundedWindow> windowCoder) {
            return new BagUserStateHandler<ByteString, Object, BoundedWindow>() {

                @Override
                public Iterable<Object> get(ByteString key, BoundedWindow window) {
                    return (Iterable) userStateData.get(userStateId);
                }

                @Override
                public void append(ByteString key, BoundedWindow window, Iterator<Object> values) {
                    Iterators.addAll(userStateData.get(userStateId), (Iterator) values);
                }

                @Override
                public void clear(ByteString key, BoundedWindow window) {
                    userStateData.get(userStateId).clear();
                }
            };
        }
    }));
    try (RemoteBundle bundle = processor.newBundle(outputReceivers, stateRequestHandler, BundleProgressHandler.ignored())) {
        Iterables.getOnlyElement(bundle.getInputReceivers().values()).accept(valueInGlobalWindow(KV.of("X", "Y")));
    }
    try (RemoteBundle bundle2 = processor.newBundle(outputReceivers, stateRequestHandler, BundleProgressHandler.ignored())) {
        Iterables.getOnlyElement(bundle2.getInputReceivers().values()).accept(valueInGlobalWindow(KV.of("X", "Z")));
    }
    for (Collection<WindowedValue<?>> windowedValues : outputValues.values()) {
        assertThat(windowedValues, containsInAnyOrder(valueInGlobalWindow(KV.of("X", "A")), valueInGlobalWindow(KV.of("X", "B")), valueInGlobalWindow(KV.of("X", "C")), valueInGlobalWindow(KV.of("X", "A")), valueInGlobalWindow(KV.of("X", "B")), valueInGlobalWindow(KV.of("X", "C")), valueInGlobalWindow(KV.of("X", "Empty"))));
    }
    assertThat(userStateData.get(stateId), IsIterableContainingInOrder.contains(ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "A", Coder.Context.NESTED)), ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "B", Coder.Context.NESTED)), ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "C", Coder.Context.NESTED))));
    assertThat(userStateData.get(stateId2), IsEmptyIterable.emptyIterable());
    // 3 Requests expected: state read, state2 read, and state2 clear
    assertEquals(3, stateRequestHandler.getRequestCount());
    ByteString.Output out = ByteString.newOutput();
    StringUtf8Coder.of().encode("X", out);
    assertEquals(stateId, stateRequestHandler.receivedRequests.get(0).getStateKey().getBagUserState().getUserStateId());
    assertEquals(stateRequestHandler.receivedRequests.get(0).getStateKey().getBagUserState().getKey(), out.toByteString());
    assertTrue(stateRequestHandler.receivedRequests.get(0).hasGet());
    assertEquals(stateId2, stateRequestHandler.receivedRequests.get(1).getStateKey().getBagUserState().getUserStateId());
    assertEquals(stateRequestHandler.receivedRequests.get(1).getStateKey().getBagUserState().getKey(), out.toByteString());
    assertTrue(stateRequestHandler.receivedRequests.get(1).hasGet());
    assertEquals(stateId2, stateRequestHandler.receivedRequests.get(2).getStateKey().getBagUserState().getUserStateId());
    assertEquals(stateRequestHandler.receivedRequests.get(2).getStateKey().getBagUserState().getKey(), out.toByteString());
    assertTrue(stateRequestHandler.receivedRequests.get(2).hasClear());
}
Also used : IsEmptyIterable(org.hamcrest.collection.IsEmptyIterable) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) HashMap(java.util.HashMap) ByteString(org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString) ArrayList(java.util.ArrayList) ByteString(org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString) RunnerApi(org.apache.beam.model.pipeline.v1.RunnerApi) BundleProcessor(org.apache.beam.runners.fnexecution.control.SdkHarnessClient.BundleProcessor) WindowedValue(org.apache.beam.sdk.util.WindowedValue) Iterator(java.util.Iterator) BoundedWindow(org.apache.beam.sdk.transforms.windowing.BoundedWindow) ArrayList(java.util.ArrayList) PCollectionList(org.apache.beam.sdk.values.PCollectionList) List(java.util.List) ExecutableStage(org.apache.beam.runners.core.construction.graph.ExecutableStage) BagState(org.apache.beam.sdk.state.BagState) ExecutableProcessBundleDescriptor(org.apache.beam.runners.fnexecution.control.ProcessBundleDescriptors.ExecutableProcessBundleDescriptor) KvCoder(org.apache.beam.sdk.coders.KvCoder) Coder(org.apache.beam.sdk.coders.Coder) StringUtf8Coder(org.apache.beam.sdk.coders.StringUtf8Coder) BigEndianLongCoder(org.apache.beam.sdk.coders.BigEndianLongCoder) BagUserStateHandlerFactory(org.apache.beam.runners.fnexecution.state.StateRequestHandlers.BagUserStateHandlerFactory) BagUserStateHandler(org.apache.beam.runners.fnexecution.state.StateRequestHandlers.BagUserStateHandler) FusedPipeline(org.apache.beam.runners.core.construction.graph.FusedPipeline) ReadableState(org.apache.beam.sdk.state.ReadableState) KV(org.apache.beam.sdk.values.KV) Pipeline(org.apache.beam.sdk.Pipeline) FusedPipeline(org.apache.beam.runners.core.construction.graph.FusedPipeline) Collection(java.util.Collection) PCollection(org.apache.beam.sdk.values.PCollection) Test(org.junit.Test)

Example 2 with ExecutableStage

use of org.apache.beam.runners.core.construction.graph.ExecutableStage in project beam by apache.

the class RemoteExecutionTest method testExecutionWithMultipleStages.

@Test
public void testExecutionWithMultipleStages() throws Exception {
    launchSdkHarness(PipelineOptionsFactory.create());
    Pipeline p = Pipeline.create();
    Function<String, PCollection<String>> pCollectionGenerator = suffix -> p.apply("impulse" + suffix, Impulse.create()).apply("create" + suffix, ParDo.of(new DoFn<byte[], String>() {

        @ProcessElement
        public void process(ProcessContext c) {
            try {
                c.output(CoderUtils.decodeFromByteArray(StringUtf8Coder.of(), c.element()));
            } catch (CoderException e) {
                throw new RuntimeException(e);
            }
        }
    })).setCoder(StringUtf8Coder.of()).apply(ParDo.of(new DoFn<String, String>() {

        @ProcessElement
        public void processElement(ProcessContext c) {
            c.output("stream" + suffix + c.element());
        }
    }));
    PCollection<String> input1 = pCollectionGenerator.apply("1");
    PCollection<String> input2 = pCollectionGenerator.apply("2");
    PCollection<String> outputMerged = PCollectionList.of(input1).and(input2).apply(Flatten.pCollections());
    outputMerged.apply("createKV", ParDo.of(new DoFn<String, KV<String, String>>() {

        @ProcessElement
        public void process(ProcessContext c) {
            c.output(KV.of(c.element(), ""));
        }
    })).setCoder(KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of())).apply("gbk", GroupByKey.create());
    RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
    FusedPipeline fused = GreedyPipelineFuser.fuse(pipelineProto);
    Set<ExecutableStage> stages = fused.getFusedStages();
    assertThat(stages.size(), equalTo(2));
    List<WindowedValue<?>> outputValues = Collections.synchronizedList(new ArrayList<>());
    for (ExecutableStage stage : stages) {
        ExecutableProcessBundleDescriptor descriptor = ProcessBundleDescriptors.fromExecutableStage(stage.toString(), stage, dataServer.getApiServiceDescriptor(), stateServer.getApiServiceDescriptor());
        BundleProcessor processor = controlClient.getProcessor(descriptor.getProcessBundleDescriptor(), descriptor.getRemoteInputDestinations(), stateDelegator);
        Map<String, Coder> remoteOutputCoders = descriptor.getRemoteOutputCoders();
        Map<String, RemoteOutputReceiver<?>> outputReceivers = new HashMap<>();
        for (Entry<String, Coder> remoteOutputCoder : remoteOutputCoders.entrySet()) {
            outputReceivers.putIfAbsent(remoteOutputCoder.getKey(), RemoteOutputReceiver.of((Coder<WindowedValue<?>>) remoteOutputCoder.getValue(), outputValues::add));
        }
        try (RemoteBundle bundle = processor.newBundle(outputReceivers, StateRequestHandler.unsupported(), BundleProgressHandler.ignored())) {
            Iterables.getOnlyElement(bundle.getInputReceivers().values()).accept(valueInGlobalWindow(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "X")));
        }
    }
    assertThat(outputValues, containsInAnyOrder(valueInGlobalWindow(KV.of("stream1X", "")), valueInGlobalWindow(KV.of("stream2X", ""))));
}
Also used : Arrays(java.util.Arrays) CoderUtils(org.apache.beam.sdk.util.CoderUtils) TimerSpecs(org.apache.beam.sdk.state.TimerSpecs) Matchers.not(org.hamcrest.Matchers.not) WindowedValue.valueInGlobalWindow(org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow) ImmutableMap(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap) Metrics(org.apache.beam.sdk.metrics.Metrics) Future(java.util.concurrent.Future) GrpcDataService(org.apache.beam.runners.fnexecution.data.GrpcDataService) Map(java.util.Map) SimpleMonitoringInfoBuilder(org.apache.beam.runners.core.metrics.SimpleMonitoringInfoBuilder) GlobalWindow(org.apache.beam.sdk.transforms.windowing.GlobalWindow) ExecutableProcessBundleDescriptor(org.apache.beam.runners.fnexecution.control.ProcessBundleDescriptors.ExecutableProcessBundleDescriptor) Iterators(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterators) BagUserStateHandlerFactory(org.apache.beam.runners.fnexecution.state.StateRequestHandlers.BagUserStateHandlerFactory) KvCoder(org.apache.beam.sdk.coders.KvCoder) PTransformTranslation(org.apache.beam.runners.core.construction.PTransformTranslation) Matchers.allOf(org.hamcrest.Matchers.allOf) FnDataReceiver(org.apache.beam.sdk.fn.data.FnDataReceiver) Set(java.util.Set) BeamFnApi(org.apache.beam.model.fnexecution.v1.BeamFnApi) ExecutableStage(org.apache.beam.runners.core.construction.graph.ExecutableStage) Executors(java.util.concurrent.Executors) GrpcLoggingService(org.apache.beam.runners.fnexecution.logging.GrpcLoggingService) Serializable(java.io.Serializable) ManagedChannelFactory(org.apache.beam.sdk.fn.channel.ManagedChannelFactory) MultimapSideInputHandler(org.apache.beam.runners.fnexecution.state.StateRequestHandlers.MultimapSideInputHandler) CountDownLatch(java.util.concurrent.CountDownLatch) CoderException(org.apache.beam.sdk.coders.CoderException) CompletionStage(java.util.concurrent.CompletionStage) ProtoOverrides(org.apache.beam.runners.core.construction.graph.ProtoOverrides) Assert.assertFalse(org.junit.Assert.assertFalse) KV(org.apache.beam.sdk.values.KV) ExperimentalOptions(org.apache.beam.sdk.options.ExperimentalOptions) Duration(org.joda.time.Duration) RunWith(org.junit.runner.RunWith) Impulse(org.apache.beam.sdk.transforms.Impulse) View(org.apache.beam.sdk.transforms.View) Optional(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Optional) GrpcStateService(org.apache.beam.runners.fnexecution.state.GrpcStateService) ArrayList(java.util.ArrayList) TimerSpec(org.apache.beam.sdk.state.TimerSpec) ScheduledExecutorService(java.util.concurrent.ScheduledExecutorService) MatcherAssert.assertThat(org.hamcrest.MatcherAssert.assertThat) Pipeline(org.apache.beam.sdk.Pipeline) StateRequestHandler(org.apache.beam.runners.fnexecution.state.StateRequestHandler) RestrictionTracker(org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker) InProcessServerFactory(org.apache.beam.sdk.fn.server.InProcessServerFactory) RunnerApi(org.apache.beam.model.pipeline.v1.RunnerApi) DoFn(org.apache.beam.sdk.transforms.DoFn) Assert.assertTrue(org.junit.Assert.assertTrue) StateRequestHandlers(org.apache.beam.runners.fnexecution.state.StateRequestHandlers) Test(org.junit.Test) SingleOutput(org.apache.beam.sdk.transforms.ParDo.SingleOutput) ExecutionException(java.util.concurrent.ExecutionException) Preconditions.checkState(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState) PCollectionView(org.apache.beam.sdk.values.PCollectionView) BoundedWindow(org.apache.beam.sdk.transforms.windowing.BoundedWindow) Matcher(org.hamcrest.Matcher) TimeDomain(org.apache.beam.sdk.state.TimeDomain) Assert.assertEquals(org.junit.Assert.assertEquals) IsEmptyIterable(org.hamcrest.collection.IsEmptyIterable) StateSpec(org.apache.beam.sdk.state.StateSpec) IsIterableContainingInOrder(org.hamcrest.collection.IsIterableContainingInOrder) ScheduledFuture(java.util.concurrent.ScheduledFuture) WindowedValue(org.apache.beam.sdk.util.WindowedValue) ChannelSplit(org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleSplitResponse.ChannelSplit) Urns(org.apache.beam.runners.core.metrics.MonitoringInfoConstants.Urns) GreedyPipelineFuser(org.apache.beam.runners.core.construction.graph.GreedyPipelineFuser) ExperimentalOptions.addExperiment(org.apache.beam.sdk.options.ExperimentalOptions.addExperiment) PCollectionList(org.apache.beam.sdk.values.PCollectionList) GrpcContextHeaderAccessorProvider(org.apache.beam.sdk.fn.server.GrpcContextHeaderAccessorProvider) ResetDateTimeProvider(org.apache.beam.sdk.testing.ResetDateTimeProvider) Iterables(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables) After(org.junit.After) Assert.fail(org.junit.Assert.fail) ProcessBundleSplitResponse(org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleSplitResponse) ThreadFactory(java.util.concurrent.ThreadFactory) Flatten(org.apache.beam.sdk.transforms.Flatten) FusedPipeline(org.apache.beam.runners.core.construction.graph.FusedPipeline) IterableSideInputHandler(org.apache.beam.runners.fnexecution.state.StateRequestHandlers.IterableSideInputHandler) PaneInfo(org.apache.beam.sdk.transforms.windowing.PaneInfo) Collection(java.util.Collection) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) OutboundObserverFactory(org.apache.beam.sdk.fn.stream.OutboundObserverFactory) UUID(java.util.UUID) ThreadFactoryBuilder(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.ThreadFactoryBuilder) List(java.util.List) ParDo(org.apache.beam.sdk.transforms.ParDo) Matchers.containsInAnyOrder(org.hamcrest.Matchers.containsInAnyOrder) Timer(org.apache.beam.sdk.state.Timer) Matchers.equalTo(org.hamcrest.Matchers.equalTo) Entry(java.util.Map.Entry) ProcessBundleResponse(org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleResponse) FnHarness(org.apache.beam.fn.harness.FnHarness) Slf4jLogWriter(org.apache.beam.runners.fnexecution.logging.Slf4jLogWriter) DistributionData(org.apache.beam.runners.core.metrics.DistributionData) BundleProcessor(org.apache.beam.runners.fnexecution.control.SdkHarnessClient.BundleProcessor) DateTimeUtils(org.joda.time.DateTimeUtils) SideInputHandlerFactory(org.apache.beam.runners.fnexecution.state.StateRequestHandlers.SideInputHandlerFactory) Coder(org.apache.beam.sdk.coders.Coder) HashMap(java.util.HashMap) ExecutionStateSampler(org.apache.beam.runners.core.metrics.ExecutionStateSampler) PipelineTranslation(org.apache.beam.runners.core.construction.PipelineTranslation) PipelineOptionsFactory(org.apache.beam.sdk.options.PipelineOptionsFactory) Function(java.util.function.Function) ConcurrentMap(java.util.concurrent.ConcurrentMap) StringUtf8Coder(org.apache.beam.sdk.coders.StringUtf8Coder) MonitoringInfoConstants(org.apache.beam.runners.core.metrics.MonitoringInfoConstants) ByteString(org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString) PTransformNode(org.apache.beam.runners.core.construction.graph.PipelineNode.PTransformNode) TypeUrns(org.apache.beam.runners.core.metrics.MonitoringInfoConstants.TypeUrns) PipelineOptions(org.apache.beam.sdk.options.PipelineOptions) ExecutorService(java.util.concurrent.ExecutorService) ProcessBundleProgressResponse(org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleProgressResponse) MonitoringInfo(org.apache.beam.model.pipeline.v1.MetricsApi.MonitoringInfo) GrpcFnServer(org.apache.beam.sdk.fn.server.GrpcFnServer) GroupByKey(org.apache.beam.sdk.transforms.GroupByKey) WithKeys(org.apache.beam.sdk.transforms.WithKeys) Iterator(java.util.Iterator) BigEndianLongCoder(org.apache.beam.sdk.coders.BigEndianLongCoder) Matchers(org.hamcrest.Matchers) JUnit4(org.junit.runners.JUnit4) PCollection(org.apache.beam.sdk.values.PCollection) TimeUnit(java.util.concurrent.TimeUnit) SplittableParDoExpander(org.apache.beam.runners.core.construction.graph.SplittableParDoExpander) BagState(org.apache.beam.sdk.state.BagState) StateSpecs(org.apache.beam.sdk.state.StateSpecs) Rule(org.junit.Rule) MonitoringInfoMatchers(org.apache.beam.runners.core.metrics.MonitoringInfoMatchers) Caches(org.apache.beam.fn.harness.Caches) SplitResult(org.apache.beam.sdk.transforms.splittabledofn.SplitResult) Collections(java.util.Collections) BagUserStateHandler(org.apache.beam.runners.fnexecution.state.StateRequestHandlers.BagUserStateHandler) ReadableState(org.apache.beam.sdk.state.ReadableState) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) HashMap(java.util.HashMap) ByteString(org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString) RunnerApi(org.apache.beam.model.pipeline.v1.RunnerApi) WindowedValue(org.apache.beam.sdk.util.WindowedValue) BundleProcessor(org.apache.beam.runners.fnexecution.control.SdkHarnessClient.BundleProcessor) ExecutableStage(org.apache.beam.runners.core.construction.graph.ExecutableStage) ExecutableProcessBundleDescriptor(org.apache.beam.runners.fnexecution.control.ProcessBundleDescriptors.ExecutableProcessBundleDescriptor) KvCoder(org.apache.beam.sdk.coders.KvCoder) Coder(org.apache.beam.sdk.coders.Coder) StringUtf8Coder(org.apache.beam.sdk.coders.StringUtf8Coder) BigEndianLongCoder(org.apache.beam.sdk.coders.BigEndianLongCoder) FusedPipeline(org.apache.beam.runners.core.construction.graph.FusedPipeline) Pipeline(org.apache.beam.sdk.Pipeline) FusedPipeline(org.apache.beam.runners.core.construction.graph.FusedPipeline) PCollection(org.apache.beam.sdk.values.PCollection) DoFn(org.apache.beam.sdk.transforms.DoFn) CoderException(org.apache.beam.sdk.coders.CoderException) Test(org.junit.Test)

Example 3 with ExecutableStage

use of org.apache.beam.runners.core.construction.graph.ExecutableStage in project beam by apache.

the class ProcessBundleDescriptorsTest method testLengthPrefixingOfKeyCoderInStatefulExecutableStage.

/**
 * Tests that a stateful stage will wrap the key coder of a stateful transform in a
 * LengthPrefixCoder.
 */
@Test
public void testLengthPrefixingOfKeyCoderInStatefulExecutableStage() throws Exception {
    // Add another stateful stage with a non-standard key coder
    Pipeline p = Pipeline.create();
    Coder<Void> keycoder = VoidCoder.of();
    assertThat(ModelCoderRegistrar.isKnownCoder(keycoder), is(false));
    p.apply("impulse", Impulse.create()).apply("create", ParDo.of(new DoFn<byte[], KV<Void, String>>() {

        @ProcessElement
        public void process(ProcessContext ctxt) {
        }
    })).setCoder(KvCoder.of(keycoder, StringUtf8Coder.of())).apply("userState", ParDo.of(new DoFn<KV<Void, String>, KV<Void, String>>() {

        @StateId("stateId")
        private final StateSpec<BagState<String>> bufferState = StateSpecs.bag(StringUtf8Coder.of());

        @TimerId("timerId")
        private final TimerSpec timerSpec = TimerSpecs.timer(TimeDomain.EVENT_TIME);

        @ProcessElement
        public void processElement(@Element KV<Void, String> element, @StateId("stateId") BagState<String> state, @TimerId("timerId") Timer timer, OutputReceiver<KV<Void, String>> r) {
        }

        @OnTimer("timerId")
        public void onTimer() {
        }
    })).apply("gbk", GroupByKey.create());
    RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
    FusedPipeline fused = GreedyPipelineFuser.fuse(pipelineProto);
    Optional<ExecutableStage> optionalStage = Iterables.tryFind(fused.getFusedStages(), (ExecutableStage stage) -> stage.getUserStates().stream().anyMatch(spec -> spec.localName().equals("stateId")));
    checkState(optionalStage.isPresent(), "Expected a stage with user state.");
    ExecutableStage stage = optionalStage.get();
    PipelineNode.PCollectionNode inputPCollection = stage.getInputPCollection();
    // Ensure original key coder is not a LengthPrefixCoder
    Map<String, RunnerApi.Coder> stageCoderMap = stage.getComponents().getCodersMap();
    RunnerApi.Coder originalMainInputCoder = stageCoderMap.get(inputPCollection.getPCollection().getCoderId());
    String originalKeyCoderId = ModelCoders.getKvCoderComponents(originalMainInputCoder).keyCoderId();
    RunnerApi.Coder originalKeyCoder = stageCoderMap.get(originalKeyCoderId);
    assertThat(originalKeyCoder.getSpec().getUrn(), is(CoderTranslation.JAVA_SERIALIZED_CODER_URN));
    // Now create ProcessBundleDescriptor and check for the LengthPrefixCoder around the key coder
    BeamFnApi.ProcessBundleDescriptor pbd = ProcessBundleDescriptors.fromExecutableStage("test_stage", stage, Endpoints.ApiServiceDescriptor.getDefaultInstance()).getProcessBundleDescriptor();
    Map<String, RunnerApi.Coder> pbsCoderMap = pbd.getCodersMap();
    RunnerApi.Coder pbsMainInputCoder = pbsCoderMap.get(pbd.getPcollectionsOrThrow(inputPCollection.getId()).getCoderId());
    String keyCoderId = ModelCoders.getKvCoderComponents(pbsMainInputCoder).keyCoderId();
    RunnerApi.Coder keyCoder = pbsCoderMap.get(keyCoderId);
    ensureLengthPrefixed(keyCoder, originalKeyCoder, pbsCoderMap);
    TimerReference timerRef = Iterables.getOnlyElement(stage.getTimers());
    String timerTransformId = timerRef.transform().getId();
    RunnerApi.ParDoPayload parDoPayload = RunnerApi.ParDoPayload.parseFrom(pbd.getTransformsOrThrow(timerTransformId).getSpec().getPayload());
    RunnerApi.TimerFamilySpec timerSpec = parDoPayload.getTimerFamilySpecsOrThrow(timerRef.localName());
    RunnerApi.Coder timerCoder = pbsCoderMap.get(timerSpec.getTimerFamilyCoderId());
    String timerKeyCoderId = timerCoder.getComponentCoderIds(0);
    RunnerApi.Coder timerKeyCoder = pbsCoderMap.get(timerKeyCoderId);
    ensureLengthPrefixed(timerKeyCoder, originalKeyCoder, pbsCoderMap);
}
Also used : CoreMatchers.is(org.hamcrest.CoreMatchers.is) Endpoints(org.apache.beam.model.pipeline.v1.Endpoints) StateSpec(org.apache.beam.sdk.state.StateSpec) KV(org.apache.beam.sdk.values.KV) CoderTranslation(org.apache.beam.runners.core.construction.CoderTranslation) TimerSpecs(org.apache.beam.sdk.state.TimerSpecs) Coder(org.apache.beam.sdk.coders.Coder) Impulse(org.apache.beam.sdk.transforms.Impulse) GreedyPipelineFuser(org.apache.beam.runners.core.construction.graph.GreedyPipelineFuser) PipelineTranslation(org.apache.beam.runners.core.construction.PipelineTranslation) Optional(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Optional) StringUtf8Coder(org.apache.beam.sdk.coders.StringUtf8Coder) TimerSpec(org.apache.beam.sdk.state.TimerSpec) Map(java.util.Map) Iterables(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables) TimerReference(org.apache.beam.runners.core.construction.graph.TimerReference) MatcherAssert.assertThat(org.hamcrest.MatcherAssert.assertThat) ModelCoderRegistrar(org.apache.beam.runners.core.construction.ModelCoderRegistrar) Pipeline(org.apache.beam.sdk.Pipeline) ProcessContext(org.apache.beam.sdk.transforms.DoFn.ProcessContext) RestrictionTracker(org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker) RunnerApi(org.apache.beam.model.pipeline.v1.RunnerApi) FusedPipeline(org.apache.beam.runners.core.construction.graph.FusedPipeline) DoFn(org.apache.beam.sdk.transforms.DoFn) KvCoder(org.apache.beam.sdk.coders.KvCoder) GroupByKey(org.apache.beam.sdk.transforms.GroupByKey) PTransformTranslation(org.apache.beam.runners.core.construction.PTransformTranslation) ProcessElement(org.apache.beam.sdk.transforms.DoFn.ProcessElement) Test(org.junit.Test) BeamFnApi(org.apache.beam.model.fnexecution.v1.BeamFnApi) ExecutableStage(org.apache.beam.runners.core.construction.graph.ExecutableStage) ModelCoders(org.apache.beam.runners.core.construction.ModelCoders) Serializable(java.io.Serializable) SplittableParDoExpander(org.apache.beam.runners.core.construction.graph.SplittableParDoExpander) BagState(org.apache.beam.sdk.state.BagState) StateSpecs(org.apache.beam.sdk.state.StateSpecs) ParDo(org.apache.beam.sdk.transforms.ParDo) ProtoOverrides(org.apache.beam.runners.core.construction.graph.ProtoOverrides) Timer(org.apache.beam.sdk.state.Timer) Preconditions.checkState(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState) PipelineNode(org.apache.beam.runners.core.construction.graph.PipelineNode) VoidCoder(org.apache.beam.sdk.coders.VoidCoder) TimeDomain(org.apache.beam.sdk.state.TimeDomain) TimerReference(org.apache.beam.runners.core.construction.graph.TimerReference) ProcessContext(org.apache.beam.sdk.transforms.DoFn.ProcessContext) RunnerApi(org.apache.beam.model.pipeline.v1.RunnerApi) ExecutableStage(org.apache.beam.runners.core.construction.graph.ExecutableStage) BagState(org.apache.beam.sdk.state.BagState) TimerSpec(org.apache.beam.sdk.state.TimerSpec) Coder(org.apache.beam.sdk.coders.Coder) StringUtf8Coder(org.apache.beam.sdk.coders.StringUtf8Coder) KvCoder(org.apache.beam.sdk.coders.KvCoder) VoidCoder(org.apache.beam.sdk.coders.VoidCoder) BeamFnApi(org.apache.beam.model.fnexecution.v1.BeamFnApi) FusedPipeline(org.apache.beam.runners.core.construction.graph.FusedPipeline) KV(org.apache.beam.sdk.values.KV) PipelineNode(org.apache.beam.runners.core.construction.graph.PipelineNode) Pipeline(org.apache.beam.sdk.Pipeline) FusedPipeline(org.apache.beam.runners.core.construction.graph.FusedPipeline) Timer(org.apache.beam.sdk.state.Timer) ProcessElement(org.apache.beam.sdk.transforms.DoFn.ProcessElement) Test(org.junit.Test)

Example 4 with ExecutableStage

use of org.apache.beam.runners.core.construction.graph.ExecutableStage in project beam by apache.

the class RemoteExecutionTest method testExecution.

@Test
public void testExecution() throws Exception {
    launchSdkHarness(PipelineOptionsFactory.create());
    Pipeline p = Pipeline.create();
    p.apply("impulse", Impulse.create()).apply("create", ParDo.of(new DoFn<byte[], String>() {

        @ProcessElement
        public void process(ProcessContext ctxt) {
            ctxt.output("zero");
            ctxt.output("one");
            ctxt.output("two");
        }
    })).apply("len", ParDo.of(new DoFn<String, Long>() {

        @ProcessElement
        public void process(ProcessContext ctxt) {
            ctxt.output((long) ctxt.element().length());
        }
    })).apply("addKeys", WithKeys.of("foo")).setCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianLongCoder.of())).apply("gbk", GroupByKey.create());
    RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
    FusedPipeline fused = GreedyPipelineFuser.fuse(pipelineProto);
    checkState(fused.getFusedStages().size() == 1, "Expected exactly one fused stage");
    ExecutableStage stage = fused.getFusedStages().iterator().next();
    ExecutableProcessBundleDescriptor descriptor = ProcessBundleDescriptors.fromExecutableStage("my_stage", stage, dataServer.getApiServiceDescriptor());
    BundleProcessor processor = controlClient.getProcessor(descriptor.getProcessBundleDescriptor(), descriptor.getRemoteInputDestinations());
    Map<String, ? super Coder<WindowedValue<?>>> remoteOutputCoders = descriptor.getRemoteOutputCoders();
    Map<String, Collection<? super WindowedValue<?>>> outputValues = new HashMap<>();
    Map<String, RemoteOutputReceiver<?>> outputReceivers = new HashMap<>();
    for (Entry<String, ? super Coder<WindowedValue<?>>> remoteOutputCoder : remoteOutputCoders.entrySet()) {
        List<? super WindowedValue<?>> outputContents = Collections.synchronizedList(new ArrayList<>());
        outputValues.put(remoteOutputCoder.getKey(), outputContents);
        outputReceivers.put(remoteOutputCoder.getKey(), RemoteOutputReceiver.of((Coder) remoteOutputCoder.getValue(), (FnDataReceiver<? super WindowedValue<?>>) outputContents::add));
    }
    try (RemoteBundle bundle = processor.newBundle(outputReceivers, BundleProgressHandler.ignored())) {
        Iterables.getOnlyElement(bundle.getInputReceivers().values()).accept(valueInGlobalWindow(new byte[0]));
    }
    for (Collection<? super WindowedValue<?>> windowedValues : outputValues.values()) {
        assertThat(windowedValues, containsInAnyOrder(valueInGlobalWindow(byteValueOf("foo", 4)), valueInGlobalWindow(byteValueOf("foo", 3)), valueInGlobalWindow(byteValueOf("foo", 3))));
    }
}
Also used : ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) HashMap(java.util.HashMap) ByteString(org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString) RunnerApi(org.apache.beam.model.pipeline.v1.RunnerApi) BundleProcessor(org.apache.beam.runners.fnexecution.control.SdkHarnessClient.BundleProcessor) WindowedValue(org.apache.beam.sdk.util.WindowedValue) ExecutableStage(org.apache.beam.runners.core.construction.graph.ExecutableStage) ExecutableProcessBundleDescriptor(org.apache.beam.runners.fnexecution.control.ProcessBundleDescriptors.ExecutableProcessBundleDescriptor) KvCoder(org.apache.beam.sdk.coders.KvCoder) Coder(org.apache.beam.sdk.coders.Coder) StringUtf8Coder(org.apache.beam.sdk.coders.StringUtf8Coder) BigEndianLongCoder(org.apache.beam.sdk.coders.BigEndianLongCoder) FnDataReceiver(org.apache.beam.sdk.fn.data.FnDataReceiver) FusedPipeline(org.apache.beam.runners.core.construction.graph.FusedPipeline) Pipeline(org.apache.beam.sdk.Pipeline) FusedPipeline(org.apache.beam.runners.core.construction.graph.FusedPipeline) DoFn(org.apache.beam.sdk.transforms.DoFn) Collection(java.util.Collection) PCollection(org.apache.beam.sdk.values.PCollection) Test(org.junit.Test)

Example 5 with ExecutableStage

use of org.apache.beam.runners.core.construction.graph.ExecutableStage in project beam by apache.

the class SparkExecutableStageFunction method getStateRequestHandler.

private StateRequestHandler getStateRequestHandler(ExecutableStage executableStage, ProcessBundleDescriptors.ExecutableProcessBundleDescriptor processBundleDescriptor) {
    EnumMap<TypeCase, StateRequestHandler> handlerMap = new EnumMap<>(StateKey.TypeCase.class);
    final StateRequestHandler sideInputHandler;
    StateRequestHandlers.SideInputHandlerFactory sideInputHandlerFactory = BatchSideInputHandlerFactory.forStage(executableStage, new BatchSideInputHandlerFactory.SideInputGetter() {

        @Override
        public <T> List<T> getSideInput(String pCollectionId) {
            Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<SideInputT>> tuple2 = sideInputs.get(pCollectionId);
            Broadcast<List<byte[]>> broadcast = tuple2._1;
            WindowedValueCoder<SideInputT> coder = tuple2._2;
            return (List<T>) broadcast.value().stream().map(bytes -> CoderHelpers.fromByteArray(bytes, coder)).collect(Collectors.toList());
        }
    });
    try {
        sideInputHandler = StateRequestHandlers.forSideInputHandlerFactory(ProcessBundleDescriptors.getSideInputs(executableStage), sideInputHandlerFactory);
    } catch (IOException e) {
        throw new RuntimeException("Failed to setup state handler", e);
    }
    if (bagUserStateHandlerFactory == null) {
        bagUserStateHandlerFactory = new InMemoryBagUserStateFactory();
    }
    final StateRequestHandler userStateHandler;
    if (executableStage.getUserStates().size() > 0) {
        // Need to discard the old key's state
        bagUserStateHandlerFactory.resetForNewKey();
        userStateHandler = StateRequestHandlers.forBagUserStateHandlerFactory(processBundleDescriptor, bagUserStateHandlerFactory);
    } else {
        userStateHandler = StateRequestHandler.unsupported();
    }
    handlerMap.put(StateKey.TypeCase.ITERABLE_SIDE_INPUT, sideInputHandler);
    handlerMap.put(StateKey.TypeCase.MULTIMAP_SIDE_INPUT, sideInputHandler);
    handlerMap.put(StateKey.TypeCase.MULTIMAP_KEYS_SIDE_INPUT, sideInputHandler);
    handlerMap.put(StateKey.TypeCase.BAG_USER_STATE, userStateHandler);
    return StateRequestHandlers.delegateBasedUponType(handlerMap);
}
Also used : WindowedValueCoder(org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder) SerializablePipelineOptions(org.apache.beam.runners.core.construction.SerializablePipelineOptions) WindowedValue(org.apache.beam.sdk.util.WindowedValue) TimerInternals(org.apache.beam.runners.core.TimerInternals) BatchSideInputHandlerFactory(org.apache.beam.runners.fnexecution.translation.BatchSideInputHandlerFactory) Locale(java.util.Locale) JobBundleFactory(org.apache.beam.runners.fnexecution.control.JobBundleFactory) Map(java.util.Map) Iterables(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables) JobInfo(org.apache.beam.runners.fnexecution.provisioning.JobInfo) TimerReceiverFactory(org.apache.beam.runners.fnexecution.control.TimerReceiverFactory) FlatMapFunction(org.apache.spark.api.java.function.FlatMapFunction) Broadcast(org.apache.spark.broadcast.Broadcast) StageBundleFactory(org.apache.beam.runners.fnexecution.control.StageBundleFactory) EnumMap(java.util.EnumMap) FnDataReceiver(org.apache.beam.sdk.fn.data.FnDataReceiver) BundleProgressHandler(org.apache.beam.runners.fnexecution.control.BundleProgressHandler) ExecutableStage(org.apache.beam.runners.core.construction.graph.ExecutableStage) Tuple2(scala.Tuple2) Collectors(java.util.stream.Collectors) Serializable(java.io.Serializable) List(java.util.List) ByteArray(org.apache.beam.runners.spark.util.ByteArray) SparkPipelineOptions(org.apache.beam.runners.spark.SparkPipelineOptions) StateKey(org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey) ConcurrentLinkedQueue(java.util.concurrent.ConcurrentLinkedQueue) ProcessBundleResponse(org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleResponse) Coder(org.apache.beam.sdk.coders.Coder) CoderHelpers(org.apache.beam.runners.spark.coders.CoderHelpers) RawUnionValue(org.apache.beam.sdk.transforms.join.RawUnionValue) RemoteBundle(org.apache.beam.runners.fnexecution.control.RemoteBundle) InMemoryBagUserStateFactory(org.apache.beam.runners.fnexecution.state.InMemoryBagUserStateFactory) StateRequestHandler(org.apache.beam.runners.fnexecution.state.StateRequestHandler) ProcessBundleProgressResponse(org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleProgressResponse) RunnerApi(org.apache.beam.model.pipeline.v1.RunnerApi) Iterator(java.util.Iterator) OutputReceiverFactory(org.apache.beam.runners.fnexecution.control.OutputReceiverFactory) ProcessBundleDescriptors(org.apache.beam.runners.fnexecution.control.ProcessBundleDescriptors) MetricsContainerImpl(org.apache.beam.runners.core.metrics.MetricsContainerImpl) PipelineTranslatorUtils(org.apache.beam.runners.fnexecution.translation.PipelineTranslatorUtils) StateRequestHandlers(org.apache.beam.runners.fnexecution.state.StateRequestHandlers) IOException(java.io.IOException) MetricsContainerStepMapAccumulator(org.apache.beam.runners.spark.metrics.MetricsContainerStepMapAccumulator) InMemoryTimerInternals(org.apache.beam.runners.core.InMemoryTimerInternals) Timer(org.apache.beam.runners.core.construction.Timer) BoundedWindow(org.apache.beam.sdk.transforms.windowing.BoundedWindow) Instant(org.joda.time.Instant) FileSystems(org.apache.beam.sdk.io.FileSystems) Collections(java.util.Collections) TypeCase(org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey.TypeCase) ExecutableStageContext(org.apache.beam.runners.fnexecution.control.ExecutableStageContext) StateRequestHandler(org.apache.beam.runners.fnexecution.state.StateRequestHandler) StateKey(org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey) TypeCase(org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey.TypeCase) IOException(java.io.IOException) InMemoryBagUserStateFactory(org.apache.beam.runners.fnexecution.state.InMemoryBagUserStateFactory) WindowedValueCoder(org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder) Broadcast(org.apache.spark.broadcast.Broadcast) BatchSideInputHandlerFactory(org.apache.beam.runners.fnexecution.translation.BatchSideInputHandlerFactory) Tuple2(scala.Tuple2) List(java.util.List) StateRequestHandlers(org.apache.beam.runners.fnexecution.state.StateRequestHandlers) EnumMap(java.util.EnumMap)

Aggregations

ExecutableStage (org.apache.beam.runners.core.construction.graph.ExecutableStage)22 Test (org.junit.Test)17 RunnerApi (org.apache.beam.model.pipeline.v1.RunnerApi)16 Pipeline (org.apache.beam.sdk.Pipeline)15 Coder (org.apache.beam.sdk.coders.Coder)14 HashMap (java.util.HashMap)12 FusedPipeline (org.apache.beam.runners.core.construction.graph.FusedPipeline)12 KvCoder (org.apache.beam.sdk.coders.KvCoder)12 StringUtf8Coder (org.apache.beam.sdk.coders.StringUtf8Coder)12 WindowedValue (org.apache.beam.sdk.util.WindowedValue)12 ByteString (org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString)11 Map (java.util.Map)10 ConcurrentHashMap (java.util.concurrent.ConcurrentHashMap)10 ExecutableProcessBundleDescriptor (org.apache.beam.runners.fnexecution.control.ProcessBundleDescriptors.ExecutableProcessBundleDescriptor)10 BundleProcessor (org.apache.beam.runners.fnexecution.control.SdkHarnessClient.BundleProcessor)10 BigEndianLongCoder (org.apache.beam.sdk.coders.BigEndianLongCoder)10 Collection (java.util.Collection)9 KV (org.apache.beam.sdk.values.KV)9 PCollection (org.apache.beam.sdk.values.PCollection)9 ArrayList (java.util.ArrayList)7