use of org.apache.beam.runners.spark.metrics.MetricsContainerStepMapAccumulator in project beam by apache.
the class StreamingTransformTranslator method parDo.
private static <InputT, OutputT> TransformEvaluator<ParDo.MultiOutput<InputT, OutputT>> parDo() {
return new TransformEvaluator<ParDo.MultiOutput<InputT, OutputT>>() {
@Override
public void evaluate(final ParDo.MultiOutput<InputT, OutputT> transform, final EvaluationContext context) {
final DoFn<InputT, OutputT> doFn = transform.getFn();
checkArgument(!DoFnSignatures.signatureForDoFn(doFn).processElement().isSplittable(), "Splittable DoFn not yet supported in streaming mode: %s", doFn);
rejectStateAndTimers(doFn);
final SerializablePipelineOptions options = context.getSerializableOptions();
final SparkPCollectionView pviews = context.getPViews();
final WindowingStrategy<?, ?> windowingStrategy = context.getInput(transform).getWindowingStrategy();
Coder<InputT> inputCoder = (Coder<InputT>) context.getInput(transform).getCoder();
Map<TupleTag<?>, Coder<?>> outputCoders = context.getOutputCoders();
@SuppressWarnings("unchecked") UnboundedDataset<InputT> unboundedDataset = (UnboundedDataset<InputT>) context.borrowDataset(transform);
JavaDStream<WindowedValue<InputT>> dStream = unboundedDataset.getDStream();
final DoFnSchemaInformation doFnSchemaInformation = ParDoTranslation.getSchemaInformation(context.getCurrentTransform());
final Map<String, PCollectionView<?>> sideInputMapping = ParDoTranslation.getSideInputMapping(context.getCurrentTransform());
final String stepName = context.getCurrentTransform().getFullName();
JavaPairDStream<TupleTag<?>, WindowedValue<?>> all = dStream.transformToPair(rdd -> {
final MetricsContainerStepMapAccumulator metricsAccum = MetricsAccumulator.getInstance();
final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> sideInputs = TranslationUtils.getSideInputs(transform.getSideInputs().values(), JavaSparkContext.fromSparkContext(rdd.context()), pviews);
return rdd.mapPartitionsToPair(new MultiDoFnFunction<>(metricsAccum, stepName, doFn, options, transform.getMainOutputTag(), transform.getAdditionalOutputTags().getAll(), inputCoder, outputCoders, sideInputs, windowingStrategy, false, doFnSchemaInformation, sideInputMapping));
});
Map<TupleTag<?>, PCollection<?>> outputs = context.getOutputs(transform);
if (outputs.size() > 1) {
// Caching can cause Serialization, we need to code to bytes
// more details in https://issues.apache.org/jira/browse/BEAM-2669
Map<TupleTag<?>, Coder<WindowedValue<?>>> coderMap = TranslationUtils.getTupleTagCoders(outputs);
all = all.mapToPair(TranslationUtils.getTupleTagEncodeFunction(coderMap)).cache().mapToPair(TranslationUtils.getTupleTagDecodeFunction(coderMap));
}
for (Map.Entry<TupleTag<?>, PCollection<?>> output : outputs.entrySet()) {
@SuppressWarnings("unchecked") JavaPairDStream<TupleTag<?>, WindowedValue<?>> filtered = all.filter(new TranslationUtils.TupleTagFilter(output.getKey()));
@SuppressWarnings("unchecked") JavaDStream<WindowedValue<Object>> // Object is the best we can do since different outputs can have different tags
values = (JavaDStream<WindowedValue<Object>>) (JavaDStream<?>) TranslationUtils.dStreamValues(filtered);
context.putDataset(output.getValue(), new UnboundedDataset<>(values, unboundedDataset.getStreamSources()));
}
}
@Override
public String toNativeString() {
return "mapPartitions(new <fn>())";
}
};
}
use of org.apache.beam.runners.spark.metrics.MetricsContainerStepMapAccumulator in project beam by apache.
the class TransformTranslator method parDo.
private static <InputT, OutputT> TransformEvaluator<ParDo.MultiOutput<InputT, OutputT>> parDo() {
return new TransformEvaluator<ParDo.MultiOutput<InputT, OutputT>>() {
@Override
@SuppressWarnings("unchecked")
public void evaluate(ParDo.MultiOutput<InputT, OutputT> transform, EvaluationContext context) {
String stepName = context.getCurrentTransform().getFullName();
DoFn<InputT, OutputT> doFn = transform.getFn();
checkState(!DoFnSignatures.signatureForDoFn(doFn).processElement().isSplittable(), "Not expected to directly translate splittable DoFn, should have been overridden: %s", doFn);
JavaRDD<WindowedValue<InputT>> inRDD = ((BoundedDataset<InputT>) context.borrowDataset(transform)).getRDD();
WindowingStrategy<?, ?> windowingStrategy = context.getInput(transform).getWindowingStrategy();
MetricsContainerStepMapAccumulator metricsAccum = MetricsAccumulator.getInstance();
Coder<InputT> inputCoder = (Coder<InputT>) context.getInput(transform).getCoder();
Map<TupleTag<?>, Coder<?>> outputCoders = context.getOutputCoders();
JavaPairRDD<TupleTag<?>, WindowedValue<?>> all;
DoFnSignature signature = DoFnSignatures.getSignature(transform.getFn().getClass());
boolean stateful = signature.stateDeclarations().size() > 0 || signature.timerDeclarations().size() > 0;
DoFnSchemaInformation doFnSchemaInformation;
doFnSchemaInformation = ParDoTranslation.getSchemaInformation(context.getCurrentTransform());
Map<String, PCollectionView<?>> sideInputMapping = ParDoTranslation.getSideInputMapping(context.getCurrentTransform());
MultiDoFnFunction<InputT, OutputT> multiDoFnFunction = new MultiDoFnFunction<>(metricsAccum, stepName, doFn, context.getSerializableOptions(), transform.getMainOutputTag(), transform.getAdditionalOutputTags().getAll(), inputCoder, outputCoders, TranslationUtils.getSideInputs(transform.getSideInputs().values(), context), windowingStrategy, stateful, doFnSchemaInformation, sideInputMapping);
if (stateful) {
// Based on the fact that the signature is stateful, DoFnSignatures ensures
// that it is also keyed
all = statefulParDoTransform((KvCoder) context.getInput(transform).getCoder(), windowingStrategy.getWindowFn().windowCoder(), (JavaRDD) inRDD, getPartitioner(context), (MultiDoFnFunction) multiDoFnFunction, signature.processElement().requiresTimeSortedInput());
} else {
all = inRDD.mapPartitionsToPair(multiDoFnFunction);
}
Map<TupleTag<?>, PCollection<?>> outputs = context.getOutputs(transform);
if (outputs.size() > 1) {
StorageLevel level = StorageLevel.fromString(context.storageLevel());
if (canAvoidRddSerialization(level)) {
// if it is memory only reduce the overhead of moving to bytes
all = all.persist(level);
} else {
// Caching can cause Serialization, we need to code to bytes
// more details in https://issues.apache.org/jira/browse/BEAM-2669
Map<TupleTag<?>, Coder<WindowedValue<?>>> coderMap = TranslationUtils.getTupleTagCoders(outputs);
all = all.mapToPair(TranslationUtils.getTupleTagEncodeFunction(coderMap)).persist(level).mapToPair(TranslationUtils.getTupleTagDecodeFunction(coderMap));
}
}
for (Map.Entry<TupleTag<?>, PCollection<?>> output : outputs.entrySet()) {
JavaPairRDD<TupleTag<?>, WindowedValue<?>> filtered = all.filter(new TranslationUtils.TupleTagFilter(output.getKey()));
// Object is the best we can do since different outputs can have different tags
JavaRDD<WindowedValue<Object>> values = (JavaRDD<WindowedValue<Object>>) (JavaRDD<?>) filtered.values();
context.putDataset(output.getValue(), new BoundedDataset<>(values));
}
}
@Override
public String toNativeString() {
return "mapPartitions(new <fn>())";
}
};
}
Aggregations