Search in sources :

Example 66 with Coder

use of org.apache.beam.model.pipeline.v1.RunnerApi.Coder in project beam by apache.

the class SparkBatchPortablePipelineTranslator method broadcastSideInputs.

/**
 * Broadcast the side inputs of an executable stage. *This can be expensive.*
 *
 * @return Map from PCollection ID to Spark broadcast variable and coder to decode its contents.
 */
private static <SideInputT> ImmutableMap<String, Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<SideInputT>>> broadcastSideInputs(RunnerApi.ExecutableStagePayload stagePayload, SparkTranslationContext context) {
    Map<String, Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<SideInputT>>> broadcastVariables = new HashMap<>();
    for (SideInputId sideInputId : stagePayload.getSideInputsList()) {
        RunnerApi.Components stagePayloadComponents = stagePayload.getComponents();
        String collectionId = stagePayloadComponents.getTransformsOrThrow(sideInputId.getTransformId()).getInputsOrThrow(sideInputId.getLocalName());
        if (broadcastVariables.containsKey(collectionId)) {
            // This PCollection has already been broadcast.
            continue;
        }
        Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<SideInputT>> tuple2 = broadcastSideInput(collectionId, stagePayloadComponents, context);
        broadcastVariables.put(collectionId, tuple2);
    }
    return ImmutableMap.copyOf(broadcastVariables);
}
Also used : RunnerApi(org.apache.beam.model.pipeline.v1.RunnerApi) WindowedValueCoder(org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder) PipelineTranslatorUtils.getWindowedValueCoder(org.apache.beam.runners.fnexecution.translation.PipelineTranslatorUtils.getWindowedValueCoder) HashMap(java.util.HashMap) Broadcast(org.apache.spark.broadcast.Broadcast) Tuple2(scala.Tuple2) Components(org.apache.beam.model.pipeline.v1.RunnerApi.Components) List(java.util.List) SideInputId(org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload.SideInputId)

Example 67 with Coder

use of org.apache.beam.model.pipeline.v1.RunnerApi.Coder in project beam by apache.

the class SparkBatchPortablePipelineTranslator method translateExecutableStage.

