Search in sources :

Example 1 with ByteArray

use of org.apache.beam.runners.spark.util.ByteArray in project beam by apache.

the class SparkGroupAlsoByWindowViaWindowSet method groupAlsoByWindow.

public static <K, InputT, W extends BoundedWindow> JavaDStream<WindowedValue<KV<K, Iterable<InputT>>>> groupAlsoByWindow(JavaDStream<WindowedValue<KV<K, Iterable<WindowedValue<InputT>>>>> inputDStream, final Coder<K> keyCoder, final Coder<WindowedValue<InputT>> wvCoder, final WindowingStrategy<?, W> windowingStrategy, final SparkRuntimeContext runtimeContext, final List<Integer> sourceIds) {
    final IterableCoder<WindowedValue<InputT>> itrWvCoder = IterableCoder.of(wvCoder);
    final Coder<InputT> iCoder = ((FullWindowedValueCoder<InputT>) wvCoder).getValueCoder();
    final Coder<? extends BoundedWindow> wCoder = ((FullWindowedValueCoder<InputT>) wvCoder).getWindowCoder();
    final Coder<WindowedValue<KV<K, Iterable<InputT>>>> wvKvIterCoder = FullWindowedValueCoder.of(KvCoder.of(keyCoder, IterableCoder.of(iCoder)), wCoder);
    final TimerInternals.TimerDataCoder timerDataCoder = TimerInternals.TimerDataCoder.of(windowingStrategy.getWindowFn().windowCoder());
    long checkpointDurationMillis = runtimeContext.getPipelineOptions().as(SparkPipelineOptions.class).getCheckpointDurationMillis();
    // we have to switch to Scala API to avoid Optional in the Java API, see: SPARK-4819.
    // we also have a broader API for Scala (access to the actual key and entire iterator).
    // we use coders to convert objects in the PCollection to byte arrays, so they
    // can be transferred over the network for the shuffle and be in serialized form
    // for checkpointing.
    // for readability, we add comments with actual type next to byte[].
    // to shorten line length, we use:
    //---- WV: WindowedValue
    //---- Iterable: Itr
    //---- AccumT: A
    //---- InputT: I
    DStream<Tuple2<ByteArray, byte[]>> /*Itr<WV<I>>*/
    pairDStream = inputDStream.transformToPair(new Function<JavaRDD<WindowedValue<KV<K, Iterable<WindowedValue<InputT>>>>>, JavaPairRDD<ByteArray, byte[]>>() {

        // we use mapPartitions with the RDD API because its the only available API
        // that allows to preserve partitioning.
        @Override
        public JavaPairRDD<ByteArray, byte[]> call(JavaRDD<WindowedValue<KV<K, Iterable<WindowedValue<InputT>>>>> rdd) throws Exception {
            return rdd.mapPartitions(TranslationUtils.functionToFlatMapFunction(WindowingHelpers.<KV<K, Iterable<WindowedValue<InputT>>>>unwindowFunction()), true).mapPartitionsToPair(TranslationUtils.<K, Iterable<WindowedValue<InputT>>>toPairFlatMapFunction(), true).mapPartitionsToPair(TranslationUtils.pairFunctionToPairFlatMapFunction(CoderHelpers.toByteFunction(keyCoder, itrWvCoder)), true);
        }
    }).dstream();
    PairDStreamFunctions<ByteArray, byte[]> pairDStreamFunctions = DStream.toPairDStreamFunctions(pairDStream, JavaSparkContext$.MODULE$.<ByteArray>fakeClassTag(), JavaSparkContext$.MODULE$.<byte[]>fakeClassTag(), null);
    int defaultNumPartitions = pairDStreamFunctions.defaultPartitioner$default$1();
    Partitioner partitioner = pairDStreamFunctions.defaultPartitioner(defaultNumPartitions);
    // use updateStateByKey to scan through the state and update elements and timers.
    DStream<Tuple2<ByteArray, Tuple2<StateAndTimers, List<byte[]>>>> /*WV<KV<K, Itr<I>>>*/
    firedStream = pairDStreamFunctions.updateStateByKey(new SerializableFunction1<scala.collection.Iterator<Tuple3</*K*/
    ByteArray, Seq<byte[]>, Option<Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/
    List<byte[]>>>>>, scala.collection.Iterator<Tuple2</*K*/
    ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/
    List<byte[]>>>>>() {

        @Override
        public scala.collection.Iterator<Tuple2</*K*/
        ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/
        List<byte[]>>>> apply(final scala.collection.Iterator<Tuple3</*K*/
        ByteArray, Seq<byte[]>, Option<Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/
        List<byte[]>>>>> iter) {
            //--- ACTUAL STATEFUL OPERATION:
            //
            // Input Iterator: the partition (~bundle) of a cogrouping of the input
            // and the previous state (if exists).
            //
            // Output Iterator: the output key, and the updated state.
            //
            // possible input scenarios for (K, Seq, Option<S>):
            // (1) Option<S>.isEmpty: new data with no previous state.
            // (2) Seq.isEmpty: no new data, but evaluating previous state (timer-like behaviour).
            // (3) Seq.nonEmpty && Option<S>.isDefined: new data with previous state.
            final SystemReduceFn<K, InputT, Iterable<InputT>, Iterable<InputT>, W> reduceFn = SystemReduceFn.buffering(((FullWindowedValueCoder<InputT>) wvCoder).getValueCoder());
            final OutputWindowedValueHolder<K, InputT> outputHolder = new OutputWindowedValueHolder<>();
            // use in memory Aggregators since Spark Accumulators are not resilient
            // in stateful operators, once done with this partition.
            final MetricsContainerImpl cellProvider = new MetricsContainerImpl("cellProvider");
            final CounterCell droppedDueToClosedWindow = cellProvider.getCounter(MetricName.named(SparkGroupAlsoByWindowViaWindowSet.class, GroupAlsoByWindowsAggregators.DROPPED_DUE_TO_CLOSED_WINDOW_COUNTER));
            final CounterCell droppedDueToLateness = cellProvider.getCounter(MetricName.named(SparkGroupAlsoByWindowViaWindowSet.class, GroupAlsoByWindowsAggregators.DROPPED_DUE_TO_LATENESS_COUNTER));
            AbstractIterator<Tuple2<ByteArray, Tuple2<StateAndTimers, List<byte[]>>>> /*WV<KV<K, Itr<I>>>*/
            outIter = new AbstractIterator<Tuple2</*K*/
            ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/
            List<byte[]>>>>() {

                @Override
                protected Tuple2</*K*/
                ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/
                List<byte[]>>> computeNext() {
                    // (possibly) previous-state and (possibly) new data.
                    while (iter.hasNext()) {
                        // for each element in the partition:
                        Tuple3<ByteArray, Seq<byte[]>, Option<Tuple2<StateAndTimers, List<byte[]>>>> next = iter.next();
                        ByteArray encodedKey = next._1();
                        K key = CoderHelpers.fromByteArray(encodedKey.getValue(), keyCoder);
                        Seq<byte[]> seq = next._2();
                        Option<Tuple2<StateAndTimers, List<byte[]>>> prevStateAndTimersOpt = next._3();
                        SparkStateInternals<K> stateInternals;
                        SparkTimerInternals timerInternals = SparkTimerInternals.forStreamFromSources(sourceIds, GlobalWatermarkHolder.get());
                        // get state(internals) per key.
                        if (prevStateAndTimersOpt.isEmpty()) {
                            // no previous state.
                            stateInternals = SparkStateInternals.forKey(key);
                        } else {
                            // with pre-existing state.
                            StateAndTimers prevStateAndTimers = prevStateAndTimersOpt.get()._1();
                            stateInternals = SparkStateInternals.forKeyAndState(key, prevStateAndTimers.getState());
                            Collection<byte[]> serTimers = prevStateAndTimers.getTimers();
                            timerInternals.addTimers(SparkTimerInternals.deserializeTimers(serTimers, timerDataCoder));
                        }
                        ReduceFnRunner<K, InputT, Iterable<InputT>, W> reduceFnRunner = new ReduceFnRunner<>(key, windowingStrategy, ExecutableTriggerStateMachine.create(TriggerStateMachines.stateMachineForTrigger(TriggerTranslation.toProto(windowingStrategy.getTrigger()))), stateInternals, timerInternals, outputHolder, new UnsupportedSideInputReader("GroupAlsoByWindow"), reduceFn, runtimeContext.getPipelineOptions());
                        // clear before potential use.
                        outputHolder.clear();
                        if (!seq.isEmpty()) {
                            // new input for key.
                            try {
                                Iterable<WindowedValue<InputT>> elementsIterable = CoderHelpers.fromByteArray(seq.head(), itrWvCoder);
                                Iterable<WindowedValue<InputT>> validElements = LateDataUtils.dropExpiredWindows(key, elementsIterable, timerInternals, windowingStrategy, droppedDueToLateness);
                                reduceFnRunner.processElements(validElements);
                            } catch (Exception e) {
                                throw new RuntimeException("Failed to process element with ReduceFnRunner", e);
                            }
                        } else if (stateInternals.getState().isEmpty()) {
                            // no input and no state -> GC evict now.
                            continue;
                        }
                        try {
                            // advance the watermark to HWM to fire by timers.
                            timerInternals.advanceWatermark();
                            // call on timers that are ready.
                            reduceFnRunner.onTimers(timerInternals.getTimersReadyToProcess());
                        } catch (Exception e) {
                            throw new RuntimeException("Failed to process ReduceFnRunner onTimer.", e);
                        }
                        // this is mostly symbolic since actual persist is done by emitting output.
                        reduceFnRunner.persist();
                        // obtain output, if fired.
                        List<WindowedValue<KV<K, Iterable<InputT>>>> outputs = outputHolder.get();
                        if (!outputs.isEmpty() || !stateInternals.getState().isEmpty()) {
                            StateAndTimers updated = new StateAndTimers(stateInternals.getState(), SparkTimerInternals.serializeTimers(timerInternals.getTimers(), timerDataCoder));
                            // persist Spark's state by outputting.
                            List<byte[]> serOutput = CoderHelpers.toByteArrays(outputs, wvKvIterCoder);
                            return new Tuple2<>(encodedKey, new Tuple2<>(updated, serOutput));
                        }
                    // an empty state with no output, can be evicted completely - do nothing.
                    }
                    return endOfData();
                }
            };
            // log if there's something to log.
            long lateDropped = droppedDueToLateness.getCumulative();
            if (lateDropped > 0) {
                LOG.info(String.format("Dropped %d elements due to lateness.", lateDropped));
                droppedDueToLateness.inc(-droppedDueToLateness.getCumulative());
            }
            long closedWindowDropped = droppedDueToClosedWindow.getCumulative();
            if (closedWindowDropped > 0) {
                LOG.info(String.format("Dropped %d elements due to closed window.", closedWindowDropped));
                droppedDueToClosedWindow.inc(-droppedDueToClosedWindow.getCumulative());
            }
            return scala.collection.JavaConversions.asScalaIterator(outIter);
        }
    }, partitioner, true, JavaSparkContext$.MODULE$.<Tuple2<StateAndTimers, List<byte[]>>>fakeClassTag());
    if (checkpointDurationMillis > 0) {
        firedStream.checkpoint(new Duration(checkpointDurationMillis));
    }
    // go back to Java now.
    JavaPairDStream<ByteArray, Tuple2<StateAndTimers, List<byte[]>>> /*WV<KV<K, Itr<I>>>*/
    javaFiredStream = JavaPairDStream.fromPairDStream(firedStream, JavaSparkContext$.MODULE$.<ByteArray>fakeClassTag(), JavaSparkContext$.MODULE$.<Tuple2<StateAndTimers, List<byte[]>>>fakeClassTag());
    // filter state-only output (nothing to fire) and remove the state from the output.
    return javaFiredStream.filter(new Function<Tuple2</*K*/
    ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/
    List<byte[]>>>, Boolean>() {

        @Override
        public Boolean call(Tuple2</*K*/
        ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/
        List<byte[]>>> t2) throws Exception {
            // filter output if defined.
            return !t2._2()._2().isEmpty();
        }
    }).flatMap(new FlatMapFunction<Tuple2</*K*/
    ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/
    List<byte[]>>>, WindowedValue<KV<K, Iterable<InputT>>>>() {

        @Override
        public Iterable<WindowedValue<KV<K, Iterable<InputT>>>> call(Tuple2</*K*/
        ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/
        List<byte[]>>> t2) throws Exception {
            // return in serialized form.
            return CoderHelpers.fromByteArrays(t2._2()._2(), wvKvIterCoder);
        }
    });
}
Also used : MetricsContainerImpl(org.apache.beam.runners.core.metrics.MetricsContainerImpl) CounterCell(org.apache.beam.runners.core.metrics.CounterCell) WindowedValue(org.apache.beam.sdk.util.WindowedValue) OutputWindowedValue(org.apache.beam.runners.core.OutputWindowedValue) ByteArray(org.apache.beam.runners.spark.util.ByteArray) List(java.util.List) ArrayList(java.util.ArrayList) ReduceFnRunner(org.apache.beam.runners.core.ReduceFnRunner) SystemReduceFn(org.apache.beam.runners.core.SystemReduceFn) Duration(org.apache.spark.streaming.Duration) TimerInternals(org.apache.beam.runners.core.TimerInternals) Collection(java.util.Collection) Option(scala.Option) Seq(scala.collection.Seq) FlatMapFunction(org.apache.spark.api.java.function.FlatMapFunction) Function(org.apache.spark.api.java.function.Function) UnsupportedSideInputReader(org.apache.beam.runners.core.UnsupportedSideInputReader) AbstractIterator(com.google.common.collect.AbstractIterator) AbstractIterator(com.google.common.collect.AbstractIterator) Partitioner(org.apache.spark.Partitioner) FullWindowedValueCoder(org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder) KV(org.apache.beam.sdk.values.KV) SparkPipelineOptions(org.apache.beam.runners.spark.SparkPipelineOptions) JavaRDD(org.apache.spark.api.java.JavaRDD) Tuple2(scala.Tuple2) Tuple3(scala.Tuple3)

