Search in sources :

Example 16 with WindowedValueCoder

use of org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder in project beam by apache.

the class SparkBatchPortablePipelineTranslator method translateExecutableStage.

private static <InputT, OutputT, SideInputT> void translateExecutableStage(PTransformNode transformNode, RunnerApi.Pipeline pipeline, SparkTranslationContext context) {
    RunnerApi.ExecutableStagePayload stagePayload;
    try {
        stagePayload = RunnerApi.ExecutableStagePayload.parseFrom(transformNode.getTransform().getSpec().getPayload());
    } catch (IOException e) {
        throw new RuntimeException(e);
    }
    String inputPCollectionId = stagePayload.getInput();
    Dataset inputDataset = context.popDataset(inputPCollectionId);
    Map<String, String> outputs = transformNode.getTransform().getOutputsMap();
    BiMap<String, Integer> outputExtractionMap = createOutputMap(outputs.values());
    Components components = pipeline.getComponents();
    Coder windowCoder = getWindowingStrategy(inputPCollectionId, components).getWindowFn().windowCoder();
    ImmutableMap<String, Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<SideInputT>>> broadcastVariables = broadcastSideInputs(stagePayload, context);
    JavaRDD<RawUnionValue> staged;
    if (stagePayload.getUserStatesCount() > 0 || stagePayload.getTimersCount() > 0) {
        Coder<WindowedValue<InputT>> windowedInputCoder = instantiateCoder(inputPCollectionId, components);
        Coder valueCoder = ((WindowedValue.FullWindowedValueCoder) windowedInputCoder).getValueCoder();
        // Stateful stages are only allowed of KV input to be able to group on the key
        if (!(valueCoder instanceof KvCoder)) {
            throw new IllegalStateException(String.format(Locale.ENGLISH, "The element coder for stateful DoFn '%s' must be KvCoder but is: %s", inputPCollectionId, valueCoder.getClass().getSimpleName()));
        }
        Coder keyCoder = ((KvCoder) valueCoder).getKeyCoder();
        Coder innerValueCoder = ((KvCoder) valueCoder).getValueCoder();
        WindowingStrategy windowingStrategy = getWindowingStrategy(inputPCollectionId, components);
        WindowFn<Object, BoundedWindow> windowFn = windowingStrategy.getWindowFn();
        WindowedValue.WindowedValueCoder wvCoder = WindowedValue.FullWindowedValueCoder.of(innerValueCoder, windowFn.windowCoder());
        JavaPairRDD<ByteArray, Iterable<WindowedValue<KV>>> groupedByKey = groupByKeyPair(inputDataset, keyCoder, wvCoder);
        SparkExecutableStageFunction<KV, SideInputT> function = new SparkExecutableStageFunction<>(context.getSerializableOptions(), stagePayload, context.jobInfo, outputExtractionMap, SparkExecutableStageContextFactory.getInstance(), broadcastVariables, MetricsAccumulator.getInstance(), windowCoder);
        staged = groupedByKey.flatMap(function.forPair());
    } else {
        JavaRDD<WindowedValue<InputT>> inputRdd2 = ((BoundedDataset<InputT>) inputDataset).getRDD();
        SparkExecutableStageFunction<InputT, SideInputT> function2 = new SparkExecutableStageFunction<>(context.getSerializableOptions(), stagePayload, context.jobInfo, outputExtractionMap, SparkExecutableStageContextFactory.getInstance(), broadcastVariables, MetricsAccumulator.getInstance(), windowCoder);
        staged = inputRdd2.mapPartitions(function2);
    }
    String intermediateId = getExecutableStageIntermediateId(transformNode);
    context.pushDataset(intermediateId, new Dataset() {

        @Override
        public void cache(String storageLevel, Coder<?> coder) {
            StorageLevel level = StorageLevel.fromString(storageLevel);
            staged.persist(level);
        }

        @Override
        public void action() {
            // Empty function to force computation of RDD.
            staged.foreach(TranslationUtils.emptyVoidFunction());
        }

        @Override
        public void setName(String name) {
            staged.setName(name);
        }
    });
    // pop dataset to mark RDD as used
    context.popDataset(intermediateId);
    for (String outputId : outputs.values()) {
        JavaRDD<WindowedValue<OutputT>> outputRdd = staged.flatMap(new SparkExecutableStageExtractionFunction<>(outputExtractionMap.get(outputId)));
        context.pushDataset(outputId, new BoundedDataset<>(outputRdd));
    }
    if (outputs.isEmpty()) {
        // After pipeline translation, we traverse the set of unconsumed PCollections and add a
        // no-op sink to each to make sure they are materialized by Spark. However, some SDK-executed
        // stages have no runner-visible output after fusion. We handle this case by adding a sink
        // here.
        JavaRDD<WindowedValue<OutputT>> outputRdd = staged.flatMap((rawUnionValue) -> Collections.emptyIterator());
        context.pushDataset(String.format("EmptyOutputSink_%d", context.nextSinkId()), new BoundedDataset<>(outputRdd));
    }
}
Also used : PipelineTranslatorUtils.getWindowingStrategy(org.apache.beam.runners.fnexecution.translation.PipelineTranslatorUtils.getWindowingStrategy) WindowingStrategy(org.apache.beam.sdk.values.WindowingStrategy) Components(org.apache.beam.model.pipeline.v1.RunnerApi.Components) WindowedValueCoder(org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder) RunnerApi(org.apache.beam.model.pipeline.v1.RunnerApi) WindowedValue(org.apache.beam.sdk.util.WindowedValue) BoundedWindow(org.apache.beam.sdk.transforms.windowing.BoundedWindow) ByteArray(org.apache.beam.runners.spark.util.ByteArray) List(java.util.List) StorageLevel(org.apache.spark.storage.StorageLevel) WindowedValueCoder(org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder) KvCoder(org.apache.beam.sdk.coders.KvCoder) PipelineTranslatorUtils.instantiateCoder(org.apache.beam.runners.fnexecution.translation.PipelineTranslatorUtils.instantiateCoder) PipelineTranslatorUtils.getWindowedValueCoder(org.apache.beam.runners.fnexecution.translation.PipelineTranslatorUtils.getWindowedValueCoder) Coder(org.apache.beam.sdk.coders.Coder) ByteArrayCoder(org.apache.beam.sdk.coders.ByteArrayCoder) RawUnionValue(org.apache.beam.sdk.transforms.join.RawUnionValue) KvCoder(org.apache.beam.sdk.coders.KvCoder) IOException(java.io.IOException) KV(org.apache.beam.sdk.values.KV) Tuple2(scala.Tuple2)

