use of org.apache.beam.sdk.values.PCollectionView in project beam by apache.
the class StreamingSideInputHandlerFactory method forStage.
/**
* Creates a new state handler for the given stage. Note that this requires a traversal of the
* stage itself, so this should only be called once per stage rather than once per bundle.
*/
public static StreamingSideInputHandlerFactory forStage(ExecutableStage stage, Map<SideInputId, PCollectionView<?>> viewMapping, org.apache.beam.runners.core.SideInputHandler runnerHandler) {
ImmutableMap.Builder<SideInputId, PCollectionView<?>> sideInputBuilder = ImmutableMap.builder();
for (SideInputReference sideInput : stage.getSideInputs()) {
SideInputId sideInputId = SideInputId.newBuilder().setTransformId(sideInput.transform().getId()).setLocalName(sideInput.localName()).build();
sideInputBuilder.put(sideInputId, checkNotNull(viewMapping.get(sideInputId), "No side input for %s/%s", sideInputId.getTransformId(), sideInputId.getLocalName()));
}
StreamingSideInputHandlerFactory factory = new StreamingSideInputHandlerFactory(sideInputBuilder.build(), runnerHandler);
return factory;
}
use of org.apache.beam.sdk.values.PCollectionView in project beam by apache.
the class StreamingSideInputHandlerFactory method forIterableSideInput.
@Override
public <V, W extends BoundedWindow> IterableSideInputHandler<V, W> forIterableSideInput(String transformId, String sideInputId, Coder<V> elementCoder, Coder<W> windowCoder) {
PCollectionView collectionNode = sideInputToCollection.get(SideInputId.newBuilder().setTransformId(transformId).setLocalName(sideInputId).build());
checkArgument(collectionNode != null, "No side input for %s/%s", transformId, sideInputId);
return new IterableSideInputHandler<V, W>() {
@Override
public Iterable<V> get(W window) {
return checkNotNull((Iterable<V>) runnerHandler.getIterable(collectionNode, window), "Element processed by SDK before side input is ready");
}
@Override
public Coder<V> elementCoder() {
return elementCoder;
}
};
}
use of org.apache.beam.sdk.values.PCollectionView in project beam by apache.
the class StreamingSideInputHandlerFactory method forMultimapSideInput.
@Override
public <K, V, W extends BoundedWindow> MultimapSideInputHandler<K, V, W> forMultimapSideInput(String transformId, String sideInputId, KvCoder<K, V> elementCoder, Coder<W> windowCoder) {
PCollectionView collectionNode = sideInputToCollection.get(SideInputId.newBuilder().setTransformId(transformId).setLocalName(sideInputId).build());
checkArgument(collectionNode != null, "No side input for %s/%s", transformId, sideInputId);
return new MultimapSideInputHandler<K, V, W>() {
@Override
public Iterable<V> get(K key, W window) {
Iterable<KV<K, V>> values = (Iterable<KV<K, V>>) runnerHandler.getIterable(collectionNode, window);
Object structuralK = keyCoder().structuralValue(key);
ArrayList<V> result = new ArrayList<>();
// find values for the given key
for (KV<K, V> kv : values) {
if (structuralK.equals(keyCoder().structuralValue(kv.getKey()))) {
result.add(kv.getValue());
}
}
return Collections.unmodifiableList(result);
}
@Override
public Iterable<K> get(W window) {
Iterable<KV<K, V>> values = (Iterable<KV<K, V>>) runnerHandler.getIterable(collectionNode, window);
Map<Object, K> result = new HashMap<>();
// find all keys
for (KV<K, V> kv : values) {
result.putIfAbsent(keyCoder().structuralValue(kv.getKey()), kv.getKey());
}
return Collections.unmodifiableCollection(result.values());
}
@Override
public Coder<K> keyCoder() {
return elementCoder.getKeyCoder();
}
@Override
public Coder<V> valueCoder() {
return elementCoder.getValueCoder();
}
};
}
use of org.apache.beam.sdk.values.PCollectionView in project beam by apache.
the class ParDoBoundMultiTranslator method doTranslate.
// static for serializing anonymous functions
private static <InT, OutT> void doTranslate(ParDo.MultiOutput<InT, OutT> transform, TransformHierarchy.Node node, TranslationContext ctx) {
final PCollection<? extends InT> input = ctx.getInput(transform);
final Map<TupleTag<?>, Coder<?>> outputCoders = ctx.getCurrentTransform().getOutputs().entrySet().stream().filter(e -> e.getValue() instanceof PCollection).collect(Collectors.toMap(e -> e.getKey(), e -> ((PCollection<?>) e.getValue()).getCoder()));
final Coder<?> keyCoder = StateUtils.isStateful(transform.getFn()) ? ((KvCoder<?, ?>) input.getCoder()).getKeyCoder() : null;
if (DoFnSignatures.isSplittable(transform.getFn())) {
throw new UnsupportedOperationException("Splittable DoFn is not currently supported");
}
if (DoFnSignatures.requiresTimeSortedInput(transform.getFn())) {
throw new UnsupportedOperationException("@RequiresTimeSortedInput annotation is not currently supported");
}
final MessageStream<OpMessage<InT>> inputStream = ctx.getMessageStream(input);
final List<MessageStream<OpMessage<InT>>> sideInputStreams = transform.getSideInputs().values().stream().map(ctx::<InT>getViewStream).collect(Collectors.toList());
final ArrayList<Map.Entry<TupleTag<?>, PCollection<?>>> outputs = new ArrayList<>(node.getOutputs().entrySet());
final Map<TupleTag<?>, Integer> tagToIndexMap = new HashMap<>();
final Map<Integer, PCollection<?>> indexToPCollectionMap = new HashMap<>();
for (int index = 0; index < outputs.size(); ++index) {
final Map.Entry<TupleTag<?>, PCollection<?>> taggedOutput = outputs.get(index);
tagToIndexMap.put(taggedOutput.getKey(), index);
if (!(taggedOutput.getValue() instanceof PCollection)) {
throw new IllegalArgumentException("Expected side output to be PCollection, but was: " + taggedOutput.getValue());
}
final PCollection<?> sideOutputCollection = taggedOutput.getValue();
indexToPCollectionMap.put(index, sideOutputCollection);
}
final HashMap<String, PCollectionView<?>> idToPValueMap = new HashMap<>();
for (PCollectionView<?> view : transform.getSideInputs().values()) {
idToPValueMap.put(ctx.getViewId(view), view);
}
DoFnSchemaInformation doFnSchemaInformation;
doFnSchemaInformation = ParDoTranslation.getSchemaInformation(ctx.getCurrentTransform());
Map<String, PCollectionView<?>> sideInputMapping = ParDoTranslation.getSideInputMapping(ctx.getCurrentTransform());
final DoFnOp<InT, OutT, RawUnionValue> op = new DoFnOp<>(transform.getMainOutputTag(), transform.getFn(), keyCoder, (Coder<InT>) input.getCoder(), null, outputCoders, transform.getSideInputs().values(), transform.getAdditionalOutputTags().getAll(), input.getWindowingStrategy(), idToPValueMap, new DoFnOp.MultiOutputManagerFactory(tagToIndexMap), ctx.getTransformFullName(), ctx.getTransformId(), input.isBounded(), false, null, null, Collections.emptyMap(), doFnSchemaInformation, sideInputMapping);
final MessageStream<OpMessage<InT>> mergedStreams;
if (sideInputStreams.isEmpty()) {
mergedStreams = inputStream;
} else {
MessageStream<OpMessage<InT>> mergedSideInputStreams = MessageStream.mergeAll(sideInputStreams).flatMap(new SideInputWatermarkFn());
mergedStreams = inputStream.merge(Collections.singletonList(mergedSideInputStreams));
}
final MessageStream<OpMessage<RawUnionValue>> taggedOutputStream = mergedStreams.flatMapAsync(OpAdapter.adapt(op));
for (int outputIndex : tagToIndexMap.values()) {
@SuppressWarnings("unchecked") final MessageStream<OpMessage<OutT>> outputStream = taggedOutputStream.filter(message -> message.getType() != OpMessage.Type.ELEMENT || message.getElement().getValue().getUnionTag() == outputIndex).flatMapAsync(OpAdapter.adapt(new RawUnionValueToValue()));
ctx.registerMessageStream(indexToPCollectionMap.get(outputIndex), outputStream);
}
}
use of org.apache.beam.sdk.values.PCollectionView in project beam by apache.
the class ParDoBoundMultiTranslator method doTranslatePortable.
// static for serializing anonymous functions
private static <InT, OutT> void doTranslatePortable(PipelineNode.PTransformNode transform, QueryablePipeline pipeline, PortableTranslationContext ctx) {
Map<String, String> outputs = transform.getTransform().getOutputsMap();
final RunnerApi.ExecutableStagePayload stagePayload;
try {
stagePayload = RunnerApi.ExecutableStagePayload.parseFrom(transform.getTransform().getSpec().getPayload());
} catch (IOException e) {
throw new RuntimeException(e);
}
String inputId = stagePayload.getInput();
final MessageStream<OpMessage<InT>> inputStream = ctx.getMessageStreamById(inputId);
// Analyze side inputs
final List<MessageStream<OpMessage<Iterable<?>>>> sideInputStreams = new ArrayList<>();
final Map<SideInputId, PCollectionView<?>> sideInputMapping = new HashMap<>();
final Map<String, PCollectionView<?>> idToViewMapping = new HashMap<>();
final RunnerApi.Components components = stagePayload.getComponents();
for (SideInputId sideInputId : stagePayload.getSideInputsList()) {
final String sideInputCollectionId = components.getTransformsOrThrow(sideInputId.getTransformId()).getInputsOrThrow(sideInputId.getLocalName());
final WindowingStrategy<?, BoundedWindow> windowingStrategy = WindowUtils.getWindowStrategy(sideInputCollectionId, components);
final WindowedValue.WindowedValueCoder<?> coder = (WindowedValue.WindowedValueCoder) instantiateCoder(sideInputCollectionId, components);
// Create a runner-side view
final PCollectionView<?> view = createPCollectionView(sideInputId, coder, windowingStrategy);
// Use GBK to aggregate the side inputs and then broadcast it out
final MessageStream<OpMessage<Iterable<?>>> broadcastSideInput = groupAndBroadcastSideInput(sideInputId, sideInputCollectionId, components.getPcollectionsOrThrow(sideInputCollectionId), (WindowingStrategy) windowingStrategy, coder, ctx);
sideInputStreams.add(broadcastSideInput);
sideInputMapping.put(sideInputId, view);
idToViewMapping.put(getSideInputUniqueId(sideInputId), view);
}
final Map<TupleTag<?>, Integer> tagToIndexMap = new HashMap<>();
final Map<Integer, String> indexToIdMap = new HashMap<>();
final Map<String, TupleTag<?>> idToTupleTagMap = new HashMap<>();
// first output as the main output
final TupleTag<OutT> mainOutputTag = outputs.isEmpty() ? null : new TupleTag(outputs.keySet().iterator().next());
AtomicInteger index = new AtomicInteger(0);
outputs.keySet().iterator().forEachRemaining(outputName -> {
TupleTag<?> tupleTag = new TupleTag<>(outputName);
tagToIndexMap.put(tupleTag, index.get());
String collectionId = outputs.get(outputName);
indexToIdMap.put(index.get(), collectionId);
idToTupleTagMap.put(collectionId, tupleTag);
index.incrementAndGet();
});
WindowedValue.WindowedValueCoder<InT> windowedInputCoder = WindowUtils.instantiateWindowedCoder(inputId, pipeline.getComponents());
// TODO: support schema and side inputs for portable runner
// Note: transform.getTransform() is an ExecutableStage, not ParDo, so we need to extract
// these info from its components.
final DoFnSchemaInformation doFnSchemaInformation = null;
final RunnerApi.PCollection input = pipeline.getComponents().getPcollectionsOrThrow(inputId);
final PCollection.IsBounded isBounded = SamzaPipelineTranslatorUtils.isBounded(input);
final Coder<?> keyCoder = StateUtils.isStateful(stagePayload) ? ((KvCoder) ((WindowedValue.FullWindowedValueCoder) windowedInputCoder).getValueCoder()).getKeyCoder() : null;
final DoFnOp<InT, OutT, RawUnionValue> op = new DoFnOp<>(mainOutputTag, new NoOpDoFn<>(), keyCoder, // input coder not in use
windowedInputCoder.getValueCoder(), windowedInputCoder, // output coders not in use
Collections.emptyMap(), new ArrayList<>(sideInputMapping.values()), // used by java runner only
new ArrayList<>(idToTupleTagMap.values()), WindowUtils.getWindowStrategy(inputId, stagePayload.getComponents()), idToViewMapping, new DoFnOp.MultiOutputManagerFactory(tagToIndexMap), ctx.getTransformFullName(), ctx.getTransformId(), isBounded, true, stagePayload, ctx.getJobInfo(), idToTupleTagMap, doFnSchemaInformation, sideInputMapping);
final MessageStream<OpMessage<InT>> mergedStreams;
if (sideInputStreams.isEmpty()) {
mergedStreams = inputStream;
} else {
MessageStream<OpMessage<InT>> mergedSideInputStreams = MessageStream.mergeAll(sideInputStreams).flatMap(new SideInputWatermarkFn());
mergedStreams = inputStream.merge(Collections.singletonList(mergedSideInputStreams));
}
final MessageStream<OpMessage<RawUnionValue>> taggedOutputStream = mergedStreams.flatMapAsync(OpAdapter.adapt(op));
for (int outputIndex : tagToIndexMap.values()) {
@SuppressWarnings("unchecked") final MessageStream<OpMessage<OutT>> outputStream = taggedOutputStream.filter(message -> message.getType() != OpMessage.Type.ELEMENT || message.getElement().getValue().getUnionTag() == outputIndex).flatMapAsync(OpAdapter.adapt(new RawUnionValueToValue()));
ctx.registerMessageStream(indexToIdMap.get(outputIndex), outputStream);
}
}
Aggregations