Search in sources :

Example 1 with Impulse

use of org.apache.beam.sdk.transforms.Impulse in project beam by apache.

the class RemoteExecutionTest method testExecutionWithMultipleStages.

@Test
public void testExecutionWithMultipleStages() throws Exception {
    launchSdkHarness(PipelineOptionsFactory.create());
    Pipeline p = Pipeline.create();
    Function<String, PCollection<String>> pCollectionGenerator = suffix -> p.apply("impulse" + suffix, Impulse.create()).apply("create" + suffix, ParDo.of(new DoFn<byte[], String>() {

        @ProcessElement
        public void process(ProcessContext c) {
            try {
                c.output(CoderUtils.decodeFromByteArray(StringUtf8Coder.of(), c.element()));
            } catch (CoderException e) {
                throw new RuntimeException(e);
            }
        }
    })).setCoder(StringUtf8Coder.of()).apply(ParDo.of(new DoFn<String, String>() {

        @ProcessElement
        public void processElement(ProcessContext c) {
            c.output("stream" + suffix + c.element());
        }
    }));
    PCollection<String> input1 = pCollectionGenerator.apply("1");
    PCollection<String> input2 = pCollectionGenerator.apply("2");
    PCollection<String> outputMerged = PCollectionList.of(input1).and(input2).apply(Flatten.pCollections());
    outputMerged.apply("createKV", ParDo.of(new DoFn<String, KV<String, String>>() {

        @ProcessElement
        public void process(ProcessContext c) {
            c.output(KV.of(c.element(), ""));
        }
    })).setCoder(KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of())).apply("gbk", GroupByKey.create());
    RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
    FusedPipeline fused = GreedyPipelineFuser.fuse(pipelineProto);
    Set<ExecutableStage> stages = fused.getFusedStages();
    assertThat(stages.size(), equalTo(2));
    List<WindowedValue<?>> outputValues = Collections.synchronizedList(new ArrayList<>());
    for (ExecutableStage stage : stages) {
        ExecutableProcessBundleDescriptor descriptor = ProcessBundleDescriptors.fromExecutableStage(stage.toString(), stage, dataServer.getApiServiceDescriptor(), stateServer.getApiServiceDescriptor());
        BundleProcessor processor = controlClient.getProcessor(descriptor.getProcessBundleDescriptor(), descriptor.getRemoteInputDestinations(), stateDelegator);
        Map<String, Coder> remoteOutputCoders = descriptor.getRemoteOutputCoders();
        Map<String, RemoteOutputReceiver<?>> outputReceivers = new HashMap<>();
        for (Entry<String, Coder> remoteOutputCoder : remoteOutputCoders.entrySet()) {
            outputReceivers.putIfAbsent(remoteOutputCoder.getKey(), RemoteOutputReceiver.of((Coder<WindowedValue<?>>) remoteOutputCoder.getValue(), outputValues::add));
        }
        try (RemoteBundle bundle = processor.newBundle(outputReceivers, StateRequestHandler.unsupported(), BundleProgressHandler.ignored())) {
            Iterables.getOnlyElement(bundle.getInputReceivers().values()).accept(valueInGlobalWindow(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "X")));
        }
    }
    assertThat(outputValues, containsInAnyOrder(valueInGlobalWindow(KV.of("stream1X", "")), valueInGlobalWindow(KV.of("stream2X", ""))));
}
Also used : Arrays(java.util.Arrays) CoderUtils(org.apache.beam.sdk.util.CoderUtils) TimerSpecs(org.apache.beam.sdk.state.TimerSpecs) Matchers.not(org.hamcrest.Matchers.not) WindowedValue.valueInGlobalWindow(org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow) ImmutableMap(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap) Metrics(org.apache.beam.sdk.metrics.Metrics) Future(java.util.concurrent.Future) GrpcDataService(org.apache.beam.runners.fnexecution.data.GrpcDataService) Map(java.util.Map) SimpleMonitoringInfoBuilder(org.apache.beam.runners.core.metrics.SimpleMonitoringInfoBuilder) GlobalWindow(org.apache.beam.sdk.transforms.windowing.GlobalWindow) ExecutableProcessBundleDescriptor(org.apache.beam.runners.fnexecution.control.ProcessBundleDescriptors.ExecutableProcessBundleDescriptor) Iterators(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterators) BagUserStateHandlerFactory(org.apache.beam.runners.fnexecution.state.StateRequestHandlers.BagUserStateHandlerFactory) KvCoder(org.apache.beam.sdk.coders.KvCoder) PTransformTranslation(org.apache.beam.runners.core.construction.PTransformTranslation) Matchers.allOf(org.hamcrest.Matchers.allOf) FnDataReceiver(org.apache.beam.sdk.fn.data.FnDataReceiver) Set(java.util.Set) BeamFnApi(org.apache.beam.model.fnexecution.v1.BeamFnApi) ExecutableStage(org.apache.beam.runners.core.construction.graph.ExecutableStage) Executors(java.util.concurrent.Executors) GrpcLoggingService(org.apache.beam.runners.fnexecution.logging.GrpcLoggingService) Serializable(java.io.Serializable) ManagedChannelFactory(org.apache.beam.sdk.fn.channel.ManagedChannelFactory) MultimapSideInputHandler(org.apache.beam.runners.fnexecution.state.StateRequestHandlers.MultimapSideInputHandler) CountDownLatch(java.util.concurrent.CountDownLatch) CoderException(org.apache.beam.sdk.coders.CoderException) CompletionStage(java.util.concurrent.CompletionStage) ProtoOverrides(org.apache.beam.runners.core.construction.graph.ProtoOverrides) Assert.assertFalse(org.junit.Assert.assertFalse) KV(org.apache.beam.sdk.values.KV) ExperimentalOptions(org.apache.beam.sdk.options.ExperimentalOptions) Duration(org.joda.time.Duration) RunWith(org.junit.runner.RunWith) Impulse(org.apache.beam.sdk.transforms.Impulse) View(org.apache.beam.sdk.transforms.View) Optional(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Optional) GrpcStateService(org.apache.beam.runners.fnexecution.state.GrpcStateService) ArrayList(java.util.ArrayList) TimerSpec(org.apache.beam.sdk.state.TimerSpec) ScheduledExecutorService(java.util.concurrent.ScheduledExecutorService) MatcherAssert.assertThat(org.hamcrest.MatcherAssert.assertThat) Pipeline(org.apache.beam.sdk.Pipeline) StateRequestHandler(org.apache.beam.runners.fnexecution.state.StateRequestHandler) RestrictionTracker(org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker) InProcessServerFactory(org.apache.beam.sdk.fn.server.InProcessServerFactory) RunnerApi(org.apache.beam.model.pipeline.v1.RunnerApi) DoFn(org.apache.beam.sdk.transforms.DoFn) Assert.assertTrue(org.junit.Assert.assertTrue) StateRequestHandlers(org.apache.beam.runners.fnexecution.state.StateRequestHandlers) Test(org.junit.Test) SingleOutput(org.apache.beam.sdk.transforms.ParDo.SingleOutput) ExecutionException(java.util.concurrent.ExecutionException) Preconditions.checkState(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState) PCollectionView(org.apache.beam.sdk.values.PCollectionView) BoundedWindow(org.apache.beam.sdk.transforms.windowing.BoundedWindow) Matcher(org.hamcrest.Matcher) TimeDomain(org.apache.beam.sdk.state.TimeDomain) Assert.assertEquals(org.junit.Assert.assertEquals) IsEmptyIterable(org.hamcrest.collection.IsEmptyIterable) StateSpec(org.apache.beam.sdk.state.StateSpec) IsIterableContainingInOrder(org.hamcrest.collection.IsIterableContainingInOrder) ScheduledFuture(java.util.concurrent.ScheduledFuture) WindowedValue(org.apache.beam.sdk.util.WindowedValue) ChannelSplit(org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleSplitResponse.ChannelSplit) Urns(org.apache.beam.runners.core.metrics.MonitoringInfoConstants.Urns) GreedyPipelineFuser(org.apache.beam.runners.core.construction.graph.GreedyPipelineFuser) ExperimentalOptions.addExperiment(org.apache.beam.sdk.options.ExperimentalOptions.addExperiment) PCollectionList(org.apache.beam.sdk.values.PCollectionList) GrpcContextHeaderAccessorProvider(org.apache.beam.sdk.fn.server.GrpcContextHeaderAccessorProvider) ResetDateTimeProvider(org.apache.beam.sdk.testing.ResetDateTimeProvider) Iterables(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables) After(org.junit.After) Assert.fail(org.junit.Assert.fail) ProcessBundleSplitResponse(org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleSplitResponse) ThreadFactory(java.util.concurrent.ThreadFactory) Flatten(org.apache.beam.sdk.transforms.Flatten) FusedPipeline(org.apache.beam.runners.core.construction.graph.FusedPipeline) IterableSideInputHandler(org.apache.beam.runners.fnexecution.state.StateRequestHandlers.IterableSideInputHandler) PaneInfo(org.apache.beam.sdk.transforms.windowing.PaneInfo) Collection(java.util.Collection) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) OutboundObserverFactory(org.apache.beam.sdk.fn.stream.OutboundObserverFactory) UUID(java.util.UUID) ThreadFactoryBuilder(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.ThreadFactoryBuilder) List(java.util.List) ParDo(org.apache.beam.sdk.transforms.ParDo) Matchers.containsInAnyOrder(org.hamcrest.Matchers.containsInAnyOrder) Timer(org.apache.beam.sdk.state.Timer) Matchers.equalTo(org.hamcrest.Matchers.equalTo) Entry(java.util.Map.Entry) ProcessBundleResponse(org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleResponse) FnHarness(org.apache.beam.fn.harness.FnHarness) Slf4jLogWriter(org.apache.beam.runners.fnexecution.logging.Slf4jLogWriter) DistributionData(org.apache.beam.runners.core.metrics.DistributionData) BundleProcessor(org.apache.beam.runners.fnexecution.control.SdkHarnessClient.BundleProcessor) DateTimeUtils(org.joda.time.DateTimeUtils) SideInputHandlerFactory(org.apache.beam.runners.fnexecution.state.StateRequestHandlers.SideInputHandlerFactory) Coder(org.apache.beam.sdk.coders.Coder) HashMap(java.util.HashMap) ExecutionStateSampler(org.apache.beam.runners.core.metrics.ExecutionStateSampler) PipelineTranslation(org.apache.beam.runners.core.construction.PipelineTranslation) PipelineOptionsFactory(org.apache.beam.sdk.options.PipelineOptionsFactory) Function(java.util.function.Function) ConcurrentMap(java.util.concurrent.ConcurrentMap) StringUtf8Coder(org.apache.beam.sdk.coders.StringUtf8Coder) MonitoringInfoConstants(org.apache.beam.runners.core.metrics.MonitoringInfoConstants) ByteString(org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString) PTransformNode(org.apache.beam.runners.core.construction.graph.PipelineNode.PTransformNode) TypeUrns(org.apache.beam.runners.core.metrics.MonitoringInfoConstants.TypeUrns) PipelineOptions(org.apache.beam.sdk.options.PipelineOptions) ExecutorService(java.util.concurrent.ExecutorService) ProcessBundleProgressResponse(org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleProgressResponse) MonitoringInfo(org.apache.beam.model.pipeline.v1.MetricsApi.MonitoringInfo) GrpcFnServer(org.apache.beam.sdk.fn.server.GrpcFnServer) GroupByKey(org.apache.beam.sdk.transforms.GroupByKey) WithKeys(org.apache.beam.sdk.transforms.WithKeys) Iterator(java.util.Iterator) BigEndianLongCoder(org.apache.beam.sdk.coders.BigEndianLongCoder) Matchers(org.hamcrest.Matchers) JUnit4(org.junit.runners.JUnit4) PCollection(org.apache.beam.sdk.values.PCollection) TimeUnit(java.util.concurrent.TimeUnit) SplittableParDoExpander(org.apache.beam.runners.core.construction.graph.SplittableParDoExpander) BagState(org.apache.beam.sdk.state.BagState) StateSpecs(org.apache.beam.sdk.state.StateSpecs) Rule(org.junit.Rule) MonitoringInfoMatchers(org.apache.beam.runners.core.metrics.MonitoringInfoMatchers) Caches(org.apache.beam.fn.harness.Caches) SplitResult(org.apache.beam.sdk.transforms.splittabledofn.SplitResult) Collections(java.util.Collections) BagUserStateHandler(org.apache.beam.runners.fnexecution.state.StateRequestHandlers.BagUserStateHandler) ReadableState(org.apache.beam.sdk.state.ReadableState) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) HashMap(java.util.HashMap) ByteString(org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString) RunnerApi(org.apache.beam.model.pipeline.v1.RunnerApi) WindowedValue(org.apache.beam.sdk.util.WindowedValue) BundleProcessor(org.apache.beam.runners.fnexecution.control.SdkHarnessClient.BundleProcessor) ExecutableStage(org.apache.beam.runners.core.construction.graph.ExecutableStage) ExecutableProcessBundleDescriptor(org.apache.beam.runners.fnexecution.control.ProcessBundleDescriptors.ExecutableProcessBundleDescriptor) KvCoder(org.apache.beam.sdk.coders.KvCoder) Coder(org.apache.beam.sdk.coders.Coder) StringUtf8Coder(org.apache.beam.sdk.coders.StringUtf8Coder) BigEndianLongCoder(org.apache.beam.sdk.coders.BigEndianLongCoder) FusedPipeline(org.apache.beam.runners.core.construction.graph.FusedPipeline) Pipeline(org.apache.beam.sdk.Pipeline) FusedPipeline(org.apache.beam.runners.core.construction.graph.FusedPipeline) PCollection(org.apache.beam.sdk.values.PCollection) DoFn(org.apache.beam.sdk.transforms.DoFn) CoderException(org.apache.beam.sdk.coders.CoderException) Test(org.junit.Test)

