use of org.apache.spark.api.java.JavaRDD in project beam by apache.
the class SparkGroupAlsoByWindowViaWindowSet method groupAlsoByWindow.
public static <K, InputT, W extends BoundedWindow> JavaDStream<WindowedValue<KV<K, Iterable<InputT>>>> groupAlsoByWindow(JavaDStream<WindowedValue<KV<K, Iterable<WindowedValue<InputT>>>>> inputDStream, final Coder<K> keyCoder, final Coder<WindowedValue<InputT>> wvCoder, final WindowingStrategy<?, W> windowingStrategy, final SparkRuntimeContext runtimeContext, final List<Integer> sourceIds) {
final IterableCoder<WindowedValue<InputT>> itrWvCoder = IterableCoder.of(wvCoder);
final Coder<InputT> iCoder = ((FullWindowedValueCoder<InputT>) wvCoder).getValueCoder();
final Coder<? extends BoundedWindow> wCoder = ((FullWindowedValueCoder<InputT>) wvCoder).getWindowCoder();
final Coder<WindowedValue<KV<K, Iterable<InputT>>>> wvKvIterCoder = FullWindowedValueCoder.of(KvCoder.of(keyCoder, IterableCoder.of(iCoder)), wCoder);
final TimerInternals.TimerDataCoder timerDataCoder = TimerInternals.TimerDataCoder.of(windowingStrategy.getWindowFn().windowCoder());
long checkpointDurationMillis = runtimeContext.getPipelineOptions().as(SparkPipelineOptions.class).getCheckpointDurationMillis();
// we have to switch to Scala API to avoid Optional in the Java API, see: SPARK-4819.
// we also have a broader API for Scala (access to the actual key and entire iterator).
// we use coders to convert objects in the PCollection to byte arrays, so they
// can be transferred over the network for the shuffle and be in serialized form
// for checkpointing.
// for readability, we add comments with actual type next to byte[].
// to shorten line length, we use:
//---- WV: WindowedValue
//---- Iterable: Itr
//---- AccumT: A
//---- InputT: I
DStream<Tuple2<ByteArray, byte[]>> /*Itr<WV<I>>*/
pairDStream = inputDStream.transformToPair(new Function<JavaRDD<WindowedValue<KV<K, Iterable<WindowedValue<InputT>>>>>, JavaPairRDD<ByteArray, byte[]>>() {
// we use mapPartitions with the RDD API because its the only available API
// that allows to preserve partitioning.
@Override
public JavaPairRDD<ByteArray, byte[]> call(JavaRDD<WindowedValue<KV<K, Iterable<WindowedValue<InputT>>>>> rdd) throws Exception {
return rdd.mapPartitions(TranslationUtils.functionToFlatMapFunction(WindowingHelpers.<KV<K, Iterable<WindowedValue<InputT>>>>unwindowFunction()), true).mapPartitionsToPair(TranslationUtils.<K, Iterable<WindowedValue<InputT>>>toPairFlatMapFunction(), true).mapPartitionsToPair(TranslationUtils.pairFunctionToPairFlatMapFunction(CoderHelpers.toByteFunction(keyCoder, itrWvCoder)), true);
}
}).dstream();
PairDStreamFunctions<ByteArray, byte[]> pairDStreamFunctions = DStream.toPairDStreamFunctions(pairDStream, JavaSparkContext$.MODULE$.<ByteArray>fakeClassTag(), JavaSparkContext$.MODULE$.<byte[]>fakeClassTag(), null);
int defaultNumPartitions = pairDStreamFunctions.defaultPartitioner$default$1();
Partitioner partitioner = pairDStreamFunctions.defaultPartitioner(defaultNumPartitions);
// use updateStateByKey to scan through the state and update elements and timers.
DStream<Tuple2<ByteArray, Tuple2<StateAndTimers, List<byte[]>>>> /*WV<KV<K, Itr<I>>>*/
firedStream = pairDStreamFunctions.updateStateByKey(new SerializableFunction1<scala.collection.Iterator<Tuple3</*K*/
ByteArray, Seq<byte[]>, Option<Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/
List<byte[]>>>>>, scala.collection.Iterator<Tuple2</*K*/
ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/
List<byte[]>>>>>() {
@Override
public scala.collection.Iterator<Tuple2</*K*/
ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/
List<byte[]>>>> apply(final scala.collection.Iterator<Tuple3</*K*/
ByteArray, Seq<byte[]>, Option<Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/
List<byte[]>>>>> iter) {
//--- ACTUAL STATEFUL OPERATION:
//
// Input Iterator: the partition (~bundle) of a cogrouping of the input
// and the previous state (if exists).
//
// Output Iterator: the output key, and the updated state.
//
// possible input scenarios for (K, Seq, Option<S>):
// (1) Option<S>.isEmpty: new data with no previous state.
// (2) Seq.isEmpty: no new data, but evaluating previous state (timer-like behaviour).
// (3) Seq.nonEmpty && Option<S>.isDefined: new data with previous state.
final SystemReduceFn<K, InputT, Iterable<InputT>, Iterable<InputT>, W> reduceFn = SystemReduceFn.buffering(((FullWindowedValueCoder<InputT>) wvCoder).getValueCoder());
final OutputWindowedValueHolder<K, InputT> outputHolder = new OutputWindowedValueHolder<>();
// use in memory Aggregators since Spark Accumulators are not resilient
// in stateful operators, once done with this partition.
final MetricsContainerImpl cellProvider = new MetricsContainerImpl("cellProvider");
final CounterCell droppedDueToClosedWindow = cellProvider.getCounter(MetricName.named(SparkGroupAlsoByWindowViaWindowSet.class, GroupAlsoByWindowsAggregators.DROPPED_DUE_TO_CLOSED_WINDOW_COUNTER));
final CounterCell droppedDueToLateness = cellProvider.getCounter(MetricName.named(SparkGroupAlsoByWindowViaWindowSet.class, GroupAlsoByWindowsAggregators.DROPPED_DUE_TO_LATENESS_COUNTER));
AbstractIterator<Tuple2<ByteArray, Tuple2<StateAndTimers, List<byte[]>>>> /*WV<KV<K, Itr<I>>>*/
outIter = new AbstractIterator<Tuple2</*K*/
ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/
List<byte[]>>>>() {
@Override
protected Tuple2</*K*/
ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/
List<byte[]>>> computeNext() {
// (possibly) previous-state and (possibly) new data.
while (iter.hasNext()) {
// for each element in the partition:
Tuple3<ByteArray, Seq<byte[]>, Option<Tuple2<StateAndTimers, List<byte[]>>>> next = iter.next();
ByteArray encodedKey = next._1();
K key = CoderHelpers.fromByteArray(encodedKey.getValue(), keyCoder);
Seq<byte[]> seq = next._2();
Option<Tuple2<StateAndTimers, List<byte[]>>> prevStateAndTimersOpt = next._3();
SparkStateInternals<K> stateInternals;
SparkTimerInternals timerInternals = SparkTimerInternals.forStreamFromSources(sourceIds, GlobalWatermarkHolder.get());
// get state(internals) per key.
if (prevStateAndTimersOpt.isEmpty()) {
// no previous state.
stateInternals = SparkStateInternals.forKey(key);
} else {
// with pre-existing state.
StateAndTimers prevStateAndTimers = prevStateAndTimersOpt.get()._1();
stateInternals = SparkStateInternals.forKeyAndState(key, prevStateAndTimers.getState());
Collection<byte[]> serTimers = prevStateAndTimers.getTimers();
timerInternals.addTimers(SparkTimerInternals.deserializeTimers(serTimers, timerDataCoder));
}
ReduceFnRunner<K, InputT, Iterable<InputT>, W> reduceFnRunner = new ReduceFnRunner<>(key, windowingStrategy, ExecutableTriggerStateMachine.create(TriggerStateMachines.stateMachineForTrigger(TriggerTranslation.toProto(windowingStrategy.getTrigger()))), stateInternals, timerInternals, outputHolder, new UnsupportedSideInputReader("GroupAlsoByWindow"), reduceFn, runtimeContext.getPipelineOptions());
// clear before potential use.
outputHolder.clear();
if (!seq.isEmpty()) {
// new input for key.
try {
Iterable<WindowedValue<InputT>> elementsIterable = CoderHelpers.fromByteArray(seq.head(), itrWvCoder);
Iterable<WindowedValue<InputT>> validElements = LateDataUtils.dropExpiredWindows(key, elementsIterable, timerInternals, windowingStrategy, droppedDueToLateness);
reduceFnRunner.processElements(validElements);
} catch (Exception e) {
throw new RuntimeException("Failed to process element with ReduceFnRunner", e);
}
} else if (stateInternals.getState().isEmpty()) {
// no input and no state -> GC evict now.
continue;
}
try {
// advance the watermark to HWM to fire by timers.
timerInternals.advanceWatermark();
// call on timers that are ready.
reduceFnRunner.onTimers(timerInternals.getTimersReadyToProcess());
} catch (Exception e) {
throw new RuntimeException("Failed to process ReduceFnRunner onTimer.", e);
}
// this is mostly symbolic since actual persist is done by emitting output.
reduceFnRunner.persist();
// obtain output, if fired.
List<WindowedValue<KV<K, Iterable<InputT>>>> outputs = outputHolder.get();
if (!outputs.isEmpty() || !stateInternals.getState().isEmpty()) {
StateAndTimers updated = new StateAndTimers(stateInternals.getState(), SparkTimerInternals.serializeTimers(timerInternals.getTimers(), timerDataCoder));
// persist Spark's state by outputting.
List<byte[]> serOutput = CoderHelpers.toByteArrays(outputs, wvKvIterCoder);
return new Tuple2<>(encodedKey, new Tuple2<>(updated, serOutput));
}
// an empty state with no output, can be evicted completely - do nothing.
}
return endOfData();
}
};
// log if there's something to log.
long lateDropped = droppedDueToLateness.getCumulative();
if (lateDropped > 0) {
LOG.info(String.format("Dropped %d elements due to lateness.", lateDropped));
droppedDueToLateness.inc(-droppedDueToLateness.getCumulative());
}
long closedWindowDropped = droppedDueToClosedWindow.getCumulative();
if (closedWindowDropped > 0) {
LOG.info(String.format("Dropped %d elements due to closed window.", closedWindowDropped));
droppedDueToClosedWindow.inc(-droppedDueToClosedWindow.getCumulative());
}
return scala.collection.JavaConversions.asScalaIterator(outIter);
}
}, partitioner, true, JavaSparkContext$.MODULE$.<Tuple2<StateAndTimers, List<byte[]>>>fakeClassTag());
if (checkpointDurationMillis > 0) {
firedStream.checkpoint(new Duration(checkpointDurationMillis));
}
// go back to Java now.
JavaPairDStream<ByteArray, Tuple2<StateAndTimers, List<byte[]>>> /*WV<KV<K, Itr<I>>>*/
javaFiredStream = JavaPairDStream.fromPairDStream(firedStream, JavaSparkContext$.MODULE$.<ByteArray>fakeClassTag(), JavaSparkContext$.MODULE$.<Tuple2<StateAndTimers, List<byte[]>>>fakeClassTag());
// filter state-only output (nothing to fire) and remove the state from the output.
return javaFiredStream.filter(new Function<Tuple2</*K*/
ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/
List<byte[]>>>, Boolean>() {
@Override
public Boolean call(Tuple2</*K*/
ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/
List<byte[]>>> t2) throws Exception {
// filter output if defined.
return !t2._2()._2().isEmpty();
}
}).flatMap(new FlatMapFunction<Tuple2</*K*/
ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/
List<byte[]>>>, WindowedValue<KV<K, Iterable<InputT>>>>() {
@Override
public Iterable<WindowedValue<KV<K, Iterable<InputT>>>> call(Tuple2</*K*/
ByteArray, Tuple2<StateAndTimers, /*WV<KV<K, Itr<I>>>*/
List<byte[]>>> t2) throws Exception {
// return in serialized form.
return CoderHelpers.fromByteArrays(t2._2()._2(), wvKvIterCoder);
}
});
}
use of org.apache.spark.api.java.JavaRDD in project beam by apache.
the class StreamingTransformTranslator method groupByKey.
private static <K, V, W extends BoundedWindow> TransformEvaluator<GroupByKey<K, V>> groupByKey() {
return new TransformEvaluator<GroupByKey<K, V>>() {
@Override
public void evaluate(GroupByKey<K, V> transform, EvaluationContext context) {
@SuppressWarnings("unchecked") UnboundedDataset<KV<K, V>> inputDataset = (UnboundedDataset<KV<K, V>>) context.borrowDataset(transform);
List<Integer> streamSources = inputDataset.getStreamSources();
JavaDStream<WindowedValue<KV<K, V>>> dStream = inputDataset.getDStream();
@SuppressWarnings("unchecked") final KvCoder<K, V> coder = (KvCoder<K, V>) context.getInput(transform).getCoder();
final SparkRuntimeContext runtimeContext = context.getRuntimeContext();
@SuppressWarnings("unchecked") final WindowingStrategy<?, W> windowingStrategy = (WindowingStrategy<?, W>) context.getInput(transform).getWindowingStrategy();
@SuppressWarnings("unchecked") final WindowFn<Object, W> windowFn = (WindowFn<Object, W>) windowingStrategy.getWindowFn();
//--- coders.
final WindowedValue.WindowedValueCoder<V> wvCoder = WindowedValue.FullWindowedValueCoder.of(coder.getValueCoder(), windowFn.windowCoder());
//--- group by key only.
JavaDStream<WindowedValue<KV<K, Iterable<WindowedValue<V>>>>> groupedByKeyStream = dStream.transform(new Function<JavaRDD<WindowedValue<KV<K, V>>>, JavaRDD<WindowedValue<KV<K, Iterable<WindowedValue<V>>>>>>() {
@Override
public JavaRDD<WindowedValue<KV<K, Iterable<WindowedValue<V>>>>> call(JavaRDD<WindowedValue<KV<K, V>>> rdd) throws Exception {
return GroupCombineFunctions.groupByKeyOnly(rdd, coder.getKeyCoder(), wvCoder);
}
});
//--- now group also by window.
JavaDStream<WindowedValue<KV<K, Iterable<V>>>> outStream = SparkGroupAlsoByWindowViaWindowSet.groupAlsoByWindow(groupedByKeyStream, coder.getKeyCoder(), wvCoder, windowingStrategy, runtimeContext, streamSources);
context.putDataset(transform, new UnboundedDataset<>(outStream, streamSources));
}
@Override
public String toNativeString() {
return "groupByKey()";
}
};
}
use of org.apache.spark.api.java.JavaRDD in project beam by apache.
the class StreamingTransformTranslator method flattenPColl.
private static <T> TransformEvaluator<Flatten.PCollections<T>> flattenPColl() {
return new TransformEvaluator<Flatten.PCollections<T>>() {
@SuppressWarnings("unchecked")
@Override
public void evaluate(Flatten.PCollections<T> transform, EvaluationContext context) {
Map<TupleTag<?>, PValue> pcs = context.getInputs(transform);
// since this is a streaming pipeline, at least one of the PCollections to "flatten" are
// unbounded, meaning it represents a DStream.
// So we could end up with an unbounded unified DStream.
final List<JavaDStream<WindowedValue<T>>> dStreams = new ArrayList<>();
final List<Integer> streamingSources = new ArrayList<>();
for (PValue pv : pcs.values()) {
checkArgument(pv instanceof PCollection, "Flatten had non-PCollection value in input: %s of type %s", pv, pv.getClass().getSimpleName());
PCollection<T> pcol = (PCollection<T>) pv;
Dataset dataset = context.borrowDataset(pcol);
if (dataset instanceof UnboundedDataset) {
UnboundedDataset<T> unboundedDataset = (UnboundedDataset<T>) dataset;
streamingSources.addAll(unboundedDataset.getStreamSources());
dStreams.add(unboundedDataset.getDStream());
} else {
// create a single RDD stream.
Queue<JavaRDD<WindowedValue<T>>> q = new LinkedBlockingQueue<>();
q.offer(((BoundedDataset) dataset).getRDD());
//TODO: this is not recoverable from checkpoint!
JavaDStream<WindowedValue<T>> dStream = context.getStreamingContext().queueStream(q);
dStreams.add(dStream);
}
}
// start by unifying streams into a single stream.
JavaDStream<WindowedValue<T>> unifiedStreams = context.getStreamingContext().union(dStreams.remove(0), dStreams);
context.putDataset(transform, new UnboundedDataset<>(unifiedStreams, streamingSources));
}
@Override
public String toNativeString() {
return "streamingContext.union(...)";
}
};
}
use of org.apache.spark.api.java.JavaRDD in project beam by apache.
the class StreamingTransformTranslator method parDo.
private static <InputT, OutputT> TransformEvaluator<ParDo.MultiOutput<InputT, OutputT>> parDo() {
return new TransformEvaluator<ParDo.MultiOutput<InputT, OutputT>>() {
public void evaluate(final ParDo.MultiOutput<InputT, OutputT> transform, final EvaluationContext context) {
final DoFn<InputT, OutputT> doFn = transform.getFn();
rejectSplittable(doFn);
rejectStateAndTimers(doFn);
final SparkRuntimeContext runtimeContext = context.getRuntimeContext();
final SparkPCollectionView pviews = context.getPViews();
final WindowingStrategy<?, ?> windowingStrategy = context.getInput(transform).getWindowingStrategy();
@SuppressWarnings("unchecked") UnboundedDataset<InputT> unboundedDataset = ((UnboundedDataset<InputT>) context.borrowDataset(transform));
JavaDStream<WindowedValue<InputT>> dStream = unboundedDataset.getDStream();
final String stepName = context.getCurrentTransform().getFullName();
JavaPairDStream<TupleTag<?>, WindowedValue<?>> all = dStream.transformToPair(new Function<JavaRDD<WindowedValue<InputT>>, JavaPairRDD<TupleTag<?>, WindowedValue<?>>>() {
@Override
public JavaPairRDD<TupleTag<?>, WindowedValue<?>> call(JavaRDD<WindowedValue<InputT>> rdd) throws Exception {
final Accumulator<NamedAggregators> aggAccum = AggregatorsAccumulator.getInstance();
final Accumulator<MetricsContainerStepMap> metricsAccum = MetricsAccumulator.getInstance();
final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> sideInputs = TranslationUtils.getSideInputs(transform.getSideInputs(), JavaSparkContext.fromSparkContext(rdd.context()), pviews);
return rdd.mapPartitionsToPair(new MultiDoFnFunction<>(aggAccum, metricsAccum, stepName, doFn, runtimeContext, transform.getMainOutputTag(), transform.getAdditionalOutputTags().getAll(), sideInputs, windowingStrategy, false));
}
});
Map<TupleTag<?>, PValue> outputs = context.getOutputs(transform);
if (outputs.size() > 1) {
// cache the DStream if we're going to filter it more than once.
all.cache();
}
for (Map.Entry<TupleTag<?>, PValue> output : outputs.entrySet()) {
@SuppressWarnings("unchecked") JavaPairDStream<TupleTag<?>, WindowedValue<?>> filtered = all.filter(new TranslationUtils.TupleTagFilter(output.getKey()));
@SuppressWarnings("unchecked") JavaDStream<WindowedValue<Object>> // Object is the best we can do since different outputs can have different tags
values = (JavaDStream<WindowedValue<Object>>) (JavaDStream<?>) TranslationUtils.dStreamValues(filtered);
context.putDataset(output.getValue(), new UnboundedDataset<>(values, unboundedDataset.getStreamSources()));
}
}
@Override
public String toNativeString() {
return "mapPartitions(new <fn>())";
}
};
}
use of org.apache.spark.api.java.JavaRDD in project beam by apache.
the class TransformTranslator method flattenPColl.
private static <T> TransformEvaluator<Flatten.PCollections<T>> flattenPColl() {
return new TransformEvaluator<Flatten.PCollections<T>>() {
@SuppressWarnings("unchecked")
@Override
public void evaluate(Flatten.PCollections<T> transform, EvaluationContext context) {
Collection<PValue> pcs = context.getInputs(transform).values();
JavaRDD<WindowedValue<T>> unionRDD;
if (pcs.size() == 0) {
unionRDD = context.getSparkContext().emptyRDD();
} else {
JavaRDD<WindowedValue<T>>[] rdds = new JavaRDD[pcs.size()];
int index = 0;
for (PValue pc : pcs) {
checkArgument(pc instanceof PCollection, "Flatten had non-PCollection value in input: %s of type %s", pc, pc.getClass().getSimpleName());
rdds[index] = ((BoundedDataset<T>) context.borrowDataset(pc)).getRDD();
index++;
}
unionRDD = context.getSparkContext().union(rdds);
}
context.putDataset(transform, new BoundedDataset<>(unionRDD));
}
@Override
public String toNativeString() {
return "sparkContext.union(...)";
}
};
}
Aggregations