use of org.apache.spark.api.java.JavaRDD in project spark-dataflow by cloudera.
the class StreamingTransformTranslator method window.
private static <T, W extends BoundedWindow> TransformEvaluator<Window.Bound<T>> window() {
return new TransformEvaluator<Window.Bound<T>>() {
@Override
public void evaluate(Window.Bound<T> transform, EvaluationContext context) {
StreamingEvaluationContext sec = (StreamingEvaluationContext) context;
//--- first we apply windowing to the stream
WindowFn<? super T, W> windowFn = WINDOW_FG.get("windowFn", transform);
@SuppressWarnings("unchecked") JavaDStream<WindowedValue<T>> dStream = (JavaDStream<WindowedValue<T>>) sec.getStream(transform);
if (windowFn instanceof FixedWindows) {
Duration windowDuration = Durations.milliseconds(((FixedWindows) windowFn).getSize().getMillis());
sec.setStream(transform, dStream.window(windowDuration));
} else if (windowFn instanceof SlidingWindows) {
Duration windowDuration = Durations.milliseconds(((SlidingWindows) windowFn).getSize().getMillis());
Duration slideDuration = Durations.milliseconds(((SlidingWindows) windowFn).getPeriod().getMillis());
sec.setStream(transform, dStream.window(windowDuration, slideDuration));
}
//--- then we apply windowing to the elements
DoFn<T, T> addWindowsDoFn = new AssignWindowsDoFn<>(windowFn);
DoFnFunction<T, T> dofn = new DoFnFunction<>(addWindowsDoFn, ((StreamingEvaluationContext) context).getRuntimeContext(), null);
@SuppressWarnings("unchecked") JavaDStreamLike<WindowedValue<T>, ?, JavaRDD<WindowedValue<T>>> dstream = (JavaDStreamLike<WindowedValue<T>, ?, JavaRDD<WindowedValue<T>>>) sec.getStream(transform);
sec.setStream(transform, dstream.mapPartitions(dofn));
}
};
}
use of org.apache.spark.api.java.JavaRDD in project gatk by broadinstitute.
the class SparkSharderUnitTest method testContigBoundary.
@Test
public void testContigBoundary() throws IOException {
JavaSparkContext ctx = SparkContextFactory.getTestSparkContext();
// Consider the following reads (divided into four partitions), and intervals.
// This test counts the number of reads that overlap each interval.
// 1 2
// 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7
// ---------------------------------------------------------
// Reads in partition 0
// [-----] chr 1
// [-----] chr 1
// [-----] chr 1
// [-----] chr 2
// [-----] chr 2
// ---------------------------------------------------------
// Per-partition read extents
// [-----------------] chr 1
// [-------] chr 2
// ---------------------------------------------------------
// Intervals
// [-----] chr 1
// [---------] chr 1
// [-----------------------] chr 2
// ---------------------------------------------------------
// 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7
JavaRDD<TestRead> reads = ctx.parallelize(ImmutableList.of(new TestRead("1", 1, 3), new TestRead("1", 5, 7), new TestRead("1", 7, 9), new TestRead("2", 1, 3), new TestRead("2", 2, 4)), 1);
List<SimpleInterval> intervals = ImmutableList.of(new SimpleInterval("1", 2, 4), new SimpleInterval("1", 8, 12), new SimpleInterval("2", 1, 12));
List<ShardBoundary> shardBoundaries = intervals.stream().map(si -> new ShardBoundary(si, si)).collect(Collectors.toList());
ImmutableMap<SimpleInterval, Integer> expectedReadsPerInterval = ImmutableMap.of(intervals.get(0), 1, intervals.get(1), 1, intervals.get(2), 2);
JavaPairRDD<Locatable, Integer> readsPerInterval = SparkSharder.shard(ctx, reads, TestRead.class, sequenceDictionary, shardBoundaries, STANDARD_READ_LENGTH, false).flatMapToPair(new CountOverlappingReadsFunction());
assertEquals(readsPerInterval.collectAsMap(), expectedReadsPerInterval);
JavaPairRDD<Locatable, Integer> readsPerIntervalShuffle = SparkSharder.shard(ctx, reads, TestRead.class, sequenceDictionary, shardBoundaries, STANDARD_READ_LENGTH, true).flatMapToPair(new CountOverlappingReadsFunction());
assertEquals(readsPerIntervalShuffle.collectAsMap(), expectedReadsPerInterval);
}
use of org.apache.spark.api.java.JavaRDD in project gatk by broadinstitute.
the class JoinReadsWithVariantsSparkUnitTest method pairReadsAndVariantsTest.
@Test(dataProvider = "pairedReadsAndVariants", groups = "spark")
public void pairReadsAndVariantsTest(List<GATKRead> reads, List<GATKVariant> variantList, List<KV<GATKRead, Iterable<GATKVariant>>> kvReadiVariant, JoinStrategy joinStrategy) {
JavaSparkContext ctx = SparkContextFactory.getTestSparkContext();
JavaRDD<GATKRead> rddReads = ctx.parallelize(reads);
JavaRDD<GATKVariant> rddVariants = ctx.parallelize(variantList);
JavaPairRDD<GATKRead, Iterable<GATKVariant>> actual = joinStrategy == JoinStrategy.SHUFFLE ? ShuffleJoinReadsWithVariants.join(rddReads, rddVariants) : BroadcastJoinReadsWithVariants.join(rddReads, rddVariants);
Map<GATKRead, Iterable<GATKVariant>> gatkReadIterableMap = actual.collectAsMap();
Assert.assertEquals(gatkReadIterableMap.size(), kvReadiVariant.size());
for (KV<GATKRead, Iterable<GATKVariant>> kv : kvReadiVariant) {
List<GATKVariant> variants = Lists.newArrayList(gatkReadIterableMap.get(kv.getKey()));
Assert.assertTrue(variants.stream().noneMatch(v -> v == null));
HashSet<GATKVariant> hashVariants = new LinkedHashSet<>(variants);
final Iterable<GATKVariant> iVariants = kv.getValue();
HashSet<GATKVariant> expectedHashVariants = Sets.newLinkedHashSet(iVariants);
Assert.assertEquals(hashVariants, expectedHashVariants);
}
}
use of org.apache.spark.api.java.JavaRDD in project gatk by broadinstitute.
the class FindBreakpointEvidenceSpark method handleAssemblies.
/**
* Transform all the reads for a supplied set of template names in each interval into FASTQ records
* for each interval, and do something with the list of FASTQ records for each interval (like write it to a file).
*/
@VisibleForTesting
static List<AlignedAssemblyOrExcuse> handleAssemblies(final JavaSparkContext ctx, final HopscotchUniqueMultiMap<String, Integer, QNameAndInterval> qNamesMultiMap, final JavaRDD<GATKRead> reads, final int nIntervals, final boolean includeMappingLocation, final boolean dumpFASTQs, final LocalAssemblyHandler localAssemblyHandler) {
final Broadcast<HopscotchUniqueMultiMap<String, Integer, QNameAndInterval>> broadcastQNamesMultiMap = ctx.broadcast(qNamesMultiMap);
final List<AlignedAssemblyOrExcuse> intervalDispositions = reads.mapPartitionsToPair(readItr -> new ReadsForQNamesFinder(broadcastQNamesMultiMap.value(), nIntervals, includeMappingLocation, dumpFASTQs).call(readItr).iterator(), false).combineByKey(x -> x, FindBreakpointEvidenceSpark::combineLists, FindBreakpointEvidenceSpark::combineLists, new HashPartitioner(nIntervals), false, null).map(localAssemblyHandler::apply).collect();
broadcastQNamesMultiMap.destroy();
BwaMemIndexSingleton.closeAllDistributedInstances(ctx);
return intervalDispositions;
}
use of org.apache.spark.api.java.JavaRDD in project gatk by broadinstitute.
the class SortReadFileSpark method runTool.
@Override
protected void runTool(final JavaSparkContext ctx) {
JavaRDD<GATKRead> reads = getReads();
int numReducers = getRecommendedNumReducers();
logger.info("Using %s reducers" + numReducers);
final SAMFileHeader readsHeader = getHeaderForReads();
ReadCoordinateComparator comparator = new ReadCoordinateComparator(readsHeader);
JavaRDD<GATKRead> sortedReads;
if (shardedOutput) {
sortedReads = reads.mapToPair(read -> new Tuple2<>(read, null)).sortByKey(comparator, true, numReducers).keys();
} else {
// sorting is done by writeReads below
sortedReads = reads;
}
readsHeader.setSortOrder(SAMFileHeader.SortOrder.coordinate);
writeReads(ctx, outputFile, sortedReads);
}
Aggregations