Search in sources :

Example 1 with SideInputBroadcast

use of org.apache.beam.runners.spark.structuredstreaming.translation.helpers.SideInputBroadcast in project beam by apache.

the class ParDoTranslatorBatch method createBroadcastSideInputs.

private static SideInputBroadcast createBroadcastSideInputs(List<PCollectionView<?>> sideInputs, AbstractTranslationContext context) {
    JavaSparkContext jsc = JavaSparkContext.fromSparkContext(context.getSparkSession().sparkContext());
    SideInputBroadcast sideInputBroadcast = new SideInputBroadcast();
    for (PCollectionView<?> sideInput : sideInputs) {
        Coder<? extends BoundedWindow> windowCoder = sideInput.getPCollection().getWindowingStrategy().getWindowFn().windowCoder();
        Coder<WindowedValue<?>> windowedValueCoder = (Coder<WindowedValue<?>>) (Coder<?>) WindowedValue.getFullCoder(sideInput.getPCollection().getCoder(), windowCoder);
        Dataset<WindowedValue<?>> broadcastSet = context.getSideInputDataSet(sideInput);
        List<WindowedValue<?>> valuesList = broadcastSet.collectAsList();
        List<byte[]> codedValues = new ArrayList<>();
        for (WindowedValue<?> v : valuesList) {
            codedValues.add(CoderHelpers.toByteArray(v, windowedValueCoder));
        }
        sideInputBroadcast.add(sideInput.getTagInternal().getId(), jsc.broadcast(codedValues), windowedValueCoder);
    }
    return sideInputBroadcast;
}
Also used : SerializableCoder(org.apache.beam.sdk.coders.SerializableCoder) Coder(org.apache.beam.sdk.coders.Coder) MultiOutputCoder(org.apache.beam.runners.spark.structuredstreaming.translation.helpers.MultiOutputCoder) SideInputBroadcast(org.apache.beam.runners.spark.structuredstreaming.translation.helpers.SideInputBroadcast) WindowedValue(org.apache.beam.sdk.util.WindowedValue) ArrayList(java.util.ArrayList) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext)

Example 2 with SideInputBroadcast

use of org.apache.beam.runners.spark.structuredstreaming.translation.helpers.SideInputBroadcast in project beam by apache.

the class ParDoTranslatorBatch method translateTransform.

