Search in sources :

Example 1 with ProcessContext

use of org.apache.beam.sdk.transforms.DoFn.ProcessContext in project beam by apache.

the class ProcessBundleDescriptorsTest method testLengthPrefixingOfKeyCoderInStatefulExecutableStage.

/**
 * Tests that a stateful stage will wrap the key coder of a stateful transform in a
 * LengthPrefixCoder.
 */
@Test
public void testLengthPrefixingOfKeyCoderInStatefulExecutableStage() throws Exception {
    // Add another stateful stage with a non-standard key coder
    Pipeline p = Pipeline.create();
    Coder<Void> keycoder = VoidCoder.of();
    assertThat(ModelCoderRegistrar.isKnownCoder(keycoder), is(false));
    p.apply("impulse", Impulse.create()).apply("create", ParDo.of(new DoFn<byte[], KV<Void, String>>() {

        @ProcessElement
        public void process(ProcessContext ctxt) {
        }
    })).setCoder(KvCoder.of(keycoder, StringUtf8Coder.of())).apply("userState", ParDo.of(new DoFn<KV<Void, String>, KV<Void, String>>() {

        @StateId("stateId")
        private final StateSpec<BagState<String>> bufferState = StateSpecs.bag(StringUtf8Coder.of());

        @TimerId("timerId")
        private final TimerSpec timerSpec = TimerSpecs.timer(TimeDomain.EVENT_TIME);

        @ProcessElement
        public void processElement(@Element KV<Void, String> element, @StateId("stateId") BagState<String> state, @TimerId("timerId") Timer timer, OutputReceiver<KV<Void, String>> r) {
        }

        @OnTimer("timerId")
        public void onTimer() {
        }
    })).apply("gbk", GroupByKey.create());
    RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
    FusedPipeline fused = GreedyPipelineFuser.fuse(pipelineProto);
    Optional<ExecutableStage> optionalStage = Iterables.tryFind(fused.getFusedStages(), (ExecutableStage stage) -> stage.getUserStates().stream().anyMatch(spec -> spec.localName().equals("stateId")));
    checkState(optionalStage.isPresent(), "Expected a stage with user state.");
    ExecutableStage stage = optionalStage.get();
    PipelineNode.PCollectionNode inputPCollection = stage.getInputPCollection();
    // Ensure original key coder is not a LengthPrefixCoder
    Map<String, RunnerApi.Coder> stageCoderMap = stage.getComponents().getCodersMap();
    RunnerApi.Coder originalMainInputCoder = stageCoderMap.get(inputPCollection.getPCollection().getCoderId());
    String originalKeyCoderId = ModelCoders.getKvCoderComponents(originalMainInputCoder).keyCoderId();
    RunnerApi.Coder originalKeyCoder = stageCoderMap.get(originalKeyCoderId);
    assertThat(originalKeyCoder.getSpec().getUrn(), is(CoderTranslation.JAVA_SERIALIZED_CODER_URN));
    // Now create ProcessBundleDescriptor and check for the LengthPrefixCoder around the key coder
    BeamFnApi.ProcessBundleDescriptor pbd = ProcessBundleDescriptors.fromExecutableStage("test_stage", stage, Endpoints.ApiServiceDescriptor.getDefaultInstance()).getProcessBundleDescriptor();
    Map<String, RunnerApi.Coder> pbsCoderMap = pbd.getCodersMap();
    RunnerApi.Coder pbsMainInputCoder = pbsCoderMap.get(pbd.getPcollectionsOrThrow(inputPCollection.getId()).getCoderId());
    String keyCoderId = ModelCoders.getKvCoderComponents(pbsMainInputCoder).keyCoderId();
    RunnerApi.Coder keyCoder = pbsCoderMap.get(keyCoderId);
    ensureLengthPrefixed(keyCoder, originalKeyCoder, pbsCoderMap);
    TimerReference timerRef = Iterables.getOnlyElement(stage.getTimers());
    String timerTransformId = timerRef.transform().getId();
    RunnerApi.ParDoPayload parDoPayload = RunnerApi.ParDoPayload.parseFrom(pbd.getTransformsOrThrow(timerTransformId).getSpec().getPayload());
    RunnerApi.TimerFamilySpec timerSpec = parDoPayload.getTimerFamilySpecsOrThrow(timerRef.localName());
    RunnerApi.Coder timerCoder = pbsCoderMap.get(timerSpec.getTimerFamilyCoderId());
    String timerKeyCoderId = timerCoder.getComponentCoderIds(0);
    RunnerApi.Coder timerKeyCoder = pbsCoderMap.get(timerKeyCoderId);
    ensureLengthPrefixed(timerKeyCoder, originalKeyCoder, pbsCoderMap);
}
Also used : CoreMatchers.is(org.hamcrest.CoreMatchers.is) Endpoints(org.apache.beam.model.pipeline.v1.Endpoints) StateSpec(org.apache.beam.sdk.state.StateSpec) KV(org.apache.beam.sdk.values.KV) CoderTranslation(org.apache.beam.runners.core.construction.CoderTranslation) TimerSpecs(org.apache.beam.sdk.state.TimerSpecs) Coder(org.apache.beam.sdk.coders.Coder) Impulse(org.apache.beam.sdk.transforms.Impulse) GreedyPipelineFuser(org.apache.beam.runners.core.construction.graph.GreedyPipelineFuser) PipelineTranslation(org.apache.beam.runners.core.construction.PipelineTranslation) Optional(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Optional) StringUtf8Coder(org.apache.beam.sdk.coders.StringUtf8Coder) TimerSpec(org.apache.beam.sdk.state.TimerSpec) Map(java.util.Map) Iterables(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables) TimerReference(org.apache.beam.runners.core.construction.graph.TimerReference) MatcherAssert.assertThat(org.hamcrest.MatcherAssert.assertThat) ModelCoderRegistrar(org.apache.beam.runners.core.construction.ModelCoderRegistrar) Pipeline(org.apache.beam.sdk.Pipeline) ProcessContext(org.apache.beam.sdk.transforms.DoFn.ProcessContext) RestrictionTracker(org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker) RunnerApi(org.apache.beam.model.pipeline.v1.RunnerApi) FusedPipeline(org.apache.beam.runners.core.construction.graph.FusedPipeline) DoFn(org.apache.beam.sdk.transforms.DoFn) KvCoder(org.apache.beam.sdk.coders.KvCoder) GroupByKey(org.apache.beam.sdk.transforms.GroupByKey) PTransformTranslation(org.apache.beam.runners.core.construction.PTransformTranslation) ProcessElement(org.apache.beam.sdk.transforms.DoFn.ProcessElement) Test(org.junit.Test) BeamFnApi(org.apache.beam.model.fnexecution.v1.BeamFnApi) ExecutableStage(org.apache.beam.runners.core.construction.graph.ExecutableStage) ModelCoders(org.apache.beam.runners.core.construction.ModelCoders) Serializable(java.io.Serializable) SplittableParDoExpander(org.apache.beam.runners.core.construction.graph.SplittableParDoExpander) BagState(org.apache.beam.sdk.state.BagState) StateSpecs(org.apache.beam.sdk.state.StateSpecs) ParDo(org.apache.beam.sdk.transforms.ParDo) ProtoOverrides(org.apache.beam.runners.core.construction.graph.ProtoOverrides) Timer(org.apache.beam.sdk.state.Timer) Preconditions.checkState(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState) PipelineNode(org.apache.beam.runners.core.construction.graph.PipelineNode) VoidCoder(org.apache.beam.sdk.coders.VoidCoder) TimeDomain(org.apache.beam.sdk.state.TimeDomain) TimerReference(org.apache.beam.runners.core.construction.graph.TimerReference) ProcessContext(org.apache.beam.sdk.transforms.DoFn.ProcessContext) RunnerApi(org.apache.beam.model.pipeline.v1.RunnerApi) ExecutableStage(org.apache.beam.runners.core.construction.graph.ExecutableStage) BagState(org.apache.beam.sdk.state.BagState) TimerSpec(org.apache.beam.sdk.state.TimerSpec) Coder(org.apache.beam.sdk.coders.Coder) StringUtf8Coder(org.apache.beam.sdk.coders.StringUtf8Coder) KvCoder(org.apache.beam.sdk.coders.KvCoder) VoidCoder(org.apache.beam.sdk.coders.VoidCoder) BeamFnApi(org.apache.beam.model.fnexecution.v1.BeamFnApi) FusedPipeline(org.apache.beam.runners.core.construction.graph.FusedPipeline) KV(org.apache.beam.sdk.values.KV) PipelineNode(org.apache.beam.runners.core.construction.graph.PipelineNode) Pipeline(org.apache.beam.sdk.Pipeline) FusedPipeline(org.apache.beam.runners.core.construction.graph.FusedPipeline) Timer(org.apache.beam.sdk.state.Timer) ProcessElement(org.apache.beam.sdk.transforms.DoFn.ProcessElement) Test(org.junit.Test)

