use of org.apache.beam.model.pipeline.v1.RunnerApi.WindowingStrategy in project beam by apache.
the class GroupAlsoByWindowParDoFnFactory method create.
@Override
public ParDoFn create(PipelineOptions options, CloudObject cloudUserFn, @Nullable List<SideInputInfo> sideInputInfos, TupleTag<?> mainOutputTag, Map<TupleTag<?>, Integer> outputTupleTagsToReceiverIndices, final DataflowExecutionContext<?> executionContext, DataflowOperationContext operationContext) throws Exception {
Map.Entry<TupleTag<?>, Integer> entry = Iterables.getOnlyElement(outputTupleTagsToReceiverIndices.entrySet());
checkArgument(entry.getKey().equals(mainOutputTag), "Output tags should reference only the main output tag: %s vs %s", entry.getKey(), mainOutputTag);
checkArgument(entry.getValue() == 0, "There should be a single receiver, but using receiver index %s", entry.getValue());
byte[] encodedWindowingStrategy = getBytes(cloudUserFn, PropertyNames.SERIALIZED_FN);
WindowingStrategy windowingStrategy;
try {
windowingStrategy = deserializeWindowingStrategy(encodedWindowingStrategy);
} catch (Exception e) {
// TODO: Catch block disappears, becoming an error once Python SDK is compliant.
if (DataflowRunner.hasExperiment(options.as(DataflowPipelineDebugOptions.class), "beam_fn_api")) {
LOG.info("FnAPI: Unable to deserialize windowing strategy, assuming default", e);
windowingStrategy = WindowingStrategy.globalDefault();
} else {
throw e;
}
}
byte[] serializedCombineFn = getBytes(cloudUserFn, WorkerPropertyNames.COMBINE_FN, null);
AppliedCombineFn<?, ?, ?, ?> combineFn = null;
if (serializedCombineFn != null) {
Object combineFnObj = SerializableUtils.deserializeFromByteArray(serializedCombineFn, "serialized combine fn");
checkArgument(combineFnObj instanceof AppliedCombineFn, "unexpected kind of AppliedCombineFn: " + combineFnObj.getClass().getName());
combineFn = (AppliedCombineFn<?, ?, ?, ?>) combineFnObj;
}
Map<String, Object> inputCoderObject = getObject(cloudUserFn, WorkerPropertyNames.INPUT_CODER);
Coder<?> inputCoder = CloudObjects.coderFromCloudObject(CloudObject.fromSpec(inputCoderObject));
checkArgument(inputCoder instanceof WindowedValueCoder, "Expected WindowedValueCoder for inputCoder, got: " + inputCoder.getClass().getName());
@SuppressWarnings("unchecked") WindowedValueCoder<?> windowedValueCoder = (WindowedValueCoder<?>) inputCoder;
Coder<?> elemCoder = windowedValueCoder.getValueCoder();
checkArgument(elemCoder instanceof KvCoder, "Expected KvCoder for inputCoder, got: " + elemCoder.getClass().getName());
@SuppressWarnings("unchecked") KvCoder<?, ?> kvCoder = (KvCoder<?, ?>) elemCoder;
boolean isStreamingPipeline = options.as(StreamingOptions.class).isStreaming();
SideInputReader sideInputReader = NullSideInputReader.empty();
@Nullable AppliedCombineFn<?, ?, ?, ?> maybeMergingCombineFn = null;
if (combineFn != null) {
sideInputReader = executionContext.getSideInputReader(sideInputInfos, combineFn.getSideInputViews(), operationContext);
String phase = getString(cloudUserFn, WorkerPropertyNames.PHASE, CombinePhase.ALL);
checkArgument(phase.equals(CombinePhase.ALL) || phase.equals(CombinePhase.MERGE), "Unexpected phase: %s", phase);
if (phase.equals(CombinePhase.MERGE)) {
maybeMergingCombineFn = makeAppliedMergingFunction(combineFn);
} else {
maybeMergingCombineFn = combineFn;
}
}
StateInternalsFactory<?> stateInternalsFactory = key -> executionContext.getStepContext(operationContext).stateInternals();
// This will be a GABW Fn for either batch or streaming, with combiner in it or not
GroupAlsoByWindowFn<?, ?> fn;
// This will be a FakeKeyedWorkItemCoder for streaming or null for batch
Coder<?> gabwInputCoder;
// TODO: do not do this with mess of "if"
if (isStreamingPipeline) {
if (maybeMergingCombineFn == null) {
fn = StreamingGroupAlsoByWindowsDoFns.createForIterable(windowingStrategy, stateInternalsFactory, ((KvCoder) kvCoder).getValueCoder());
gabwInputCoder = WindmillKeyedWorkItem.FakeKeyedWorkItemCoder.of(kvCoder);
} else {
fn = StreamingGroupAlsoByWindowsDoFns.create(windowingStrategy, stateInternalsFactory, (AppliedCombineFn) maybeMergingCombineFn, ((KvCoder) kvCoder).getKeyCoder());
gabwInputCoder = WindmillKeyedWorkItem.FakeKeyedWorkItemCoder.of(((AppliedCombineFn) maybeMergingCombineFn).getKvCoder());
}
} else {
if (maybeMergingCombineFn == null) {
fn = BatchGroupAlsoByWindowsDoFns.createForIterable(windowingStrategy, stateInternalsFactory, ((KvCoder) kvCoder).getValueCoder());
gabwInputCoder = null;
} else {
fn = BatchGroupAlsoByWindowsDoFns.create(windowingStrategy, (AppliedCombineFn) maybeMergingCombineFn);
gabwInputCoder = null;
}
}
// TODO: or anyhow related to it, do not do this with mess of "if"
if (maybeMergingCombineFn != null) {
return new GroupAlsoByWindowsParDoFn(options, fn, windowingStrategy, ((AppliedCombineFn) maybeMergingCombineFn).getSideInputViews(), gabwInputCoder, sideInputReader, mainOutputTag, executionContext.getStepContext(operationContext));
} else {
return new GroupAlsoByWindowsParDoFn(options, fn, windowingStrategy, null, gabwInputCoder, sideInputReader, mainOutputTag, executionContext.getStepContext(operationContext));
}
}
use of org.apache.beam.model.pipeline.v1.RunnerApi.WindowingStrategy in project beam by apache.
the class ParDoBoundMultiTranslator method doTranslatePortable.
// static for serializing anonymous functions
private static <InT, OutT> void doTranslatePortable(PipelineNode.PTransformNode transform, QueryablePipeline pipeline, PortableTranslationContext ctx) {
Map<String, String> outputs = transform.getTransform().getOutputsMap();
final RunnerApi.ExecutableStagePayload stagePayload;
try {
stagePayload = RunnerApi.ExecutableStagePayload.parseFrom(transform.getTransform().getSpec().getPayload());
} catch (IOException e) {
throw new RuntimeException(e);
}
String inputId = stagePayload.getInput();
final MessageStream<OpMessage<InT>> inputStream = ctx.getMessageStreamById(inputId);
// Analyze side inputs
final List<MessageStream<OpMessage<Iterable<?>>>> sideInputStreams = new ArrayList<>();
final Map<SideInputId, PCollectionView<?>> sideInputMapping = new HashMap<>();
final Map<String, PCollectionView<?>> idToViewMapping = new HashMap<>();
final RunnerApi.Components components = stagePayload.getComponents();
for (SideInputId sideInputId : stagePayload.getSideInputsList()) {
final String sideInputCollectionId = components.getTransformsOrThrow(sideInputId.getTransformId()).getInputsOrThrow(sideInputId.getLocalName());
final WindowingStrategy<?, BoundedWindow> windowingStrategy = WindowUtils.getWindowStrategy(sideInputCollectionId, components);
final WindowedValue.WindowedValueCoder<?> coder = (WindowedValue.WindowedValueCoder) instantiateCoder(sideInputCollectionId, components);
// Create a runner-side view
final PCollectionView<?> view = createPCollectionView(sideInputId, coder, windowingStrategy);
// Use GBK to aggregate the side inputs and then broadcast it out
final MessageStream<OpMessage<Iterable<?>>> broadcastSideInput = groupAndBroadcastSideInput(sideInputId, sideInputCollectionId, components.getPcollectionsOrThrow(sideInputCollectionId), (WindowingStrategy) windowingStrategy, coder, ctx);
sideInputStreams.add(broadcastSideInput);
sideInputMapping.put(sideInputId, view);
idToViewMapping.put(getSideInputUniqueId(sideInputId), view);
}
final Map<TupleTag<?>, Integer> tagToIndexMap = new HashMap<>();
final Map<Integer, String> indexToIdMap = new HashMap<>();
final Map<String, TupleTag<?>> idToTupleTagMap = new HashMap<>();
// first output as the main output
final TupleTag<OutT> mainOutputTag = outputs.isEmpty() ? null : new TupleTag(outputs.keySet().iterator().next());
AtomicInteger index = new AtomicInteger(0);
outputs.keySet().iterator().forEachRemaining(outputName -> {
TupleTag<?> tupleTag = new TupleTag<>(outputName);
tagToIndexMap.put(tupleTag, index.get());
String collectionId = outputs.get(outputName);
indexToIdMap.put(index.get(), collectionId);
idToTupleTagMap.put(collectionId, tupleTag);
index.incrementAndGet();
});
WindowedValue.WindowedValueCoder<InT> windowedInputCoder = WindowUtils.instantiateWindowedCoder(inputId, pipeline.getComponents());
// TODO: support schema and side inputs for portable runner
// Note: transform.getTransform() is an ExecutableStage, not ParDo, so we need to extract
// these info from its components.
final DoFnSchemaInformation doFnSchemaInformation = null;
final RunnerApi.PCollection input = pipeline.getComponents().getPcollectionsOrThrow(inputId);
final PCollection.IsBounded isBounded = SamzaPipelineTranslatorUtils.isBounded(input);
final Coder<?> keyCoder = StateUtils.isStateful(stagePayload) ? ((KvCoder) ((WindowedValue.FullWindowedValueCoder) windowedInputCoder).getValueCoder()).getKeyCoder() : null;
final DoFnOp<InT, OutT, RawUnionValue> op = new DoFnOp<>(mainOutputTag, new NoOpDoFn<>(), keyCoder, // input coder not in use
windowedInputCoder.getValueCoder(), windowedInputCoder, // output coders not in use
Collections.emptyMap(), new ArrayList<>(sideInputMapping.values()), // used by java runner only
new ArrayList<>(idToTupleTagMap.values()), WindowUtils.getWindowStrategy(inputId, stagePayload.getComponents()), idToViewMapping, new DoFnOp.MultiOutputManagerFactory(tagToIndexMap), ctx.getTransformFullName(), ctx.getTransformId(), isBounded, true, stagePayload, ctx.getJobInfo(), idToTupleTagMap, doFnSchemaInformation, sideInputMapping);
final MessageStream<OpMessage<InT>> mergedStreams;
if (sideInputStreams.isEmpty()) {
mergedStreams = inputStream;
} else {
MessageStream<OpMessage<InT>> mergedSideInputStreams = MessageStream.mergeAll(sideInputStreams).flatMap(new SideInputWatermarkFn());
mergedStreams = inputStream.merge(Collections.singletonList(mergedSideInputStreams));
}
final MessageStream<OpMessage<RawUnionValue>> taggedOutputStream = mergedStreams.flatMapAsync(OpAdapter.adapt(op));
for (int outputIndex : tagToIndexMap.values()) {
@SuppressWarnings("unchecked") final MessageStream<OpMessage<OutT>> outputStream = taggedOutputStream.filter(message -> message.getType() != OpMessage.Type.ELEMENT || message.getElement().getValue().getUnionTag() == outputIndex).flatMapAsync(OpAdapter.adapt(new RawUnionValueToValue()));
ctx.registerMessageStream(indexToIdMap.get(outputIndex), outputStream);
}
}
use of org.apache.beam.model.pipeline.v1.RunnerApi.WindowingStrategy in project beam by apache.
the class SparkStreamingPortablePipelineTranslator method translateGroupByKey.
private static <K, V> void translateGroupByKey(PTransformNode transformNode, RunnerApi.Pipeline pipeline, SparkStreamingTranslationContext context) {
RunnerApi.Components components = pipeline.getComponents();
String inputId = getInputId(transformNode);
UnboundedDataset<KV<K, V>> inputDataset = (UnboundedDataset<KV<K, V>>) context.popDataset(inputId);
List<Integer> streamSources = inputDataset.getStreamSources();
WindowedValue.WindowedValueCoder<KV<K, V>> inputCoder = getWindowedValueCoder(inputId, components);
KvCoder<K, V> inputKvCoder = (KvCoder<K, V>) inputCoder.getValueCoder();
WindowingStrategy windowingStrategy = getWindowingStrategy(inputId, components);
WindowFn<Object, BoundedWindow> windowFn = windowingStrategy.getWindowFn();
WindowedValue.WindowedValueCoder<V> wvCoder = WindowedValue.FullWindowedValueCoder.of(inputKvCoder.getValueCoder(), windowFn.windowCoder());
JavaDStream<WindowedValue<KV<K, Iterable<V>>>> outStream = SparkGroupAlsoByWindowViaWindowSet.groupByKeyAndWindow(inputDataset.getDStream(), inputKvCoder.getKeyCoder(), wvCoder, windowingStrategy, context.getSerializableOptions(), streamSources, transformNode.getId());
context.pushDataset(getOutputId(transformNode), new UnboundedDataset<>(outStream, streamSources));
}
use of org.apache.beam.model.pipeline.v1.RunnerApi.WindowingStrategy in project beam by apache.
the class SparkBatchPortablePipelineTranslator method translateExecutableStage.
private static <InputT, OutputT, SideInputT> void translateExecutableStage(PTransformNode transformNode, RunnerApi.Pipeline pipeline, SparkTranslationContext context) {
RunnerApi.ExecutableStagePayload stagePayload;
try {
stagePayload = RunnerApi.ExecutableStagePayload.parseFrom(transformNode.getTransform().getSpec().getPayload());
} catch (IOException e) {
throw new RuntimeException(e);
}
String inputPCollectionId = stagePayload.getInput();
Dataset inputDataset = context.popDataset(inputPCollectionId);
Map<String, String> outputs = transformNode.getTransform().getOutputsMap();
BiMap<String, Integer> outputExtractionMap = createOutputMap(outputs.values());
Components components = pipeline.getComponents();
Coder windowCoder = getWindowingStrategy(inputPCollectionId, components).getWindowFn().windowCoder();
ImmutableMap<String, Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<SideInputT>>> broadcastVariables = broadcastSideInputs(stagePayload, context);
JavaRDD<RawUnionValue> staged;
if (stagePayload.getUserStatesCount() > 0 || stagePayload.getTimersCount() > 0) {
Coder<WindowedValue<InputT>> windowedInputCoder = instantiateCoder(inputPCollectionId, components);
Coder valueCoder = ((WindowedValue.FullWindowedValueCoder) windowedInputCoder).getValueCoder();
// Stateful stages are only allowed of KV input to be able to group on the key
if (!(valueCoder instanceof KvCoder)) {
throw new IllegalStateException(String.format(Locale.ENGLISH, "The element coder for stateful DoFn '%s' must be KvCoder but is: %s", inputPCollectionId, valueCoder.getClass().getSimpleName()));
}
Coder keyCoder = ((KvCoder) valueCoder).getKeyCoder();
Coder innerValueCoder = ((KvCoder) valueCoder).getValueCoder();
WindowingStrategy windowingStrategy = getWindowingStrategy(inputPCollectionId, components);
WindowFn<Object, BoundedWindow> windowFn = windowingStrategy.getWindowFn();
WindowedValue.WindowedValueCoder wvCoder = WindowedValue.FullWindowedValueCoder.of(innerValueCoder, windowFn.windowCoder());
JavaPairRDD<ByteArray, Iterable<WindowedValue<KV>>> groupedByKey = groupByKeyPair(inputDataset, keyCoder, wvCoder);
SparkExecutableStageFunction<KV, SideInputT> function = new SparkExecutableStageFunction<>(context.getSerializableOptions(), stagePayload, context.jobInfo, outputExtractionMap, SparkExecutableStageContextFactory.getInstance(), broadcastVariables, MetricsAccumulator.getInstance(), windowCoder);
staged = groupedByKey.flatMap(function.forPair());
} else {
JavaRDD<WindowedValue<InputT>> inputRdd2 = ((BoundedDataset<InputT>) inputDataset).getRDD();
SparkExecutableStageFunction<InputT, SideInputT> function2 = new SparkExecutableStageFunction<>(context.getSerializableOptions(), stagePayload, context.jobInfo, outputExtractionMap, SparkExecutableStageContextFactory.getInstance(), broadcastVariables, MetricsAccumulator.getInstance(), windowCoder);
staged = inputRdd2.mapPartitions(function2);
}
String intermediateId = getExecutableStageIntermediateId(transformNode);
context.pushDataset(intermediateId, new Dataset() {
@Override
public void cache(String storageLevel, Coder<?> coder) {
StorageLevel level = StorageLevel.fromString(storageLevel);
staged.persist(level);
}
@Override
public void action() {
// Empty function to force computation of RDD.
staged.foreach(TranslationUtils.emptyVoidFunction());
}
@Override
public void setName(String name) {
staged.setName(name);
}
});
// pop dataset to mark RDD as used
context.popDataset(intermediateId);
for (String outputId : outputs.values()) {
JavaRDD<WindowedValue<OutputT>> outputRdd = staged.flatMap(new SparkExecutableStageExtractionFunction<>(outputExtractionMap.get(outputId)));
context.pushDataset(outputId, new BoundedDataset<>(outputRdd));
}
if (outputs.isEmpty()) {
// After pipeline translation, we traverse the set of unconsumed PCollections and add a
// no-op sink to each to make sure they are materialized by Spark. However, some SDK-executed
// stages have no runner-visible output after fusion. We handle this case by adding a sink
// here.
JavaRDD<WindowedValue<OutputT>> outputRdd = staged.flatMap((rawUnionValue) -> Collections.emptyIterator());
context.pushDataset(String.format("EmptyOutputSink_%d", context.nextSinkId()), new BoundedDataset<>(outputRdd));
}
}
use of org.apache.beam.model.pipeline.v1.RunnerApi.WindowingStrategy in project beam by apache.
the class SparkBatchPortablePipelineTranslator method translateGroupByKey.
private static <K, V> void translateGroupByKey(PTransformNode transformNode, RunnerApi.Pipeline pipeline, SparkTranslationContext context) {
RunnerApi.Components components = pipeline.getComponents();
String inputId = getInputId(transformNode);
Dataset inputDataset = context.popDataset(inputId);
JavaRDD<WindowedValue<KV<K, V>>> inputRdd = ((BoundedDataset<KV<K, V>>) inputDataset).getRDD();
WindowedValueCoder<KV<K, V>> inputCoder = getWindowedValueCoder(inputId, components);
KvCoder<K, V> inputKvCoder = (KvCoder<K, V>) inputCoder.getValueCoder();
Coder<K> inputKeyCoder = inputKvCoder.getKeyCoder();
Coder<V> inputValueCoder = inputKvCoder.getValueCoder();
WindowingStrategy windowingStrategy = getWindowingStrategy(inputId, components);
WindowFn<Object, BoundedWindow> windowFn = windowingStrategy.getWindowFn();
WindowedValue.WindowedValueCoder<V> wvCoder = WindowedValue.FullWindowedValueCoder.of(inputValueCoder, windowFn.windowCoder());
JavaRDD<WindowedValue<KV<K, Iterable<V>>>> groupedByKeyAndWindow;
Partitioner partitioner = getPartitioner(context);
// As this is batch, we can ignore triggering and allowed lateness parameters.
if (windowingStrategy.getWindowFn().equals(new GlobalWindows()) && windowingStrategy.getTimestampCombiner().equals(TimestampCombiner.END_OF_WINDOW)) {
// we can drop the windows and recover them later
groupedByKeyAndWindow = GroupNonMergingWindowsFunctions.groupByKeyInGlobalWindow(inputRdd, inputKeyCoder, inputValueCoder, partitioner);
} else if (GroupNonMergingWindowsFunctions.isEligibleForGroupByWindow(windowingStrategy)) {
// we can have a memory sensitive translation for non-merging windows
groupedByKeyAndWindow = GroupNonMergingWindowsFunctions.groupByKeyAndWindow(inputRdd, inputKeyCoder, inputValueCoder, windowingStrategy, partitioner);
} else {
JavaRDD<KV<K, Iterable<WindowedValue<V>>>> groupedByKeyOnly = GroupCombineFunctions.groupByKeyOnly(inputRdd, inputKeyCoder, wvCoder, partitioner);
// for batch, GroupAlsoByWindow uses an in-memory StateInternals.
groupedByKeyAndWindow = groupedByKeyOnly.flatMap(new SparkGroupAlsoByWindowViaOutputBufferFn<>(windowingStrategy, new TranslationUtils.InMemoryStateInternalsFactory<>(), SystemReduceFn.buffering(inputValueCoder), context.serializablePipelineOptions));
}
context.pushDataset(getOutputId(transformNode), new BoundedDataset<>(groupedByKeyAndWindow));
}
Aggregations