use of org.apache.beam.sdk.coders.CannotProvideCoderException in project beam by apache.
the class TransformTranslator method combineGlobally.
private static <InputT, AccumT, OutputT> TransformEvaluator<Combine.Globally<InputT, OutputT>> combineGlobally() {
return new TransformEvaluator<Combine.Globally<InputT, OutputT>>() {
@Override
public void evaluate(Combine.Globally<InputT, OutputT> transform, EvaluationContext context) {
final PCollection<InputT> input = context.getInput(transform);
final Coder<InputT> iCoder = context.getInput(transform).getCoder();
final Coder<OutputT> oCoder = context.getOutput(transform).getCoder();
final WindowingStrategy<?, ?> windowingStrategy = input.getWindowingStrategy();
@SuppressWarnings("unchecked") final CombineWithContext.CombineFnWithContext<InputT, AccumT, OutputT> combineFn = (CombineWithContext.CombineFnWithContext<InputT, AccumT, OutputT>) CombineFnUtil.toFnWithContext(transform.getFn());
final WindowedValue.FullWindowedValueCoder<OutputT> wvoCoder = WindowedValue.FullWindowedValueCoder.of(oCoder, windowingStrategy.getWindowFn().windowCoder());
final SparkRuntimeContext runtimeContext = context.getRuntimeContext();
final boolean hasDefault = transform.isInsertDefault();
final SparkGlobalCombineFn<InputT, AccumT, OutputT> sparkCombineFn = new SparkGlobalCombineFn<>(combineFn, runtimeContext, TranslationUtils.getSideInputs(transform.getSideInputs(), context), windowingStrategy);
final Coder<AccumT> aCoder;
try {
aCoder = combineFn.getAccumulatorCoder(runtimeContext.getCoderRegistry(), iCoder);
} catch (CannotProvideCoderException e) {
throw new IllegalStateException("Could not determine coder for accumulator", e);
}
@SuppressWarnings("unchecked") JavaRDD<WindowedValue<InputT>> inRdd = ((BoundedDataset<InputT>) context.borrowDataset(transform)).getRDD();
JavaRDD<WindowedValue<OutputT>> outRdd;
Optional<Iterable<WindowedValue<AccumT>>> maybeAccumulated = GroupCombineFunctions.combineGlobally(inRdd, sparkCombineFn, iCoder, aCoder, windowingStrategy);
if (maybeAccumulated.isPresent()) {
Iterable<WindowedValue<OutputT>> output = sparkCombineFn.extractOutput(maybeAccumulated.get());
outRdd = context.getSparkContext().parallelize(CoderHelpers.toByteArrays(output, wvoCoder)).map(CoderHelpers.fromByteFunction(wvoCoder));
} else {
// handle empty input RDD, which will naturally skip the entire execution
// as Spark will not run on empty RDDs.
JavaSparkContext jsc = new JavaSparkContext(inRdd.context());
if (hasDefault) {
OutputT defaultValue = combineFn.defaultValue();
outRdd = jsc.parallelize(Lists.newArrayList(CoderHelpers.toByteArray(defaultValue, oCoder))).map(CoderHelpers.fromByteFunction(oCoder)).map(WindowingHelpers.<OutputT>windowFunction());
} else {
outRdd = jsc.emptyRDD();
}
}
context.putDataset(transform, new BoundedDataset<>(outRdd));
}
@Override
public String toNativeString() {
return "aggregate(..., new <fn>(), ...)";
}
};
}
Aggregations