Search in sources :

Example 31 with PCollectionNode

use of org.apache.beam.runners.core.construction.graph.PipelineNode.PCollectionNode in project beam by apache.

the class FlinkBatchPortablePipelineTranslator method translateExecutableStage.

private static <InputT> void translateExecutableStage(PTransformNode transform, RunnerApi.Pipeline pipeline, BatchTranslationContext context) {
    // TODO: Fail on splittable DoFns.
    // TODO: Special-case single outputs to avoid multiplexing PCollections.
    RunnerApi.Components components = pipeline.getComponents();
    Map<String, String> outputs = transform.getTransform().getOutputsMap();
    // Mapping from PCollection id to coder tag id.
    BiMap<String, Integer> outputMap = createOutputMap(outputs.values());
    // Collect all output Coders and create a UnionCoder for our tagged outputs.
    List<Coder<?>> unionCoders = Lists.newArrayList();
    // Enforce tuple tag sorting by union tag index.
    Map<String, Coder<WindowedValue<?>>> outputCoders = Maps.newHashMap();
    for (String collectionId : new TreeMap<>(outputMap.inverse()).values()) {
        PCollectionNode collectionNode = PipelineNode.pCollection(collectionId, components.getPcollectionsOrThrow(collectionId));
        Coder<WindowedValue<?>> coder;
        try {
            coder = (Coder) WireCoders.instantiateRunnerWireCoder(collectionNode, components);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
        outputCoders.put(collectionId, coder);
        unionCoders.add(coder);
    }
    UnionCoder unionCoder = UnionCoder.of(unionCoders);
    TypeInformation<RawUnionValue> typeInformation = new CoderTypeInformation<>(unionCoder, context.getPipelineOptions());
    RunnerApi.ExecutableStagePayload stagePayload;
    try {
        stagePayload = RunnerApi.ExecutableStagePayload.parseFrom(transform.getTransform().getSpec().getPayload());
    } catch (IOException e) {
        throw new RuntimeException(e);
    }
    String inputPCollectionId = stagePayload.getInput();
    Coder<WindowedValue<InputT>> windowedInputCoder = instantiateCoder(inputPCollectionId, components);
    DataSet<WindowedValue<InputT>> inputDataSet = context.getDataSetOrThrow(inputPCollectionId);
    final FlinkExecutableStageFunction<InputT> function = new FlinkExecutableStageFunction<>(transform.getTransform().getUniqueName(), context.getPipelineOptions(), stagePayload, context.getJobInfo(), outputMap, FlinkExecutableStageContextFactory.getInstance(), getWindowingStrategy(inputPCollectionId, components).getWindowFn().windowCoder(), windowedInputCoder);
    final String operatorName = generateNameFromStagePayload(stagePayload);
    final SingleInputUdfOperator taggedDataset;
    if (stagePayload.getUserStatesCount() > 0 || stagePayload.getTimersCount() > 0) {
        Coder valueCoder = ((WindowedValue.FullWindowedValueCoder) windowedInputCoder).getValueCoder();
        // Stateful stages are only allowed of KV input to be able to group on the key
        if (!(valueCoder instanceof KvCoder)) {
            throw new IllegalStateException(String.format(Locale.ENGLISH, "The element coder for stateful DoFn '%s' must be KvCoder but is: %s", inputPCollectionId, valueCoder.getClass().getSimpleName()));
        }
        Coder keyCoder = ((KvCoder) valueCoder).getKeyCoder();
        Grouping<WindowedValue<InputT>> groupedInput = inputDataSet.groupBy(new KvKeySelector<>(keyCoder));
        boolean requiresTimeSortedInput = requiresTimeSortedInput(stagePayload, false);
        if (requiresTimeSortedInput) {
            groupedInput = ((UnsortedGrouping<WindowedValue<InputT>>) groupedInput).sortGroup(WindowedValue::getTimestamp, Order.ASCENDING);
        }
        taggedDataset = new GroupReduceOperator<>(groupedInput, typeInformation, function, operatorName);
    } else {
        taggedDataset = new MapPartitionOperator<>(inputDataSet, typeInformation, function, operatorName);
    }
    for (SideInputId sideInputId : stagePayload.getSideInputsList()) {
        String collectionId = stagePayload.getComponents().getTransformsOrThrow(sideInputId.getTransformId()).getInputsOrThrow(sideInputId.getLocalName());
        // Register under the global PCollection name. Only ExecutableStageFunction needs to know the
        // mapping from local name to global name and how to translate the broadcast data to a state
        // API view.
        taggedDataset.withBroadcastSet(context.getDataSetOrThrow(collectionId), collectionId);
    }
    for (String collectionId : outputs.values()) {
        pruneOutput(taggedDataset, context, outputMap.get(collectionId), outputCoders.get(collectionId), collectionId);
    }
    if (outputs.isEmpty()) {
        // NOTE: After pipeline translation, we traverse the set of unconsumed PCollections and add a
        // no-op sink to each to make sure they are materialized by Flink. However, some SDK-executed
        // stages have no runner-visible output after fusion. We handle this case by adding a sink
        // here.
        taggedDataset.output(new DiscardingOutputFormat<>()).name("DiscardingOutput");
    }
}
Also used : DiscardingOutputFormat(org.apache.flink.api.java.io.DiscardingOutputFormat) RunnerApi(org.apache.beam.model.pipeline.v1.RunnerApi) FlinkExecutableStageFunction(org.apache.beam.runners.flink.translation.functions.FlinkExecutableStageFunction) WindowedValue(org.apache.beam.sdk.util.WindowedValue) SideInputId(org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload.SideInputId) CoderTypeInformation(org.apache.beam.runners.flink.translation.types.CoderTypeInformation) WindowedValueCoder(org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder) KvCoder(org.apache.beam.sdk.coders.KvCoder) UnionCoder(org.apache.beam.sdk.transforms.join.UnionCoder) PipelineTranslatorUtils.instantiateCoder(org.apache.beam.runners.fnexecution.translation.PipelineTranslatorUtils.instantiateCoder) Coder(org.apache.beam.sdk.coders.Coder) ByteArrayCoder(org.apache.beam.sdk.coders.ByteArrayCoder) VoidCoder(org.apache.beam.sdk.coders.VoidCoder) UnionCoder(org.apache.beam.sdk.transforms.join.UnionCoder) RawUnionValue(org.apache.beam.sdk.transforms.join.RawUnionValue) SingleInputUdfOperator(org.apache.flink.api.java.operators.SingleInputUdfOperator) KvCoder(org.apache.beam.sdk.coders.KvCoder) IOException(java.io.IOException) PCollectionNode(org.apache.beam.runners.core.construction.graph.PipelineNode.PCollectionNode)

Example 32 with PCollectionNode

use of org.apache.beam.runners.core.construction.graph.PipelineNode.PCollectionNode in project beam by apache.

the class FlinkBatchPortablePipelineTranslator method translateGroupByKey.

private static <K, V> void translateGroupByKey(PTransformNode transform, RunnerApi.Pipeline pipeline, BatchTranslationContext context) {
    RunnerApi.Components components = pipeline.getComponents();
    String inputPCollectionId = Iterables.getOnlyElement(transform.getTransform().getInputsMap().values());
    PCollectionNode inputCollection = PipelineNode.pCollection(inputPCollectionId, components.getPcollectionsOrThrow(inputPCollectionId));
    DataSet<WindowedValue<KV<K, V>>> inputDataSet = context.getDataSetOrThrow(inputPCollectionId);
    RunnerApi.WindowingStrategy windowingStrategyProto = pipeline.getComponents().getWindowingStrategiesOrThrow(pipeline.getComponents().getPcollectionsOrThrow(inputPCollectionId).getWindowingStrategyId());
    RehydratedComponents rehydratedComponents = RehydratedComponents.forComponents(pipeline.getComponents());
    WindowingStrategy<Object, BoundedWindow> windowingStrategy;
    try {
        windowingStrategy = (WindowingStrategy<Object, BoundedWindow>) WindowingStrategyTranslation.fromProto(windowingStrategyProto, rehydratedComponents);
    } catch (InvalidProtocolBufferException e) {
        throw new IllegalStateException(String.format("Unable to hydrate GroupByKey windowing strategy %s.", windowingStrategyProto), e);
    }
    WindowedValueCoder<KV<K, V>> inputCoder;
    try {
        inputCoder = (WindowedValueCoder) WireCoders.instantiateRunnerWireCoder(inputCollection, pipeline.getComponents());
    } catch (IOException e) {
        throw new RuntimeException(e);
    }
    KvCoder<K, V> inputElementCoder = (KvCoder<K, V>) inputCoder.getValueCoder();
    Concatenate<V> combineFn = new Concatenate<>();
    Coder<List<V>> accumulatorCoder = combineFn.getAccumulatorCoder(CoderRegistry.createDefault(), inputElementCoder.getValueCoder());
    Coder<WindowedValue<KV<K, List<V>>>> outputCoder = WindowedValue.getFullCoder(KvCoder.of(inputElementCoder.getKeyCoder(), accumulatorCoder), windowingStrategy.getWindowFn().windowCoder());
    TypeInformation<WindowedValue<KV<K, List<V>>>> partialReduceTypeInfo = new CoderTypeInformation<>(outputCoder, context.getPipelineOptions());
    Grouping<WindowedValue<KV<K, V>>> inputGrouping = inputDataSet.groupBy(new KvKeySelector<>(inputElementCoder.getKeyCoder()));
    FlinkPartialReduceFunction<K, V, List<V>, ?> partialReduceFunction = new FlinkPartialReduceFunction<>(combineFn, windowingStrategy, Collections.emptyMap(), context.getPipelineOptions());
    FlinkReduceFunction<K, List<V>, List<V>, ?> reduceFunction = new FlinkReduceFunction<>(combineFn, windowingStrategy, Collections.emptyMap(), context.getPipelineOptions());
    // Partially GroupReduce the values into the intermediate format AccumT (combine)
    GroupCombineOperator<WindowedValue<KV<K, V>>, WindowedValue<KV<K, List<V>>>> groupCombine = new GroupCombineOperator<>(inputGrouping, partialReduceTypeInfo, partialReduceFunction, "GroupCombine: " + transform.getTransform().getUniqueName());
    Grouping<WindowedValue<KV<K, List<V>>>> intermediateGrouping = groupCombine.groupBy(new KvKeySelector<>(inputElementCoder.getKeyCoder()));
    // Fully reduce the values and create output format VO
    GroupReduceOperator<WindowedValue<KV<K, List<V>>>, WindowedValue<KV<K, List<V>>>> outputDataSet = new GroupReduceOperator<>(intermediateGrouping, partialReduceTypeInfo, reduceFunction, transform.getTransform().getUniqueName());
    context.addDataSet(Iterables.getOnlyElement(transform.getTransform().getOutputsMap().values()), outputDataSet);
}
Also used : RunnerApi(org.apache.beam.model.pipeline.v1.RunnerApi) WindowedValue(org.apache.beam.sdk.util.WindowedValue) KV(org.apache.beam.sdk.values.KV) BoundedWindow(org.apache.beam.sdk.transforms.windowing.BoundedWindow) FlinkReduceFunction(org.apache.beam.runners.flink.translation.functions.FlinkReduceFunction) List(java.util.List) GroupCombineOperator(org.apache.flink.api.java.operators.GroupCombineOperator) CoderTypeInformation(org.apache.beam.runners.flink.translation.types.CoderTypeInformation) InvalidProtocolBufferException(org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.InvalidProtocolBufferException) FlinkPartialReduceFunction(org.apache.beam.runners.flink.translation.functions.FlinkPartialReduceFunction) KvCoder(org.apache.beam.sdk.coders.KvCoder) KV(org.apache.beam.sdk.values.KV) IOException(java.io.IOException) PCollectionNode(org.apache.beam.runners.core.construction.graph.PipelineNode.PCollectionNode) GroupReduceOperator(org.apache.flink.api.java.operators.GroupReduceOperator) Concatenate(org.apache.beam.runners.core.Concatenate) RehydratedComponents(org.apache.beam.runners.core.construction.RehydratedComponents)

Example 33 with PCollectionNode

use of org.apache.beam.runners.core.construction.graph.PipelineNode.PCollectionNode in project beam by apache.

the class BatchSideInputHandlerFactory method forMultimapSideInput.

@Override
public <K, V, W extends BoundedWindow> MultimapSideInputHandler<K, V, W> forMultimapSideInput(String transformId, String sideInputId, KvCoder<K, V> elementCoder, Coder<W> windowCoder) {
    PCollectionNode collectionNode = sideInputToCollection.get(SideInputId.newBuilder().setTransformId(transformId).setLocalName(sideInputId).build());
    checkArgument(collectionNode != null, "No side input for %s/%s", transformId, sideInputId);
    Coder<K> keyCoder = elementCoder.getKeyCoder();
    Map<Object, Map<Object, KV<K, List<V>>>> /* structural key */
    data = new HashMap<>();
    List<WindowedValue<KV<K, V>>> broadcastVariable = sideInputGetter.getSideInput(collectionNode.getId());
    for (WindowedValue<KV<K, V>> windowedValue : broadcastVariable) {
        K key = windowedValue.getValue().getKey();
        V value = windowedValue.getValue().getValue();
        for (BoundedWindow boundedWindow : windowedValue.getWindows()) {
            @SuppressWarnings("unchecked") W window = (W) boundedWindow;
            Object structuralW = windowCoder.structuralValue(window);
            Object structuralK = keyCoder.structuralValue(key);
            KV<K, List<V>> records = data.computeIfAbsent(structuralW, o -> new HashMap<>()).computeIfAbsent(structuralK, o -> KV.of(key, new ArrayList<>()));
            records.getValue().add(value);
        }
    }
    return new MultimapSideInputHandler<K, V, W>() {

        @Override
        public Iterable<V> get(K key, W window) {
            KV<K, List<V>> records = data.getOrDefault(windowCoder.structuralValue(window), Collections.emptyMap()).get(keyCoder.structuralValue(key));
            if (records == null) {
                return Collections.emptyList();
            }
            return Collections.unmodifiableList(records.getValue());
        }

        @Override
        public Coder<V> valueCoder() {
            return elementCoder.getValueCoder();
        }

        @Override
        public Iterable<K> get(W window) {
            Map<Object, KV<K, List<V>>> records = data.getOrDefault(windowCoder.structuralValue(window), Collections.emptyMap());
            return Iterables.unmodifiableIterable(FluentIterable.concat(records.values()).transform(kListKV -> kListKV.getKey()));
        }

        @Override
        public Coder<K> keyCoder() {
            return elementCoder.getKeyCoder();
        }
    };
}
Also used : KvCoder(org.apache.beam.sdk.coders.KvCoder) KV(org.apache.beam.sdk.values.KV) WindowedValue(org.apache.beam.sdk.util.WindowedValue) ImmutableMultimap(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMultimap) IterableSideInputHandler(org.apache.beam.runners.fnexecution.state.StateRequestHandlers.IterableSideInputHandler) SideInputHandlerFactory(org.apache.beam.runners.fnexecution.state.StateRequestHandlers.SideInputHandlerFactory) Coder(org.apache.beam.sdk.coders.Coder) HashMap(java.util.HashMap) ImmutableMap(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap) ExecutableStage(org.apache.beam.runners.core.construction.graph.ExecutableStage) ArrayList(java.util.ArrayList) MultimapSideInputHandler(org.apache.beam.runners.fnexecution.state.StateRequestHandlers.MultimapSideInputHandler) List(java.util.List) SideInputReference(org.apache.beam.runners.core.construction.graph.SideInputReference) Map(java.util.Map) BoundedWindow(org.apache.beam.sdk.transforms.windowing.BoundedWindow) Iterables(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables) Preconditions.checkArgument(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument) FluentIterable(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.FluentIterable) SideInputId(org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload.SideInputId) PCollectionNode(org.apache.beam.runners.core.construction.graph.PipelineNode.PCollectionNode) StateRequestHandler(org.apache.beam.runners.fnexecution.state.StateRequestHandler) Collections(java.util.Collections) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) KV(org.apache.beam.sdk.values.KV) PCollectionNode(org.apache.beam.runners.core.construction.graph.PipelineNode.PCollectionNode) WindowedValue(org.apache.beam.sdk.util.WindowedValue) KV(org.apache.beam.sdk.values.KV) BoundedWindow(org.apache.beam.sdk.transforms.windowing.BoundedWindow) ArrayList(java.util.ArrayList) List(java.util.List) MultimapSideInputHandler(org.apache.beam.runners.fnexecution.state.StateRequestHandlers.MultimapSideInputHandler) HashMap(java.util.HashMap) ImmutableMap(org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap) Map(java.util.Map)

