use of org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload.UserStateId in project beam by apache.
the class RemoteExecutionTest method testExecutionWithUserStateCaching.
@Test
public void testExecutionWithUserStateCaching() throws Exception {
Pipeline p = Pipeline.create();
launchSdkHarness(p.getOptions());
final String stateId = "foo";
final String stateId2 = "bar";
p.apply("impulse", Impulse.create()).apply("create", ParDo.of(new DoFn<byte[], KV<String, String>>() {
@ProcessElement
public void process(ProcessContext ctxt) {
}
})).setCoder(KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of())).apply("userState", ParDo.of(new DoFn<KV<String, String>, KV<String, String>>() {
@StateId(stateId)
private final StateSpec<BagState<String>> bufferState = StateSpecs.bag(StringUtf8Coder.of());
@StateId(stateId2)
private final StateSpec<BagState<String>> bufferState2 = StateSpecs.bag(StringUtf8Coder.of());
@ProcessElement
public void processElement(@Element KV<String, String> element, @StateId(stateId) BagState<String> state, @StateId(stateId2) BagState<String> state2, OutputReceiver<KV<String, String>> r) {
for (String value : state.read()) {
r.output(KV.of(element.getKey(), value));
}
ReadableState<Boolean> isEmpty = state2.isEmpty();
if (isEmpty.read()) {
r.output(KV.of(element.getKey(), "Empty"));
} else {
state2.clear();
}
}
})).apply("gbk", GroupByKey.create());
RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
FusedPipeline fused = GreedyPipelineFuser.fuse(pipelineProto);
Optional<ExecutableStage> optionalStage = Iterables.tryFind(fused.getFusedStages(), (ExecutableStage stage) -> !stage.getUserStates().isEmpty());
checkState(optionalStage.isPresent(), "Expected a stage with user state.");
ExecutableStage stage = optionalStage.get();
ExecutableProcessBundleDescriptor descriptor = ProcessBundleDescriptors.fromExecutableStage("test_stage", stage, dataServer.getApiServiceDescriptor(), stateServer.getApiServiceDescriptor());
BundleProcessor processor = controlClient.getProcessor(descriptor.getProcessBundleDescriptor(), descriptor.getRemoteInputDestinations(), stateDelegator);
Map<String, Coder> remoteOutputCoders = descriptor.getRemoteOutputCoders();
Map<String, Collection<WindowedValue<?>>> outputValues = new HashMap<>();
Map<String, RemoteOutputReceiver<?>> outputReceivers = new HashMap<>();
for (Entry<String, Coder> remoteOutputCoder : remoteOutputCoders.entrySet()) {
List<WindowedValue<?>> outputContents = Collections.synchronizedList(new ArrayList<>());
outputValues.put(remoteOutputCoder.getKey(), outputContents);
outputReceivers.put(remoteOutputCoder.getKey(), RemoteOutputReceiver.of((Coder<WindowedValue<?>>) remoteOutputCoder.getValue(), outputContents::add));
}
Map<String, List<ByteString>> userStateData = ImmutableMap.of(stateId, new ArrayList(Arrays.asList(ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "A", Coder.Context.NESTED)), ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "B", Coder.Context.NESTED)), ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "C", Coder.Context.NESTED)))), stateId2, new ArrayList(Arrays.asList(ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "D", Coder.Context.NESTED)))));
StoringStateRequestHandler stateRequestHandler = new StoringStateRequestHandler(StateRequestHandlers.forBagUserStateHandlerFactory(descriptor, new BagUserStateHandlerFactory<ByteString, Object, BoundedWindow>() {
@Override
public BagUserStateHandler<ByteString, Object, BoundedWindow> forUserState(String pTransformId, String userStateId, Coder<ByteString> keyCoder, Coder<Object> valueCoder, Coder<BoundedWindow> windowCoder) {
return new BagUserStateHandler<ByteString, Object, BoundedWindow>() {
@Override
public Iterable<Object> get(ByteString key, BoundedWindow window) {
return (Iterable) userStateData.get(userStateId);
}
@Override
public void append(ByteString key, BoundedWindow window, Iterator<Object> values) {
Iterators.addAll(userStateData.get(userStateId), (Iterator) values);
}
@Override
public void clear(ByteString key, BoundedWindow window) {
userStateData.get(userStateId).clear();
}
};
}
}));
try (RemoteBundle bundle = processor.newBundle(outputReceivers, stateRequestHandler, BundleProgressHandler.ignored())) {
Iterables.getOnlyElement(bundle.getInputReceivers().values()).accept(valueInGlobalWindow(KV.of("X", "Y")));
}
try (RemoteBundle bundle2 = processor.newBundle(outputReceivers, stateRequestHandler, BundleProgressHandler.ignored())) {
Iterables.getOnlyElement(bundle2.getInputReceivers().values()).accept(valueInGlobalWindow(KV.of("X", "Z")));
}
for (Collection<WindowedValue<?>> windowedValues : outputValues.values()) {
assertThat(windowedValues, containsInAnyOrder(valueInGlobalWindow(KV.of("X", "A")), valueInGlobalWindow(KV.of("X", "B")), valueInGlobalWindow(KV.of("X", "C")), valueInGlobalWindow(KV.of("X", "A")), valueInGlobalWindow(KV.of("X", "B")), valueInGlobalWindow(KV.of("X", "C")), valueInGlobalWindow(KV.of("X", "Empty"))));
}
assertThat(userStateData.get(stateId), IsIterableContainingInOrder.contains(ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "A", Coder.Context.NESTED)), ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "B", Coder.Context.NESTED)), ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "C", Coder.Context.NESTED))));
assertThat(userStateData.get(stateId2), IsEmptyIterable.emptyIterable());
// 3 Requests expected: state read, state2 read, and state2 clear
assertEquals(3, stateRequestHandler.getRequestCount());
ByteString.Output out = ByteString.newOutput();
StringUtf8Coder.of().encode("X", out);
assertEquals(stateId, stateRequestHandler.receivedRequests.get(0).getStateKey().getBagUserState().getUserStateId());
assertEquals(stateRequestHandler.receivedRequests.get(0).getStateKey().getBagUserState().getKey(), out.toByteString());
assertTrue(stateRequestHandler.receivedRequests.get(0).hasGet());
assertEquals(stateId2, stateRequestHandler.receivedRequests.get(1).getStateKey().getBagUserState().getUserStateId());
assertEquals(stateRequestHandler.receivedRequests.get(1).getStateKey().getBagUserState().getKey(), out.toByteString());
assertTrue(stateRequestHandler.receivedRequests.get(1).hasGet());
assertEquals(stateId2, stateRequestHandler.receivedRequests.get(2).getStateKey().getBagUserState().getUserStateId());
assertEquals(stateRequestHandler.receivedRequests.get(2).getStateKey().getBagUserState().getKey(), out.toByteString());
assertTrue(stateRequestHandler.receivedRequests.get(2).hasClear());
}
use of org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload.UserStateId in project beam by apache.
the class ExecutableStage method fromPayload.
/**
* Return an {@link ExecutableStage} constructed from the provided {@link FunctionSpec}
* representation.
*
* <p>See {@link #toPTransform} for how the payload is constructed.
*
* <p>Note: The payload contains some information redundant with the {@link PTransform} it is the
* payload of. The {@link ExecutableStagePayload} should be sufficiently rich to construct a
* {@code ProcessBundleDescriptor} using only the payload.
*/
static ExecutableStage fromPayload(ExecutableStagePayload payload) {
Components components = payload.getComponents();
Environment environment = payload.getEnvironment();
Collection<WireCoderSetting> wireCoderSettings = payload.getWireCoderSettingsList();
PCollectionNode input = PipelineNode.pCollection(payload.getInput(), components.getPcollectionsOrThrow(payload.getInput()));
List<SideInputReference> sideInputs = payload.getSideInputsList().stream().map(sideInputId -> SideInputReference.fromSideInputId(sideInputId, components)).collect(Collectors.toList());
List<UserStateReference> userStates = payload.getUserStatesList().stream().map(userStateId -> UserStateReference.fromUserStateId(userStateId, components)).collect(Collectors.toList());
List<TimerReference> timers = payload.getTimersList().stream().map(timerId -> TimerReference.fromTimerId(timerId, components)).collect(Collectors.toList());
List<PTransformNode> transforms = payload.getTransformsList().stream().map(id -> PipelineNode.pTransform(id, components.getTransformsOrThrow(id))).collect(Collectors.toList());
List<PCollectionNode> outputs = payload.getOutputsList().stream().map(id -> PipelineNode.pCollection(id, components.getPcollectionsOrThrow(id))).collect(Collectors.toList());
return ImmutableExecutableStage.of(components, environment, input, sideInputs, userStates, timers, transforms, outputs, wireCoderSettings);
}
use of org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload.UserStateId in project beam by apache.
the class CreateExecutableStageNodeFunction method apply.
@Override
public Node apply(MutableNetwork<Node, Edge> input) {
for (Node node : input.nodes()) {
if (node instanceof RemoteGrpcPortNode || node instanceof ParallelInstructionNode || node instanceof InstructionOutputNode) {
continue;
}
throw new IllegalArgumentException(String.format("Network contains unknown type of node: %s", input));
}
// Fix all non output nodes to have named edges.
for (Node node : input.nodes()) {
if (node instanceof InstructionOutputNode) {
continue;
}
for (Node successor : input.successors(node)) {
for (Edge edge : input.edgesConnecting(node, successor)) {
if (edge instanceof DefaultEdge) {
input.removeEdge(edge);
input.addEdge(node, successor, MultiOutputInfoEdge.create(new MultiOutputInfo().setTag(idGenerator.getId())));
}
}
}
}
RunnerApi.Components.Builder componentsBuilder = RunnerApi.Components.newBuilder();
componentsBuilder.mergeFrom(this.pipeline.getComponents());
// Default to use the Java environment if pipeline doesn't have environment specified.
if (pipeline.getComponents().getEnvironmentsMap().isEmpty()) {
String envId = Environments.JAVA_SDK_HARNESS_ENVIRONMENT.getUrn() + idGenerator.getId();
componentsBuilder.putEnvironments(envId, Environments.JAVA_SDK_HARNESS_ENVIRONMENT);
}
// By default, use GlobalWindow for all languages.
// For java, if there is a IntervalWindowCoder, then use FixedWindow instead.
// TODO: should get real WindowingStategy from pipeline proto.
String globalWindowingStrategyId = "generatedGlobalWindowingStrategy" + idGenerator.getId();
String intervalWindowEncodingWindowingStrategyId = "generatedIntervalWindowEncodingWindowingStrategy" + idGenerator.getId();
SdkComponents sdkComponents = SdkComponents.create(pipeline.getComponents(), null);
try {
registerWindowingStrategy(globalWindowingStrategyId, WindowingStrategy.globalDefault(), componentsBuilder, sdkComponents);
registerWindowingStrategy(intervalWindowEncodingWindowingStrategyId, WindowingStrategy.of(FixedWindows.of(Duration.standardSeconds(1))), componentsBuilder, sdkComponents);
} catch (IOException exc) {
throw new RuntimeException("Could not convert default windowing stratey to proto", exc);
}
Map<Node, String> nodesToPCollections = new HashMap<>();
ImmutableMap.Builder<String, NameContext> ptransformIdToNameContexts = ImmutableMap.builder();
ImmutableMap.Builder<String, Iterable<SideInputInfo>> ptransformIdToSideInputInfos = ImmutableMap.builder();
ImmutableMap.Builder<String, Iterable<PCollectionView<?>>> ptransformIdToPCollectionViews = ImmutableMap.builder();
// A field of ExecutableStage which includes the PCollection goes to worker side.
Set<PCollectionNode> executableStageOutputs = new HashSet<>();
// A field of ExecutableStage which includes the PCollection goes to runner side.
Set<PCollectionNode> executableStageInputs = new HashSet<>();
for (InstructionOutputNode node : Iterables.filter(input.nodes(), InstructionOutputNode.class)) {
InstructionOutput instructionOutput = node.getInstructionOutput();
String coderId = "generatedCoder" + idGenerator.getId();
String windowingStrategyId;
try (ByteString.Output output = ByteString.newOutput()) {
try {
Coder<?> javaCoder = CloudObjects.coderFromCloudObject(CloudObject.fromSpec(instructionOutput.getCodec()));
Coder<?> elementCoder = ((WindowedValueCoder<?>) javaCoder).getValueCoder();
sdkComponents.registerCoder(elementCoder);
RunnerApi.Coder coderProto = CoderTranslation.toProto(elementCoder, sdkComponents);
componentsBuilder.putCoders(coderId, coderProto);
// For now, Dataflow runner harness only deal with FixedWindow.
if (javaCoder instanceof FullWindowedValueCoder) {
FullWindowedValueCoder<?> windowedValueCoder = (FullWindowedValueCoder<?>) javaCoder;
Coder<?> windowCoder = windowedValueCoder.getWindowCoder();
if (windowCoder instanceof IntervalWindowCoder) {
windowingStrategyId = intervalWindowEncodingWindowingStrategyId;
} else if (windowCoder instanceof GlobalWindow.Coder) {
windowingStrategyId = globalWindowingStrategyId;
} else {
throw new UnsupportedOperationException(String.format("Dataflow portable runner harness doesn't support windowing with %s", windowCoder));
}
} else {
throw new UnsupportedOperationException("Dataflow portable runner harness only supports FullWindowedValueCoder");
}
} catch (IOException e) {
throw new IllegalArgumentException(String.format("Unable to encode coder %s for output %s", instructionOutput.getCodec(), instructionOutput), e);
} catch (Exception e) {
// Coder probably wasn't a java coder
OBJECT_MAPPER.writeValue(output, instructionOutput.getCodec());
componentsBuilder.putCoders(coderId, RunnerApi.Coder.newBuilder().setSpec(RunnerApi.FunctionSpec.newBuilder().setPayload(output.toByteString())).build());
// For non-java coder, hope it's GlobalWindows by default.
// TODO(BEAM-6231): Actually discover the right windowing strategy.
windowingStrategyId = globalWindowingStrategyId;
}
} catch (IOException e) {
throw new IllegalArgumentException(String.format("Unable to encode coder %s for output %s", instructionOutput.getCodec(), instructionOutput), e);
}
// TODO(BEAM-6275): Set correct IsBounded on generated PCollections
String pcollectionId = node.getPcollectionId();
RunnerApi.PCollection pCollection = RunnerApi.PCollection.newBuilder().setCoderId(coderId).setWindowingStrategyId(windowingStrategyId).setIsBounded(RunnerApi.IsBounded.Enum.BOUNDED).build();
nodesToPCollections.put(node, pcollectionId);
componentsBuilder.putPcollections(pcollectionId, pCollection);
// is set
if (isExecutableStageOutputPCollection(input, node)) {
executableStageOutputs.add(PipelineNode.pCollection(pcollectionId, pCollection));
}
if (isExecutableStageInputPCollection(input, node)) {
executableStageInputs.add(PipelineNode.pCollection(pcollectionId, pCollection));
}
}
componentsBuilder.putAllCoders(sdkComponents.toComponents().getCodersMap());
Set<PTransformNode> executableStageTransforms = new HashSet<>();
Set<TimerReference> executableStageTimers = new HashSet<>();
List<UserStateId> userStateIds = new ArrayList<>();
Set<SideInputReference> executableStageSideInputs = new HashSet<>();
for (ParallelInstructionNode node : Iterables.filter(input.nodes(), ParallelInstructionNode.class)) {
ImmutableMap.Builder<String, PCollectionNode> sideInputIds = ImmutableMap.builder();
ParallelInstruction parallelInstruction = node.getParallelInstruction();
String ptransformId = "generatedPtransform" + idGenerator.getId();
ptransformIdToNameContexts.put(ptransformId, NameContext.create(null, parallelInstruction.getOriginalName(), parallelInstruction.getSystemName(), parallelInstruction.getName()));
RunnerApi.PTransform.Builder pTransform = RunnerApi.PTransform.newBuilder();
RunnerApi.FunctionSpec.Builder transformSpec = RunnerApi.FunctionSpec.newBuilder();
List<String> timerIds = new ArrayList<>();
if (parallelInstruction.getParDo() != null) {
ParDoInstruction parDoInstruction = parallelInstruction.getParDo();
CloudObject userFnSpec = CloudObject.fromSpec(parDoInstruction.getUserFn());
String userFnClassName = userFnSpec.getClassName();
if (userFnClassName.equals("CombineValuesFn") || userFnClassName.equals("KeyedCombineFn")) {
transformSpec = transformCombineValuesFnToFunctionSpec(userFnSpec);
ptransformIdToPCollectionViews.put(ptransformId, Collections.emptyList());
} else {
String parDoPTransformId = getString(userFnSpec, PropertyNames.SERIALIZED_FN);
RunnerApi.PTransform parDoPTransform = pipeline.getComponents().getTransformsOrDefault(parDoPTransformId, null);
// TODO: only the non-null branch should exist; for migration ease only
if (parDoPTransform != null) {
checkArgument(parDoPTransform.getSpec().getUrn().equals(PTransformTranslation.PAR_DO_TRANSFORM_URN), "Found transform \"%s\" for ParallelDo instruction, " + " but that transform had unexpected URN \"%s\" (expected \"%s\")", parDoPTransformId, parDoPTransform.getSpec().getUrn(), PTransformTranslation.PAR_DO_TRANSFORM_URN);
RunnerApi.ParDoPayload parDoPayload;
try {
parDoPayload = RunnerApi.ParDoPayload.parseFrom(parDoPTransform.getSpec().getPayload());
} catch (InvalidProtocolBufferException exc) {
throw new RuntimeException("ParDo did not have a ParDoPayload", exc);
}
// user timers and user state.
for (Map.Entry<String, RunnerApi.TimerFamilySpec> entry : parDoPayload.getTimerFamilySpecsMap().entrySet()) {
timerIds.add(entry.getKey());
}
for (Map.Entry<String, RunnerApi.StateSpec> entry : parDoPayload.getStateSpecsMap().entrySet()) {
UserStateId.Builder builder = UserStateId.newBuilder();
builder.setTransformId(parDoPTransformId);
builder.setLocalName(entry.getKey());
userStateIds.add(builder.build());
}
// To facilitate the creation of Set executableStageSideInputs.
for (String sideInputTag : parDoPayload.getSideInputsMap().keySet()) {
String sideInputPCollectionId = parDoPTransform.getInputsOrThrow(sideInputTag);
RunnerApi.PCollection sideInputPCollection = pipeline.getComponents().getPcollectionsOrThrow(sideInputPCollectionId);
pTransform.putInputs(sideInputTag, sideInputPCollectionId);
PCollectionNode pCollectionNode = PipelineNode.pCollection(sideInputPCollectionId, sideInputPCollection);
sideInputIds.put(sideInputTag, pCollectionNode);
}
// To facilitate the creation of Map(ptransformId -> pCollectionView), which is
// required by constructing an ExecutableStageNode.
ImmutableList.Builder<PCollectionView<?>> pcollectionViews = ImmutableList.builder();
for (Map.Entry<String, RunnerApi.SideInput> sideInputEntry : parDoPayload.getSideInputsMap().entrySet()) {
pcollectionViews.add(RegisterNodeFunction.transformSideInputForRunner(pipeline, parDoPTransform, sideInputEntry.getKey(), sideInputEntry.getValue()));
}
ptransformIdToPCollectionViews.put(ptransformId, pcollectionViews.build());
transformSpec.setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN).setPayload(parDoPayload.toByteString());
} else {
// legacy path - bytes are the FunctionSpec's payload field, basically, and
// SDKs expect it in the PTransform's payload field
byte[] userFnBytes = getBytes(userFnSpec, PropertyNames.SERIALIZED_FN);
transformSpec.setUrn(ParDoTranslation.CUSTOM_JAVA_DO_FN_URN).setPayload(ByteString.copyFrom(userFnBytes));
}
if (parDoInstruction.getSideInputs() != null) {
ptransformIdToSideInputInfos.put(ptransformId, forSideInputInfos(parDoInstruction.getSideInputs(), true));
}
}
} else if (parallelInstruction.getRead() != null) {
ReadInstruction readInstruction = parallelInstruction.getRead();
CloudObject sourceSpec = CloudObject.fromSpec(CloudSourceUtils.flattenBaseSpecs(readInstruction.getSource()).getSpec());
// TODO: Need to plumb through the SDK specific function spec.
transformSpec.setUrn(JAVA_SOURCE_URN);
try {
byte[] serializedSource = Base64.getDecoder().decode(getString(sourceSpec, SERIALIZED_SOURCE));
ByteString sourceByteString = ByteString.copyFrom(serializedSource);
transformSpec.setPayload(sourceByteString);
} catch (Exception e) {
throw new IllegalArgumentException(String.format("Unable to process Read %s", parallelInstruction), e);
}
} else if (parallelInstruction.getFlatten() != null) {
transformSpec.setUrn(PTransformTranslation.FLATTEN_TRANSFORM_URN);
} else {
throw new IllegalArgumentException(String.format("Unknown type of ParallelInstruction %s", parallelInstruction));
}
// predecessor in a ParDo. This PCollection is called the "main input".
for (Node predecessorOutput : input.predecessors(node)) {
pTransform.putInputs("generatedInput" + idGenerator.getId(), nodesToPCollections.get(predecessorOutput));
}
for (Edge edge : input.outEdges(node)) {
Node nodeOutput = input.incidentNodes(edge).target();
MultiOutputInfoEdge edge2 = (MultiOutputInfoEdge) edge;
pTransform.putOutputs(edge2.getMultiOutputInfo().getTag(), nodesToPCollections.get(nodeOutput));
}
pTransform.setSpec(transformSpec);
PTransformNode pTransformNode = PipelineNode.pTransform(ptransformId, pTransform.build());
executableStageTransforms.add(pTransformNode);
for (String timerId : timerIds) {
executableStageTimers.add(TimerReference.of(pTransformNode, timerId));
}
ImmutableMap<String, PCollectionNode> sideInputIdToPCollectionNodes = sideInputIds.build();
for (String sideInputTag : sideInputIdToPCollectionNodes.keySet()) {
SideInputReference sideInputReference = SideInputReference.of(pTransformNode, sideInputTag, sideInputIdToPCollectionNodes.get(sideInputTag));
executableStageSideInputs.add(sideInputReference);
}
executableStageTransforms.add(pTransformNode);
}
if (executableStageInputs.size() != 1) {
throw new UnsupportedOperationException("ExecutableStage only support one input PCollection");
}
PCollectionNode executableInput = executableStageInputs.iterator().next();
RunnerApi.Components executableStageComponents = componentsBuilder.build();
// Get Environment from ptransform, otherwise, use JAVA_SDK_HARNESS_ENVIRONMENT as default.
Environment executableStageEnv = getEnvironmentFromPTransform(executableStageComponents, executableStageTransforms);
if (executableStageEnv == null) {
executableStageEnv = Environments.JAVA_SDK_HARNESS_ENVIRONMENT;
}
Set<UserStateReference> executableStageUserStateReference = new HashSet<>();
for (UserStateId userStateId : userStateIds) {
executableStageUserStateReference.add(UserStateReference.fromUserStateId(userStateId, executableStageComponents));
}
ExecutableStage executableStage = ImmutableExecutableStage.ofFullComponents(executableStageComponents, executableStageEnv, executableInput, executableStageSideInputs, executableStageUserStateReference, executableStageTimers, executableStageTransforms, executableStageOutputs, DEFAULT_WIRE_CODER_SETTINGS);
return ExecutableStageNode.create(executableStage, ptransformIdToNameContexts.build(), ptransformIdToSideInputInfos.build(), ptransformIdToPCollectionViews.build());
}
use of org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload.UserStateId in project beam by apache.
the class UserStateReference method fromUserStateId.
/**
* Create a user state reference from a UserStateId proto and components.
*/
public static UserStateReference fromUserStateId(UserStateId userStateId, RunnerApi.Components components) {
PTransform transform = components.getTransformsOrThrow(userStateId.getTransformId());
String mainInputCollectionId;
try {
mainInputCollectionId = transform.getInputsOrThrow(ParDoTranslation.getMainInputName(transform));
} catch (IOException e) {
throw new RuntimeException(e);
}
return UserStateReference.of(PipelineNode.pTransform(userStateId.getTransformId(), transform), userStateId.getLocalName(), PipelineNode.pCollection(mainInputCollectionId, components.getPcollectionsOrThrow(mainInputCollectionId)));
}
use of org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload.UserStateId in project beam by apache.
the class RemoteExecutionTest method testExecutionWithUserState.
@Test
public void testExecutionWithUserState() throws Exception {
launchSdkHarness(PipelineOptionsFactory.create());
Pipeline p = Pipeline.create();
final String stateId = "foo";
final String stateId2 = "foo2";
p.apply("impulse", Impulse.create()).apply("create", ParDo.of(new DoFn<byte[], KV<String, String>>() {
@ProcessElement
public void process(ProcessContext ctxt) {
}
})).setCoder(KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of())).apply("userState", ParDo.of(new DoFn<KV<String, String>, KV<String, String>>() {
@StateId(stateId)
private final StateSpec<BagState<String>> bufferState = StateSpecs.bag(StringUtf8Coder.of());
@StateId(stateId2)
private final StateSpec<BagState<String>> bufferState2 = StateSpecs.bag(StringUtf8Coder.of());
@ProcessElement
public void processElement(@Element KV<String, String> element, @StateId(stateId) BagState<String> state, @StateId(stateId2) BagState<String> state2, OutputReceiver<KV<String, String>> r) {
for (String value : state.read()) {
r.output(KV.of(element.getKey(), value));
}
state.add(element.getValue());
state2.clear();
}
})).apply("gbk", GroupByKey.create());
RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
FusedPipeline fused = GreedyPipelineFuser.fuse(pipelineProto);
Optional<ExecutableStage> optionalStage = Iterables.tryFind(fused.getFusedStages(), (ExecutableStage stage) -> !stage.getUserStates().isEmpty());
checkState(optionalStage.isPresent(), "Expected a stage with user state.");
ExecutableStage stage = optionalStage.get();
ExecutableProcessBundleDescriptor descriptor = ProcessBundleDescriptors.fromExecutableStage("test_stage", stage, dataServer.getApiServiceDescriptor(), stateServer.getApiServiceDescriptor());
BundleProcessor processor = controlClient.getProcessor(descriptor.getProcessBundleDescriptor(), descriptor.getRemoteInputDestinations(), stateDelegator);
Map<String, Coder> remoteOutputCoders = descriptor.getRemoteOutputCoders();
Map<String, Collection<WindowedValue<?>>> outputValues = new HashMap<>();
Map<String, RemoteOutputReceiver<?>> outputReceivers = new HashMap<>();
for (Entry<String, Coder> remoteOutputCoder : remoteOutputCoders.entrySet()) {
List<WindowedValue<?>> outputContents = Collections.synchronizedList(new ArrayList<>());
outputValues.put(remoteOutputCoder.getKey(), outputContents);
outputReceivers.put(remoteOutputCoder.getKey(), RemoteOutputReceiver.of((Coder<WindowedValue<?>>) remoteOutputCoder.getValue(), outputContents::add));
}
Map<String, List<ByteString>> userStateData = ImmutableMap.of(stateId, new ArrayList(Arrays.asList(ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "A", Coder.Context.NESTED)), ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "B", Coder.Context.NESTED)), ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "C", Coder.Context.NESTED)))), stateId2, new ArrayList(Arrays.asList(ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "D", Coder.Context.NESTED)))));
StateRequestHandler stateRequestHandler = StateRequestHandlers.forBagUserStateHandlerFactory(descriptor, new BagUserStateHandlerFactory<ByteString, Object, BoundedWindow>() {
@Override
public BagUserStateHandler<ByteString, Object, BoundedWindow> forUserState(String pTransformId, String userStateId, Coder<ByteString> keyCoder, Coder<Object> valueCoder, Coder<BoundedWindow> windowCoder) {
return new BagUserStateHandler<ByteString, Object, BoundedWindow>() {
@Override
public Iterable<Object> get(ByteString key, BoundedWindow window) {
return (Iterable) userStateData.get(userStateId);
}
@Override
public void append(ByteString key, BoundedWindow window, Iterator<Object> values) {
Iterators.addAll(userStateData.get(userStateId), (Iterator) values);
}
@Override
public void clear(ByteString key, BoundedWindow window) {
userStateData.get(userStateId).clear();
}
};
}
});
try (RemoteBundle bundle = processor.newBundle(outputReceivers, stateRequestHandler, BundleProgressHandler.ignored())) {
Iterables.getOnlyElement(bundle.getInputReceivers().values()).accept(valueInGlobalWindow(KV.of("X", "Y")));
}
for (Collection<WindowedValue<?>> windowedValues : outputValues.values()) {
assertThat(windowedValues, containsInAnyOrder(valueInGlobalWindow(KV.of("X", "A")), valueInGlobalWindow(KV.of("X", "B")), valueInGlobalWindow(KV.of("X", "C"))));
}
assertThat(userStateData.get(stateId), IsIterableContainingInOrder.contains(ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "A", Coder.Context.NESTED)), ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "B", Coder.Context.NESTED)), ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "C", Coder.Context.NESTED)), ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "Y", Coder.Context.NESTED))));
assertThat(userStateData.get(stateId2), IsEmptyIterable.emptyIterable());
}
Aggregations