Example 2 with ByteArray

use of org.apache.beam.runners.spark.util.ByteArray in project beam by apache.

the class GroupNonMergingWindowsFunctions method bringWindowToKey.

/**
 * Creates pair RDD with key being a composite of original key and window.
 */
static <K, V, OutputT, W extends BoundedWindow> JavaPairRDD<ByteArray, OutputT> bringWindowToKey(JavaRDD<WindowedValue<KV<K, V>>> rdd, Coder<K> keyCoder, Coder<W> windowCoder, SerializableFunction<WindowedValue<KV<K, V>>, OutputT> mappingFn) {
    if (!isKeyAndWindowCoderConsistentWithEquals(keyCoder, windowCoder)) {
        LOG.warn("Either coder {} or {} is not consistent with equals. " + "That might cause issues on some runners.", keyCoder, windowCoder);
    }
    return rdd.flatMapToPair((WindowedValue<KV<K, V>> windowedValue) -> {
        final byte[] keyBytes = CoderHelpers.toByteArray(windowedValue.getValue().getKey(), keyCoder);
        return Iterators.transform(windowedValue.explodeWindows().iterator(), item -> {
            Objects.requireNonNull(item, "Exploded window can not be null.");
            @SuppressWarnings("unchecked") final W window = (W) Iterables.getOnlyElement(item.getWindows());
            final byte[] windowBytes = CoderHelpers.toByteArray(window, windowCoder);
            WindowedValue<KV<K, V>> valueOut = WindowedValue.of(item.getValue(), item.getTimestamp(), window, item.getPane());
            final ByteArray windowedKey = new ByteArray(Bytes.concat(keyBytes, windowBytes));
            return new Tuple2<>(windowedKey, mappingFn.apply(valueOut));
        });
    });
}
Also used : WindowedValue(org.apache.beam.sdk.util.WindowedValue) KV(org.apache.beam.sdk.values.KV) Tuple2(scala.Tuple2) ByteArray(org.apache.beam.runners.spark.util.ByteArray) KV(org.apache.beam.sdk.values.KV)

