use of org.apache.beam.sdk.common.runner.v1.RunnerApi.Components in project beam by apache.
the class ParDoTranslation method toProto.
public static ParDoPayload toProto(ParDo.MultiOutput<?, ?> parDo, SdkComponents components) throws IOException {
DoFn<?, ?> doFn = parDo.getFn();
DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass());
Map<String, StateDeclaration> states = signature.stateDeclarations();
Map<String, TimerDeclaration> timers = signature.timerDeclarations();
List<Parameter> parameters = signature.processElement().extraParameters();
ParDoPayload.Builder builder = ParDoPayload.newBuilder();
builder.setDoFn(toProto(parDo.getFn(), parDo.getMainOutputTag()));
for (PCollectionView<?> sideInput : parDo.getSideInputs()) {
builder.putSideInputs(sideInput.getTagInternal().getId(), toProto(sideInput));
}
for (Parameter parameter : parameters) {
Optional<RunnerApi.Parameter> protoParameter = toProto(parameter);
if (protoParameter.isPresent()) {
builder.addParameters(protoParameter.get());
}
}
for (Map.Entry<String, StateDeclaration> state : states.entrySet()) {
RunnerApi.StateSpec spec = toProto(getStateSpecOrCrash(state.getValue(), doFn), components);
builder.putStateSpecs(state.getKey(), spec);
}
for (Map.Entry<String, TimerDeclaration> timer : timers.entrySet()) {
RunnerApi.TimerSpec spec = toProto(getTimerSpecOrCrash(timer.getValue(), doFn));
builder.putTimerSpecs(timer.getKey(), spec);
}
return builder.build();
}
use of org.apache.beam.sdk.common.runner.v1.RunnerApi.Components in project beam by apache.
the class ParDoTranslation method fromProto.
@VisibleForTesting
static StateSpec<?> fromProto(RunnerApi.StateSpec stateSpec, RunnerApi.Components components) throws IOException {
switch(stateSpec.getSpecCase()) {
case VALUE_SPEC:
return StateSpecs.value(CoderTranslation.fromProto(components.getCodersMap().get(stateSpec.getValueSpec().getCoderId()), components));
case BAG_SPEC:
return StateSpecs.bag(CoderTranslation.fromProto(components.getCodersMap().get(stateSpec.getBagSpec().getElementCoderId()), components));
case COMBINING_SPEC:
FunctionSpec combineFnSpec = stateSpec.getCombiningSpec().getCombineFn().getSpec();
if (!combineFnSpec.getUrn().equals(CombineTranslation.JAVA_SERIALIZED_COMBINE_FN_URN)) {
throw new UnsupportedOperationException(String.format("Cannot create %s from non-Java %s: %s", StateSpec.class.getSimpleName(), Combine.CombineFn.class.getSimpleName(), combineFnSpec.getUrn()));
}
Combine.CombineFn<?, ?, ?> combineFn = (Combine.CombineFn<?, ?, ?>) SerializableUtils.deserializeFromByteArray(combineFnSpec.getParameter().unpack(BytesValue.class).toByteArray(), Combine.CombineFn.class.getSimpleName());
// for the CombineFn, by construction
return StateSpecs.combining((Coder) CoderTranslation.fromProto(components.getCodersMap().get(stateSpec.getCombiningSpec().getAccumulatorCoderId()), components), combineFn);
case MAP_SPEC:
return StateSpecs.map(CoderTranslation.fromProto(components.getCodersOrThrow(stateSpec.getMapSpec().getKeyCoderId()), components), CoderTranslation.fromProto(components.getCodersOrThrow(stateSpec.getMapSpec().getValueCoderId()), components));
case SET_SPEC:
return StateSpecs.set(CoderTranslation.fromProto(components.getCodersMap().get(stateSpec.getSetSpec().getElementCoderId()), components));
case SPEC_NOT_SET:
default:
throw new IllegalArgumentException(String.format("Unknown %s: %s", RunnerApi.StateSpec.class.getName(), stateSpec));
}
}
use of org.apache.beam.sdk.common.runner.v1.RunnerApi.Components in project beam by apache.
the class PipelineValidator method validateExecutableStage.
private static void validateExecutableStage(String id, PTransform transform, Components outerComponents, Set<String> requirements) throws Exception {
ExecutableStagePayload payload = ExecutableStagePayload.parseFrom(transform.getSpec().getPayload());
// Everything within an ExecutableStagePayload uses only the stage's components.
Components components = payload.getComponents();
checkArgument(transform.getInputsMap().values().contains(payload.getInput()), "ExecutableStage %s uses unknown input %s", id, payload.getInput());
checkArgument(!payload.getTransformsList().isEmpty(), "ExecutableStage %s contains no transforms", id);
for (String subtransformId : payload.getTransformsList()) {
checkArgument(components.containsTransforms(subtransformId), "ExecutableStage %s uses unknown transform %s", id, subtransformId);
}
for (String outputId : payload.getOutputsList()) {
checkArgument(components.containsPcollections(outputId), "ExecutableStage %s uses unknown output %s", id, outputId);
}
validateComponents("ExecutableStage " + id, components, requirements);
// TODO: Also validate that side inputs of all transforms within components.getTransforms()
// are contained within payload.getSideInputsList()
}
use of org.apache.beam.sdk.common.runner.v1.RunnerApi.Components in project beam by apache.
the class PipelineValidator method validate.
public static void validate(RunnerApi.Pipeline p) {
Components components = p.getComponents();
for (String transformId : p.getRootTransformIdsList()) {
checkArgument(components.containsTransforms(transformId), "Root transform id %s is unknown", transformId);
}
validateComponents("pipeline", components, ImmutableSet.copyOf(p.getRequirementsList()));
}
use of org.apache.beam.sdk.common.runner.v1.RunnerApi.Components in project beam by apache.
the class GreedyPipelineFuser method sanitizeDanglingPTransformInputs.
private static ExecutableStage sanitizeDanglingPTransformInputs(ExecutableStage stage) {
/* Possible inputs to a PTransform can only be those which are:
* <ul>
* <li>Explicit input PCollection to the stage
* <li>Outputs of a PTransform within the same stage
* <li>Timer PCollections
* <li>Side input PCollections
* <li>Explicit outputs from the stage
* </ul>
*/
Set<String> possibleInputs = new HashSet<>();
possibleInputs.add(stage.getInputPCollection().getId());
possibleInputs.addAll(stage.getOutputPCollections().stream().map(PCollectionNode::getId).collect(Collectors.toSet()));
possibleInputs.addAll(stage.getSideInputs().stream().map(s -> s.collection().getId()).collect(Collectors.toSet()));
possibleInputs.addAll(stage.getTransforms().stream().flatMap(t -> t.getTransform().getOutputsMap().values().stream()).collect(Collectors.toSet()));
Set<String> danglingInputs = stage.getTransforms().stream().flatMap(t -> t.getTransform().getInputsMap().values().stream()).filter(in -> !possibleInputs.contains(in)).collect(Collectors.toSet());
ImmutableList.Builder<PTransformNode> pTransformNodesBuilder = ImmutableList.builder();
for (PTransformNode transformNode : stage.getTransforms()) {
PTransform transform = transformNode.getTransform();
Map<String, String> validInputs = transform.getInputsMap().entrySet().stream().filter(e -> !danglingInputs.contains(e.getValue())).collect(Collectors.toMap(Entry::getKey, Entry::getValue));
if (!validInputs.equals(transform.getInputsMap())) {
// Dangling inputs found so recreate pTransform without the dangling inputs.
transformNode = PipelineNode.pTransform(transformNode.getId(), transform.toBuilder().clearInputs().putAllInputs(validInputs).build());
}
pTransformNodesBuilder.add(transformNode);
}
ImmutableList<PTransformNode> pTransformNodes = pTransformNodesBuilder.build();
Components.Builder componentBuilder = stage.getComponents().toBuilder();
// Update the pTransforms in components.
componentBuilder.clearTransforms().putAllTransforms(pTransformNodes.stream().collect(Collectors.toMap(PTransformNode::getId, PTransformNode::getTransform)));
Map<String, PCollection> validPCollectionMap = stage.getComponents().getPcollectionsMap().entrySet().stream().filter(e -> !danglingInputs.contains(e.getKey())).collect(Collectors.toMap(Entry::getKey, Entry::getValue));
// Update pCollections in the components.
componentBuilder.clearPcollections().putAllPcollections(validPCollectionMap);
return ImmutableExecutableStage.of(componentBuilder.build(), stage.getEnvironment(), stage.getInputPCollection(), stage.getSideInputs(), stage.getUserStates(), stage.getTimers(), pTransformNodes, stage.getOutputPCollections(), stage.getWireCoderSettings());
}
Aggregations