Example 2 with ProcessContext

use of org.apache.beam.sdk.transforms.DoFn.ProcessContext in project beam by apache.

the class ProcessBundleDescriptorsTest method testLengthPrefixingOfInputCoderExecutableStage.

@Test
public void testLengthPrefixingOfInputCoderExecutableStage() throws Exception {
    Pipeline p = Pipeline.create();
    Coder<Void> voidCoder = VoidCoder.of();
    assertThat(ModelCoderRegistrar.isKnownCoder(voidCoder), is(false));
    p.apply("impulse", Impulse.create()).apply(ParDo.of(new DoFn<byte[], Void>() {

        @ProcessElement
        public void process(ProcessContext ctxt) {
        }
    })).setCoder(voidCoder).apply(ParDo.of(new DoFn<Void, Void>() {

        @ProcessElement
        public void processElement(ProcessContext context, RestrictionTracker<Void, Void> tracker) {
        }

        @GetInitialRestriction
        public Void getInitialRestriction() {
            return null;
        }

        @NewTracker
        public SomeTracker newTracker(@Restriction Void restriction) {
            return null;
        }
    })).setCoder(voidCoder);
    RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
    RunnerApi.Pipeline pipelineWithSdfExpanded = ProtoOverrides.updateTransform(PTransformTranslation.PAR_DO_TRANSFORM_URN, pipelineProto, SplittableParDoExpander.createSizedReplacement());
    FusedPipeline fused = GreedyPipelineFuser.fuse(pipelineWithSdfExpanded);
    Optional<ExecutableStage> optionalStage = Iterables.tryFind(fused.getFusedStages(), (ExecutableStage stage) -> stage.getTransforms().stream().anyMatch(transform -> transform.getTransform().getSpec().getUrn().equals(PTransformTranslation.SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN)));
    checkState(optionalStage.isPresent(), "Expected a stage with SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN.");
    ExecutableStage stage = optionalStage.get();
    PipelineNode.PCollectionNode inputPCollection = stage.getInputPCollection();
    Map<String, RunnerApi.Coder> stageCoderMap = stage.getComponents().getCodersMap();
    RunnerApi.Coder originalMainInputCoder = stageCoderMap.get(inputPCollection.getPCollection().getCoderId());
    BeamFnApi.ProcessBundleDescriptor pbd = ProcessBundleDescriptors.fromExecutableStage("test_stage", stage, Endpoints.ApiServiceDescriptor.getDefaultInstance()).getProcessBundleDescriptor();
    Map<String, RunnerApi.Coder> pbsCoderMap = pbd.getCodersMap();
    RunnerApi.Coder pbsMainInputCoder = pbsCoderMap.get(pbd.getPcollectionsOrThrow(inputPCollection.getId()).getCoderId());
    RunnerApi.Coder kvCoder = pbsCoderMap.get(ModelCoders.getKvCoderComponents(pbsMainInputCoder).keyCoderId());
    RunnerApi.Coder keyCoder = pbsCoderMap.get(ModelCoders.getKvCoderComponents(kvCoder).keyCoderId());
    RunnerApi.Coder valueKvCoder = pbsCoderMap.get(ModelCoders.getKvCoderComponents(kvCoder).valueCoderId());
    RunnerApi.Coder valueCoder = pbsCoderMap.get(ModelCoders.getKvCoderComponents(valueKvCoder).keyCoderId());
    RunnerApi.Coder originalKvCoder = stageCoderMap.get(ModelCoders.getKvCoderComponents(originalMainInputCoder).keyCoderId());
    RunnerApi.Coder originalKeyCoder = stageCoderMap.get(ModelCoders.getKvCoderComponents(originalKvCoder).keyCoderId());
    RunnerApi.Coder originalvalueKvCoder = stageCoderMap.get(ModelCoders.getKvCoderComponents(originalKvCoder).valueCoderId());
    RunnerApi.Coder originalvalueCoder = stageCoderMap.get(ModelCoders.getKvCoderComponents(originalvalueKvCoder).keyCoderId());
    ensureLengthPrefixed(keyCoder, originalKeyCoder, pbsCoderMap);
    ensureLengthPrefixed(valueCoder, originalvalueCoder, pbsCoderMap);
}
Also used : CoreMatchers.is(org.hamcrest.CoreMatchers.is) Endpoints(org.apache.beam.model.pipeline.v1.Endpoints) StateSpec(org.apache.beam.sdk.state.StateSpec) KV(org.apache.beam.sdk.values.KV) CoderTranslation(org.apache.beam.runners.core.construction.CoderTranslation) TimerSpecs(org.apache.beam.sdk.state.TimerSpecs) Coder(org.apache.beam.sdk.coders.Coder) Impulse(org.apache.beam.sdk.transforms.Impulse) GreedyPipelineFuser(org.apache.beam.runners.core.construction.graph.GreedyPipelineFuser) PipelineTranslation(org.apache.beam.runners.core.construction.PipelineTranslation) Optional(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Optional) StringUtf8Coder(org.apache.beam.sdk.coders.StringUtf8Coder) TimerSpec(org.apache.beam.sdk.state.TimerSpec) Map(java.util.Map) Iterables(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables) TimerReference(org.apache.beam.runners.core.construction.graph.TimerReference) MatcherAssert.assertThat(org.hamcrest.MatcherAssert.assertThat) ModelCoderRegistrar(org.apache.beam.runners.core.construction.ModelCoderRegistrar) Pipeline(org.apache.beam.sdk.Pipeline) ProcessContext(org.apache.beam.sdk.transforms.DoFn.ProcessContext) RestrictionTracker(org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker) RunnerApi(org.apache.beam.model.pipeline.v1.RunnerApi) FusedPipeline(org.apache.beam.runners.core.construction.graph.FusedPipeline) DoFn(org.apache.beam.sdk.transforms.DoFn) KvCoder(org.apache.beam.sdk.coders.KvCoder) GroupByKey(org.apache.beam.sdk.transforms.GroupByKey) PTransformTranslation(org.apache.beam.runners.core.construction.PTransformTranslation) ProcessElement(org.apache.beam.sdk.transforms.DoFn.ProcessElement) Test(org.junit.Test) BeamFnApi(org.apache.beam.model.fnexecution.v1.BeamFnApi) ExecutableStage(org.apache.beam.runners.core.construction.graph.ExecutableStage) ModelCoders(org.apache.beam.runners.core.construction.ModelCoders) Serializable(java.io.Serializable) SplittableParDoExpander(org.apache.beam.runners.core.construction.graph.SplittableParDoExpander) BagState(org.apache.beam.sdk.state.BagState) StateSpecs(org.apache.beam.sdk.state.StateSpecs) ParDo(org.apache.beam.sdk.transforms.ParDo) ProtoOverrides(org.apache.beam.runners.core.construction.graph.ProtoOverrides) Timer(org.apache.beam.sdk.state.Timer) Preconditions.checkState(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState) PipelineNode(org.apache.beam.runners.core.construction.graph.PipelineNode) VoidCoder(org.apache.beam.sdk.coders.VoidCoder) TimeDomain(org.apache.beam.sdk.state.TimeDomain) ProcessContext(org.apache.beam.sdk.transforms.DoFn.ProcessContext) RunnerApi(org.apache.beam.model.pipeline.v1.RunnerApi) ExecutableStage(org.apache.beam.runners.core.construction.graph.ExecutableStage) Coder(org.apache.beam.sdk.coders.Coder) StringUtf8Coder(org.apache.beam.sdk.coders.StringUtf8Coder) KvCoder(org.apache.beam.sdk.coders.KvCoder) VoidCoder(org.apache.beam.sdk.coders.VoidCoder) BeamFnApi(org.apache.beam.model.fnexecution.v1.BeamFnApi) FusedPipeline(org.apache.beam.runners.core.construction.graph.FusedPipeline) PipelineNode(org.apache.beam.runners.core.construction.graph.PipelineNode) Pipeline(org.apache.beam.sdk.Pipeline) FusedPipeline(org.apache.beam.runners.core.construction.graph.FusedPipeline) ProcessElement(org.apache.beam.sdk.transforms.DoFn.ProcessElement) Test(org.junit.Test)