private static <InputT, OutputT, SideInputT> void translateExecutableStage(PTransformNode transformNode, RunnerApi.Pipeline pipeline, SparkTranslationContext context) {
    RunnerApi.ExecutableStagePayload stagePayload;
    try {
        stagePayload = RunnerApi.ExecutableStagePayload.parseFrom(transformNode.getTransform().getSpec().getPayload());
    } catch (IOException e) {
        throw new RuntimeException(e);
    }
    String inputPCollectionId = stagePayload.getInput();
    Dataset inputDataset = context.popDataset(inputPCollectionId);
    Map<String, String> outputs = transformNode.getTransform().getOutputsMap();
    BiMap<String, Integer> outputExtractionMap = createOutputMap(outputs.values());
    Components components = pipeline.getComponents();
    Coder windowCoder = getWindowingStrategy(inputPCollectionId, components).getWindowFn().windowCoder();
    ImmutableMap<String, Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<SideInputT>>> broadcastVariables = broadcastSideInputs(stagePayload, context);
    JavaRDD<RawUnionValue> staged;
    if (stagePayload.getUserStatesCount() > 0 || stagePayload.getTimersCount() > 0) {
        Coder<WindowedValue<InputT>> windowedInputCoder = instantiateCoder(inputPCollectionId, components);
        Coder valueCoder = ((WindowedValue.FullWindowedValueCoder) windowedInputCoder).getValueCoder();
        // Stateful stages are only allowed of KV input to be able to group on the key
        if (!(valueCoder instanceof KvCoder)) {
            throw new IllegalStateException(String.format(Locale.ENGLISH, "The element coder for stateful DoFn '%s' must be KvCoder but is: %s", inputPCollectionId, valueCoder.getClass().getSimpleName()));
        }
        Coder keyCoder = ((KvCoder) valueCoder).getKeyCoder();
        Coder innerValueCoder = ((KvCoder) valueCoder).getValueCoder();
        WindowingStrategy windowingStrategy = getWindowingStrategy(inputPCollectionId, components);
        WindowFn<Object, BoundedWindow> windowFn = windowingStrategy.getWindowFn();
        WindowedValue.WindowedValueCoder wvCoder = WindowedValue.FullWindowedValueCoder.of(innerValueCoder, windowFn.windowCoder());
        JavaPairRDD<ByteArray, Iterable<WindowedValue<KV>>> groupedByKey = groupByKeyPair(inputDataset, keyCoder, wvCoder);
        SparkExecutableStageFunction<KV, SideInputT> function = new SparkExecutableStageFunction<>(context.getSerializableOptions(), stagePayload, context.jobInfo, outputExtractionMap, SparkExecutableStageContextFactory.getInstance(), broadcastVariables, MetricsAccumulator.getInstance(), windowCoder);
        staged = groupedByKey.flatMap(function.forPair());
    } else {
        JavaRDD<WindowedValue<InputT>> inputRdd2 = ((BoundedDataset<InputT>) inputDataset).getRDD();
        SparkExecutableStageFunction<InputT, SideInputT> function2 = new SparkExecutableStageFunction<>(context.getSerializableOptions(), stagePayload, context.jobInfo, outputExtractionMap, SparkExecutableStageContextFactory.getInstance(), broadcastVariables, MetricsAccumulator.getInstance(), windowCoder);
        staged = inputRdd2.mapPartitions(function2);
    }
    String intermediateId = getExecutableStageIntermediateId(transformNode);
    context.pushDataset(intermediateId, new Dataset() {

        @Override
        public void cache(String storageLevel, Coder<?> coder) {
            StorageLevel level = StorageLevel.fromString(storageLevel);
            staged.persist(level);
        }

        @Override
        public void action() {
            // Empty function to force computation of RDD.
            staged.foreach(TranslationUtils.emptyVoidFunction());
        }

        @Override
        public void setName(String name) {
            staged.setName(name);
        }
    });
    // pop dataset to mark RDD as used
    context.popDataset(intermediateId);
    for (String outputId : outputs.values()) {
        JavaRDD<WindowedValue<OutputT>> outputRdd = staged.flatMap(new SparkExecutableStageExtractionFunction<>(outputExtractionMap.get(outputId)));
        context.pushDataset(outputId, new BoundedDataset<>(outputRdd));
    }
    if (outputs.isEmpty()) {
        // After pipeline translation, we traverse the set of unconsumed PCollections and add a
        // no-op sink to each to make sure they are materialized by Spark. However, some SDK-executed
        // stages have no runner-visible output after fusion. We handle this case by adding a sink
        // here.
        JavaRDD<WindowedValue<OutputT>> outputRdd = staged.flatMap((rawUnionValue) -> Collections.emptyIterator());
        context.pushDataset(String.format("EmptyOutputSink_%d", context.nextSinkId()), new BoundedDataset<>(outputRdd));
    }
}
Also used : PipelineTranslatorUtils.getWindowingStrategy(org.apache.beam.runners.fnexecution.translation.PipelineTranslatorUtils.getWindowingStrategy) WindowingStrategy(org.apache.beam.sdk.values.WindowingStrategy) Components(org.apache.beam.model.pipeline.v1.RunnerApi.Components) WindowedValueCoder(org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder) RunnerApi(org.apache.beam.model.pipeline.v1.RunnerApi) WindowedValue(org.apache.beam.sdk.util.WindowedValue) BoundedWindow(org.apache.beam.sdk.transforms.windowing.BoundedWindow) ByteArray(org.apache.beam.runners.spark.util.ByteArray) List(java.util.List) StorageLevel(org.apache.spark.storage.StorageLevel) WindowedValueCoder(org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder) KvCoder(org.apache.beam.sdk.coders.KvCoder) PipelineTranslatorUtils.instantiateCoder(org.apache.beam.runners.fnexecution.translation.PipelineTranslatorUtils.instantiateCoder) PipelineTranslatorUtils.getWindowedValueCoder(org.apache.beam.runners.fnexecution.translation.PipelineTranslatorUtils.getWindowedValueCoder) Coder(org.apache.beam.sdk.coders.Coder) ByteArrayCoder(org.apache.beam.sdk.coders.ByteArrayCoder) RawUnionValue(org.apache.beam.sdk.transforms.join.RawUnionValue) KvCoder(org.apache.beam.sdk.coders.KvCoder) IOException(java.io.IOException) KV(org.apache.beam.sdk.values.KV) Tuple2(scala.Tuple2)

Example 68 with Coder

use of org.apache.beam.model.pipeline.v1.RunnerApi.Coder in project beam by apache.

the class SparkBatchPortablePipelineTranslator method translateGroupByKey.