Example 2 with Impulse

use of org.apache.beam.sdk.transforms.Impulse in project beam by apache.

the class ProcessBundleDescriptorsTest method testLengthPrefixingOfKeyCoderInStatefulExecutableStage.

/**
 * Tests that a stateful stage will wrap the key coder of a stateful transform in a
 * LengthPrefixCoder.
 */
@Test
public void testLengthPrefixingOfKeyCoderInStatefulExecutableStage() throws Exception {
    // Add another stateful stage with a non-standard key coder
    Pipeline p = Pipeline.create();
    Coder<Void> keycoder = VoidCoder.of();
    assertThat(ModelCoderRegistrar.isKnownCoder(keycoder), is(false));
    p.apply("impulse", Impulse.create()).apply("create", ParDo.of(new DoFn<byte[], KV<Void, String>>() {

        @ProcessElement
        public void process(ProcessContext ctxt) {
        }
    })).setCoder(KvCoder.of(keycoder, StringUtf8Coder.of())).apply("userState", ParDo.of(new DoFn<KV<Void, String>, KV<Void, String>>() {

        @StateId("stateId")
        private final StateSpec<BagState<String>> bufferState = StateSpecs.bag(StringUtf8Coder.of());

        @TimerId("timerId")
        private final TimerSpec timerSpec = TimerSpecs.timer(TimeDomain.EVENT_TIME);

        @ProcessElement
        public void processElement(@Element KV<Void, String> element, @StateId("stateId") BagState<String> state, @TimerId("timerId") Timer timer, OutputReceiver<KV<Void, String>> r) {
        }

        @OnTimer("timerId")
        public void onTimer() {
        }
    })).apply("gbk", GroupByKey.create());
    RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
    FusedPipeline fused = GreedyPipelineFuser.fuse(pipelineProto);
    Optional<ExecutableStage> optionalStage = Iterables.tryFind(fused.getFusedStages(), (ExecutableStage stage) -> stage.getUserStates().stream().anyMatch(spec -> spec.localName().equals("stateId")));
    checkState(optionalStage.isPresent(), "Expected a stage with user state.");
    ExecutableStage stage = optionalStage.get();
    PipelineNode.PCollectionNode inputPCollection = stage.getInputPCollection();
    // Ensure original key coder is not a LengthPrefixCoder
    Map<String, RunnerApi.Coder> stageCoderMap = stage.getComponents().getCodersMap();
    RunnerApi.Coder originalMainInputCoder = stageCoderMap.get(inputPCollection.getPCollection().getCoderId());
    String originalKeyCoderId = ModelCoders.getKvCoderComponents(originalMainInputCoder).keyCoderId();
    RunnerApi.Coder originalKeyCoder = stageCoderMap.get(originalKeyCoderId);
    assertThat(originalKeyCoder.getSpec().getUrn(), is(CoderTranslation.JAVA_SERIALIZED_CODER_URN));
    // Now create ProcessBundleDescriptor and check for the LengthPrefixCoder around the key coder
    BeamFnApi.ProcessBundleDescriptor pbd = ProcessBundleDescriptors.fromExecutableStage("test_stage", stage, Endpoints.ApiServiceDescriptor.getDefaultInstance()).getProcessBundleDescriptor();
    Map<String, RunnerApi.Coder> pbsCoderMap = pbd.getCodersMap();
    RunnerApi.Coder pbsMainInputCoder = pbsCoderMap.get(pbd.getPcollectionsOrThrow(inputPCollection.getId()).getCoderId());
    String keyCoderId = ModelCoders.getKvCoderComponents(pbsMainInputCoder).keyCoderId();
    RunnerApi.Coder keyCoder = pbsCoderMap.get(keyCoderId);
    ensureLengthPrefixed(keyCoder, originalKeyCoder, pbsCoderMap);
    TimerReference timerRef = Iterables.getOnlyElement(stage.getTimers());
    String timerTransformId = timerRef.transform().getId();
    RunnerApi.ParDoPayload parDoPayload = RunnerApi.ParDoPayload.parseFrom(pbd.getTransformsOrThrow(timerTransformId).getSpec().getPayload());
    RunnerApi.TimerFamilySpec timerSpec = parDoPayload.getTimerFamilySpecsOrThrow(timerRef.localName());
    RunnerApi.Coder timerCoder = pbsCoderMap.get(timerSpec.getTimerFamilyCoderId());
    String timerKeyCoderId = timerCoder.getComponentCoderIds(0);
    RunnerApi.Coder timerKeyCoder = pbsCoderMap.get(timerKeyCoderId);
    ensureLengthPrefixed(timerKeyCoder, originalKeyCoder, pbsCoderMap);
}
Also used : CoreMatchers.is(org.hamcrest.CoreMatchers.is) Endpoints(org.apache.beam.model.pipeline.v1.Endpoints) StateSpec(org.apache.beam.sdk.state.StateSpec) KV(org.apache.beam.sdk.values.KV) CoderTranslation(org.apache.beam.runners.core.construction.CoderTranslation) TimerSpecs(org.apache.beam.sdk.state.TimerSpecs) Coder(org.apache.beam.sdk.coders.Coder) Impulse(org.apache.beam.sdk.transforms.Impulse) GreedyPipelineFuser(org.apache.beam.runners.core.construction.graph.GreedyPipelineFuser) PipelineTranslation(org.apache.beam.runners.core.construction.PipelineTranslation) Optional(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Optional) StringUtf8Coder(org.apache.beam.sdk.coders.StringUtf8Coder) TimerSpec(org.apache.beam.sdk.state.TimerSpec) Map(java.util.Map) Iterables(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables) TimerReference(org.apache.beam.runners.core.construction.graph.TimerReference) MatcherAssert.assertThat(org.hamcrest.MatcherAssert.assertThat) ModelCoderRegistrar(org.apache.beam.runners.core.construction.ModelCoderRegistrar) Pipeline(org.apache.beam.sdk.Pipeline) ProcessContext(org.apache.beam.sdk.transforms.DoFn.ProcessContext) RestrictionTracker(org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker) RunnerApi(org.apache.beam.model.pipeline.v1.RunnerApi) FusedPipeline(org.apache.beam.runners.core.construction.graph.FusedPipeline) DoFn(org.apache.beam.sdk.transforms.DoFn) KvCoder(org.apache.beam.sdk.coders.KvCoder) GroupByKey(org.apache.beam.sdk.transforms.GroupByKey) PTransformTranslation(org.apache.beam.runners.core.construction.PTransformTranslation) ProcessElement(org.apache.beam.sdk.transforms.DoFn.ProcessElement) Test(org.junit.Test) BeamFnApi(org.apache.beam.model.fnexecution.v1.BeamFnApi) ExecutableStage(org.apache.beam.runners.core.construction.graph.ExecutableStage) ModelCoders(org.apache.beam.runners.core.construction.ModelCoders) Serializable(java.io.Serializable) SplittableParDoExpander(org.apache.beam.runners.core.construction.graph.SplittableParDoExpander) BagState(org.apache.beam.sdk.state.BagState) StateSpecs(org.apache.beam.sdk.state.StateSpecs) ParDo(org.apache.beam.sdk.transforms.ParDo) ProtoOverrides(org.apache.beam.runners.core.construction.graph.ProtoOverrides) Timer(org.apache.beam.sdk.state.Timer) Preconditions.checkState(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState) PipelineNode(org.apache.beam.runners.core.construction.graph.PipelineNode) VoidCoder(org.apache.beam.sdk.coders.VoidCoder) TimeDomain(org.apache.beam.sdk.state.TimeDomain) TimerReference(org.apache.beam.runners.core.construction.graph.TimerReference) ProcessContext(org.apache.beam.sdk.transforms.DoFn.ProcessContext) RunnerApi(org.apache.beam.model.pipeline.v1.RunnerApi) ExecutableStage(org.apache.beam.runners.core.construction.graph.ExecutableStage) BagState(org.apache.beam.sdk.state.BagState) TimerSpec(org.apache.beam.sdk.state.TimerSpec) Coder(org.apache.beam.sdk.coders.Coder) StringUtf8Coder(org.apache.beam.sdk.coders.StringUtf8Coder) KvCoder(org.apache.beam.sdk.coders.KvCoder) VoidCoder(org.apache.beam.sdk.coders.VoidCoder) BeamFnApi(org.apache.beam.model.fnexecution.v1.BeamFnApi) FusedPipeline(org.apache.beam.runners.core.construction.graph.FusedPipeline) KV(org.apache.beam.sdk.values.KV) PipelineNode(org.apache.beam.runners.core.construction.graph.PipelineNode) Pipeline(org.apache.beam.sdk.Pipeline) FusedPipeline(org.apache.beam.runners.core.construction.graph.FusedPipeline) Timer(org.apache.beam.sdk.state.Timer) ProcessElement(org.apache.beam.sdk.transforms.DoFn.ProcessElement) Test(org.junit.Test)

