use of org.apache.beam.model.pipeline.v1.RunnerApi.ParDoPayload in project beam by apache.
the class PipelineValidator method validateParDo.
private static void validateParDo(String id, PTransform transform, Components components, Set<String> requirements) throws Exception {
ParDoPayload payload = ParDoPayload.parseFrom(transform.getSpec().getPayload());
// side_inputs
for (String sideInputId : payload.getSideInputsMap().keySet()) {
checkArgument(transform.containsInputs(sideInputId), "Transform %s side input %s is not listed in the transform's inputs", id, sideInputId);
}
if (payload.getStateSpecsCount() > 0 || payload.getTimerFamilySpecsCount() > 0) {
checkArgument(requirements.contains(ParDoTranslation.REQUIRES_STATEFUL_PROCESSING_URN));
// TODO: Validate state_specs and timer_specs
}
if (!payload.getRestrictionCoderId().isEmpty()) {
checkArgument(components.containsCoders(payload.getRestrictionCoderId()));
checkArgument(requirements.contains(ParDoTranslation.REQUIRES_SPLITTABLE_DOFN_URN));
}
if (payload.getRequestsFinalization()) {
checkArgument(requirements.contains(ParDoTranslation.REQUIRES_BUNDLE_FINALIZATION_URN));
}
if (payload.getRequiresStableInput()) {
checkArgument(requirements.contains(ParDoTranslation.REQUIRES_STABLE_INPUT_URN));
}
if (payload.getRequiresTimeSortedInput()) {
checkArgument(requirements.contains(ParDoTranslation.REQUIRES_TIME_SORTED_INPUT_URN));
}
}
use of org.apache.beam.model.pipeline.v1.RunnerApi.ParDoPayload in project beam by apache.
the class EnvironmentsTest method getEnvironmentPTransform.
@Test
public void getEnvironmentPTransform() throws IOException {
Pipeline p = Pipeline.create();
SdkComponents components = SdkComponents.create();
Environment env = Environments.createDockerEnvironment("java");
components.registerEnvironment(env);
ParDoPayload payload = ParDoTranslation.translateParDo(ParDo.of(new DoFn<String, String>() {
@ProcessElement
public void process(ProcessContext ctxt) {
}
}).withOutputTags(new TupleTag<>(), TupleTagList.empty()), PCollection.createPrimitiveOutputInternal(p, WindowingStrategy.globalDefault(), IsBounded.BOUNDED, StringUtf8Coder.of()), DoFnSchemaInformation.create(), Pipeline.create(), components);
RehydratedComponents rehydratedComponents = RehydratedComponents.forComponents(components.toComponents());
PTransform ptransform = PTransform.newBuilder().setSpec(FunctionSpec.newBuilder().setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN).setPayload(payload.toByteString()).build()).setEnvironmentId(components.getOnlyEnvironmentId()).build();
Environment env1 = Environments.getEnvironment(ptransform, rehydratedComponents).get();
assertThat(env1, equalTo(components.toComponents().getEnvironmentsOrThrow(ptransform.getEnvironmentId())));
}
use of org.apache.beam.model.pipeline.v1.RunnerApi.ParDoPayload in project beam by apache.
the class ParDoTranslation method getMainInput.
public static RunnerApi.PCollection getMainInput(RunnerApi.PTransform ptransform, Components components) throws IOException {
checkArgument(ptransform.getSpec().getUrn().equals(PAR_DO_TRANSFORM_URN), "Unexpected payload type %s", ptransform.getSpec().getUrn());
ParDoPayload payload = ptransform.getSpec().getParameter().unpack(ParDoPayload.class);
String mainInputId = Iterables.getOnlyElement(Sets.difference(ptransform.getInputsMap().keySet(), payload.getSideInputsMap().keySet()));
return components.getPcollectionsOrThrow(ptransform.getInputsOrThrow(mainInputId));
}
use of org.apache.beam.model.pipeline.v1.RunnerApi.ParDoPayload in project beam by apache.
the class ParDoTranslation method toProto.
public static ParDoPayload toProto(ParDo.MultiOutput<?, ?> parDo, SdkComponents components) throws IOException {
DoFn<?, ?> doFn = parDo.getFn();
DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass());
Map<String, StateDeclaration> states = signature.stateDeclarations();
Map<String, TimerDeclaration> timers = signature.timerDeclarations();
List<Parameter> parameters = signature.processElement().extraParameters();
ParDoPayload.Builder builder = ParDoPayload.newBuilder();
builder.setDoFn(toProto(parDo.getFn(), parDo.getMainOutputTag()));
for (PCollectionView<?> sideInput : parDo.getSideInputs()) {
builder.putSideInputs(sideInput.getTagInternal().getId(), toProto(sideInput));
}
for (Parameter parameter : parameters) {
Optional<RunnerApi.Parameter> protoParameter = toProto(parameter);
if (protoParameter.isPresent()) {
builder.addParameters(protoParameter.get());
}
}
for (Map.Entry<String, StateDeclaration> state : states.entrySet()) {
RunnerApi.StateSpec spec = toProto(getStateSpecOrCrash(state.getValue(), doFn), components);
builder.putStateSpecs(state.getKey(), spec);
}
for (Map.Entry<String, TimerDeclaration> timer : timers.entrySet()) {
RunnerApi.TimerSpec spec = toProto(getTimerSpecOrCrash(timer.getValue(), doFn));
builder.putTimerSpecs(timer.getKey(), spec);
}
return builder.build();
}
use of org.apache.beam.model.pipeline.v1.RunnerApi.ParDoPayload in project beam by apache.
the class ProcessBundleHandlerTest method setupProcessBundleHandlerForSimpleRecordingDoFn.
private ProcessBundleHandler setupProcessBundleHandlerForSimpleRecordingDoFn(List<String> dataOutput, List<Timers> timerOutput, boolean enableOutputEmbedding) throws Exception {
DoFnWithExecutionInformation doFnWithExecutionInformation = DoFnWithExecutionInformation.of(new SimpleDoFn(), SimpleDoFn.MAIN_OUTPUT_TAG, Collections.emptyMap(), DoFnSchemaInformation.create());
RunnerApi.FunctionSpec functionSpec = RunnerApi.FunctionSpec.newBuilder().setUrn(ParDoTranslation.CUSTOM_JAVA_DO_FN_URN).setPayload(ByteString.copyFrom(SerializableUtils.serializeToByteArray(doFnWithExecutionInformation))).build();
RunnerApi.ParDoPayload parDoPayload = ParDoPayload.newBuilder().setDoFn(functionSpec).putTimerFamilySpecs("tfs-" + SimpleDoFn.TIMER_FAMILY_ID, TimerFamilySpec.newBuilder().setTimeDomain(RunnerApi.TimeDomain.Enum.EVENT_TIME).setTimerFamilyCoderId("timer-coder").build()).build();
BeamFnApi.ProcessBundleDescriptor processBundleDescriptor = ProcessBundleDescriptor.newBuilder().putTransforms("2L", PTransform.newBuilder().setSpec(FunctionSpec.newBuilder().setUrn(DATA_INPUT_URN).build()).putOutputs("2L-output", "2L-output-pc").build()).putTransforms("3L", PTransform.newBuilder().setSpec(FunctionSpec.newBuilder().setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN).setPayload(parDoPayload.toByteString())).putInputs("3L-input", "2L-output-pc").build()).putPcollections("2L-output-pc", PCollection.newBuilder().setWindowingStrategyId("window-strategy").setCoderId("2L-output-coder").setIsBounded(IsBounded.Enum.BOUNDED).build()).putWindowingStrategies("window-strategy", WindowingStrategy.newBuilder().setWindowCoderId("window-strategy-coder").setWindowFn(FunctionSpec.newBuilder().setUrn("beam:window_fn:global_windows:v1")).setOutputTime(OutputTime.Enum.END_OF_WINDOW).setAccumulationMode(AccumulationMode.Enum.ACCUMULATING).setTrigger(Trigger.newBuilder().setAlways(Always.getDefaultInstance())).setClosingBehavior(ClosingBehavior.Enum.EMIT_ALWAYS).setOnTimeBehavior(OnTimeBehavior.Enum.FIRE_ALWAYS).build()).setTimerApiServiceDescriptor(ApiServiceDescriptor.newBuilder().setUrl("url").build()).putCoders("string_coder", CoderTranslation.toProto(StringUtf8Coder.of()).getCoder()).putCoders("2L-output-coder", Coder.newBuilder().setSpec(FunctionSpec.newBuilder().setUrn(ModelCoders.KV_CODER_URN).build()).addComponentCoderIds("string_coder").addComponentCoderIds("string_coder").build()).putCoders("window-strategy-coder", Coder.newBuilder().setSpec(FunctionSpec.newBuilder().setUrn(ModelCoders.GLOBAL_WINDOW_CODER_URN).build()).build()).putCoders("timer-coder", Coder.newBuilder().setSpec(FunctionSpec.newBuilder().setUrn(ModelCoders.TIMER_CODER_URN)).addComponentCoderIds("string_coder").addComponentCoderIds("window-strategy-coder").build()).build();
Map<String, BeamFnApi.ProcessBundleDescriptor> fnApiRegistry = ImmutableMap.of("1L", processBundleDescriptor);
Map<String, PTransformRunnerFactory> urnToPTransformRunnerFactoryMap = Maps.newHashMap(REGISTERED_RUNNER_FACTORIES);
urnToPTransformRunnerFactoryMap.put(DATA_INPUT_URN, (PTransformRunnerFactory<Object>) (context) -> {
context.addIncomingDataEndpoint(ApiServiceDescriptor.getDefaultInstance(), KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()), (input) -> {
dataOutput.add(input.getValue());
});
return null;
});
Mockito.doAnswer((invocation) -> new BeamFnDataOutboundAggregator(PipelineOptionsFactory.create(), invocation.getArgument(1), new StreamObserver<Elements>() {
@Override
public void onNext(Elements elements) {
for (Timers timer : elements.getTimersList()) {
timerOutput.addAll(elements.getTimersList());
}
}
@Override
public void onError(Throwable throwable) {
}
@Override
public void onCompleted() {
}
}, invocation.getArgument(2))).when(beamFnDataClient).createOutboundAggregator(any(), any(), anyBoolean());
return new ProcessBundleHandler(PipelineOptionsFactory.create(), enableOutputEmbedding ? Collections.singleton(BeamUrns.getUrn(StandardRunnerProtocols.Enum.CONTROL_RESPONSE_ELEMENTS_EMBEDDING)) : Collections.emptySet(), fnApiRegistry::get, beamFnDataClient, null, /* beamFnStateClient */
null, /* finalizeBundleHandler */
new ShortIdMap(), urnToPTransformRunnerFactoryMap, Caches.noop(), new BundleProcessorCache());
}
Aggregations