use of org.apache.beam.runners.fnexecution.control.StageBundleFactory in project beam by apache.
the class SparkExecutableStageFunction method call.
@Override
public Iterator<RawUnionValue> call(Iterator<WindowedValue<InputT>> inputs) throws Exception {
SparkPipelineOptions options = pipelineOptions.get().as(SparkPipelineOptions.class);
// Register standard file systems.
FileSystems.setDefaultPipelineOptions(options);
// Otherwise, this may cause validation errors (e.g. ParDoTest)
if (!inputs.hasNext()) {
return Collections.emptyIterator();
}
try (ExecutableStageContext stageContext = contextFactory.get(jobInfo)) {
ExecutableStage executableStage = ExecutableStage.fromPayload(stagePayload);
try (StageBundleFactory stageBundleFactory = stageContext.getStageBundleFactory(executableStage)) {
ConcurrentLinkedQueue<RawUnionValue> collector = new ConcurrentLinkedQueue<>();
StateRequestHandler stateRequestHandler = getStateRequestHandler(executableStage, stageBundleFactory.getProcessBundleDescriptor());
if (executableStage.getTimers().size() == 0) {
ReceiverFactory receiverFactory = new ReceiverFactory(collector, outputMap);
processElements(stateRequestHandler, receiverFactory, null, stageBundleFactory, inputs);
return collector.iterator();
}
// Used with Batch, we know that all the data is available for this key. We can't use the
// timer manager from the context because it doesn't exist. So we create one and advance
// time to the end after processing all elements.
final InMemoryTimerInternals timerInternals = new InMemoryTimerInternals();
timerInternals.advanceProcessingTime(Instant.now());
timerInternals.advanceSynchronizedProcessingTime(Instant.now());
ReceiverFactory receiverFactory = new ReceiverFactory(collector, outputMap);
TimerReceiverFactory timerReceiverFactory = new TimerReceiverFactory(stageBundleFactory, (Timer<?> timer, TimerInternals.TimerData timerData) -> {
currentTimerKey = timer.getUserKey();
if (timer.getClearBit()) {
timerInternals.deleteTimer(timerData);
} else {
timerInternals.setTimer(timerData);
}
}, windowCoder);
// Process inputs.
processElements(stateRequestHandler, receiverFactory, timerReceiverFactory, stageBundleFactory, inputs);
// Finish any pending windows by advancing the input watermark to infinity.
timerInternals.advanceInputWatermark(BoundedWindow.TIMESTAMP_MAX_VALUE);
// Finally, advance the processing time to infinity to fire any timers.
timerInternals.advanceProcessingTime(BoundedWindow.TIMESTAMP_MAX_VALUE);
timerInternals.advanceSynchronizedProcessingTime(BoundedWindow.TIMESTAMP_MAX_VALUE);
// itself)
while (timerInternals.hasPendingTimers()) {
try (RemoteBundle bundle = stageBundleFactory.getBundle(receiverFactory, timerReceiverFactory, stateRequestHandler, getBundleProgressHandler())) {
PipelineTranslatorUtils.fireEligibleTimers(timerInternals, bundle.getTimerReceivers(), currentTimerKey);
}
}
return collector.iterator();
}
}
}
use of org.apache.beam.runners.fnexecution.control.StageBundleFactory in project beam by apache.
the class SparkExecutableStageFunctionTest method outputsAreTaggedCorrectly.
@Test
public void outputsAreTaggedCorrectly() throws Exception {
WindowedValue<Integer> three = WindowedValue.valueInGlobalWindow(3);
WindowedValue<Integer> four = WindowedValue.valueInGlobalWindow(4);
WindowedValue<Integer> five = WindowedValue.valueInGlobalWindow(5);
Map<String, Integer> outputTagMap = ImmutableMap.of("one", 1, "two", 2, "three", 3);
// We use a real StageBundleFactory here in order to exercise the output receiver factory.
StageBundleFactory stageBundleFactory = new StageBundleFactory() {
private boolean once;
@Override
public RemoteBundle getBundle(OutputReceiverFactory receiverFactory, TimerReceiverFactory timerReceiverFactory, StateRequestHandler stateRequestHandler, BundleProgressHandler progressHandler, BundleFinalizationHandler finalizationHandler, BundleCheckpointHandler checkpointHandler) {
return new RemoteBundle() {
@Override
public String getId() {
return "bundle-id";
}
@Override
public Map<String, FnDataReceiver> getInputReceivers() {
return ImmutableMap.of("input", input -> {
/* Ignore input*/
});
}
@Override
public Map<KV<String, String>, FnDataReceiver<Timer>> getTimerReceivers() {
return Collections.emptyMap();
}
@Override
public void requestProgress() {
throw new UnsupportedOperationException();
}
@Override
public void split(double fractionOfRemainder) {
throw new UnsupportedOperationException();
}
@Override
public void close() throws Exception {
if (once) {
return;
}
// Emit all values to the runner when the bundle is closed.
receiverFactory.create("one").accept(three);
receiverFactory.create("two").accept(four);
receiverFactory.create("three").accept(five);
once = true;
}
};
}
@Override
public ProcessBundleDescriptors.ExecutableProcessBundleDescriptor getProcessBundleDescriptor() {
return Mockito.mock(ProcessBundleDescriptors.ExecutableProcessBundleDescriptor.class);
}
@Override
public InstructionRequestHandler getInstructionRequestHandler() {
return null;
}
@Override
public void close() {
}
};
when(stageContext.getStageBundleFactory(any())).thenReturn(stageBundleFactory);
SparkExecutableStageFunction<Integer, ?> function = getFunction(outputTagMap);
List<WindowedValue<Integer>> inputs = new ArrayList<>();
inputs.add(WindowedValue.valueInGlobalWindow(0));
Iterator<RawUnionValue> iterator = function.call(inputs.iterator());
Iterable<RawUnionValue> iterable = () -> iterator;
assertThat(iterable, contains(new RawUnionValue(1, three), new RawUnionValue(2, four), new RawUnionValue(3, five)));
}
Aggregations