Example 3 with Impulse

use of org.apache.beam.sdk.transforms.Impulse in project beam by apache.

the class QueryablePipelineTest method transformWithSideAndMainInputs.

/**
 * Tests that inputs that are only side inputs are not returned from {@link
 * QueryablePipeline#getPerElementConsumers(PCollectionNode)} and are returned from {@link
 * QueryablePipeline#getSideInputs(PTransformNode)}.
 */
@Test
public void transformWithSideAndMainInputs() {
    Pipeline p = Pipeline.create();
    PCollection<byte[]> impulse = p.apply("Impulse", Impulse.create());
    PCollectionView<String> view = p.apply("Create", Create.of("foo")).apply("View", View.asSingleton());
    impulse.apply("par_do", ParDo.of(new TestFn()).withSideInputs(view).withOutputTags(new TupleTag<>(), TupleTagList.empty()));
    Components components = PipelineTranslation.toProto(p).getComponents();
    QueryablePipeline qp = QueryablePipeline.forPrimitivesIn(components);
    String mainInputName = getOnlyElement(PipelineNode.pTransform("Impulse", components.getTransformsOrThrow("Impulse")).getTransform().getOutputsMap().values());
    PCollectionNode mainInput = PipelineNode.pCollection(mainInputName, components.getPcollectionsOrThrow(mainInputName));
    PTransform parDoTransform = components.getTransformsOrThrow("par_do");
    String sideInputLocalName = getOnlyElement(parDoTransform.getInputsMap().entrySet().stream().filter(entry -> !entry.getValue().equals(mainInputName)).map(Map.Entry::getKey).collect(Collectors.toSet()));
    String sideInputCollectionId = parDoTransform.getInputsOrThrow(sideInputLocalName);
    PCollectionNode sideInput = PipelineNode.pCollection(sideInputCollectionId, components.getPcollectionsOrThrow(sideInputCollectionId));
    PTransformNode parDoNode = PipelineNode.pTransform("par_do", components.getTransformsOrThrow("par_do"));
    SideInputReference sideInputRef = SideInputReference.of(parDoNode, sideInputLocalName, sideInput);
    assertThat(qp.getSideInputs(parDoNode), contains(sideInputRef));
    assertThat(qp.getPerElementConsumers(mainInput), contains(parDoNode));
    assertThat(qp.getPerElementConsumers(sideInput), not(contains(parDoNode)));
}
Also used : Count(org.apache.beam.sdk.transforms.Count) PBegin(org.apache.beam.sdk.values.PBegin) Matchers.not(org.hamcrest.Matchers.not) Matchers.hasKey(org.hamcrest.Matchers.hasKey) ImmutableSet(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet) PCollectionList(org.apache.beam.sdk.values.PCollectionList) FunctionSpec(org.apache.beam.model.pipeline.v1.RunnerApi.FunctionSpec) Create(org.apache.beam.sdk.transforms.Create) Map(java.util.Map) Window(org.apache.beam.sdk.transforms.windowing.Window) Components(org.apache.beam.model.pipeline.v1.RunnerApi.Components) Flatten(org.apache.beam.sdk.transforms.Flatten) MapElements(org.apache.beam.sdk.transforms.MapElements) PTransformTranslation(org.apache.beam.runners.core.construction.PTransformTranslation) Collection(java.util.Collection) Set(java.util.Set) Collectors(java.util.stream.Collectors) ParDo(org.apache.beam.sdk.transforms.ParDo) Matchers.contains(org.hamcrest.Matchers.contains) Matchers.equalTo(org.hamcrest.Matchers.equalTo) TypeDescriptors(org.apache.beam.sdk.values.TypeDescriptors) Matchers.is(org.hamcrest.Matchers.is) SideInput(org.apache.beam.model.pipeline.v1.RunnerApi.SideInput) PCollectionNode(org.apache.beam.runners.core.construction.graph.PipelineNode.PCollectionNode) PTransform(org.apache.beam.model.pipeline.v1.RunnerApi.PTransform) Duration(org.joda.time.Duration) RunWith(org.junit.runner.RunWith) Impulse(org.apache.beam.sdk.transforms.Impulse) View(org.apache.beam.sdk.transforms.View) PipelineTranslation(org.apache.beam.runners.core.construction.PipelineTranslation) TupleTagList(org.apache.beam.sdk.values.TupleTagList) Environments(org.apache.beam.runners.core.construction.Environments) ParDoPayload(org.apache.beam.model.pipeline.v1.RunnerApi.ParDoPayload) Read(org.apache.beam.sdk.io.Read) TupleTag(org.apache.beam.sdk.values.TupleTag) Matchers.hasSize(org.hamcrest.Matchers.hasSize) MatcherAssert.assertThat(org.hamcrest.MatcherAssert.assertThat) PTransformNode(org.apache.beam.runners.core.construction.graph.PipelineNode.PTransformNode) Pipeline(org.apache.beam.sdk.Pipeline) ExpectedException(org.junit.rules.ExpectedException) RunnerApi(org.apache.beam.model.pipeline.v1.RunnerApi) DoFn(org.apache.beam.sdk.transforms.DoFn) CountingSource(org.apache.beam.sdk.io.CountingSource) GroupByKey(org.apache.beam.sdk.transforms.GroupByKey) WithKeys(org.apache.beam.sdk.transforms.WithKeys) FixedWindows(org.apache.beam.sdk.transforms.windowing.FixedWindows) Test(org.junit.Test) JUnit4(org.junit.runners.JUnit4) PCollection(org.apache.beam.sdk.values.PCollection) Matchers.emptyIterable(org.hamcrest.Matchers.emptyIterable) Rule(org.junit.Rule) PCollectionView(org.apache.beam.sdk.values.PCollectionView) Iterables.getOnlyElement(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables.getOnlyElement) PTransformNode(org.apache.beam.runners.core.construction.graph.PipelineNode.PTransformNode) TupleTag(org.apache.beam.sdk.values.TupleTag) PCollectionNode(org.apache.beam.runners.core.construction.graph.PipelineNode.PCollectionNode) Pipeline(org.apache.beam.sdk.Pipeline) Components(org.apache.beam.model.pipeline.v1.RunnerApi.Components) PTransform(org.apache.beam.model.pipeline.v1.RunnerApi.PTransform) Test(org.junit.Test)

