use of org.apache.beam.model.pipeline.v1.RunnerApi.Coder in project beam by apache.
the class GroupAlsoByWindowParDoFnFactory method create.
@Override
public ParDoFn create(PipelineOptions options, CloudObject cloudUserFn, @Nullable List<SideInputInfo> sideInputInfos, TupleTag<?> mainOutputTag, Map<TupleTag<?>, Integer> outputTupleTagsToReceiverIndices, final DataflowExecutionContext<?> executionContext, DataflowOperationContext operationContext) throws Exception {
Map.Entry<TupleTag<?>, Integer> entry = Iterables.getOnlyElement(outputTupleTagsToReceiverIndices.entrySet());
checkArgument(entry.getKey().equals(mainOutputTag), "Output tags should reference only the main output tag: %s vs %s", entry.getKey(), mainOutputTag);
checkArgument(entry.getValue() == 0, "There should be a single receiver, but using receiver index %s", entry.getValue());
byte[] encodedWindowingStrategy = getBytes(cloudUserFn, PropertyNames.SERIALIZED_FN);
WindowingStrategy windowingStrategy;
try {
windowingStrategy = deserializeWindowingStrategy(encodedWindowingStrategy);
} catch (Exception e) {
// TODO: Catch block disappears, becoming an error once Python SDK is compliant.
if (DataflowRunner.hasExperiment(options.as(DataflowPipelineDebugOptions.class), "beam_fn_api")) {
LOG.info("FnAPI: Unable to deserialize windowing strategy, assuming default", e);
windowingStrategy = WindowingStrategy.globalDefault();
} else {
throw e;
}
}
byte[] serializedCombineFn = getBytes(cloudUserFn, WorkerPropertyNames.COMBINE_FN, null);
AppliedCombineFn<?, ?, ?, ?> combineFn = null;
if (serializedCombineFn != null) {
Object combineFnObj = SerializableUtils.deserializeFromByteArray(serializedCombineFn, "serialized combine fn");
checkArgument(combineFnObj instanceof AppliedCombineFn, "unexpected kind of AppliedCombineFn: " + combineFnObj.getClass().getName());
combineFn = (AppliedCombineFn<?, ?, ?, ?>) combineFnObj;
}
Map<String, Object> inputCoderObject = getObject(cloudUserFn, WorkerPropertyNames.INPUT_CODER);
Coder<?> inputCoder = CloudObjects.coderFromCloudObject(CloudObject.fromSpec(inputCoderObject));
checkArgument(inputCoder instanceof WindowedValueCoder, "Expected WindowedValueCoder for inputCoder, got: " + inputCoder.getClass().getName());
@SuppressWarnings("unchecked") WindowedValueCoder<?> windowedValueCoder = (WindowedValueCoder<?>) inputCoder;
Coder<?> elemCoder = windowedValueCoder.getValueCoder();
checkArgument(elemCoder instanceof KvCoder, "Expected KvCoder for inputCoder, got: " + elemCoder.getClass().getName());
@SuppressWarnings("unchecked") KvCoder<?, ?> kvCoder = (KvCoder<?, ?>) elemCoder;
boolean isStreamingPipeline = options.as(StreamingOptions.class).isStreaming();
SideInputReader sideInputReader = NullSideInputReader.empty();
@Nullable AppliedCombineFn<?, ?, ?, ?> maybeMergingCombineFn = null;
if (combineFn != null) {
sideInputReader = executionContext.getSideInputReader(sideInputInfos, combineFn.getSideInputViews(), operationContext);
String phase = getString(cloudUserFn, WorkerPropertyNames.PHASE, CombinePhase.ALL);
checkArgument(phase.equals(CombinePhase.ALL) || phase.equals(CombinePhase.MERGE), "Unexpected phase: %s", phase);
if (phase.equals(CombinePhase.MERGE)) {
maybeMergingCombineFn = makeAppliedMergingFunction(combineFn);
} else {
maybeMergingCombineFn = combineFn;
}
}
StateInternalsFactory<?> stateInternalsFactory = key -> executionContext.getStepContext(operationContext).stateInternals();
// This will be a GABW Fn for either batch or streaming, with combiner in it or not
GroupAlsoByWindowFn<?, ?> fn;
// This will be a FakeKeyedWorkItemCoder for streaming or null for batch
Coder<?> gabwInputCoder;
// TODO: do not do this with mess of "if"
if (isStreamingPipeline) {
if (maybeMergingCombineFn == null) {
fn = StreamingGroupAlsoByWindowsDoFns.createForIterable(windowingStrategy, stateInternalsFactory, ((KvCoder) kvCoder).getValueCoder());
gabwInputCoder = WindmillKeyedWorkItem.FakeKeyedWorkItemCoder.of(kvCoder);
} else {
fn = StreamingGroupAlsoByWindowsDoFns.create(windowingStrategy, stateInternalsFactory, (AppliedCombineFn) maybeMergingCombineFn, ((KvCoder) kvCoder).getKeyCoder());
gabwInputCoder = WindmillKeyedWorkItem.FakeKeyedWorkItemCoder.of(((AppliedCombineFn) maybeMergingCombineFn).getKvCoder());
}
} else {
if (maybeMergingCombineFn == null) {
fn = BatchGroupAlsoByWindowsDoFns.createForIterable(windowingStrategy, stateInternalsFactory, ((KvCoder) kvCoder).getValueCoder());
gabwInputCoder = null;
} else {
fn = BatchGroupAlsoByWindowsDoFns.create(windowingStrategy, (AppliedCombineFn) maybeMergingCombineFn);
gabwInputCoder = null;
}
}
// TODO: or anyhow related to it, do not do this with mess of "if"
if (maybeMergingCombineFn != null) {
return new GroupAlsoByWindowsParDoFn(options, fn, windowingStrategy, ((AppliedCombineFn) maybeMergingCombineFn).getSideInputViews(), gabwInputCoder, sideInputReader, mainOutputTag, executionContext.getStepContext(operationContext));
} else {
return new GroupAlsoByWindowsParDoFn(options, fn, windowingStrategy, null, gabwInputCoder, sideInputReader, mainOutputTag, executionContext.getStepContext(operationContext));
}
}
use of org.apache.beam.model.pipeline.v1.RunnerApi.Coder in project beam by apache.
the class BatchSideInputHandlerFactory method forMultimapSideInput.
@Override
public <K, V, W extends BoundedWindow> MultimapSideInputHandler<K, V, W> forMultimapSideInput(String transformId, String sideInputId, KvCoder<K, V> elementCoder, Coder<W> windowCoder) {
PCollectionNode collectionNode = sideInputToCollection.get(SideInputId.newBuilder().setTransformId(transformId).setLocalName(sideInputId).build());
checkArgument(collectionNode != null, "No side input for %s/%s", transformId, sideInputId);
Coder<K> keyCoder = elementCoder.getKeyCoder();
Map<Object, Map<Object, KV<K, List<V>>>> /* structural key */
data = new HashMap<>();
List<WindowedValue<KV<K, V>>> broadcastVariable = sideInputGetter.getSideInput(collectionNode.getId());
for (WindowedValue<KV<K, V>> windowedValue : broadcastVariable) {
K key = windowedValue.getValue().getKey();
V value = windowedValue.getValue().getValue();
for (BoundedWindow boundedWindow : windowedValue.getWindows()) {
@SuppressWarnings("unchecked") W window = (W) boundedWindow;
Object structuralW = windowCoder.structuralValue(window);
Object structuralK = keyCoder.structuralValue(key);
KV<K, List<V>> records = data.computeIfAbsent(structuralW, o -> new HashMap<>()).computeIfAbsent(structuralK, o -> KV.of(key, new ArrayList<>()));
records.getValue().add(value);
}
}
return new MultimapSideInputHandler<K, V, W>() {
@Override
public Iterable<V> get(K key, W window) {
KV<K, List<V>> records = data.getOrDefault(windowCoder.structuralValue(window), Collections.emptyMap()).get(keyCoder.structuralValue(key));
if (records == null) {
return Collections.emptyList();
}
return Collections.unmodifiableList(records.getValue());
}
@Override
public Coder<V> valueCoder() {
return elementCoder.getValueCoder();
}
@Override
public Iterable<K> get(W window) {
Map<Object, KV<K, List<V>>> records = data.getOrDefault(windowCoder.structuralValue(window), Collections.emptyMap());
return Iterables.unmodifiableIterable(FluentIterable.concat(records.values()).transform(kListKV -> kListKV.getKey()));
}
@Override
public Coder<K> keyCoder() {
return elementCoder.getKeyCoder();
}
};
}
use of org.apache.beam.model.pipeline.v1.RunnerApi.Coder in project beam by apache.
the class ProcessBundleDescriptors method addStageInput.
private static RemoteInputDestination<WindowedValue<?>> addStageInput(ApiServiceDescriptor dataEndpoint, PCollectionNode inputPCollection, Components.Builder components, WireCoderSetting wireCoderSetting) throws IOException {
String inputWireCoderId = WireCoders.addSdkWireCoder(inputPCollection, components, wireCoderSetting);
@SuppressWarnings("unchecked") Coder<WindowedValue<?>> wireCoder = (Coder) WireCoders.instantiateRunnerWireCoder(inputPCollection, components.build(), wireCoderSetting);
RemoteGrpcPort inputPort = RemoteGrpcPort.newBuilder().setApiServiceDescriptor(dataEndpoint).setCoderId(inputWireCoderId).build();
String inputId = uniqueId(String.format("fn/read/%s", inputPCollection.getId()), components::containsTransforms);
PTransform inputTransform = RemoteGrpcPortRead.readFromPort(inputPort, inputPCollection.getId()).toPTransform();
components.putTransforms(inputId, inputTransform);
return RemoteInputDestination.of(wireCoder, inputId);
}
use of org.apache.beam.model.pipeline.v1.RunnerApi.Coder in project beam by apache.
the class ProcessBundleDescriptors method lengthPrefixAnyInputCoder.
/**
* Patches the input coder of the transform to ensure that the byte representation of input used
* at the Runner, matches the byte representation received from the SDK Harness.
*/
private static void lengthPrefixAnyInputCoder(String inputPCollectionId, Components.Builder componentsBuilder) {
RunnerApi.PCollection pcollection = componentsBuilder.getPcollectionsOrThrow(inputPCollectionId);
String newInputCoderId = LengthPrefixUnknownCoders.addLengthPrefixedCoder(pcollection.getCoderId(), componentsBuilder, false);
componentsBuilder.putPcollections(inputPCollectionId, pcollection.toBuilder().setCoderId(newInputCoderId).build());
}
use of org.apache.beam.model.pipeline.v1.RunnerApi.Coder in project beam by apache.
the class ProcessBundleDescriptorsTest method testLengthPrefixingOfInputCoderExecutableStage.
@Test
public void testLengthPrefixingOfInputCoderExecutableStage() throws Exception {
Pipeline p = Pipeline.create();
Coder<Void> voidCoder = VoidCoder.of();
assertThat(ModelCoderRegistrar.isKnownCoder(voidCoder), is(false));
p.apply("impulse", Impulse.create()).apply(ParDo.of(new DoFn<byte[], Void>() {
@ProcessElement
public void process(ProcessContext ctxt) {
}
})).setCoder(voidCoder).apply(ParDo.of(new DoFn<Void, Void>() {
@ProcessElement
public void processElement(ProcessContext context, RestrictionTracker<Void, Void> tracker) {
}
@GetInitialRestriction
public Void getInitialRestriction() {
return null;
}
@NewTracker
public SomeTracker newTracker(@Restriction Void restriction) {
return null;
}
})).setCoder(voidCoder);
RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
RunnerApi.Pipeline pipelineWithSdfExpanded = ProtoOverrides.updateTransform(PTransformTranslation.PAR_DO_TRANSFORM_URN, pipelineProto, SplittableParDoExpander.createSizedReplacement());
FusedPipeline fused = GreedyPipelineFuser.fuse(pipelineWithSdfExpanded);
Optional<ExecutableStage> optionalStage = Iterables.tryFind(fused.getFusedStages(), (ExecutableStage stage) -> stage.getTransforms().stream().anyMatch(transform -> transform.getTransform().getSpec().getUrn().equals(PTransformTranslation.SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN)));
checkState(optionalStage.isPresent(), "Expected a stage with SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN.");
ExecutableStage stage = optionalStage.get();
PipelineNode.PCollectionNode inputPCollection = stage.getInputPCollection();
Map<String, RunnerApi.Coder> stageCoderMap = stage.getComponents().getCodersMap();
RunnerApi.Coder originalMainInputCoder = stageCoderMap.get(inputPCollection.getPCollection().getCoderId());
BeamFnApi.ProcessBundleDescriptor pbd = ProcessBundleDescriptors.fromExecutableStage("test_stage", stage, Endpoints.ApiServiceDescriptor.getDefaultInstance()).getProcessBundleDescriptor();
Map<String, RunnerApi.Coder> pbsCoderMap = pbd.getCodersMap();
RunnerApi.Coder pbsMainInputCoder = pbsCoderMap.get(pbd.getPcollectionsOrThrow(inputPCollection.getId()).getCoderId());
RunnerApi.Coder kvCoder = pbsCoderMap.get(ModelCoders.getKvCoderComponents(pbsMainInputCoder).keyCoderId());
RunnerApi.Coder keyCoder = pbsCoderMap.get(ModelCoders.getKvCoderComponents(kvCoder).keyCoderId());
RunnerApi.Coder valueKvCoder = pbsCoderMap.get(ModelCoders.getKvCoderComponents(kvCoder).valueCoderId());
RunnerApi.Coder valueCoder = pbsCoderMap.get(ModelCoders.getKvCoderComponents(valueKvCoder).keyCoderId());
RunnerApi.Coder originalKvCoder = stageCoderMap.get(ModelCoders.getKvCoderComponents(originalMainInputCoder).keyCoderId());
RunnerApi.Coder originalKeyCoder = stageCoderMap.get(ModelCoders.getKvCoderComponents(originalKvCoder).keyCoderId());
RunnerApi.Coder originalvalueKvCoder = stageCoderMap.get(ModelCoders.getKvCoderComponents(originalKvCoder).valueCoderId());
RunnerApi.Coder originalvalueCoder = stageCoderMap.get(ModelCoders.getKvCoderComponents(originalvalueKvCoder).keyCoderId());
ensureLengthPrefixed(keyCoder, originalKeyCoder, pbsCoderMap);
ensureLengthPrefixed(valueCoder, originalvalueCoder, pbsCoderMap);
}
Aggregations