use of org.apache.spark.api.java.JavaSparkContext in project beam by apache.
the class TransformTranslator method combineGlobally.
private static <InputT, AccumT, OutputT> TransformEvaluator<Combine.Globally<InputT, OutputT>> combineGlobally() {
return new TransformEvaluator<Combine.Globally<InputT, OutputT>>() {
@Override
public void evaluate(Combine.Globally<InputT, OutputT> transform, EvaluationContext context) {
final PCollection<InputT> input = context.getInput(transform);
final Coder<InputT> iCoder = context.getInput(transform).getCoder();
final Coder<OutputT> oCoder = context.getOutput(transform).getCoder();
final WindowingStrategy<?, ?> windowingStrategy = input.getWindowingStrategy();
@SuppressWarnings("unchecked") final CombineWithContext.CombineFnWithContext<InputT, AccumT, OutputT> combineFn = (CombineWithContext.CombineFnWithContext<InputT, AccumT, OutputT>) CombineFnUtil.toFnWithContext(transform.getFn());
final WindowedValue.FullWindowedValueCoder<OutputT> wvoCoder = WindowedValue.FullWindowedValueCoder.of(oCoder, windowingStrategy.getWindowFn().windowCoder());
final SparkRuntimeContext runtimeContext = context.getRuntimeContext();
final boolean hasDefault = transform.isInsertDefault();
final SparkGlobalCombineFn<InputT, AccumT, OutputT> sparkCombineFn = new SparkGlobalCombineFn<>(combineFn, runtimeContext, TranslationUtils.getSideInputs(transform.getSideInputs(), context), windowingStrategy);
final Coder<AccumT> aCoder;
try {
aCoder = combineFn.getAccumulatorCoder(runtimeContext.getCoderRegistry(), iCoder);
} catch (CannotProvideCoderException e) {
throw new IllegalStateException("Could not determine coder for accumulator", e);
}
@SuppressWarnings("unchecked") JavaRDD<WindowedValue<InputT>> inRdd = ((BoundedDataset<InputT>) context.borrowDataset(transform)).getRDD();
JavaRDD<WindowedValue<OutputT>> outRdd;
Optional<Iterable<WindowedValue<AccumT>>> maybeAccumulated = GroupCombineFunctions.combineGlobally(inRdd, sparkCombineFn, iCoder, aCoder, windowingStrategy);
if (maybeAccumulated.isPresent()) {
Iterable<WindowedValue<OutputT>> output = sparkCombineFn.extractOutput(maybeAccumulated.get());
outRdd = context.getSparkContext().parallelize(CoderHelpers.toByteArrays(output, wvoCoder)).map(CoderHelpers.fromByteFunction(wvoCoder));
} else {
// handle empty input RDD, which will naturally skip the entire execution
// as Spark will not run on empty RDDs.
JavaSparkContext jsc = new JavaSparkContext(inRdd.context());
if (hasDefault) {
OutputT defaultValue = combineFn.defaultValue();
outRdd = jsc.parallelize(Lists.newArrayList(CoderHelpers.toByteArray(defaultValue, oCoder))).map(CoderHelpers.fromByteFunction(oCoder)).map(WindowingHelpers.<OutputT>windowFunction());
} else {
outRdd = jsc.emptyRDD();
}
}
context.putDataset(transform, new BoundedDataset<>(outRdd));
}
@Override
public String toNativeString() {
return "aggregate(..., new <fn>(), ...)";
}
};
}
use of org.apache.spark.api.java.JavaSparkContext in project beam by apache.
the class GlobalWatermarkHolderTest method testLowHighWatermarksAdvance.
@Test
public void testLowHighWatermarksAdvance() {
JavaSparkContext jsc = SparkContextFactory.getSparkContext(options);
Instant instant = new Instant(0);
// low == high.
GlobalWatermarkHolder.add(1, new SparkWatermarks(instant.plus(Duration.millis(5)), instant.plus(Duration.millis(5)), instant));
GlobalWatermarkHolder.advance(jsc);
// low < high.
GlobalWatermarkHolder.add(1, new SparkWatermarks(instant.plus(Duration.millis(10)), instant.plus(Duration.millis(15)), instant.plus(Duration.millis(100))));
GlobalWatermarkHolder.advance(jsc);
// assert watermarks in Broadcast.
SparkWatermarks currentWatermarks = GlobalWatermarkHolder.get().getValue().get(1);
assertThat(currentWatermarks.getLowWatermark(), equalTo(instant.plus(Duration.millis(10))));
assertThat(currentWatermarks.getHighWatermark(), equalTo(instant.plus(Duration.millis(15))));
assertThat(currentWatermarks.getSynchronizedProcessingTime(), equalTo(instant.plus(Duration.millis(100))));
// assert illegal watermark advance.
thrown.expect(IllegalStateException.class);
thrown.expectMessage(RegexMatcher.matches("Low watermark " + INSTANT_PATTERN + " cannot be later then high watermark " + INSTANT_PATTERN));
// low > high -> not allowed!
GlobalWatermarkHolder.add(1, new SparkWatermarks(instant.plus(Duration.millis(25)), instant.plus(Duration.millis(20)), instant.plus(Duration.millis(200))));
GlobalWatermarkHolder.advance(jsc);
}
use of org.apache.spark.api.java.JavaSparkContext in project beam by apache.
the class GlobalWatermarkHolderTest method testSynchronizedTimeMonotonic.
@Test
public void testSynchronizedTimeMonotonic() {
JavaSparkContext jsc = SparkContextFactory.getSparkContext(options);
Instant instant = new Instant(0);
GlobalWatermarkHolder.add(1, new SparkWatermarks(instant.plus(Duration.millis(5)), instant.plus(Duration.millis(10)), instant));
GlobalWatermarkHolder.advance(jsc);
thrown.expect(IllegalStateException.class);
thrown.expectMessage("Synchronized processing time must advance.");
// no actual advancement of watermarks - fine by Watermarks
// but not by synchronized processing time.
GlobalWatermarkHolder.add(1, new SparkWatermarks(instant.plus(Duration.millis(5)), instant.plus(Duration.millis(10)), instant));
GlobalWatermarkHolder.advance(jsc);
}
use of org.apache.spark.api.java.JavaSparkContext in project beam by apache.
the class ProvidedSparkContextTest method testWithProvidedContext.
/**
* Provide a context and call pipeline run.
* @throws Exception
*/
@Test
public void testWithProvidedContext() throws Exception {
JavaSparkContext jsc = new JavaSparkContext("local[*]", "Existing_Context");
testWithValidProvidedContext(jsc);
// A provided context must not be stopped after execution
assertFalse(jsc.sc().isStopped());
jsc.stop();
}
use of org.apache.spark.api.java.JavaSparkContext in project cdap by caskdata.
the class ClassicSparkProgram method main.
public static void main(String[] args) throws Exception {
SparkConf sparkConf = new SparkConf();
sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer");
sparkConf.set("spark.kryo.registrator", MyKryoRegistrator.class.getName());
Schema schema = Schema.recordOf("record", Schema.Field.of("name", Schema.of(Schema.Type.STRING)), Schema.Field.of("id", Schema.of(Schema.Type.INT)));
List<StructuredRecord> records = new ArrayList<>();
for (int i = 1; i <= 10; i++) {
records.add(StructuredRecord.builder(schema).set("name", "Name" + i).set("id", i).build());
}
// This test serialization of StructuredRecord as well as using custom kryo serializer
JavaSparkContext jsc = new JavaSparkContext(sparkConf);
int result = jsc.parallelize(records).mapToPair(new PairFunction<StructuredRecord, MyInt, StructuredRecord>() {
@Override
public Tuple2<MyInt, StructuredRecord> call(StructuredRecord record) throws Exception {
return new Tuple2<>(new MyInt((Integer) record.get("id")), record);
}
}).map(new Function<Tuple2<MyInt, StructuredRecord>, MyInt>() {
@Override
public MyInt call(Tuple2<MyInt, StructuredRecord> tuple) throws Exception {
return tuple._1;
}
}).reduce(new Function2<MyInt, MyInt, MyInt>() {
@Override
public MyInt call(MyInt v1, MyInt v2) throws Exception {
return new MyInt(v1.toInt() + v2.toInt());
}
}).toInt();
if (result != 55) {
throw new Exception("Expected result to be 55");
}
}
Aggregations