Example 4 with Impulse

use of org.apache.beam.sdk.transforms.Impulse in project beam by apache.

the class ProcessBundleDescriptorsTest method testLengthPrefixingOfInputCoderExecutableStage.

@Test
public void testLengthPrefixingOfInputCoderExecutableStage() throws Exception {
    Pipeline p = Pipeline.create();
    Coder<Void> voidCoder = VoidCoder.of();
    assertThat(ModelCoderRegistrar.isKnownCoder(voidCoder), is(false));
    p.apply("impulse", Impulse.create()).apply(ParDo.of(new DoFn<byte[], Void>() {

        @ProcessElement
        public void process(ProcessContext ctxt) {
        }
    })).setCoder(voidCoder).apply(ParDo.of(new DoFn<Void, Void>() {

        @ProcessElement
        public void processElement(ProcessContext context, RestrictionTracker<Void, Void> tracker) {
        }

        @GetInitialRestriction
        public Void getInitialRestriction() {
            return null;
        }

        @NewTracker
        public SomeTracker newTracker(@Restriction Void restriction) {
            return null;
        }
    })).setCoder(voidCoder);
    RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
    RunnerApi.Pipeline pipelineWithSdfExpanded = ProtoOverrides.updateTransform(PTransformTranslation.PAR_DO_TRANSFORM_URN, pipelineProto, SplittableParDoExpander.createSizedReplacement());
    FusedPipeline fused = GreedyPipelineFuser.fuse(pipelineWithSdfExpanded);
    Optional<ExecutableStage> optionalStage = Iterables.tryFind(fused.getFusedStages(), (ExecutableStage stage) -> stage.getTransforms().stream().anyMatch(transform -> transform.getTransform().getSpec().getUrn().equals(PTransformTranslation.SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN)));
    checkState(optionalStage.isPresent(), "Expected a stage with SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN.");
    ExecutableStage stage = optionalStage.get();
    PipelineNode.PCollectionNode inputPCollection = stage.getInputPCollection();
    Map<String, RunnerApi.Coder> stageCoderMap = stage.getComponents().getCodersMap();
    RunnerApi.Coder originalMainInputCoder = stageCoderMap.get(inputPCollection.getPCollection().getCoderId());
    BeamFnApi.ProcessBundleDescriptor pbd = ProcessBundleDescriptors.fromExecutableStage("test_stage", stage, Endpoints.ApiServiceDescriptor.getDefaultInstance()).getProcessBundleDescriptor();
    Map<String, RunnerApi.Coder> pbsCoderMap = pbd.getCodersMap();
    RunnerApi.Coder pbsMainInputCoder = pbsCoderMap.get(pbd.getPcollectionsOrThrow(inputPCollection.getId()).getCoderId());
    RunnerApi.Coder kvCoder = pbsCoderMap.get(ModelCoders.getKvCoderComponents(pbsMainInputCoder).keyCoderId());
    RunnerApi.Coder keyCoder = pbsCoderMap.get(ModelCoders.getKvCoderComponents(kvCoder).keyCoderId());
    RunnerApi.Coder valueKvCoder = pbsCoderMap.get(ModelCoders.getKvCoderComponents(kvCoder).valueCoderId());
    RunnerApi.Coder valueCoder = pbsCoderMap.get(ModelCoders.getKvCoderComponents(valueKvCoder).keyCoderId());
    RunnerApi.Coder originalKvCoder = stageCoderMap.get(ModelCoders.getKvCoderComponents(originalMainInputCoder).keyCoderId());
    RunnerApi.Coder originalKeyCoder = stageCoderMap.get(ModelCoders.getKvCoderComponents(originalKvCoder).keyCoderId());
    RunnerApi.Coder originalvalueKvCoder = stageCoderMap.get(ModelCoders.getKvCoderComponents(originalKvCoder).valueCoderId());
    RunnerApi.Coder originalvalueCoder = stageCoderMap.get(ModelCoders.getKvCoderComponents(originalvalueKvCoder).keyCoderId());
    ensureLengthPrefixed(keyCoder, originalKeyCoder, pbsCoderMap);
    ensureLengthPrefixed(valueCoder, originalvalueCoder, pbsCoderMap);
}
Also used : CoreMatchers.is(org.hamcrest.CoreMatchers.is) Endpoints(org.apache.beam.model.pipeline.v1.Endpoints) StateSpec(org.apache.beam.sdk.state.StateSpec) KV(org.apache.beam.sdk.values.KV) CoderTranslation(org.apache.beam.runners.core.construction.CoderTranslation) TimerSpecs(org.apache.beam.sdk.state.TimerSpecs) Coder(org.apache.beam.sdk.coders.Coder) Impulse(org.apache.beam.sdk.transforms.Impulse) GreedyPipelineFuser(org.apache.beam.runners.core.construction.graph.GreedyPipelineFuser) PipelineTranslation(org.apache.beam.runners.core.construction.PipelineTranslation) Optional(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Optional) StringUtf8Coder(org.apache.beam.sdk.coders.StringUtf8Coder) TimerSpec(org.apache.beam.sdk.state.TimerSpec) Map(java.util.Map) Iterables(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables) TimerReference(org.apache.beam.runners.core.construction.graph.TimerReference) MatcherAssert.assertThat(org.hamcrest.MatcherAssert.assertThat) ModelCoderRegistrar(org.apache.beam.runners.core.construction.ModelCoderRegistrar) Pipeline(org.apache.beam.sdk.Pipeline) ProcessContext(org.apache.beam.sdk.transforms.DoFn.ProcessContext) RestrictionTracker(org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker) RunnerApi(org.apache.beam.model.pipeline.v1.RunnerApi) FusedPipeline(org.apache.beam.runners.core.construction.graph.FusedPipeline) DoFn(org.apache.beam.sdk.transforms.DoFn) KvCoder(org.apache.beam.sdk.coders.KvCoder) GroupByKey(org.apache.beam.sdk.transforms.GroupByKey) PTransformTranslation(org.apache.beam.runners.core.construction.PTransformTranslation) ProcessElement(org.apache.beam.sdk.transforms.DoFn.ProcessElement) Test(org.junit.Test) BeamFnApi(org.apache.beam.model.fnexecution.v1.BeamFnApi) ExecutableStage(org.apache.beam.runners.core.construction.graph.ExecutableStage) ModelCoders(org.apache.beam.runners.core.construction.ModelCoders) Serializable(java.io.Serializable) SplittableParDoExpander(org.apache.beam.runners.core.construction.graph.SplittableParDoExpander) BagState(org.apache.beam.sdk.state.BagState) StateSpecs(org.apache.beam.sdk.state.StateSpecs) ParDo(org.apache.beam.sdk.transforms.ParDo) ProtoOverrides(org.apache.beam.runners.core.construction.graph.ProtoOverrides) Timer(org.apache.beam.sdk.state.Timer) Preconditions.checkState(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState) PipelineNode(org.apache.beam.runners.core.construction.graph.PipelineNode) VoidCoder(org.apache.beam.sdk.coders.VoidCoder) TimeDomain(org.apache.beam.sdk.state.TimeDomain) ProcessContext(org.apache.beam.sdk.transforms.DoFn.ProcessContext) RunnerApi(org.apache.beam.model.pipeline.v1.RunnerApi) ExecutableStage(org.apache.beam.runners.core.construction.graph.ExecutableStage) Coder(org.apache.beam.sdk.coders.Coder) StringUtf8Coder(org.apache.beam.sdk.coders.StringUtf8Coder) KvCoder(org.apache.beam.sdk.coders.KvCoder) VoidCoder(org.apache.beam.sdk.coders.VoidCoder) BeamFnApi(org.apache.beam.model.fnexecution.v1.BeamFnApi) FusedPipeline(org.apache.beam.runners.core.construction.graph.FusedPipeline) PipelineNode(org.apache.beam.runners.core.construction.graph.PipelineNode) Pipeline(org.apache.beam.sdk.Pipeline) FusedPipeline(org.apache.beam.runners.core.construction.graph.FusedPipeline) ProcessElement(org.apache.beam.sdk.transforms.DoFn.ProcessElement) Test(org.junit.Test)

