use of org.apache.beam.sdk.common.runner.v1.RunnerApi.Components in project beam by apache.
the class OutputDeduplicator method ensureSingleProducer.
/**
* Ensure that no {@link PCollection} output by any of the {@code stages} or {@code
* unfusedTransforms} is produced by more than one of those stages or transforms.
*
* <p>For each {@link PCollection} output by multiple stages and/or transforms, each producer is
* rewritten to produce a partial {@link PCollection}, which are then flattened together via an
* introduced Flatten node which produces the original output.
*/
static DeduplicationResult ensureSingleProducer(QueryablePipeline pipeline, Collection<ExecutableStage> stages, Collection<PTransformNode> unfusedTransforms) {
RunnerApi.Components.Builder unzippedComponents = pipeline.getComponents().toBuilder();
Multimap<PCollectionNode, StageOrTransform> pcollectionProducers = getProducers(pipeline, stages, unfusedTransforms);
Multimap<StageOrTransform, PCollectionNode> requiresNewOutput = HashMultimap.create();
// ExecutableStage must also be rewritten to have updated outputs and transforms.
for (Map.Entry<PCollectionNode, Collection<StageOrTransform>> collectionProducer : pcollectionProducers.asMap().entrySet()) {
if (collectionProducer.getValue().size() > 1) {
for (StageOrTransform producer : collectionProducer.getValue()) {
requiresNewOutput.put(producer, collectionProducer.getKey());
}
}
}
Map<ExecutableStage, ExecutableStage> updatedStages = new LinkedHashMap<>();
Map<String, PTransformNode> updatedTransforms = new LinkedHashMap<>();
Multimap<String, PCollectionNode> originalToPartial = HashMultimap.create();
for (Map.Entry<StageOrTransform, Collection<PCollectionNode>> deduplicationTargets : requiresNewOutput.asMap().entrySet()) {
if (deduplicationTargets.getKey().getStage() != null) {
StageDeduplication deduplication = deduplicatePCollections(deduplicationTargets.getKey().getStage(), deduplicationTargets.getValue(), unzippedComponents::containsPcollections);
for (Entry<String, PCollectionNode> originalToPartialReplacement : deduplication.getOriginalToPartialPCollections().entrySet()) {
originalToPartial.put(originalToPartialReplacement.getKey(), originalToPartialReplacement.getValue());
unzippedComponents.putPcollections(originalToPartialReplacement.getValue().getId(), originalToPartialReplacement.getValue().getPCollection());
}
updatedStages.put(deduplicationTargets.getKey().getStage(), deduplication.getUpdatedStage());
} else if (deduplicationTargets.getKey().getTransform() != null) {
PTransformDeduplication deduplication = deduplicatePCollections(deduplicationTargets.getKey().getTransform(), deduplicationTargets.getValue(), unzippedComponents::containsPcollections);
for (Entry<String, PCollectionNode> originalToPartialReplacement : deduplication.getOriginalToPartialPCollections().entrySet()) {
originalToPartial.put(originalToPartialReplacement.getKey(), originalToPartialReplacement.getValue());
unzippedComponents.putPcollections(originalToPartialReplacement.getValue().getId(), originalToPartialReplacement.getValue().getPCollection());
}
updatedTransforms.put(deduplicationTargets.getKey().getTransform().getId(), deduplication.getUpdatedTransform());
} else {
throw new IllegalStateException(String.format("%s with no %s or %s", StageOrTransform.class.getSimpleName(), ExecutableStage.class.getSimpleName(), PTransformNode.class.getSimpleName()));
}
}
Set<PTransformNode> introducedFlattens = new LinkedHashSet<>();
for (Map.Entry<String, Collection<PCollectionNode>> partialFlattenTargets : originalToPartial.asMap().entrySet()) {
String flattenId = SyntheticComponents.uniqueId("unzipped_flatten", unzippedComponents::containsTransforms);
PTransform flattenPartialPCollections = createFlattenOfPartials(flattenId, partialFlattenTargets.getKey(), partialFlattenTargets.getValue());
unzippedComponents.putTransforms(flattenId, flattenPartialPCollections);
introducedFlattens.add(PipelineNode.pTransform(flattenId, flattenPartialPCollections));
}
Components components = unzippedComponents.build();
return DeduplicationResult.of(components, introducedFlattens, updatedStages, updatedTransforms);
}
use of org.apache.beam.sdk.common.runner.v1.RunnerApi.Components in project beam by apache.
the class PipelineTranslationTest method getAccumulatorCoder.
private static Coder<?> getAccumulatorCoder(AppliedPTransform<?, ?, ?> transform) throws IOException {
SdkComponents sdkComponents = SdkComponents.create(transform.getPipeline().getOptions());
String id = getCombinePayload(transform, sdkComponents).map(CombinePayload::getAccumulatorCoderId).orElseThrow(() -> new IOException("Transform does not contain an AccumulatorCoder"));
Components components = sdkComponents.toComponents();
return CoderTranslation.fromProto(components.getCodersOrThrow(id), RehydratedComponents.forComponents(components), TranslationContext.DEFAULT);
}
use of org.apache.beam.sdk.common.runner.v1.RunnerApi.Components in project beam by apache.
the class LengthPrefixUnknownCodersTest method test.
@Test
public void test() throws IOException {
SdkComponents sdkComponents = SdkComponents.create();
sdkComponents.registerEnvironment(Environments.createDockerEnvironment("java"));
String coderId = sdkComponents.registerCoder(original);
Components.Builder components = sdkComponents.toComponents().toBuilder();
String updatedCoderId = LengthPrefixUnknownCoders.addLengthPrefixedCoder(coderId, components, replaceWithByteArray);
assertEquals(expected, RehydratedComponents.forComponents(components.build()).getCoder(updatedCoderId));
}
use of org.apache.beam.sdk.common.runner.v1.RunnerApi.Components in project beam by apache.
the class PTransformTranslation method toProto.
/**
* Translates an {@link AppliedPTransform} into a runner API proto.
*
* <p>Does not register the {@code appliedPTransform} within the provided {@link SdkComponents}.
*/
static RunnerApi.PTransform toProto(AppliedPTransform<?, ?, ?> appliedPTransform, List<AppliedPTransform<?, ?, ?>> subtransforms, SdkComponents components) throws IOException {
RunnerApi.PTransform.Builder transformBuilder = RunnerApi.PTransform.newBuilder();
for (Map.Entry<TupleTag<?>, PValue> taggedInput : appliedPTransform.getInputs().entrySet()) {
checkArgument(taggedInput.getValue() instanceof PCollection, "Unexpected input type %s", taggedInput.getValue().getClass());
transformBuilder.putInputs(toProto(taggedInput.getKey()), components.registerPCollection((PCollection<?>) taggedInput.getValue()));
}
for (Map.Entry<TupleTag<?>, PValue> taggedOutput : appliedPTransform.getOutputs().entrySet()) {
// TODO: Remove gating
if (taggedOutput.getValue() instanceof PCollection) {
checkArgument(taggedOutput.getValue() instanceof PCollection, "Unexpected output type %s", taggedOutput.getValue().getClass());
transformBuilder.putOutputs(toProto(taggedOutput.getKey()), components.registerPCollection((PCollection<?>) taggedOutput.getValue()));
}
}
for (AppliedPTransform<?, ?, ?> subtransform : subtransforms) {
transformBuilder.addSubtransforms(components.getExistingPTransformId(subtransform));
}
transformBuilder.setUniqueName(appliedPTransform.getFullName());
// TODO: Display Data
PTransform<?, ?> transform = appliedPTransform.getTransform();
if (KNOWN_PAYLOAD_TRANSLATORS.containsKey(transform.getClass())) {
FunctionSpec payload = KNOWN_PAYLOAD_TRANSLATORS.get(transform.getClass()).translate(appliedPTransform, components);
transformBuilder.setSpec(payload);
}
return transformBuilder.build();
}
use of org.apache.beam.sdk.common.runner.v1.RunnerApi.Components in project beam by apache.
the class ParDoTranslation method getMainInput.
public static RunnerApi.PCollection getMainInput(RunnerApi.PTransform ptransform, Components components) throws IOException {
checkArgument(ptransform.getSpec().getUrn().equals(PAR_DO_TRANSFORM_URN), "Unexpected payload type %s", ptransform.getSpec().getUrn());
ParDoPayload payload = ptransform.getSpec().getParameter().unpack(ParDoPayload.class);
String mainInputId = Iterables.getOnlyElement(Sets.difference(ptransform.getInputsMap().keySet(), payload.getSideInputsMap().keySet()));
return components.getPcollectionsOrThrow(ptransform.getInputsOrThrow(mainInputId));
}
Aggregations