@Override
public void translateTransform(PTransform<PCollection<InputT>, PCollectionTuple> transform, AbstractTranslationContext context) {
    String stepName = context.getCurrentTransform().getFullName();
    // Check for not supported advanced features
    // TODO: add support of Splittable DoFn
    DoFn<InputT, OutputT> doFn = getDoFn(context);
    checkState(!DoFnSignatures.isSplittable(doFn), "Not expected to directly translate splittable DoFn, should have been overridden: %s", doFn);
    // TODO: add support of states and timers
    checkState(!DoFnSignatures.isStateful(doFn), "States and timers are not supported for the moment.");
    checkState(!DoFnSignatures.requiresTimeSortedInput(doFn), "@RequiresTimeSortedInput is not " + "supported for the moment");
    DoFnSchemaInformation doFnSchemaInformation = ParDoTranslation.getSchemaInformation(context.getCurrentTransform());
    // Init main variables
    PValue input = context.getInput();
    Dataset<WindowedValue<InputT>> inputDataSet = context.getDataset(input);
    Map<TupleTag<?>, PCollection<?>> outputs = context.getOutputs();
    TupleTag<?> mainOutputTag = getTupleTag(context);
    List<TupleTag<?>> outputTags = new ArrayList<>(outputs.keySet());
    WindowingStrategy<?, ?> windowingStrategy = ((PCollection<InputT>) input).getWindowingStrategy();
    Coder<InputT> inputCoder = ((PCollection<InputT>) input).getCoder();
    Coder<? extends BoundedWindow> windowCoder = windowingStrategy.getWindowFn().windowCoder();
    // construct a map from side input to WindowingStrategy so that
    // the DoFn runner can map main-input windows to side input windows
    List<PCollectionView<?>> sideInputs = getSideInputs(context);
    Map<PCollectionView<?>, WindowingStrategy<?, ?>> sideInputStrategies = new HashMap<>();
    for (PCollectionView<?> sideInput : sideInputs) {
        sideInputStrategies.put(sideInput, sideInput.getPCollection().getWindowingStrategy());
    }
    SideInputBroadcast broadcastStateData = createBroadcastSideInputs(sideInputs, context);
    Map<TupleTag<?>, Coder<?>> outputCoderMap = context.getOutputCoders();
    MetricsContainerStepMapAccumulator metricsAccum = MetricsAccumulator.getInstance();
    List<TupleTag<?>> additionalOutputTags = new ArrayList<>();
    for (TupleTag<?> tag : outputTags) {
        if (!tag.equals(mainOutputTag)) {
            additionalOutputTags.add(tag);
        }
    }
    Map<String, PCollectionView<?>> sideInputMapping = ParDoTranslation.getSideInputMapping(context.getCurrentTransform());
    @SuppressWarnings("unchecked") DoFnFunction<InputT, OutputT> doFnWrapper = new DoFnFunction(metricsAccum, stepName, doFn, windowingStrategy, sideInputStrategies, context.getSerializableOptions(), additionalOutputTags, mainOutputTag, inputCoder, outputCoderMap, broadcastStateData, doFnSchemaInformation, sideInputMapping);
    MultiOutputCoder multipleOutputCoder = MultiOutputCoder.of(SerializableCoder.of(TupleTag.class), outputCoderMap, windowCoder);
    Dataset<Tuple2<TupleTag<?>, WindowedValue<?>>> allOutputs = inputDataSet.mapPartitions(doFnWrapper, EncoderHelpers.fromBeamCoder(multipleOutputCoder));
    if (outputs.entrySet().size() > 1) {
        allOutputs.persist();
        for (Map.Entry<TupleTag<?>, PCollection<?>> output : outputs.entrySet()) {
            pruneOutputFilteredByTag(context, allOutputs, output, windowCoder);
        }
    } else {
        Coder<OutputT> outputCoder = ((PCollection<OutputT>) outputs.get(mainOutputTag)).getCoder();
        Coder<WindowedValue<?>> windowedValueCoder = (Coder<WindowedValue<?>>) (Coder<?>) WindowedValue.getFullCoder(outputCoder, windowCoder);
        Dataset<WindowedValue<?>> outputDataset = allOutputs.map((MapFunction<Tuple2<TupleTag<?>, WindowedValue<?>>, WindowedValue<?>>) value -> value._2, EncoderHelpers.fromBeamCoder(windowedValueCoder));
        context.putDatasetWildcard(outputs.entrySet().iterator().next().getValue(), outputDataset);
    }
}
Also used : SideInputBroadcast(org.apache.beam.runners.spark.structuredstreaming.translation.helpers.SideInputBroadcast) SerializableCoder(org.apache.beam.sdk.coders.SerializableCoder) WindowedValue(org.apache.beam.sdk.util.WindowedValue) Dataset(org.apache.spark.sql.Dataset) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) Coder(org.apache.beam.sdk.coders.Coder) HashMap(java.util.HashMap) AbstractTranslationContext(org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext) DoFnSchemaInformation(org.apache.beam.sdk.transforms.DoFnSchemaInformation) ArrayList(java.util.ArrayList) PTransform(org.apache.beam.sdk.transforms.PTransform) DoFnSignatures(org.apache.beam.sdk.transforms.reflect.DoFnSignatures) EncoderHelpers(org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers) TupleTag(org.apache.beam.sdk.values.TupleTag) Map(java.util.Map) PCollectionTuple(org.apache.beam.sdk.values.PCollectionTuple) MultiOutputCoder(org.apache.beam.runners.spark.structuredstreaming.translation.helpers.MultiOutputCoder) CoderHelpers(org.apache.beam.runners.spark.structuredstreaming.translation.helpers.CoderHelpers) MapFunction(org.apache.spark.api.java.function.MapFunction) ParDoTranslation(org.apache.beam.runners.core.construction.ParDoTranslation) DoFn(org.apache.beam.sdk.transforms.DoFn) MetricsAccumulator(org.apache.beam.runners.spark.structuredstreaming.metrics.MetricsAccumulator) IOException(java.io.IOException) PCollection(org.apache.beam.sdk.values.PCollection) Tuple2(scala.Tuple2) List(java.util.List) PValue(org.apache.beam.sdk.values.PValue) Preconditions.checkState(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState) PCollectionView(org.apache.beam.sdk.values.PCollectionView) TransformTranslator(org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator) BoundedWindow(org.apache.beam.sdk.transforms.windowing.BoundedWindow) MetricsContainerStepMapAccumulator(org.apache.beam.runners.spark.structuredstreaming.metrics.MetricsContainerStepMapAccumulator) WindowingStrategy(org.apache.beam.sdk.values.WindowingStrategy) FilterFunction(org.apache.spark.api.java.function.FilterFunction) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) TupleTag(org.apache.beam.sdk.values.TupleTag) WindowingStrategy(org.apache.beam.sdk.values.WindowingStrategy) WindowedValue(org.apache.beam.sdk.util.WindowedValue) SerializableCoder(org.apache.beam.sdk.coders.SerializableCoder) Coder(org.apache.beam.sdk.coders.Coder) MultiOutputCoder(org.apache.beam.runners.spark.structuredstreaming.translation.helpers.MultiOutputCoder) SideInputBroadcast(org.apache.beam.runners.spark.structuredstreaming.translation.helpers.SideInputBroadcast) MultiOutputCoder(org.apache.beam.runners.spark.structuredstreaming.translation.helpers.MultiOutputCoder) PValue(org.apache.beam.sdk.values.PValue) MetricsContainerStepMapAccumulator(org.apache.beam.runners.spark.structuredstreaming.metrics.MetricsContainerStepMapAccumulator) PCollection(org.apache.beam.sdk.values.PCollection) PCollectionView(org.apache.beam.sdk.values.PCollectionView) DoFnSchemaInformation(org.apache.beam.sdk.transforms.DoFnSchemaInformation) Tuple2(scala.Tuple2) HashMap(java.util.HashMap) Map(java.util.Map)

