use of org.apache.beam.runners.dataflow.util.DoFnInfo in project beam by apache.
the class ProcessBundleHandler method createDoFnRunner.
/**
* Converts a {@link org.apache.beam.fn.v1.BeamFnApi.FunctionSpec} into a {@link DoFnRunner}.
*/
private <InputT, OutputT> DoFnRunner<InputT, OutputT> createDoFnRunner(BeamFnApi.FunctionSpec functionSpec, Map<String, Collection<ThrowingConsumer<WindowedValue<OutputT>>>> outputMap) {
ByteString serializedFn;
try {
serializedFn = functionSpec.getData().unpack(BytesValue.class).getValue();
} catch (InvalidProtocolBufferException e) {
throw new IllegalArgumentException(String.format("Unable to unwrap DoFn %s", functionSpec), e);
}
DoFnInfo<?, ?> doFnInfo = (DoFnInfo<?, ?>) SerializableUtils.deserializeFromByteArray(serializedFn.toByteArray(), "DoFnInfo");
checkArgument(Objects.equals(new HashSet<>(Collections2.transform(outputMap.keySet(), Long::parseLong)), doFnInfo.getOutputMap().keySet()), "Unexpected mismatch between transform output map %s and DoFnInfo output map %s.", outputMap.keySet(), doFnInfo.getOutputMap());
ImmutableMultimap.Builder<TupleTag<?>, ThrowingConsumer<WindowedValue<OutputT>>> tagToOutput = ImmutableMultimap.builder();
for (Map.Entry<Long, TupleTag<?>> entry : doFnInfo.getOutputMap().entrySet()) {
tagToOutput.putAll(entry.getValue(), outputMap.get(Long.toString(entry.getKey())));
}
@SuppressWarnings({ "unchecked", "rawtypes" }) final Map<TupleTag<?>, Collection<ThrowingConsumer<WindowedValue<?>>>> tagBasedOutputMap = (Map) tagToOutput.build().asMap();
OutputManager outputManager = new OutputManager() {
Map<TupleTag<?>, Collection<ThrowingConsumer<WindowedValue<?>>>> tupleTagToOutput = tagBasedOutputMap;
@Override
public <T> void output(TupleTag<T> tag, WindowedValue<T> output) {
try {
Collection<ThrowingConsumer<WindowedValue<?>>> consumers = tupleTagToOutput.get(tag);
if (consumers == null) {
/* This is a normal case, e.g., if a DoFn has output but that output is not
* consumed. Drop the output. */
return;
}
for (ThrowingConsumer<WindowedValue<?>> consumer : consumers) {
consumer.accept(output);
}
} catch (Throwable t) {
throw new RuntimeException(t);
}
}
};
@SuppressWarnings({ "unchecked", "rawtypes", "deprecation" }) DoFnRunner<InputT, OutputT> runner = DoFnRunners.simpleRunner(PipelineOptionsFactory.create(), /* TODO */
(DoFn) doFnInfo.getDoFn(), NullSideInputReader.empty(), /* TODO */
outputManager, (TupleTag) doFnInfo.getOutputMap().get(doFnInfo.getMainOutput()), new ArrayList<>(doFnInfo.getOutputMap().values()), new FakeStepContext(), (WindowingStrategy) doFnInfo.getWindowingStrategy());
return runner;
}
Aggregations