Example 3 with ByteArray

use of org.apache.beam.runners.spark.util.ByteArray in project beam by apache.

the class TransformTranslatorTest method testSplitBySameKey.

@Test
public void testSplitBySameKey() {
    VarIntCoder coder = VarIntCoder.of();
    WindowedValue.WindowedValueCoder<Integer> wvCoder = WindowedValue.FullWindowedValueCoder.of(coder, GlobalWindow.Coder.INSTANCE);
    Instant now = Instant.now();
    List<GlobalWindow> window = Arrays.asList(GlobalWindow.INSTANCE);
    PaneInfo paneInfo = PaneInfo.NO_FIRING;
    List<Tuple2<ByteArray, byte[]>> firstKey = Arrays.asList(new Tuple2(new ByteArray(CoderHelpers.toByteArrayWithTs(1, coder, now)), CoderHelpers.toByteArray(WindowedValue.of(1, now, window, paneInfo), wvCoder)), new Tuple2(new ByteArray(CoderHelpers.toByteArrayWithTs(1, coder, now.plus(Duration.millis(1)))), CoderHelpers.toByteArray(WindowedValue.of(2, now.plus(Duration.millis(1)), window, paneInfo), wvCoder)));
    List<Tuple2<ByteArray, byte[]>> secondKey = Arrays.asList(new Tuple2(new ByteArray(CoderHelpers.toByteArrayWithTs(2, coder, now)), CoderHelpers.toByteArray(WindowedValue.of(3, now, window, paneInfo), wvCoder)), new Tuple2(new ByteArray(CoderHelpers.toByteArrayWithTs(2, coder, now.plus(Duration.millis(2)))), CoderHelpers.toByteArray(WindowedValue.of(4, now.plus(Duration.millis(2)), window, paneInfo), wvCoder)));
    Iterable<Tuple2<ByteArray, byte[]>> concat = Iterables.concat(firstKey, secondKey);
    Iterator<Iterator<WindowedValue<KV<Integer, Integer>>>> keySplit;
    keySplit = TransformTranslator.splitBySameKey(concat.iterator(), coder, wvCoder);
    for (int i = 0; i < 2; i++) {
        Iterator<WindowedValue<KV<Integer, Integer>>> iter = keySplit.next();
        List<WindowedValue<KV<Integer, Integer>>> list = new ArrayList<>();
        Iterators.addAll(list, iter);
        if (i == 0) {
            // first key
            assertEquals(Arrays.asList(WindowedValue.of(KV.of(1, 1), now, window, paneInfo), WindowedValue.of(KV.of(1, 2), now.plus(Duration.millis(1)), window, paneInfo)), list);
        } else {
            // second key
            assertEquals(Arrays.asList(WindowedValue.of(KV.of(2, 3), now, window, paneInfo), WindowedValue.of(KV.of(2, 4), now.plus(Duration.millis(2)), window, paneInfo)), list);
        }
    }
}
Also used : VarIntCoder(org.apache.beam.sdk.coders.VarIntCoder) Instant(org.joda.time.Instant) ArrayList(java.util.ArrayList) KV(org.apache.beam.sdk.values.KV) WindowedValue(org.apache.beam.sdk.util.WindowedValue) Tuple2(scala.Tuple2) PaneInfo(org.apache.beam.sdk.transforms.windowing.PaneInfo) Iterator(java.util.Iterator) ByteArray(org.apache.beam.runners.spark.util.ByteArray) GlobalWindow(org.apache.beam.sdk.transforms.windowing.GlobalWindow) Test(org.junit.Test)