Aggregations

Map (java.util.Map)4 RunnerApi (org.apache.beam.model.pipeline.v1.RunnerApi)4 PTransformTranslation (org.apache.beam.runners.core.construction.PTransformTranslation)4 PipelineTranslation (org.apache.beam.runners.core.construction.PipelineTranslation)4 Pipeline (org.apache.beam.sdk.Pipeline)4 DoFn (org.apache.beam.sdk.transforms.DoFn)4 GroupByKey (org.apache.beam.sdk.transforms.GroupByKey)4 Impulse (org.apache.beam.sdk.transforms.Impulse)4 ParDo (org.apache.beam.sdk.transforms.ParDo)4 MatcherAssert.assertThat (org.hamcrest.MatcherAssert.assertThat)4 Test (org.junit.Test)4 Serializable (java.io.Serializable)3 BeamFnApi (org.apache.beam.model.fnexecution.v1.BeamFnApi)3 ExecutableStage (org.apache.beam.runners.core.construction.graph.ExecutableStage)3 FusedPipeline (org.apache.beam.runners.core.construction.graph.FusedPipeline)3 GreedyPipelineFuser (org.apache.beam.runners.core.construction.graph.GreedyPipelineFuser)3 ProtoOverrides (org.apache.beam.runners.core.construction.graph.ProtoOverrides)3 SplittableParDoExpander (org.apache.beam.runners.core.construction.graph.SplittableParDoExpander)3 Coder (org.apache.beam.sdk.coders.Coder)3 KvCoder (org.apache.beam.sdk.coders.KvCoder)3