use of org.apache.beam.sdk.coders.KvCoder in project beam by apache.
the class SparkStreamingPortablePipelineTranslator method translateGroupByKey.
private static <K, V> void translateGroupByKey(PTransformNode transformNode, RunnerApi.Pipeline pipeline, SparkStreamingTranslationContext context) {
RunnerApi.Components components = pipeline.getComponents();
String inputId = getInputId(transformNode);
UnboundedDataset<KV<K, V>> inputDataset = (UnboundedDataset<KV<K, V>>) context.popDataset(inputId);
List<Integer> streamSources = inputDataset.getStreamSources();
WindowedValue.WindowedValueCoder<KV<K, V>> inputCoder = getWindowedValueCoder(inputId, components);
KvCoder<K, V> inputKvCoder = (KvCoder<K, V>) inputCoder.getValueCoder();
WindowingStrategy windowingStrategy = getWindowingStrategy(inputId, components);
WindowFn<Object, BoundedWindow> windowFn = windowingStrategy.getWindowFn();
WindowedValue.WindowedValueCoder<V> wvCoder = WindowedValue.FullWindowedValueCoder.of(inputKvCoder.getValueCoder(), windowFn.windowCoder());
JavaDStream<WindowedValue<KV<K, Iterable<V>>>> outStream = SparkGroupAlsoByWindowViaWindowSet.groupByKeyAndWindow(inputDataset.getDStream(), inputKvCoder.getKeyCoder(), wvCoder, windowingStrategy, context.getSerializableOptions(), streamSources, transformNode.getId());
context.pushDataset(getOutputId(transformNode), new UnboundedDataset<>(outStream, streamSources));
}
use of org.apache.beam.sdk.coders.KvCoder in project beam by apache.
the class SparkBatchPortablePipelineTranslator method translateExecutableStage.
private static <InputT, OutputT, SideInputT> void translateExecutableStage(PTransformNode transformNode, RunnerApi.Pipeline pipeline, SparkTranslationContext context) {
RunnerApi.ExecutableStagePayload stagePayload;
try {
stagePayload = RunnerApi.ExecutableStagePayload.parseFrom(transformNode.getTransform().getSpec().getPayload());
} catch (IOException e) {
throw new RuntimeException(e);
}
String inputPCollectionId = stagePayload.getInput();
Dataset inputDataset = context.popDataset(inputPCollectionId);
Map<String, String> outputs = transformNode.getTransform().getOutputsMap();
BiMap<String, Integer> outputExtractionMap = createOutputMap(outputs.values());
Components components = pipeline.getComponents();
Coder windowCoder = getWindowingStrategy(inputPCollectionId, components).getWindowFn().windowCoder();
ImmutableMap<String, Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<SideInputT>>> broadcastVariables = broadcastSideInputs(stagePayload, context);
JavaRDD<RawUnionValue> staged;
if (stagePayload.getUserStatesCount() > 0 || stagePayload.getTimersCount() > 0) {
Coder<WindowedValue<InputT>> windowedInputCoder = instantiateCoder(inputPCollectionId, components);
Coder valueCoder = ((WindowedValue.FullWindowedValueCoder) windowedInputCoder).getValueCoder();
// Stateful stages are only allowed of KV input to be able to group on the key
if (!(valueCoder instanceof KvCoder)) {
throw new IllegalStateException(String.format(Locale.ENGLISH, "The element coder for stateful DoFn '%s' must be KvCoder but is: %s", inputPCollectionId, valueCoder.getClass().getSimpleName()));
}
Coder keyCoder = ((KvCoder) valueCoder).getKeyCoder();
Coder innerValueCoder = ((KvCoder) valueCoder).getValueCoder();
WindowingStrategy windowingStrategy = getWindowingStrategy(inputPCollectionId, components);
WindowFn<Object, BoundedWindow> windowFn = windowingStrategy.getWindowFn();
WindowedValue.WindowedValueCoder wvCoder = WindowedValue.FullWindowedValueCoder.of(innerValueCoder, windowFn.windowCoder());
JavaPairRDD<ByteArray, Iterable<WindowedValue<KV>>> groupedByKey = groupByKeyPair(inputDataset, keyCoder, wvCoder);
SparkExecutableStageFunction<KV, SideInputT> function = new SparkExecutableStageFunction<>(context.getSerializableOptions(), stagePayload, context.jobInfo, outputExtractionMap, SparkExecutableStageContextFactory.getInstance(), broadcastVariables, MetricsAccumulator.getInstance(), windowCoder);
staged = groupedByKey.flatMap(function.forPair());
} else {
JavaRDD<WindowedValue<InputT>> inputRdd2 = ((BoundedDataset<InputT>) inputDataset).getRDD();
SparkExecutableStageFunction<InputT, SideInputT> function2 = new SparkExecutableStageFunction<>(context.getSerializableOptions(), stagePayload, context.jobInfo, outputExtractionMap, SparkExecutableStageContextFactory.getInstance(), broadcastVariables, MetricsAccumulator.getInstance(), windowCoder);
staged = inputRdd2.mapPartitions(function2);
}
String intermediateId = getExecutableStageIntermediateId(transformNode);
context.pushDataset(intermediateId, new Dataset() {
@Override
public void cache(String storageLevel, Coder<?> coder) {
StorageLevel level = StorageLevel.fromString(storageLevel);
staged.persist(level);
}
@Override
public void action() {
// Empty function to force computation of RDD.
staged.foreach(TranslationUtils.emptyVoidFunction());
}
@Override
public void setName(String name) {
staged.setName(name);
}
});
// pop dataset to mark RDD as used
context.popDataset(intermediateId);
for (String outputId : outputs.values()) {
JavaRDD<WindowedValue<OutputT>> outputRdd = staged.flatMap(new SparkExecutableStageExtractionFunction<>(outputExtractionMap.get(outputId)));
context.pushDataset(outputId, new BoundedDataset<>(outputRdd));
}
if (outputs.isEmpty()) {
// After pipeline translation, we traverse the set of unconsumed PCollections and add a
// no-op sink to each to make sure they are materialized by Spark. However, some SDK-executed
// stages have no runner-visible output after fusion. We handle this case by adding a sink
// here.
JavaRDD<WindowedValue<OutputT>> outputRdd = staged.flatMap((rawUnionValue) -> Collections.emptyIterator());
context.pushDataset(String.format("EmptyOutputSink_%d", context.nextSinkId()), new BoundedDataset<>(outputRdd));
}
}
use of org.apache.beam.sdk.coders.KvCoder in project beam by apache.
the class SparkBatchPortablePipelineTranslator method translateGroupByKey.
private static <K, V> void translateGroupByKey(PTransformNode transformNode, RunnerApi.Pipeline pipeline, SparkTranslationContext context) {
RunnerApi.Components components = pipeline.getComponents();
String inputId = getInputId(transformNode);
Dataset inputDataset = context.popDataset(inputId);
JavaRDD<WindowedValue<KV<K, V>>> inputRdd = ((BoundedDataset<KV<K, V>>) inputDataset).getRDD();
WindowedValueCoder<KV<K, V>> inputCoder = getWindowedValueCoder(inputId, components);
KvCoder<K, V> inputKvCoder = (KvCoder<K, V>) inputCoder.getValueCoder();
Coder<K> inputKeyCoder = inputKvCoder.getKeyCoder();
Coder<V> inputValueCoder = inputKvCoder.getValueCoder();
WindowingStrategy windowingStrategy = getWindowingStrategy(inputId, components);
WindowFn<Object, BoundedWindow> windowFn = windowingStrategy.getWindowFn();
WindowedValue.WindowedValueCoder<V> wvCoder = WindowedValue.FullWindowedValueCoder.of(inputValueCoder, windowFn.windowCoder());
JavaRDD<WindowedValue<KV<K, Iterable<V>>>> groupedByKeyAndWindow;
Partitioner partitioner = getPartitioner(context);
// As this is batch, we can ignore triggering and allowed lateness parameters.
if (windowingStrategy.getWindowFn().equals(new GlobalWindows()) && windowingStrategy.getTimestampCombiner().equals(TimestampCombiner.END_OF_WINDOW)) {
// we can drop the windows and recover them later
groupedByKeyAndWindow = GroupNonMergingWindowsFunctions.groupByKeyInGlobalWindow(inputRdd, inputKeyCoder, inputValueCoder, partitioner);
} else if (GroupNonMergingWindowsFunctions.isEligibleForGroupByWindow(windowingStrategy)) {
// we can have a memory sensitive translation for non-merging windows
groupedByKeyAndWindow = GroupNonMergingWindowsFunctions.groupByKeyAndWindow(inputRdd, inputKeyCoder, inputValueCoder, windowingStrategy, partitioner);
} else {
JavaRDD<KV<K, Iterable<WindowedValue<V>>>> groupedByKeyOnly = GroupCombineFunctions.groupByKeyOnly(inputRdd, inputKeyCoder, wvCoder, partitioner);
// for batch, GroupAlsoByWindow uses an in-memory StateInternals.
groupedByKeyAndWindow = groupedByKeyOnly.flatMap(new SparkGroupAlsoByWindowViaOutputBufferFn<>(windowingStrategy, new TranslationUtils.InMemoryStateInternalsFactory<>(), SystemReduceFn.buffering(inputValueCoder), context.serializablePipelineOptions));
}
context.pushDataset(getOutputId(transformNode), new BoundedDataset<>(groupedByKeyAndWindow));
}
use of org.apache.beam.sdk.coders.KvCoder in project beam by apache.
the class TransformTranslator method groupByKey.
private static <K, V, W extends BoundedWindow> TransformEvaluator<GroupByKey<K, V>> groupByKey() {
return new TransformEvaluator<GroupByKey<K, V>>() {
@Override
public void evaluate(GroupByKey<K, V> transform, EvaluationContext context) {
@SuppressWarnings("unchecked") JavaRDD<WindowedValue<KV<K, V>>> inRDD = ((BoundedDataset<KV<K, V>>) context.borrowDataset(transform)).getRDD();
final KvCoder<K, V> coder = (KvCoder<K, V>) context.getInput(transform).getCoder();
@SuppressWarnings("unchecked") final WindowingStrategy<?, W> windowingStrategy = (WindowingStrategy<?, W>) context.getInput(transform).getWindowingStrategy();
@SuppressWarnings("unchecked") final WindowFn<Object, W> windowFn = (WindowFn<Object, W>) windowingStrategy.getWindowFn();
// --- coders.
final Coder<K> keyCoder = coder.getKeyCoder();
final WindowedValue.WindowedValueCoder<V> wvCoder = WindowedValue.FullWindowedValueCoder.of(coder.getValueCoder(), windowFn.windowCoder());
JavaRDD<WindowedValue<KV<K, Iterable<V>>>> groupedByKey;
Partitioner partitioner = getPartitioner(context);
// As this is batch, we can ignore triggering and allowed lateness parameters.
if (windowingStrategy.getWindowFn().equals(new GlobalWindows()) && windowingStrategy.getTimestampCombiner().equals(TimestampCombiner.END_OF_WINDOW)) {
// we can drop the windows and recover them later
groupedByKey = GroupNonMergingWindowsFunctions.groupByKeyInGlobalWindow(inRDD, keyCoder, coder.getValueCoder(), partitioner);
} else if (GroupNonMergingWindowsFunctions.isEligibleForGroupByWindow(windowingStrategy)) {
// we can have a memory sensitive translation for non-merging windows
groupedByKey = GroupNonMergingWindowsFunctions.groupByKeyAndWindow(inRDD, keyCoder, coder.getValueCoder(), windowingStrategy, partitioner);
} else {
// --- group by key only.
JavaRDD<KV<K, Iterable<WindowedValue<V>>>> groupedByKeyOnly = GroupCombineFunctions.groupByKeyOnly(inRDD, keyCoder, wvCoder, partitioner);
// --- now group also by window.
// for batch, GroupAlsoByWindow uses an in-memory StateInternals.
groupedByKey = groupedByKeyOnly.flatMap(new SparkGroupAlsoByWindowViaOutputBufferFn<>(windowingStrategy, new TranslationUtils.InMemoryStateInternalsFactory<>(), SystemReduceFn.buffering(coder.getValueCoder()), context.getSerializableOptions()));
}
context.putDataset(transform, new BoundedDataset<>(groupedByKey));
}
@Override
public String toNativeString() {
return "groupByKey()";
}
};
}
use of org.apache.beam.sdk.coders.KvCoder in project beam by apache.
the class TransformTranslator method parDo.
private static <InputT, OutputT> TransformEvaluator<ParDo.MultiOutput<InputT, OutputT>> parDo() {
return new TransformEvaluator<ParDo.MultiOutput<InputT, OutputT>>() {
@Override
@SuppressWarnings("unchecked")
public void evaluate(ParDo.MultiOutput<InputT, OutputT> transform, EvaluationContext context) {
String stepName = context.getCurrentTransform().getFullName();
DoFn<InputT, OutputT> doFn = transform.getFn();
checkState(!DoFnSignatures.signatureForDoFn(doFn).processElement().isSplittable(), "Not expected to directly translate splittable DoFn, should have been overridden: %s", doFn);
JavaRDD<WindowedValue<InputT>> inRDD = ((BoundedDataset<InputT>) context.borrowDataset(transform)).getRDD();
WindowingStrategy<?, ?> windowingStrategy = context.getInput(transform).getWindowingStrategy();
MetricsContainerStepMapAccumulator metricsAccum = MetricsAccumulator.getInstance();
Coder<InputT> inputCoder = (Coder<InputT>) context.getInput(transform).getCoder();
Map<TupleTag<?>, Coder<?>> outputCoders = context.getOutputCoders();
JavaPairRDD<TupleTag<?>, WindowedValue<?>> all;
DoFnSignature signature = DoFnSignatures.getSignature(transform.getFn().getClass());
boolean stateful = signature.stateDeclarations().size() > 0 || signature.timerDeclarations().size() > 0;
DoFnSchemaInformation doFnSchemaInformation;
doFnSchemaInformation = ParDoTranslation.getSchemaInformation(context.getCurrentTransform());
Map<String, PCollectionView<?>> sideInputMapping = ParDoTranslation.getSideInputMapping(context.getCurrentTransform());
MultiDoFnFunction<InputT, OutputT> multiDoFnFunction = new MultiDoFnFunction<>(metricsAccum, stepName, doFn, context.getSerializableOptions(), transform.getMainOutputTag(), transform.getAdditionalOutputTags().getAll(), inputCoder, outputCoders, TranslationUtils.getSideInputs(transform.getSideInputs().values(), context), windowingStrategy, stateful, doFnSchemaInformation, sideInputMapping);
if (stateful) {
// Based on the fact that the signature is stateful, DoFnSignatures ensures
// that it is also keyed
all = statefulParDoTransform((KvCoder) context.getInput(transform).getCoder(), windowingStrategy.getWindowFn().windowCoder(), (JavaRDD) inRDD, getPartitioner(context), (MultiDoFnFunction) multiDoFnFunction, signature.processElement().requiresTimeSortedInput());
} else {
all = inRDD.mapPartitionsToPair(multiDoFnFunction);
}
Map<TupleTag<?>, PCollection<?>> outputs = context.getOutputs(transform);
if (outputs.size() > 1) {
StorageLevel level = StorageLevel.fromString(context.storageLevel());
if (canAvoidRddSerialization(level)) {
// if it is memory only reduce the overhead of moving to bytes
all = all.persist(level);
} else {
// Caching can cause Serialization, we need to code to bytes
// more details in https://issues.apache.org/jira/browse/BEAM-2669
Map<TupleTag<?>, Coder<WindowedValue<?>>> coderMap = TranslationUtils.getTupleTagCoders(outputs);
all = all.mapToPair(TranslationUtils.getTupleTagEncodeFunction(coderMap)).persist(level).mapToPair(TranslationUtils.getTupleTagDecodeFunction(coderMap));
}
}
for (Map.Entry<TupleTag<?>, PCollection<?>> output : outputs.entrySet()) {
JavaPairRDD<TupleTag<?>, WindowedValue<?>> filtered = all.filter(new TranslationUtils.TupleTagFilter(output.getKey()));
// Object is the best we can do since different outputs can have different tags
JavaRDD<WindowedValue<Object>> values = (JavaRDD<WindowedValue<Object>>) (JavaRDD<?>) filtered.values();
context.putDataset(output.getValue(), new BoundedDataset<>(values));
}
}
@Override
public String toNativeString() {
return "mapPartitions(new <fn>())";
}
};
}
Aggregations