Aggregations

WindowedValueCoder (org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder)16 RunnerApi (org.apache.beam.model.pipeline.v1.RunnerApi)9 WindowedValue (org.apache.beam.sdk.util.WindowedValue)8 Coder (org.apache.beam.sdk.coders.Coder)7 KvCoder (org.apache.beam.sdk.coders.KvCoder)7 CloudObject (org.apache.beam.runners.dataflow.util.CloudObject)6 List (java.util.List)5 Map (java.util.Map)5 IOException (java.io.IOException)4 ArrayList (java.util.ArrayList)4 HashMap (java.util.HashMap)4 RehydratedComponents (org.apache.beam.runners.core.construction.RehydratedComponents)4 InvalidProtocolBufferException (org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.InvalidProtocolBufferException)4 ParallelInstruction (com.google.api.services.dataflow.model.ParallelInstruction)3 Structs.getString (org.apache.beam.runners.dataflow.util.Structs.getString)3 ParDoFn (org.apache.beam.runners.dataflow.worker.util.common.worker.ParDoFn)3 KV (org.apache.beam.sdk.values.KV)3 PCollectionView (org.apache.beam.sdk.values.PCollectionView)3 TupleTag (org.apache.beam.sdk.values.TupleTag)3 ImmutableMap (org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap)3