Aggregations

ArrayList (java.util.ArrayList)2 MultiOutputCoder (org.apache.beam.runners.spark.structuredstreaming.translation.helpers.MultiOutputCoder)2 SideInputBroadcast (org.apache.beam.runners.spark.structuredstreaming.translation.helpers.SideInputBroadcast)2 Coder (org.apache.beam.sdk.coders.Coder)2 SerializableCoder (org.apache.beam.sdk.coders.SerializableCoder)2 WindowedValue (org.apache.beam.sdk.util.WindowedValue)2 JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)2 IOException (java.io.IOException)1 HashMap (java.util.HashMap)1 List (java.util.List)1 Map (java.util.Map)1 ParDoTranslation (org.apache.beam.runners.core.construction.ParDoTranslation)1 MetricsAccumulator (org.apache.beam.runners.spark.structuredstreaming.metrics.MetricsAccumulator)1 MetricsContainerStepMapAccumulator (org.apache.beam.runners.spark.structuredstreaming.metrics.MetricsContainerStepMapAccumulator)1 AbstractTranslationContext (org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext)1 TransformTranslator (org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator)1 CoderHelpers (org.apache.beam.runners.spark.structuredstreaming.translation.helpers.CoderHelpers)1 EncoderHelpers (org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers)1 DoFn (org.apache.beam.sdk.transforms.DoFn)1 DoFnSchemaInformation (org.apache.beam.sdk.transforms.DoFnSchemaInformation)1