private static <K, V> void translateGroupByKey(PTransformNode transformNode, RunnerApi.Pipeline pipeline, SparkTranslationContext context) {
    RunnerApi.Components components = pipeline.getComponents();
    String inputId = getInputId(transformNode);
    Dataset inputDataset = context.popDataset(inputId);
    JavaRDD<WindowedValue<KV<K, V>>> inputRdd = ((BoundedDataset<KV<K, V>>) inputDataset).getRDD();
    WindowedValueCoder<KV<K, V>> inputCoder = getWindowedValueCoder(inputId, components);
    KvCoder<K, V> inputKvCoder = (KvCoder<K, V>) inputCoder.getValueCoder();
    Coder<K> inputKeyCoder = inputKvCoder.getKeyCoder();
    Coder<V> inputValueCoder = inputKvCoder.getValueCoder();
    WindowingStrategy windowingStrategy = getWindowingStrategy(inputId, components);
    WindowFn<Object, BoundedWindow> windowFn = windowingStrategy.getWindowFn();
    WindowedValue.WindowedValueCoder<V> wvCoder = WindowedValue.FullWindowedValueCoder.of(inputValueCoder, windowFn.windowCoder());
    JavaRDD<WindowedValue<KV<K, Iterable<V>>>> groupedByKeyAndWindow;
    Partitioner partitioner = getPartitioner(context);
    // As this is batch, we can ignore triggering and allowed lateness parameters.
    if (windowingStrategy.getWindowFn().equals(new GlobalWindows()) && windowingStrategy.getTimestampCombiner().equals(TimestampCombiner.END_OF_WINDOW)) {
        // we can drop the windows and recover them later
        groupedByKeyAndWindow = GroupNonMergingWindowsFunctions.groupByKeyInGlobalWindow(inputRdd, inputKeyCoder, inputValueCoder, partitioner);
    } else if (GroupNonMergingWindowsFunctions.isEligibleForGroupByWindow(windowingStrategy)) {
        // we can have a memory sensitive translation for non-merging windows
        groupedByKeyAndWindow = GroupNonMergingWindowsFunctions.groupByKeyAndWindow(inputRdd, inputKeyCoder, inputValueCoder, windowingStrategy, partitioner);
    } else {
        JavaRDD<KV<K, Iterable<WindowedValue<V>>>> groupedByKeyOnly = GroupCombineFunctions.groupByKeyOnly(inputRdd, inputKeyCoder, wvCoder, partitioner);
        // for batch, GroupAlsoByWindow uses an in-memory StateInternals.
        groupedByKeyAndWindow = groupedByKeyOnly.flatMap(new SparkGroupAlsoByWindowViaOutputBufferFn<>(windowingStrategy, new TranslationUtils.InMemoryStateInternalsFactory<>(), SystemReduceFn.buffering(inputValueCoder), context.serializablePipelineOptions));
    }
    context.pushDataset(getOutputId(transformNode), new BoundedDataset<>(groupedByKeyAndWindow));
}
Also used : PipelineTranslatorUtils.getWindowingStrategy(org.apache.beam.runners.fnexecution.translation.PipelineTranslatorUtils.getWindowingStrategy) WindowingStrategy(org.apache.beam.sdk.values.WindowingStrategy) RunnerApi(org.apache.beam.model.pipeline.v1.RunnerApi) WindowedValue(org.apache.beam.sdk.util.WindowedValue) KV(org.apache.beam.sdk.values.KV) Components(org.apache.beam.model.pipeline.v1.RunnerApi.Components) BoundedWindow(org.apache.beam.sdk.transforms.windowing.BoundedWindow) Partitioner(org.apache.spark.Partitioner) HashPartitioner(org.apache.spark.HashPartitioner) GlobalWindows(org.apache.beam.sdk.transforms.windowing.GlobalWindows) KvCoder(org.apache.beam.sdk.coders.KvCoder) KV(org.apache.beam.sdk.values.KV) JavaRDD(org.apache.spark.api.java.JavaRDD)

Aggregations

RunnerApi (org.apache.beam.model.pipeline.v1.RunnerApi)48 Coder (org.apache.beam.sdk.coders.Coder)33 WindowedValue (org.apache.beam.sdk.util.WindowedValue)32 KvCoder (org.apache.beam.sdk.coders.KvCoder)30 Test (org.junit.Test)30 Map (java.util.Map)23 ByteString (org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString)23 HashMap (java.util.HashMap)21 KV (org.apache.beam.sdk.values.KV)20 ArrayList (java.util.ArrayList)19 IOException (java.io.IOException)18 StringUtf8Coder (org.apache.beam.sdk.coders.StringUtf8Coder)17 List (java.util.List)16 ExecutableStage (org.apache.beam.runners.core.construction.graph.ExecutableStage)16 BoundedWindow (org.apache.beam.sdk.transforms.windowing.BoundedWindow)15 ImmutableMap (org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap)15 Collection (java.util.Collection)13 Pipeline (org.apache.beam.sdk.Pipeline)13 FusedPipeline (org.apache.beam.runners.core.construction.graph.FusedPipeline)12 ConcurrentHashMap (java.util.concurrent.ConcurrentHashMap)11