use of org.apache.beam.sdk.coders.Coder in project beam by apache.
the class BeamFnMapTaskExecutorTest method generateDataflowStepContext.
/**
* Generates bare minumum DataflowStepContext to use for testing.
*
* @param valuesPrefix prefix for all types of names that are specified in DataflowStepContext.
* @return new instance of DataflowStepContext
*/
private DataflowStepContext generateDataflowStepContext(String valuesPrefix) {
NameContext nc = new NameContext() {
@Override
@Nullable
public String stageName() {
return valuesPrefix + "Stage";
}
@Override
@Nullable
public String originalName() {
return valuesPrefix + "OriginalName";
}
@Override
@Nullable
public String systemName() {
return valuesPrefix + "SystemName";
}
@Override
@Nullable
public String userName() {
return valuesPrefix + "UserName";
}
};
DataflowStepContext dsc = new DataflowStepContext(nc) {
@Override
@Nullable
public <W extends BoundedWindow> TimerData getNextFiredTimer(Coder<W> windowCoder) {
return null;
}
@Override
public <W extends BoundedWindow> void setStateCleanupTimer(String timerId, W window, Coder<W> windowCoder, Instant cleanupTime, Instant cleanupOutputTimestamp) {
}
@Override
public DataflowStepContext namespacedToUser() {
return this;
}
@Override
public StateInternals stateInternals() {
return null;
}
@Override
public TimerInternals timerInternals() {
return null;
}
};
return dsc;
}
use of org.apache.beam.sdk.coders.Coder in project beam by apache.
the class ProcessBundleDescriptorsTest method testLengthPrefixingOfKeyCoderInStatefulExecutableStage.
/**
* Tests that a stateful stage will wrap the key coder of a stateful transform in a
* LengthPrefixCoder.
*/
@Test
public void testLengthPrefixingOfKeyCoderInStatefulExecutableStage() throws Exception {
// Add another stateful stage with a non-standard key coder
Pipeline p = Pipeline.create();
Coder<Void> keycoder = VoidCoder.of();
assertThat(ModelCoderRegistrar.isKnownCoder(keycoder), is(false));
p.apply("impulse", Impulse.create()).apply("create", ParDo.of(new DoFn<byte[], KV<Void, String>>() {
@ProcessElement
public void process(ProcessContext ctxt) {
}
})).setCoder(KvCoder.of(keycoder, StringUtf8Coder.of())).apply("userState", ParDo.of(new DoFn<KV<Void, String>, KV<Void, String>>() {
@StateId("stateId")
private final StateSpec<BagState<String>> bufferState = StateSpecs.bag(StringUtf8Coder.of());
@TimerId("timerId")
private final TimerSpec timerSpec = TimerSpecs.timer(TimeDomain.EVENT_TIME);
@ProcessElement
public void processElement(@Element KV<Void, String> element, @StateId("stateId") BagState<String> state, @TimerId("timerId") Timer timer, OutputReceiver<KV<Void, String>> r) {
}
@OnTimer("timerId")
public void onTimer() {
}
})).apply("gbk", GroupByKey.create());
RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
FusedPipeline fused = GreedyPipelineFuser.fuse(pipelineProto);
Optional<ExecutableStage> optionalStage = Iterables.tryFind(fused.getFusedStages(), (ExecutableStage stage) -> stage.getUserStates().stream().anyMatch(spec -> spec.localName().equals("stateId")));
checkState(optionalStage.isPresent(), "Expected a stage with user state.");
ExecutableStage stage = optionalStage.get();
PipelineNode.PCollectionNode inputPCollection = stage.getInputPCollection();
// Ensure original key coder is not a LengthPrefixCoder
Map<String, RunnerApi.Coder> stageCoderMap = stage.getComponents().getCodersMap();
RunnerApi.Coder originalMainInputCoder = stageCoderMap.get(inputPCollection.getPCollection().getCoderId());
String originalKeyCoderId = ModelCoders.getKvCoderComponents(originalMainInputCoder).keyCoderId();
RunnerApi.Coder originalKeyCoder = stageCoderMap.get(originalKeyCoderId);
assertThat(originalKeyCoder.getSpec().getUrn(), is(CoderTranslation.JAVA_SERIALIZED_CODER_URN));
// Now create ProcessBundleDescriptor and check for the LengthPrefixCoder around the key coder
BeamFnApi.ProcessBundleDescriptor pbd = ProcessBundleDescriptors.fromExecutableStage("test_stage", stage, Endpoints.ApiServiceDescriptor.getDefaultInstance()).getProcessBundleDescriptor();
Map<String, RunnerApi.Coder> pbsCoderMap = pbd.getCodersMap();
RunnerApi.Coder pbsMainInputCoder = pbsCoderMap.get(pbd.getPcollectionsOrThrow(inputPCollection.getId()).getCoderId());
String keyCoderId = ModelCoders.getKvCoderComponents(pbsMainInputCoder).keyCoderId();
RunnerApi.Coder keyCoder = pbsCoderMap.get(keyCoderId);
ensureLengthPrefixed(keyCoder, originalKeyCoder, pbsCoderMap);
TimerReference timerRef = Iterables.getOnlyElement(stage.getTimers());
String timerTransformId = timerRef.transform().getId();
RunnerApi.ParDoPayload parDoPayload = RunnerApi.ParDoPayload.parseFrom(pbd.getTransformsOrThrow(timerTransformId).getSpec().getPayload());
RunnerApi.TimerFamilySpec timerSpec = parDoPayload.getTimerFamilySpecsOrThrow(timerRef.localName());
RunnerApi.Coder timerCoder = pbsCoderMap.get(timerSpec.getTimerFamilyCoderId());
String timerKeyCoderId = timerCoder.getComponentCoderIds(0);
RunnerApi.Coder timerKeyCoder = pbsCoderMap.get(timerKeyCoderId);
ensureLengthPrefixed(timerKeyCoder, originalKeyCoder, pbsCoderMap);
}
use of org.apache.beam.sdk.coders.Coder in project beam by apache.
the class RemoteExecutionTest method testExecution.
@Test
public void testExecution() throws Exception {
launchSdkHarness(PipelineOptionsFactory.create());
Pipeline p = Pipeline.create();
p.apply("impulse", Impulse.create()).apply("create", ParDo.of(new DoFn<byte[], String>() {
@ProcessElement
public void process(ProcessContext ctxt) {
ctxt.output("zero");
ctxt.output("one");
ctxt.output("two");
}
})).apply("len", ParDo.of(new DoFn<String, Long>() {
@ProcessElement
public void process(ProcessContext ctxt) {
ctxt.output((long) ctxt.element().length());
}
})).apply("addKeys", WithKeys.of("foo")).setCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianLongCoder.of())).apply("gbk", GroupByKey.create());
RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
FusedPipeline fused = GreedyPipelineFuser.fuse(pipelineProto);
checkState(fused.getFusedStages().size() == 1, "Expected exactly one fused stage");
ExecutableStage stage = fused.getFusedStages().iterator().next();
ExecutableProcessBundleDescriptor descriptor = ProcessBundleDescriptors.fromExecutableStage("my_stage", stage, dataServer.getApiServiceDescriptor());
BundleProcessor processor = controlClient.getProcessor(descriptor.getProcessBundleDescriptor(), descriptor.getRemoteInputDestinations());
Map<String, ? super Coder<WindowedValue<?>>> remoteOutputCoders = descriptor.getRemoteOutputCoders();
Map<String, Collection<? super WindowedValue<?>>> outputValues = new HashMap<>();
Map<String, RemoteOutputReceiver<?>> outputReceivers = new HashMap<>();
for (Entry<String, ? super Coder<WindowedValue<?>>> remoteOutputCoder : remoteOutputCoders.entrySet()) {
List<? super WindowedValue<?>> outputContents = Collections.synchronizedList(new ArrayList<>());
outputValues.put(remoteOutputCoder.getKey(), outputContents);
outputReceivers.put(remoteOutputCoder.getKey(), RemoteOutputReceiver.of((Coder) remoteOutputCoder.getValue(), (FnDataReceiver<? super WindowedValue<?>>) outputContents::add));
}
try (RemoteBundle bundle = processor.newBundle(outputReceivers, BundleProgressHandler.ignored())) {
Iterables.getOnlyElement(bundle.getInputReceivers().values()).accept(valueInGlobalWindow(new byte[0]));
}
for (Collection<? super WindowedValue<?>> windowedValues : outputValues.values()) {
assertThat(windowedValues, containsInAnyOrder(valueInGlobalWindow(byteValueOf("foo", 4)), valueInGlobalWindow(byteValueOf("foo", 3)), valueInGlobalWindow(byteValueOf("foo", 3))));
}
}
use of org.apache.beam.sdk.coders.Coder in project beam by apache.
the class RemoteExecutionTest method testExecutionWithUserStateCaching.
@Test
public void testExecutionWithUserStateCaching() throws Exception {
Pipeline p = Pipeline.create();
launchSdkHarness(p.getOptions());
final String stateId = "foo";
final String stateId2 = "bar";
p.apply("impulse", Impulse.create()).apply("create", ParDo.of(new DoFn<byte[], KV<String, String>>() {
@ProcessElement
public void process(ProcessContext ctxt) {
}
})).setCoder(KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of())).apply("userState", ParDo.of(new DoFn<KV<String, String>, KV<String, String>>() {
@StateId(stateId)
private final StateSpec<BagState<String>> bufferState = StateSpecs.bag(StringUtf8Coder.of());
@StateId(stateId2)
private final StateSpec<BagState<String>> bufferState2 = StateSpecs.bag(StringUtf8Coder.of());
@ProcessElement
public void processElement(@Element KV<String, String> element, @StateId(stateId) BagState<String> state, @StateId(stateId2) BagState<String> state2, OutputReceiver<KV<String, String>> r) {
for (String value : state.read()) {
r.output(KV.of(element.getKey(), value));
}
ReadableState<Boolean> isEmpty = state2.isEmpty();
if (isEmpty.read()) {
r.output(KV.of(element.getKey(), "Empty"));
} else {
state2.clear();
}
}
})).apply("gbk", GroupByKey.create());
RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
FusedPipeline fused = GreedyPipelineFuser.fuse(pipelineProto);
Optional<ExecutableStage> optionalStage = Iterables.tryFind(fused.getFusedStages(), (ExecutableStage stage) -> !stage.getUserStates().isEmpty());
checkState(optionalStage.isPresent(), "Expected a stage with user state.");
ExecutableStage stage = optionalStage.get();
ExecutableProcessBundleDescriptor descriptor = ProcessBundleDescriptors.fromExecutableStage("test_stage", stage, dataServer.getApiServiceDescriptor(), stateServer.getApiServiceDescriptor());
BundleProcessor processor = controlClient.getProcessor(descriptor.getProcessBundleDescriptor(), descriptor.getRemoteInputDestinations(), stateDelegator);
Map<String, Coder> remoteOutputCoders = descriptor.getRemoteOutputCoders();
Map<String, Collection<WindowedValue<?>>> outputValues = new HashMap<>();
Map<String, RemoteOutputReceiver<?>> outputReceivers = new HashMap<>();
for (Entry<String, Coder> remoteOutputCoder : remoteOutputCoders.entrySet()) {
List<WindowedValue<?>> outputContents = Collections.synchronizedList(new ArrayList<>());
outputValues.put(remoteOutputCoder.getKey(), outputContents);
outputReceivers.put(remoteOutputCoder.getKey(), RemoteOutputReceiver.of((Coder<WindowedValue<?>>) remoteOutputCoder.getValue(), outputContents::add));
}
Map<String, List<ByteString>> userStateData = ImmutableMap.of(stateId, new ArrayList(Arrays.asList(ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "A", Coder.Context.NESTED)), ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "B", Coder.Context.NESTED)), ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "C", Coder.Context.NESTED)))), stateId2, new ArrayList(Arrays.asList(ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "D", Coder.Context.NESTED)))));
StoringStateRequestHandler stateRequestHandler = new StoringStateRequestHandler(StateRequestHandlers.forBagUserStateHandlerFactory(descriptor, new BagUserStateHandlerFactory<ByteString, Object, BoundedWindow>() {
@Override
public BagUserStateHandler<ByteString, Object, BoundedWindow> forUserState(String pTransformId, String userStateId, Coder<ByteString> keyCoder, Coder<Object> valueCoder, Coder<BoundedWindow> windowCoder) {
return new BagUserStateHandler<ByteString, Object, BoundedWindow>() {
@Override
public Iterable<Object> get(ByteString key, BoundedWindow window) {
return (Iterable) userStateData.get(userStateId);
}
@Override
public void append(ByteString key, BoundedWindow window, Iterator<Object> values) {
Iterators.addAll(userStateData.get(userStateId), (Iterator) values);
}
@Override
public void clear(ByteString key, BoundedWindow window) {
userStateData.get(userStateId).clear();
}
};
}
}));
try (RemoteBundle bundle = processor.newBundle(outputReceivers, stateRequestHandler, BundleProgressHandler.ignored())) {
Iterables.getOnlyElement(bundle.getInputReceivers().values()).accept(valueInGlobalWindow(KV.of("X", "Y")));
}
try (RemoteBundle bundle2 = processor.newBundle(outputReceivers, stateRequestHandler, BundleProgressHandler.ignored())) {
Iterables.getOnlyElement(bundle2.getInputReceivers().values()).accept(valueInGlobalWindow(KV.of("X", "Z")));
}
for (Collection<WindowedValue<?>> windowedValues : outputValues.values()) {
assertThat(windowedValues, containsInAnyOrder(valueInGlobalWindow(KV.of("X", "A")), valueInGlobalWindow(KV.of("X", "B")), valueInGlobalWindow(KV.of("X", "C")), valueInGlobalWindow(KV.of("X", "A")), valueInGlobalWindow(KV.of("X", "B")), valueInGlobalWindow(KV.of("X", "C")), valueInGlobalWindow(KV.of("X", "Empty"))));
}
assertThat(userStateData.get(stateId), IsIterableContainingInOrder.contains(ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "A", Coder.Context.NESTED)), ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "B", Coder.Context.NESTED)), ByteString.copyFrom(CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "C", Coder.Context.NESTED))));
assertThat(userStateData.get(stateId2), IsEmptyIterable.emptyIterable());
// 3 Requests expected: state read, state2 read, and state2 clear
assertEquals(3, stateRequestHandler.getRequestCount());
ByteString.Output out = ByteString.newOutput();
StringUtf8Coder.of().encode("X", out);
assertEquals(stateId, stateRequestHandler.receivedRequests.get(0).getStateKey().getBagUserState().getUserStateId());
assertEquals(stateRequestHandler.receivedRequests.get(0).getStateKey().getBagUserState().getKey(), out.toByteString());
assertTrue(stateRequestHandler.receivedRequests.get(0).hasGet());
assertEquals(stateId2, stateRequestHandler.receivedRequests.get(1).getStateKey().getBagUserState().getUserStateId());
assertEquals(stateRequestHandler.receivedRequests.get(1).getStateKey().getBagUserState().getKey(), out.toByteString());
assertTrue(stateRequestHandler.receivedRequests.get(1).hasGet());
assertEquals(stateId2, stateRequestHandler.receivedRequests.get(2).getStateKey().getBagUserState().getUserStateId());
assertEquals(stateRequestHandler.receivedRequests.get(2).getStateKey().getBagUserState().getKey(), out.toByteString());
assertTrue(stateRequestHandler.receivedRequests.get(2).hasClear());
}
use of org.apache.beam.sdk.coders.Coder in project beam by apache.
the class ProcessBundleDescriptors method fromExecutableStageInternal.
private static ExecutableProcessBundleDescriptor fromExecutableStageInternal(String id, ExecutableStage stage, ApiServiceDescriptor dataEndpoint, @Nullable ApiServiceDescriptor stateEndpoint) throws IOException {
// Create with all of the processing transforms, and all of the components.
// TODO: Remove the unreachable subcomponents if the size of the descriptor matters.
Map<String, PTransform> stageTransforms = stage.getTransforms().stream().collect(Collectors.toMap(PTransformNode::getId, PTransformNode::getTransform));
Components.Builder components = stage.getComponents().toBuilder().clearTransforms().putAllTransforms(stageTransforms);
ImmutableList.Builder<RemoteInputDestination> inputDestinationsBuilder = ImmutableList.builder();
ImmutableMap.Builder<String, Coder> remoteOutputCodersBuilder = ImmutableMap.builder();
WireCoderSetting wireCoderSetting = stage.getWireCoderSettings().stream().filter(ws -> ws.getInputOrOutputId().equals(stage.getInputPCollection().getId())).findAny().orElse(WireCoderSetting.getDefaultInstance());
// The order of these does not matter.
inputDestinationsBuilder.add(addStageInput(dataEndpoint, stage.getInputPCollection(), components, wireCoderSetting));
remoteOutputCodersBuilder.putAll(addStageOutputs(dataEndpoint, stage.getOutputPCollections(), components, stage.getWireCoderSettings()));
Map<String, Map<String, SideInputSpec>> sideInputSpecs = addSideInputs(stage, components);
Map<String, Map<String, BagUserStateSpec>> bagUserStateSpecs = forBagUserStates(stage, components.build());
Map<String, Map<String, TimerSpec>> timerSpecs = forTimerSpecs(stage, components);
lengthPrefixAnyInputCoder(stage.getInputPCollection().getId(), components);
// Copy data from components to ProcessBundleDescriptor.
ProcessBundleDescriptor.Builder bundleDescriptorBuilder = ProcessBundleDescriptor.newBuilder().setId(id);
if (stateEndpoint != null) {
bundleDescriptorBuilder.setStateApiServiceDescriptor(stateEndpoint);
}
if (timerSpecs.size() > 0) {
// By default use the data endpoint for timers, in the future considering enabling specifying
// a different ApiServiceDescriptor for timers.
bundleDescriptorBuilder.setTimerApiServiceDescriptor(dataEndpoint);
}
bundleDescriptorBuilder.putAllCoders(components.getCodersMap()).putAllEnvironments(components.getEnvironmentsMap()).putAllPcollections(components.getPcollectionsMap()).putAllWindowingStrategies(components.getWindowingStrategiesMap()).putAllTransforms(components.getTransformsMap());
return ExecutableProcessBundleDescriptor.of(bundleDescriptorBuilder.build(), inputDestinationsBuilder.build(), remoteOutputCodersBuilder.build(), sideInputSpecs, bagUserStateSpecs, timerSpecs);
}
Aggregations