use of org.apache.beam.sdk.transforms.Combine in project beam by apache.
the class CombineTranslationTest method testToFromProto.
@Test
public void testToFromProto() throws Exception {
PCollection<Integer> input = pipeline.apply(Create.of(1, 2, 3));
input.apply(Combine.globally(combineFn));
final AtomicReference<AppliedPTransform<?, ?, Combine.PerKey<?, ?, ?>>> combine = new AtomicReference<>();
pipeline.traverseTopologically(new PipelineVisitor.Defaults() {
@Override
public void leaveCompositeTransform(Node node) {
if (node.getTransform() instanceof Combine.PerKey) {
checkState(combine.get() == null);
combine.set((AppliedPTransform) node.toAppliedPTransform(getPipeline()));
}
}
});
checkState(combine.get() != null);
SdkComponents sdkComponents = SdkComponents.create();
CombinePayload combineProto = CombineTranslation.toProto(combine.get(), sdkComponents);
RunnerApi.Components componentsProto = sdkComponents.toComponents();
assertEquals(combineFn.getAccumulatorCoder(pipeline.getCoderRegistry(), input.getCoder()), CombineTranslation.getAccumulatorCoder(combineProto, componentsProto));
assertEquals(combineFn, CombineTranslation.getCombineFn(combineProto));
}
use of org.apache.beam.sdk.transforms.Combine in project beam by apache.
the class StreamingTransformTranslator method combineGrouped.
private static <K, InputT, OutputT> TransformEvaluator<Combine.GroupedValues<K, InputT, OutputT>> combineGrouped() {
return new TransformEvaluator<Combine.GroupedValues<K, InputT, OutputT>>() {
@Override
public void evaluate(final Combine.GroupedValues<K, InputT, OutputT> transform, EvaluationContext context) {
// get the applied combine function.
PCollection<? extends KV<K, ? extends Iterable<InputT>>> input = context.getInput(transform);
final WindowingStrategy<?, ?> windowingStrategy = input.getWindowingStrategy();
@SuppressWarnings("unchecked") final CombineWithContext.CombineFnWithContext<InputT, ?, OutputT> fn = (CombineWithContext.CombineFnWithContext<InputT, ?, OutputT>) CombineFnUtil.toFnWithContext(transform.getFn());
@SuppressWarnings("unchecked") UnboundedDataset<KV<K, Iterable<InputT>>> unboundedDataset = ((UnboundedDataset<KV<K, Iterable<InputT>>>) context.borrowDataset(transform));
JavaDStream<WindowedValue<KV<K, Iterable<InputT>>>> dStream = unboundedDataset.getDStream();
final SparkRuntimeContext runtimeContext = context.getRuntimeContext();
final SparkPCollectionView pviews = context.getPViews();
JavaDStream<WindowedValue<KV<K, OutputT>>> outStream = dStream.transform(new Function<JavaRDD<WindowedValue<KV<K, Iterable<InputT>>>>, JavaRDD<WindowedValue<KV<K, OutputT>>>>() {
@Override
public JavaRDD<WindowedValue<KV<K, OutputT>>> call(JavaRDD<WindowedValue<KV<K, Iterable<InputT>>>> rdd) throws Exception {
SparkKeyedCombineFn<K, InputT, ?, OutputT> combineFnWithContext = new SparkKeyedCombineFn<>(fn, runtimeContext, TranslationUtils.getSideInputs(transform.getSideInputs(), new JavaSparkContext(rdd.context()), pviews), windowingStrategy);
return rdd.map(new TranslationUtils.CombineGroupedValues<>(combineFnWithContext));
}
});
context.putDataset(transform, new UnboundedDataset<>(outStream, unboundedDataset.getStreamSources()));
}
@Override
public String toNativeString() {
return "map(new <fn>())";
}
};
}
Aggregations