Example 34 with PCollectionNode

use of org.apache.beam.runners.core.construction.graph.PipelineNode.PCollectionNode in project beam by apache.

the class BatchSideInputHandlerFactoryTest method createExecutableStage.

private static ExecutableStage createExecutableStage(Collection<SideInputReference> sideInputs) {
    Components components = Components.getDefaultInstance();
    Environment environment = Environment.getDefaultInstance();
    PCollectionNode inputCollection = PipelineNode.pCollection("collection-id", RunnerApi.PCollection.getDefaultInstance());
    return ImmutableExecutableStage.of(components, environment, inputCollection, sideInputs, Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), DEFAULT_WIRE_CODER_SETTINGS);
}
Also used : Components(org.apache.beam.model.pipeline.v1.RunnerApi.Components) Environment(org.apache.beam.model.pipeline.v1.RunnerApi.Environment) PCollectionNode(org.apache.beam.runners.core.construction.graph.PipelineNode.PCollectionNode)

Aggregations

PCollectionNode (org.apache.beam.runners.core.construction.graph.PipelineNode.PCollectionNode)34 PTransformNode (org.apache.beam.runners.core.construction.graph.PipelineNode.PTransformNode)22 PTransform (org.apache.beam.model.pipeline.v1.RunnerApi.PTransform)19 Environment (org.apache.beam.model.pipeline.v1.RunnerApi.Environment)14 Components (org.apache.beam.model.pipeline.v1.RunnerApi.Components)12 PCollection (org.apache.beam.model.pipeline.v1.RunnerApi.PCollection)12 Test (org.junit.Test)12 RunnerApi (org.apache.beam.model.pipeline.v1.RunnerApi)10 Map (java.util.Map)9 Collection (java.util.Collection)7 LinkedHashSet (java.util.LinkedHashSet)7 ArrayList (java.util.ArrayList)6 HashSet (java.util.HashSet)6 Collectors (java.util.stream.Collectors)6 List (java.util.List)5 PTransformTranslation (org.apache.beam.runners.core.construction.PTransformTranslation)5 DeduplicationResult (org.apache.beam.runners.core.construction.graph.OutputDeduplicator.DeduplicationResult)5 ImmutableList (org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList)5 HashMap (java.util.HashMap)4 TreeSet (java.util.TreeSet)4