Example 4 with ByteArray

use of org.apache.beam.runners.spark.util.ByteArray in project beam by apache.

the class GroupCombineFunctions method groupByKeyOnly.

/**
   * An implementation of
   * {@link org.apache.beam.runners.core.GroupByKeyViaGroupByKeyOnly.GroupByKeyOnly}
   * for the Spark runner.
   */
public static <K, V> JavaRDD<WindowedValue<KV<K, Iterable<WindowedValue<V>>>>> groupByKeyOnly(JavaRDD<WindowedValue<KV<K, V>>> rdd, Coder<K> keyCoder, WindowedValueCoder<V> wvCoder) {
    // we use coders to convert objects in the PCollection to byte arrays, so they
    // can be transferred over the network for the shuffle.
    JavaPairRDD<ByteArray, byte[]> pairRDD = rdd.map(new ReifyTimestampsAndWindowsFunction<K, V>()).map(WindowingHelpers.<KV<K, WindowedValue<V>>>unwindowFunction()).mapToPair(TranslationUtils.<K, WindowedValue<V>>toPairFunction()).mapToPair(CoderHelpers.toByteFunction(keyCoder, wvCoder));
    // use a default parallelism HashPartitioner.
    Partitioner partitioner = new HashPartitioner(rdd.rdd().sparkContext().defaultParallelism());
    // and avoid unnecessary shuffle downstream.
    return pairRDD.groupByKey(partitioner).mapPartitionsToPair(TranslationUtils.pairFunctionToPairFlatMapFunction(CoderHelpers.fromByteFunctionIterable(keyCoder, wvCoder)), true).mapPartitions(TranslationUtils.<K, Iterable<WindowedValue<V>>>fromPairFlatMapFunction(), true).mapPartitions(TranslationUtils.functionToFlatMapFunction(WindowingHelpers.<KV<K, Iterable<WindowedValue<V>>>>windowFunction()), true);
}
Also used : WindowedValue(org.apache.beam.sdk.util.WindowedValue) KV(org.apache.beam.sdk.values.KV) ByteArray(org.apache.beam.runners.spark.util.ByteArray) HashPartitioner(org.apache.spark.HashPartitioner) KV(org.apache.beam.sdk.values.KV) HashPartitioner(org.apache.spark.HashPartitioner) Partitioner(org.apache.spark.Partitioner)

