use of org.apache.beam.runners.apex.translation.operators.ApexParDoOperator in project beam by apache.
the class ParDoTranslator method translate.
@Override
public void translate(ParDo.MultiOutput<InputT, OutputT> transform, TranslationContext context) {
DoFn<InputT, OutputT> doFn = transform.getFn();
DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass());
if (signature.processElement().isSplittable()) {
throw new UnsupportedOperationException(String.format("%s does not support splittable DoFn: %s", ApexRunner.class.getSimpleName(), doFn));
}
if (signature.stateDeclarations().size() > 0) {
throw new UnsupportedOperationException(String.format("Found %s annotations on %s, but %s cannot yet be used with state in the %s.", DoFn.StateId.class.getSimpleName(), doFn.getClass().getName(), DoFn.class.getSimpleName(), ApexRunner.class.getSimpleName()));
}
if (signature.timerDeclarations().size() > 0) {
throw new UnsupportedOperationException(String.format("Found %s annotations on %s, but %s cannot yet be used with timers in the %s.", DoFn.TimerId.class.getSimpleName(), doFn.getClass().getName(), DoFn.class.getSimpleName(), ApexRunner.class.getSimpleName()));
}
Map<TupleTag<?>, PValue> outputs = context.getOutputs();
PCollection<InputT> input = context.getInput();
List<PCollectionView<?>> sideInputs = transform.getSideInputs();
Coder<InputT> inputCoder = input.getCoder();
WindowedValueCoder<InputT> wvInputCoder = FullWindowedValueCoder.of(inputCoder, input.getWindowingStrategy().getWindowFn().windowCoder());
ApexParDoOperator<InputT, OutputT> operator = new ApexParDoOperator<>(context.getPipelineOptions(), doFn, transform.getMainOutputTag(), transform.getAdditionalOutputTags().getAll(), input.getWindowingStrategy(), sideInputs, wvInputCoder, context.getStateBackend());
Map<PCollection<?>, OutputPort<?>> ports = Maps.newHashMapWithExpectedSize(outputs.size());
for (Entry<TupleTag<?>, PValue> output : outputs.entrySet()) {
checkArgument(output.getValue() instanceof PCollection, "%s %s outputs non-PCollection %s of type %s", ParDo.MultiOutput.class.getSimpleName(), context.getFullName(), output.getValue(), output.getValue().getClass().getSimpleName());
PCollection<?> pc = (PCollection<?>) output.getValue();
if (output.getKey().equals(transform.getMainOutputTag())) {
ports.put(pc, operator.output);
} else {
int portIndex = 0;
for (TupleTag<?> tag : transform.getAdditionalOutputTags().getAll()) {
if (tag.equals(output.getKey())) {
ports.put(pc, operator.additionalOutputPorts[portIndex]);
break;
}
portIndex++;
}
}
}
context.addOperator(operator, ports);
context.addStream(context.getInput(), operator.input);
if (!sideInputs.isEmpty()) {
addSideInputs(operator.sideInput1, sideInputs, context);
}
}
use of org.apache.beam.runners.apex.translation.operators.ApexParDoOperator in project beam by apache.
the class ParDoTranslatorTest method testSerialization.
@Test
public void testSerialization() throws Exception {
ApexPipelineOptions options = PipelineOptionsFactory.create().as(ApexPipelineOptions.class);
options.setRunner(TestApexRunner.class);
Pipeline pipeline = Pipeline.create(options);
Coder<WindowedValue<Integer>> coder = WindowedValue.getValueOnlyCoder(VarIntCoder.of());
PCollectionView<Integer> singletonView = pipeline.apply(Create.of(1)).apply(Sum.integersGlobally().asSingletonView());
ApexParDoOperator<Integer, Integer> operator = new ApexParDoOperator<>(options, new Add(singletonView), new TupleTag<Integer>(), TupleTagList.empty().getAll(), WindowingStrategy.globalDefault(), Collections.<PCollectionView<?>>singletonList(singletonView), coder, new ApexStateInternals.ApexStateBackend());
operator.setup(null);
operator.beginWindow(0);
WindowedValue<Integer> wv1 = WindowedValue.valueInGlobalWindow(1);
WindowedValue<Iterable<?>> sideInput = WindowedValue.<Iterable<?>>valueInGlobalWindow(Lists.<Integer>newArrayList(22));
// pushed back input
operator.input.process(ApexStreamTuple.DataTuple.of(wv1));
final List<Object> results = Lists.newArrayList();
Sink<Object> sink = new Sink<Object>() {
@Override
public void put(Object tuple) {
results.add(tuple);
}
@Override
public int getCount(boolean reset) {
return 0;
}
};
// verify pushed back input checkpointing
Assert.assertNotNull("Serialization", operator = KryoCloneUtils.cloneObject(operator));
operator.output.setSink(sink);
operator.setup(null);
operator.beginWindow(1);
WindowedValue<Integer> wv2 = WindowedValue.valueInGlobalWindow(2);
operator.sideInput1.process(ApexStreamTuple.DataTuple.of(sideInput));
Assert.assertEquals("number outputs", 1, results.size());
Assert.assertEquals("result", WindowedValue.valueInGlobalWindow(23), ((ApexStreamTuple.DataTuple<?>) results.get(0)).getValue());
// verify side input checkpointing
results.clear();
Assert.assertNotNull("Serialization", operator = KryoCloneUtils.cloneObject(operator));
operator.output.setSink(sink);
operator.setup(null);
operator.beginWindow(2);
operator.input.process(ApexStreamTuple.DataTuple.of(wv2));
Assert.assertEquals("number outputs", 1, results.size());
Assert.assertEquals("result", WindowedValue.valueInGlobalWindow(24), ((ApexStreamTuple.DataTuple<?>) results.get(0)).getValue());
}
Aggregations