use of org.apache.beam.sdk.transforms.ParDo in project beam by apache.
the class SplittableParDoTest method applySplittableParDo.
private PCollection<String> applySplittableParDo(String name, PCollection<Integer> input, DoFn<Integer, String> fn) {
ParDo.MultiOutput<Integer, String> multiOutput = ParDo.of(fn).withOutputTags(MAIN_OUTPUT_TAG, TupleTagList.empty());
PCollectionTuple output = multiOutput.expand(input);
output.get(MAIN_OUTPUT_TAG).setName("main");
AppliedPTransform<PCollection<Integer>, PCollectionTuple, ?> transform = AppliedPTransform.of("ParDo", PValues.expandInput(input), PValues.expandOutput(output), multiOutput, ResourceHints.create(), pipeline);
return input.apply(name, SplittableParDo.forAppliedParDo(transform)).get(MAIN_OUTPUT_TAG);
}
use of org.apache.beam.sdk.transforms.ParDo in project beam by apache.
the class ParDoBoundMultiTranslator method doTranslatePortable.
// static for serializing anonymous functions
private static <InT, OutT> void doTranslatePortable(PipelineNode.PTransformNode transform, QueryablePipeline pipeline, PortableTranslationContext ctx) {
Map<String, String> outputs = transform.getTransform().getOutputsMap();
final RunnerApi.ExecutableStagePayload stagePayload;
try {
stagePayload = RunnerApi.ExecutableStagePayload.parseFrom(transform.getTransform().getSpec().getPayload());
} catch (IOException e) {
throw new RuntimeException(e);
}
String inputId = stagePayload.getInput();
final MessageStream<OpMessage<InT>> inputStream = ctx.getMessageStreamById(inputId);
// Analyze side inputs
final List<MessageStream<OpMessage<Iterable<?>>>> sideInputStreams = new ArrayList<>();
final Map<SideInputId, PCollectionView<?>> sideInputMapping = new HashMap<>();
final Map<String, PCollectionView<?>> idToViewMapping = new HashMap<>();
final RunnerApi.Components components = stagePayload.getComponents();
for (SideInputId sideInputId : stagePayload.getSideInputsList()) {
final String sideInputCollectionId = components.getTransformsOrThrow(sideInputId.getTransformId()).getInputsOrThrow(sideInputId.getLocalName());
final WindowingStrategy<?, BoundedWindow> windowingStrategy = WindowUtils.getWindowStrategy(sideInputCollectionId, components);
final WindowedValue.WindowedValueCoder<?> coder = (WindowedValue.WindowedValueCoder) instantiateCoder(sideInputCollectionId, components);
// Create a runner-side view
final PCollectionView<?> view = createPCollectionView(sideInputId, coder, windowingStrategy);
// Use GBK to aggregate the side inputs and then broadcast it out
final MessageStream<OpMessage<Iterable<?>>> broadcastSideInput = groupAndBroadcastSideInput(sideInputId, sideInputCollectionId, components.getPcollectionsOrThrow(sideInputCollectionId), (WindowingStrategy) windowingStrategy, coder, ctx);
sideInputStreams.add(broadcastSideInput);
sideInputMapping.put(sideInputId, view);
idToViewMapping.put(getSideInputUniqueId(sideInputId), view);
}
final Map<TupleTag<?>, Integer> tagToIndexMap = new HashMap<>();
final Map<Integer, String> indexToIdMap = new HashMap<>();
final Map<String, TupleTag<?>> idToTupleTagMap = new HashMap<>();
// first output as the main output
final TupleTag<OutT> mainOutputTag = outputs.isEmpty() ? null : new TupleTag(outputs.keySet().iterator().next());
AtomicInteger index = new AtomicInteger(0);
outputs.keySet().iterator().forEachRemaining(outputName -> {
TupleTag<?> tupleTag = new TupleTag<>(outputName);
tagToIndexMap.put(tupleTag, index.get());
String collectionId = outputs.get(outputName);
indexToIdMap.put(index.get(), collectionId);
idToTupleTagMap.put(collectionId, tupleTag);
index.incrementAndGet();
});
WindowedValue.WindowedValueCoder<InT> windowedInputCoder = WindowUtils.instantiateWindowedCoder(inputId, pipeline.getComponents());
// TODO: support schema and side inputs for portable runner
// Note: transform.getTransform() is an ExecutableStage, not ParDo, so we need to extract
// these info from its components.
final DoFnSchemaInformation doFnSchemaInformation = null;
final RunnerApi.PCollection input = pipeline.getComponents().getPcollectionsOrThrow(inputId);
final PCollection.IsBounded isBounded = SamzaPipelineTranslatorUtils.isBounded(input);
final Coder<?> keyCoder = StateUtils.isStateful(stagePayload) ? ((KvCoder) ((WindowedValue.FullWindowedValueCoder) windowedInputCoder).getValueCoder()).getKeyCoder() : null;
final DoFnOp<InT, OutT, RawUnionValue> op = new DoFnOp<>(mainOutputTag, new NoOpDoFn<>(), keyCoder, // input coder not in use
windowedInputCoder.getValueCoder(), windowedInputCoder, // output coders not in use
Collections.emptyMap(), new ArrayList<>(sideInputMapping.values()), // used by java runner only
new ArrayList<>(idToTupleTagMap.values()), WindowUtils.getWindowStrategy(inputId, stagePayload.getComponents()), idToViewMapping, new DoFnOp.MultiOutputManagerFactory(tagToIndexMap), ctx.getTransformFullName(), ctx.getTransformId(), isBounded, true, stagePayload, ctx.getJobInfo(), idToTupleTagMap, doFnSchemaInformation, sideInputMapping);
final MessageStream<OpMessage<InT>> mergedStreams;
if (sideInputStreams.isEmpty()) {
mergedStreams = inputStream;
} else {
MessageStream<OpMessage<InT>> mergedSideInputStreams = MessageStream.mergeAll(sideInputStreams).flatMap(new SideInputWatermarkFn());
mergedStreams = inputStream.merge(Collections.singletonList(mergedSideInputStreams));
}
final MessageStream<OpMessage<RawUnionValue>> taggedOutputStream = mergedStreams.flatMapAsync(OpAdapter.adapt(op));
for (int outputIndex : tagToIndexMap.values()) {
@SuppressWarnings("unchecked") final MessageStream<OpMessage<OutT>> outputStream = taggedOutputStream.filter(message -> message.getType() != OpMessage.Type.ELEMENT || message.getElement().getValue().getUnionTag() == outputIndex).flatMapAsync(OpAdapter.adapt(new RawUnionValueToValue()));
ctx.registerMessageStream(indexToIdMap.get(outputIndex), outputStream);
}
}
use of org.apache.beam.sdk.transforms.ParDo in project beam by apache.
the class TransformTranslator method parDo.
private static <InputT, OutputT> TransformEvaluator<ParDo.MultiOutput<InputT, OutputT>> parDo() {
return new TransformEvaluator<ParDo.MultiOutput<InputT, OutputT>>() {
@Override
@SuppressWarnings("unchecked")
public void evaluate(ParDo.MultiOutput<InputT, OutputT> transform, EvaluationContext context) {
String stepName = context.getCurrentTransform().getFullName();
DoFn<InputT, OutputT> doFn = transform.getFn();
checkState(!DoFnSignatures.signatureForDoFn(doFn).processElement().isSplittable(), "Not expected to directly translate splittable DoFn, should have been overridden: %s", doFn);
JavaRDD<WindowedValue<InputT>> inRDD = ((BoundedDataset<InputT>) context.borrowDataset(transform)).getRDD();
WindowingStrategy<?, ?> windowingStrategy = context.getInput(transform).getWindowingStrategy();
MetricsContainerStepMapAccumulator metricsAccum = MetricsAccumulator.getInstance();
Coder<InputT> inputCoder = (Coder<InputT>) context.getInput(transform).getCoder();
Map<TupleTag<?>, Coder<?>> outputCoders = context.getOutputCoders();
JavaPairRDD<TupleTag<?>, WindowedValue<?>> all;
DoFnSignature signature = DoFnSignatures.getSignature(transform.getFn().getClass());
boolean stateful = signature.stateDeclarations().size() > 0 || signature.timerDeclarations().size() > 0;
DoFnSchemaInformation doFnSchemaInformation;
doFnSchemaInformation = ParDoTranslation.getSchemaInformation(context.getCurrentTransform());
Map<String, PCollectionView<?>> sideInputMapping = ParDoTranslation.getSideInputMapping(context.getCurrentTransform());
MultiDoFnFunction<InputT, OutputT> multiDoFnFunction = new MultiDoFnFunction<>(metricsAccum, stepName, doFn, context.getSerializableOptions(), transform.getMainOutputTag(), transform.getAdditionalOutputTags().getAll(), inputCoder, outputCoders, TranslationUtils.getSideInputs(transform.getSideInputs().values(), context), windowingStrategy, stateful, doFnSchemaInformation, sideInputMapping);
if (stateful) {
// Based on the fact that the signature is stateful, DoFnSignatures ensures
// that it is also keyed
all = statefulParDoTransform((KvCoder) context.getInput(transform).getCoder(), windowingStrategy.getWindowFn().windowCoder(), (JavaRDD) inRDD, getPartitioner(context), (MultiDoFnFunction) multiDoFnFunction, signature.processElement().requiresTimeSortedInput());
} else {
all = inRDD.mapPartitionsToPair(multiDoFnFunction);
}
Map<TupleTag<?>, PCollection<?>> outputs = context.getOutputs(transform);
if (outputs.size() > 1) {
StorageLevel level = StorageLevel.fromString(context.storageLevel());
if (canAvoidRddSerialization(level)) {
// if it is memory only reduce the overhead of moving to bytes
all = all.persist(level);
} else {
// Caching can cause Serialization, we need to code to bytes
// more details in https://issues.apache.org/jira/browse/BEAM-2669
Map<TupleTag<?>, Coder<WindowedValue<?>>> coderMap = TranslationUtils.getTupleTagCoders(outputs);
all = all.mapToPair(TranslationUtils.getTupleTagEncodeFunction(coderMap)).persist(level).mapToPair(TranslationUtils.getTupleTagDecodeFunction(coderMap));
}
}
for (Map.Entry<TupleTag<?>, PCollection<?>> output : outputs.entrySet()) {
JavaPairRDD<TupleTag<?>, WindowedValue<?>> filtered = all.filter(new TranslationUtils.TupleTagFilter(output.getKey()));
// Object is the best we can do since different outputs can have different tags
JavaRDD<WindowedValue<Object>> values = (JavaRDD<WindowedValue<Object>>) (JavaRDD<?>) filtered.values();
context.putDataset(output.getValue(), new BoundedDataset<>(values));
}
}
@Override
public String toNativeString() {
return "mapPartitions(new <fn>())";
}
};
}
use of org.apache.beam.sdk.transforms.ParDo in project beam by apache.
the class BeamZetaSqlCalcRelTest method testNoFieldAccess.
@Test
public void testNoFieldAccess() throws IllegalAccessException {
String sql = "SELECT 1 FROM KeyValue";
PCollection<Row> rows = compile(sql);
final NodeGetter nodeGetter = new NodeGetter(rows);
pipeline.traverseTopologically(nodeGetter);
ParDo.MultiOutput<Row, Row> pardo = (ParDo.MultiOutput<Row, Row>) nodeGetter.producer.getTransform();
PCollection<Row> input = (PCollection<Row>) Iterables.getOnlyElement(nodeGetter.producer.getInputs().values());
DoFnSchemaInformation info = ParDo.getDoFnSchemaInformation(pardo.getFn(), input);
FieldAccessDescriptor fieldAccess = info.getFieldAccessDescriptor();
Assert.assertFalse(fieldAccess.getAllFields());
Assert.assertTrue(fieldAccess.getFieldsAccessed().isEmpty());
Assert.assertTrue(fieldAccess.getNestedFieldsAccessed().isEmpty());
pipeline.run().waitUntilFinish();
}
use of org.apache.beam.sdk.transforms.ParDo in project beam by apache.
the class BeamCalcRelTest method testSingleFieldAccess.
@Test
public void testSingleFieldAccess() throws IllegalAccessException {
String sql = "SELECT order_id FROM ORDER_DETAILS_BOUNDED";
PCollection<Row> rows = compilePipeline(sql, pipeline);
final NodeGetter nodeGetter = new NodeGetter(rows);
pipeline.traverseTopologically(nodeGetter);
ParDo.MultiOutput<Row, Row> pardo = (ParDo.MultiOutput<Row, Row>) nodeGetter.producer.getTransform();
PCollection<Row> input = (PCollection<Row>) Iterables.getOnlyElement(nodeGetter.producer.getInputs().values());
DoFnSchemaInformation info = ParDo.getDoFnSchemaInformation(pardo.getFn(), input);
FieldAccessDescriptor fieldAccess = info.getFieldAccessDescriptor();
Assert.assertTrue(fieldAccess.referencesSingleField());
Assert.assertEquals("order_id", Iterables.getOnlyElement(fieldAccess.fieldNamesAccessed()));
pipeline.run().waitUntilFinish();
}
Aggregations