Example 5 with ByteArray

use of org.apache.beam.runners.spark.util.ByteArray in project beam by apache.

the class SparkBatchPortablePipelineTranslator method translateExecutableStage.

private static <InputT, OutputT, SideInputT> void translateExecutableStage(PTransformNode transformNode, RunnerApi.Pipeline pipeline, SparkTranslationContext context) {
    RunnerApi.ExecutableStagePayload stagePayload;
    try {
        stagePayload = RunnerApi.ExecutableStagePayload.parseFrom(transformNode.getTransform().getSpec().getPayload());
    } catch (IOException e) {
        throw new RuntimeException(e);
    }
    String inputPCollectionId = stagePayload.getInput();
    Dataset inputDataset = context.popDataset(inputPCollectionId);
    Map<String, String> outputs = transformNode.getTransform().getOutputsMap();
    BiMap<String, Integer> outputExtractionMap = createOutputMap(outputs.values());
    Components components = pipeline.getComponents();
    Coder windowCoder = getWindowingStrategy(inputPCollectionId, components).getWindowFn().windowCoder();
    ImmutableMap<String, Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<SideInputT>>> broadcastVariables = broadcastSideInputs(stagePayload, context);
    JavaRDD<RawUnionValue> staged;
    if (stagePayload.getUserStatesCount() > 0 || stagePayload.getTimersCount() > 0) {
        Coder<WindowedValue<InputT>> windowedInputCoder = instantiateCoder(inputPCollectionId, components);
        Coder valueCoder = ((WindowedValue.FullWindowedValueCoder) windowedInputCoder).getValueCoder();
        // Stateful stages are only allowed of KV input to be able to group on the key
        if (!(valueCoder instanceof KvCoder)) {
            throw new IllegalStateException(String.format(Locale.ENGLISH, "The element coder for stateful DoFn '%s' must be KvCoder but is: %s", inputPCollectionId, valueCoder.getClass().getSimpleName()));
        }
        Coder keyCoder = ((KvCoder) valueCoder).getKeyCoder();
        Coder innerValueCoder = ((KvCoder) valueCoder).getValueCoder();
        WindowingStrategy windowingStrategy = getWindowingStrategy(inputPCollectionId, components);
        WindowFn<Object, BoundedWindow> windowFn = windowingStrategy.getWindowFn();
        WindowedValue.WindowedValueCoder wvCoder = WindowedValue.FullWindowedValueCoder.of(innerValueCoder, windowFn.windowCoder());
        JavaPairRDD<ByteArray, Iterable<WindowedValue<KV>>> groupedByKey = groupByKeyPair(inputDataset, keyCoder, wvCoder);
        SparkExecutableStageFunction<KV, SideInputT> function = new SparkExecutableStageFunction<>(context.getSerializableOptions(), stagePayload, context.jobInfo, outputExtractionMap, SparkExecutableStageContextFactory.getInstance(), broadcastVariables, MetricsAccumulator.getInstance(), windowCoder);
        staged = groupedByKey.flatMap(function.forPair());
    } else {
        JavaRDD<WindowedValue<InputT>> inputRdd2 = ((BoundedDataset<InputT>) inputDataset).getRDD();
        SparkExecutableStageFunction<InputT, SideInputT> function2 = new SparkExecutableStageFunction<>(context.getSerializableOptions(), stagePayload, context.jobInfo, outputExtractionMap, SparkExecutableStageContextFactory.getInstance(), broadcastVariables, MetricsAccumulator.getInstance(), windowCoder);
        staged = inputRdd2.mapPartitions(function2);
    }
    String intermediateId = getExecutableStageIntermediateId(transformNode);
    context.pushDataset(intermediateId, new Dataset() {

        @Override
        public void cache(String storageLevel, Coder<?> coder) {
            StorageLevel level = StorageLevel.fromString(storageLevel);
            staged.persist(level);
        }

        @Override
        public void action() {
            // Empty function to force computation of RDD.
            staged.foreach(TranslationUtils.emptyVoidFunction());
        }

        @Override
        public void setName(String name) {
            staged.setName(name);
        }
    });
    // pop dataset to mark RDD as used
    context.popDataset(intermediateId);
    for (String outputId : outputs.values()) {
        JavaRDD<WindowedValue<OutputT>> outputRdd = staged.flatMap(new SparkExecutableStageExtractionFunction<>(outputExtractionMap.get(outputId)));
        context.pushDataset(outputId, new BoundedDataset<>(outputRdd));
    }
    if (outputs.isEmpty()) {
        // After pipeline translation, we traverse the set of unconsumed PCollections and add a
        // no-op sink to each to make sure they are materialized by Spark. However, some SDK-executed
        // stages have no runner-visible output after fusion. We handle this case by adding a sink
        // here.
        JavaRDD<WindowedValue<OutputT>> outputRdd = staged.flatMap((rawUnionValue) -> Collections.emptyIterator());
        context.pushDataset(String.format("EmptyOutputSink_%d", context.nextSinkId()), new BoundedDataset<>(outputRdd));
    }
}
Also used : PipelineTranslatorUtils.getWindowingStrategy(org.apache.beam.runners.fnexecution.translation.PipelineTranslatorUtils.getWindowingStrategy) WindowingStrategy(org.apache.beam.sdk.values.WindowingStrategy) Components(org.apache.beam.model.pipeline.v1.RunnerApi.Components) WindowedValueCoder(org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder) RunnerApi(org.apache.beam.model.pipeline.v1.RunnerApi) WindowedValue(org.apache.beam.sdk.util.WindowedValue) BoundedWindow(org.apache.beam.sdk.transforms.windowing.BoundedWindow) ByteArray(org.apache.beam.runners.spark.util.ByteArray) List(java.util.List) StorageLevel(org.apache.spark.storage.StorageLevel) WindowedValueCoder(org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder) KvCoder(org.apache.beam.sdk.coders.KvCoder) PipelineTranslatorUtils.instantiateCoder(org.apache.beam.runners.fnexecution.translation.PipelineTranslatorUtils.instantiateCoder) PipelineTranslatorUtils.getWindowedValueCoder(org.apache.beam.runners.fnexecution.translation.PipelineTranslatorUtils.getWindowedValueCoder) Coder(org.apache.beam.sdk.coders.Coder) ByteArrayCoder(org.apache.beam.sdk.coders.ByteArrayCoder) RawUnionValue(org.apache.beam.sdk.transforms.join.RawUnionValue) KvCoder(org.apache.beam.sdk.coders.KvCoder) IOException(java.io.IOException) KV(org.apache.beam.sdk.values.KV) Tuple2(scala.Tuple2)

