use of org.apache.beam.model.pipeline.v1.RunnerApi.Components in project beam by apache.
the class ParDoTranslation method translateParDo.
public static ParDoPayload translateParDo(AppliedPTransform<?, ?, ParDo.MultiOutput<?, ?>> appliedPTransform, SdkComponents components) throws IOException {
final ParDo.MultiOutput<?, ?> parDo = appliedPTransform.getTransform();
final Pipeline pipeline = appliedPTransform.getPipeline();
final DoFn<?, ?> doFn = parDo.getFn();
// Get main input.
Set<String> allInputs = appliedPTransform.getInputs().keySet().stream().map(TupleTag::getId).collect(Collectors.toSet());
Set<String> sideInputs = parDo.getSideInputs().values().stream().map(s -> s.getTagInternal().getId()).collect(Collectors.toSet());
String mainInputName = Iterables.getOnlyElement(Sets.difference(allInputs, sideInputs));
PCollection<?> mainInput = (PCollection<?>) appliedPTransform.getInputs().get(new TupleTag<>(mainInputName));
final DoFnSchemaInformation doFnSchemaInformation = ParDo.getDoFnSchemaInformation(doFn, mainInput);
return translateParDo((ParDo.MultiOutput) parDo, mainInput, doFnSchemaInformation, pipeline, components);
}
use of org.apache.beam.model.pipeline.v1.RunnerApi.Components in project beam by apache.
the class ParDoTranslation method translateParDo.
/**
* Translate a ParDo.
*/
public static <InputT> ParDoPayload translateParDo(ParDo.MultiOutput<InputT, ?> parDo, PCollection<InputT> mainInput, DoFnSchemaInformation doFnSchemaInformation, Pipeline pipeline, SdkComponents components) throws IOException {
final DoFn<?, ?> doFn = parDo.getFn();
final DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass());
final String restrictionCoderId;
if (signature.processElement().isSplittable()) {
DoFnInvoker<?, ?> doFnInvoker = DoFnInvokers.invokerFor(doFn);
final Coder<?> restrictionAndWatermarkStateCoder = KvCoder.of(doFnInvoker.invokeGetRestrictionCoder(pipeline.getCoderRegistry()), doFnInvoker.invokeGetWatermarkEstimatorStateCoder(pipeline.getCoderRegistry()));
restrictionCoderId = components.registerCoder(restrictionAndWatermarkStateCoder);
} else {
restrictionCoderId = "";
}
Coder<BoundedWindow> windowCoder = (Coder<BoundedWindow>) mainInput.getWindowingStrategy().getWindowFn().windowCoder();
Coder<?> keyCoder;
if (signature.usesState() || signature.usesTimers()) {
checkArgument(mainInput.getCoder() instanceof KvCoder, "DoFn's that use state or timers must have an input PCollection with a KvCoder but received %s", mainInput.getCoder());
keyCoder = ((KvCoder) mainInput.getCoder()).getKeyCoder();
} else {
keyCoder = null;
}
return payloadForParDoLike(new ParDoLike() {
@Override
public FunctionSpec translateDoFn(SdkComponents newComponents) {
return ParDoTranslation.translateDoFn(parDo.getFn(), parDo.getMainOutputTag(), parDo.getSideInputs(), doFnSchemaInformation, newComponents);
}
@Override
public Map<String, SideInput> translateSideInputs(SdkComponents components) {
Map<String, SideInput> sideInputs = new HashMap<>();
for (PCollectionView<?> sideInput : parDo.getSideInputs().values()) {
sideInputs.put(sideInput.getTagInternal().getId(), translateView(sideInput, components));
}
return sideInputs;
}
@Override
public Map<String, RunnerApi.StateSpec> translateStateSpecs(SdkComponents components) throws IOException {
Map<String, RunnerApi.StateSpec> stateSpecs = new HashMap<>();
for (Map.Entry<String, StateDeclaration> state : signature.stateDeclarations().entrySet()) {
RunnerApi.StateSpec spec = translateStateSpec(getStateSpecOrThrow(state.getValue(), doFn), components);
stateSpecs.put(state.getKey(), spec);
}
return stateSpecs;
}
@Override
public ParDoLikeTimerFamilySpecs translateTimerFamilySpecs(SdkComponents newComponents) {
Map<String, RunnerApi.TimerFamilySpec> timerFamilySpecs = new HashMap<>();
for (Map.Entry<String, TimerDeclaration> timer : signature.timerDeclarations().entrySet()) {
RunnerApi.TimerFamilySpec spec = translateTimerFamilySpec(getTimerSpecOrThrow(timer.getValue(), doFn), newComponents, keyCoder, windowCoder);
timerFamilySpecs.put(timer.getKey(), spec);
}
for (Map.Entry<String, DoFnSignature.TimerFamilyDeclaration> timerFamily : signature.timerFamilyDeclarations().entrySet()) {
RunnerApi.TimerFamilySpec spec = translateTimerFamilySpec(DoFnSignatures.getTimerFamilySpecOrThrow(timerFamily.getValue(), doFn), newComponents, keyCoder, windowCoder);
timerFamilySpecs.put(timerFamily.getKey(), spec);
}
String onWindowExpirationTimerFamilySpec = null;
if (signature.onWindowExpiration() != null) {
RunnerApi.TimerFamilySpec spec = RunnerApi.TimerFamilySpec.newBuilder().setTimeDomain(translateTimeDomain(TimeDomain.EVENT_TIME)).setTimerFamilyCoderId(registerCoderOrThrow(components, Timer.Coder.of(keyCoder, windowCoder))).build();
for (int i = 0; i < Integer.MAX_VALUE; ++i) {
onWindowExpirationTimerFamilySpec = "onWindowExpiration" + i;
if (!timerFamilySpecs.containsKey(onWindowExpirationTimerFamilySpec)) {
break;
}
}
timerFamilySpecs.put(onWindowExpirationTimerFamilySpec, spec);
}
return ParDoLikeTimerFamilySpecs.create(timerFamilySpecs, onWindowExpirationTimerFamilySpec);
}
@Override
public boolean isStateful() {
return !signature.stateDeclarations().isEmpty() || !signature.timerDeclarations().isEmpty() || !signature.timerFamilyDeclarations().isEmpty() || signature.onWindowExpiration() != null;
}
@Override
public boolean isSplittable() {
return signature.processElement().isSplittable();
}
@Override
public boolean isRequiresStableInput() {
return signature.processElement().requiresStableInput();
}
@Override
public boolean isRequiresTimeSortedInput() {
return signature.processElement().requiresTimeSortedInput();
}
@Override
public boolean requestsFinalization() {
return (signature.startBundle() != null && signature.startBundle().extraParameters().contains(Parameter.bundleFinalizer())) || (signature.processElement() != null && signature.processElement().extraParameters().contains(Parameter.bundleFinalizer())) || (signature.finishBundle() != null && signature.finishBundle().extraParameters().contains(Parameter.bundleFinalizer()));
}
@Override
public String translateRestrictionCoderId(SdkComponents newComponents) {
return restrictionCoderId;
}
}, components);
}
use of org.apache.beam.model.pipeline.v1.RunnerApi.Components in project beam by apache.
the class ParDoTranslation method fromProto.
@VisibleForTesting
static StateSpec<?> fromProto(RunnerApi.StateSpec stateSpec, RehydratedComponents components) throws IOException {
switch(stateSpec.getSpecCase()) {
case READ_MODIFY_WRITE_SPEC:
return StateSpecs.value(components.getCoder(stateSpec.getReadModifyWriteSpec().getCoderId()));
case BAG_SPEC:
return StateSpecs.bag(components.getCoder(stateSpec.getBagSpec().getElementCoderId()));
case COMBINING_SPEC:
FunctionSpec combineFnSpec = stateSpec.getCombiningSpec().getCombineFn();
if (!combineFnSpec.getUrn().equals(CombineTranslation.JAVA_SERIALIZED_COMBINE_FN_URN)) {
throw new UnsupportedOperationException(String.format("Cannot create %s from non-Java %s: %s", StateSpec.class.getSimpleName(), Combine.CombineFn.class.getSimpleName(), combineFnSpec.getUrn()));
}
Combine.CombineFn<?, ?, ?> combineFn = (Combine.CombineFn<?, ?, ?>) SerializableUtils.deserializeFromByteArray(combineFnSpec.getPayload().toByteArray(), Combine.CombineFn.class.getSimpleName());
// for the CombineFn, by construction
return StateSpecs.combining((Coder) components.getCoder(stateSpec.getCombiningSpec().getAccumulatorCoderId()), combineFn);
case MAP_SPEC:
return StateSpecs.map(components.getCoder(stateSpec.getMapSpec().getKeyCoderId()), components.getCoder(stateSpec.getMapSpec().getValueCoderId()));
case SET_SPEC:
return StateSpecs.set(components.getCoder(stateSpec.getSetSpec().getElementCoderId()));
case SPEC_NOT_SET:
default:
throw new IllegalArgumentException(String.format("Unknown %s: %s", RunnerApi.StateSpec.class.getName(), stateSpec));
}
}
use of org.apache.beam.model.pipeline.v1.RunnerApi.Components in project beam by apache.
the class WindowingStrategyTranslation method fromProto.
/**
* Converts from {@link RunnerApi.WindowingStrategy} to the SDK's {@link WindowingStrategy} using
* the provided components to dereferences identifiers found in the proto.
*/
public static WindowingStrategy<?, ?> fromProto(RunnerApi.WindowingStrategy proto, RehydratedComponents components) throws InvalidProtocolBufferException {
FunctionSpec windowFnSpec = proto.getWindowFn();
WindowFn<?, ?> windowFn = windowFnFromProto(windowFnSpec);
TimestampCombiner timestampCombiner = timestampCombinerFromProto(proto.getOutputTime());
AccumulationMode accumulationMode = fromProto(proto.getAccumulationMode());
Trigger trigger = TriggerTranslation.fromProto(proto.getTrigger());
ClosingBehavior closingBehavior = fromProto(proto.getClosingBehavior());
Duration allowedLateness = Duration.millis(proto.getAllowedLateness());
OnTimeBehavior onTimeBehavior = fromProto(proto.getOnTimeBehavior());
String environmentId = proto.getEnvironmentId();
return WindowingStrategy.of(windowFn).withAllowedLateness(allowedLateness).withMode(accumulationMode).withTrigger(trigger).withTimestampCombiner(timestampCombiner).withClosingBehavior(closingBehavior).withOnTimeBehavior(onTimeBehavior).withEnvironmentId(environmentId);
}
use of org.apache.beam.model.pipeline.v1.RunnerApi.Components in project beam by apache.
the class ExecutableStage method fromPayload.
/**
* Return an {@link ExecutableStage} constructed from the provided {@link FunctionSpec}
* representation.
*
* <p>See {@link #toPTransform} for how the payload is constructed.
*
* <p>Note: The payload contains some information redundant with the {@link PTransform} it is the
* payload of. The {@link ExecutableStagePayload} should be sufficiently rich to construct a
* {@code ProcessBundleDescriptor} using only the payload.
*/
static ExecutableStage fromPayload(ExecutableStagePayload payload) {
Components components = payload.getComponents();
Environment environment = payload.getEnvironment();
Collection<WireCoderSetting> wireCoderSettings = payload.getWireCoderSettingsList();
PCollectionNode input = PipelineNode.pCollection(payload.getInput(), components.getPcollectionsOrThrow(payload.getInput()));
List<SideInputReference> sideInputs = payload.getSideInputsList().stream().map(sideInputId -> SideInputReference.fromSideInputId(sideInputId, components)).collect(Collectors.toList());
List<UserStateReference> userStates = payload.getUserStatesList().stream().map(userStateId -> UserStateReference.fromUserStateId(userStateId, components)).collect(Collectors.toList());
List<TimerReference> timers = payload.getTimersList().stream().map(timerId -> TimerReference.fromTimerId(timerId, components)).collect(Collectors.toList());
List<PTransformNode> transforms = payload.getTransformsList().stream().map(id -> PipelineNode.pTransform(id, components.getTransformsOrThrow(id))).collect(Collectors.toList());
List<PCollectionNode> outputs = payload.getOutputsList().stream().map(id -> PipelineNode.pCollection(id, components.getPcollectionsOrThrow(id))).collect(Collectors.toList());
return ImmutableExecutableStage.of(components, environment, input, sideInputs, userStates, timers, transforms, outputs, wireCoderSettings);
}
Aggregations