Aggregations

Serializable (java.io.Serializable)2 Map (java.util.Map)2 BeamFnApi (org.apache.beam.model.fnexecution.v1.BeamFnApi)2 Endpoints (org.apache.beam.model.pipeline.v1.Endpoints)2 RunnerApi (org.apache.beam.model.pipeline.v1.RunnerApi)2 CoderTranslation (org.apache.beam.runners.core.construction.CoderTranslation)2 ModelCoderRegistrar (org.apache.beam.runners.core.construction.ModelCoderRegistrar)2 ModelCoders (org.apache.beam.runners.core.construction.ModelCoders)2 PTransformTranslation (org.apache.beam.runners.core.construction.PTransformTranslation)2 PipelineTranslation (org.apache.beam.runners.core.construction.PipelineTranslation)2 ExecutableStage (org.apache.beam.runners.core.construction.graph.ExecutableStage)2 FusedPipeline (org.apache.beam.runners.core.construction.graph.FusedPipeline)2 GreedyPipelineFuser (org.apache.beam.runners.core.construction.graph.GreedyPipelineFuser)2 PipelineNode (org.apache.beam.runners.core.construction.graph.PipelineNode)2 ProtoOverrides (org.apache.beam.runners.core.construction.graph.ProtoOverrides)2 SplittableParDoExpander (org.apache.beam.runners.core.construction.graph.SplittableParDoExpander)2 TimerReference (org.apache.beam.runners.core.construction.graph.TimerReference)2 Pipeline (org.apache.beam.sdk.Pipeline)2 Coder (org.apache.beam.sdk.coders.Coder)2 KvCoder (org.apache.beam.sdk.coders.KvCoder)2