Aggregations

ByteArray (org.apache.beam.runners.spark.util.ByteArray)5 WindowedValue (org.apache.beam.sdk.util.WindowedValue)5 KV (org.apache.beam.sdk.values.KV)5 Tuple2 (scala.Tuple2)4 ArrayList (java.util.ArrayList)2 List (java.util.List)2 Partitioner (org.apache.spark.Partitioner)2 AbstractIterator (com.google.common.collect.AbstractIterator)1 IOException (java.io.IOException)1 Collection (java.util.Collection)1 Iterator (java.util.Iterator)1 RunnerApi (org.apache.beam.model.pipeline.v1.RunnerApi)1 Components (org.apache.beam.model.pipeline.v1.RunnerApi.Components)1 OutputWindowedValue (org.apache.beam.runners.core.OutputWindowedValue)1 ReduceFnRunner (org.apache.beam.runners.core.ReduceFnRunner)1 SystemReduceFn (org.apache.beam.runners.core.SystemReduceFn)1 TimerInternals (org.apache.beam.runners.core.TimerInternals)1 UnsupportedSideInputReader (org.apache.beam.runners.core.UnsupportedSideInputReader)1 CounterCell (org.apache.beam.runners.core.metrics.CounterCell)1 MetricsContainerImpl (org.apache.beam.runners.core.metrics.MetricsContainerImpl)1