Search in sources :

Example 1 with ShardBoundary

use of org.broadinstitute.hellbender.engine.ShardBoundary in project gatk by broadinstitute.

the class SparkSharderUnitTest method testSingleContig.

@Test
public void testSingleContig() throws IOException {
    JavaSparkContext ctx = SparkContextFactory.getTestSparkContext();
    // Consider the following reads (divided into four partitions), and intervals.
    // This test counts the number of reads that overlap each interval.
    //                      1                   2
    //    1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7
    // ---------------------------------------------------------
    // Reads in partition 0
    //   [-----]
    //           [-----]
    //               [-----]
    // ---------------------------------------------------------
    // Reads in partition 1
    //               [-----]
    //               [-----]
    //               [-----]
    // ---------------------------------------------------------
    // Reads in partition 2
    //               [-----]
    //                       [-----]
    //                         [-----]
    // ---------------------------------------------------------
    // Reads in partition 3
    //                                   [-----]
    //                                           [-----]
    //                                                   [-----]
    // ---------------------------------------------------------
    // Per-partition read extents
    //   [-----------------]
    //               [-----]
    //               [---------------]
    //                                   [---------------------]
    // ---------------------------------------------------------
    // Intervals
    //     [-----]
    //                 [---------]
    //                       [-----------------------]
    //
    //                      1                   2
    // ---------------------------------------------------------
    //    1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7
    JavaRDD<TestRead> reads = ctx.parallelize(ImmutableList.of(new TestRead(1, 3), new TestRead(5, 7), new TestRead(7, 9), new TestRead(7, 9), new TestRead(7, 9), new TestRead(7, 9), new TestRead(7, 9), new TestRead(11, 13), new TestRead(12, 14), new TestRead(17, 19), new TestRead(21, 23), new TestRead(25, 27)), 4);
    List<SimpleInterval> intervals = ImmutableList.of(new SimpleInterval("1", 2, 4), new SimpleInterval("1", 8, 12), new SimpleInterval("1", 11, 22));
    List<ShardBoundary> shardBoundaries = intervals.stream().map(si -> new ShardBoundary(si, si)).collect(Collectors.toList());
    ImmutableMap<SimpleInterval, Integer> expectedReadsPerInterval = ImmutableMap.of(intervals.get(0), 1, intervals.get(1), 7, intervals.get(2), 4);
    JavaPairRDD<Locatable, Integer> readsPerInterval = SparkSharder.shard(ctx, reads, TestRead.class, sequenceDictionary, shardBoundaries, STANDARD_READ_LENGTH, false).flatMapToPair(new CountOverlappingReadsFunction());
    assertEquals(readsPerInterval.collectAsMap(), expectedReadsPerInterval);
    JavaPairRDD<Locatable, Integer> readsPerIntervalShuffle = SparkSharder.shard(ctx, reads, TestRead.class, sequenceDictionary, shardBoundaries, STANDARD_READ_LENGTH, true).flatMapToPair(new CountOverlappingReadsFunction());
    assertEquals(readsPerIntervalShuffle.collectAsMap(), expectedReadsPerInterval);
    try {
        // max read length less than actual causes exception
        int maxReadLength = STANDARD_READ_LENGTH - 1;
        SparkSharder.shard(ctx, reads, TestRead.class, sequenceDictionary, shardBoundaries, maxReadLength, true).flatMapToPair(new CountOverlappingReadsFunction()).collect();
    } catch (Exception e) {
        assertEquals(e.getCause().getClass(), UserException.class);
    }
}
Also used : Locatable(htsjdk.samtools.util.Locatable) OverlapDetector(htsjdk.samtools.util.OverlapDetector) java.util(java.util) PairFlatMapFunction(org.apache.spark.api.java.function.PairFlatMapFunction) BaseTest(org.broadinstitute.hellbender.utils.test.BaseTest) SAMSequenceDictionary(htsjdk.samtools.SAMSequenceDictionary) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) Assert.assertEquals(org.testng.Assert.assertEquals) Test(org.testng.annotations.Test) IOException(java.io.IOException) Tuple2(scala.Tuple2) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) SimpleInterval(org.broadinstitute.hellbender.utils.SimpleInterval) Collectors(java.util.stream.Collectors) Shard(org.broadinstitute.hellbender.engine.Shard) Serializable(java.io.Serializable) UserException(org.broadinstitute.hellbender.exceptions.UserException) ShardBoundary(org.broadinstitute.hellbender.engine.ShardBoundary) Assert.assertTrue(org.testng.Assert.assertTrue) SAMSequenceRecord(htsjdk.samtools.SAMSequenceRecord) com.google.common.collect(com.google.common.collect) Assert.assertFalse(org.testng.Assert.assertFalse) JavaRDD(org.apache.spark.api.java.JavaRDD) ShardBoundary(org.broadinstitute.hellbender.engine.ShardBoundary) IOException(java.io.IOException) UserException(org.broadinstitute.hellbender.exceptions.UserException) SimpleInterval(org.broadinstitute.hellbender.utils.SimpleInterval) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) UserException(org.broadinstitute.hellbender.exceptions.UserException) Locatable(htsjdk.samtools.util.Locatable) BaseTest(org.broadinstitute.hellbender.utils.test.BaseTest) Test(org.testng.annotations.Test)

Example 2 with ShardBoundary

use of org.broadinstitute.hellbender.engine.ShardBoundary in project gatk by broadinstitute.

the class AddContextDataToReadSpark method addUsingOverlapsPartitioning.

/**
     * Add context data ({@link ReadContextData}) to reads, using overlaps partitioning to avoid a shuffle.
     * @param ctx the Spark context
     * @param mappedReads the coordinate-sorted reads
     * @param referenceSource the reference source
     * @param variants the coordinate-sorted variants
     * @param sequenceDictionary the sequence dictionary for the reads
     * @param shardSize the maximum size of each shard, in bases
     * @param shardPadding amount of extra context around each shard, in bases
     * @return a RDD of read-context pairs, in coordinate-sorted order
     */
private static JavaPairRDD<GATKRead, ReadContextData> addUsingOverlapsPartitioning(final JavaSparkContext ctx, final JavaRDD<GATKRead> mappedReads, final ReferenceMultiSource referenceSource, final JavaRDD<GATKVariant> variants, final SAMSequenceDictionary sequenceDictionary, final int shardSize, final int shardPadding) {
    final List<SimpleInterval> intervals = IntervalUtils.getAllIntervalsForReference(sequenceDictionary);
    // use unpadded shards (padding is only needed for reference bases)
    final List<ShardBoundary> intervalShards = intervals.stream().flatMap(interval -> Shard.divideIntervalIntoShards(interval, shardSize, 0, sequenceDictionary).stream()).collect(Collectors.toList());
    final Broadcast<ReferenceMultiSource> bReferenceSource = ctx.broadcast(referenceSource);
    final IntervalsSkipList<GATKVariant> variantSkipList = new IntervalsSkipList<>(variants.collect());
    final Broadcast<IntervalsSkipList<GATKVariant>> variantsBroadcast = ctx.broadcast(variantSkipList);
    int maxLocatableSize = Math.min(shardSize, shardPadding);
    JavaRDD<Shard<GATKRead>> shardedReads = SparkSharder.shard(ctx, mappedReads, GATKRead.class, sequenceDictionary, intervalShards, maxLocatableSize);
    return shardedReads.flatMapToPair(new PairFlatMapFunction<Shard<GATKRead>, GATKRead, ReadContextData>() {

        private static final long serialVersionUID = 1L;

        @Override
        public Iterator<Tuple2<GATKRead, ReadContextData>> call(Shard<GATKRead> shard) throws Exception {
            // get reference bases for this shard (padded)
            SimpleInterval paddedInterval = shard.getInterval().expandWithinContig(shardPadding, sequenceDictionary);
            ReferenceBases referenceBases = bReferenceSource.getValue().getReferenceBases(null, paddedInterval);
            final IntervalsSkipList<GATKVariant> intervalsSkipList = variantsBroadcast.getValue();
            Iterator<Tuple2<GATKRead, ReadContextData>> transform = Iterators.transform(shard.iterator(), new Function<GATKRead, Tuple2<GATKRead, ReadContextData>>() {

                @Nullable
                @Override
                public Tuple2<GATKRead, ReadContextData> apply(@Nullable GATKRead r) {
                    List<GATKVariant> overlappingVariants;
                    if (SimpleInterval.isValid(r.getContig(), r.getStart(), r.getEnd())) {
                        overlappingVariants = intervalsSkipList.getOverlapping(new SimpleInterval(r));
                    } else {
                        //Sometimes we have reads that do not form valid intervals (reads that do not consume any ref bases, eg CIGAR 61S90I
                        //In those cases, we'll just say that nothing overlaps the read
                        overlappingVariants = Collections.emptyList();
                    }
                    return new Tuple2<>(r, new ReadContextData(referenceBases, overlappingVariants));
                }
            });
            // only include reads that start in the shard
            return Iterators.filter(transform, r -> r._1().getStart() >= shard.getStart() && r._1().getStart() <= shard.getEnd());
        }
    });
}
Also used : PairFlatMapFunction(org.apache.spark.api.java.function.PairFlatMapFunction) ReferenceMultiSource(org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) GATKRead(org.broadinstitute.hellbender.utils.read.GATKRead) Iterators(com.google.common.collect.Iterators) IntervalUtils(org.broadinstitute.hellbender.utils.IntervalUtils) ReferenceBases(org.broadinstitute.hellbender.utils.reference.ReferenceBases) ReadContextData(org.broadinstitute.hellbender.engine.ReadContextData) JavaRDD(org.apache.spark.api.java.JavaRDD) Nullable(javax.annotation.Nullable) Broadcast(org.apache.spark.broadcast.Broadcast) IntervalsSkipList(org.broadinstitute.hellbender.utils.collections.IntervalsSkipList) Function(com.google.common.base.Function) Iterator(java.util.Iterator) SAMSequenceDictionary(htsjdk.samtools.SAMSequenceDictionary) GATKVariant(org.broadinstitute.hellbender.utils.variant.GATKVariant) Tuple2(scala.Tuple2) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) SimpleInterval(org.broadinstitute.hellbender.utils.SimpleInterval) Collectors(java.util.stream.Collectors) Shard(org.broadinstitute.hellbender.engine.Shard) List(java.util.List) UserException(org.broadinstitute.hellbender.exceptions.UserException) ShardBoundary(org.broadinstitute.hellbender.engine.ShardBoundary) Collections(java.util.Collections) ReadFilterLibrary(org.broadinstitute.hellbender.engine.filters.ReadFilterLibrary) GATKRead(org.broadinstitute.hellbender.utils.read.GATKRead) ShardBoundary(org.broadinstitute.hellbender.engine.ShardBoundary) PairFlatMapFunction(org.apache.spark.api.java.function.PairFlatMapFunction) Function(com.google.common.base.Function) IntervalsSkipList(org.broadinstitute.hellbender.utils.collections.IntervalsSkipList) Iterator(java.util.Iterator) SimpleInterval(org.broadinstitute.hellbender.utils.SimpleInterval) GATKVariant(org.broadinstitute.hellbender.utils.variant.GATKVariant) ReferenceMultiSource(org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource) UserException(org.broadinstitute.hellbender.exceptions.UserException) ReadContextData(org.broadinstitute.hellbender.engine.ReadContextData) ReferenceBases(org.broadinstitute.hellbender.utils.reference.ReferenceBases) Tuple2(scala.Tuple2) Shard(org.broadinstitute.hellbender.engine.Shard) Nullable(javax.annotation.Nullable)

Example 3 with ShardBoundary

use of org.broadinstitute.hellbender.engine.ShardBoundary in project gatk by broadinstitute.

the class SparkSharderUnitTest method testContigBoundary.

@Test
public void testContigBoundary() throws IOException {
    JavaSparkContext ctx = SparkContextFactory.getTestSparkContext();
    // Consider the following reads (divided into four partitions), and intervals.
    // This test counts the number of reads that overlap each interval.
    //                      1                   2
    //    1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7
    // ---------------------------------------------------------
    // Reads in partition 0
    //   [-----] chr 1
    //           [-----] chr 1
    //               [-----] chr 1
    //   [-----] chr 2
    //     [-----] chr 2
    // ---------------------------------------------------------
    // Per-partition read extents
    //   [-----------------] chr 1
    //   [-------] chr 2
    // ---------------------------------------------------------
    // Intervals
    //     [-----] chr 1
    //                 [---------] chr 1
    //   [-----------------------] chr 2
    // ---------------------------------------------------------
    //    1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7
    JavaRDD<TestRead> reads = ctx.parallelize(ImmutableList.of(new TestRead("1", 1, 3), new TestRead("1", 5, 7), new TestRead("1", 7, 9), new TestRead("2", 1, 3), new TestRead("2", 2, 4)), 1);
    List<SimpleInterval> intervals = ImmutableList.of(new SimpleInterval("1", 2, 4), new SimpleInterval("1", 8, 12), new SimpleInterval("2", 1, 12));
    List<ShardBoundary> shardBoundaries = intervals.stream().map(si -> new ShardBoundary(si, si)).collect(Collectors.toList());
    ImmutableMap<SimpleInterval, Integer> expectedReadsPerInterval = ImmutableMap.of(intervals.get(0), 1, intervals.get(1), 1, intervals.get(2), 2);
    JavaPairRDD<Locatable, Integer> readsPerInterval = SparkSharder.shard(ctx, reads, TestRead.class, sequenceDictionary, shardBoundaries, STANDARD_READ_LENGTH, false).flatMapToPair(new CountOverlappingReadsFunction());
    assertEquals(readsPerInterval.collectAsMap(), expectedReadsPerInterval);
    JavaPairRDD<Locatable, Integer> readsPerIntervalShuffle = SparkSharder.shard(ctx, reads, TestRead.class, sequenceDictionary, shardBoundaries, STANDARD_READ_LENGTH, true).flatMapToPair(new CountOverlappingReadsFunction());
    assertEquals(readsPerIntervalShuffle.collectAsMap(), expectedReadsPerInterval);
}
Also used : Locatable(htsjdk.samtools.util.Locatable) OverlapDetector(htsjdk.samtools.util.OverlapDetector) java.util(java.util) PairFlatMapFunction(org.apache.spark.api.java.function.PairFlatMapFunction) BaseTest(org.broadinstitute.hellbender.utils.test.BaseTest) SAMSequenceDictionary(htsjdk.samtools.SAMSequenceDictionary) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) Assert.assertEquals(org.testng.Assert.assertEquals) Test(org.testng.annotations.Test) IOException(java.io.IOException) Tuple2(scala.Tuple2) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) SimpleInterval(org.broadinstitute.hellbender.utils.SimpleInterval) Collectors(java.util.stream.Collectors) Shard(org.broadinstitute.hellbender.engine.Shard) Serializable(java.io.Serializable) UserException(org.broadinstitute.hellbender.exceptions.UserException) ShardBoundary(org.broadinstitute.hellbender.engine.ShardBoundary) Assert.assertTrue(org.testng.Assert.assertTrue) SAMSequenceRecord(htsjdk.samtools.SAMSequenceRecord) com.google.common.collect(com.google.common.collect) Assert.assertFalse(org.testng.Assert.assertFalse) JavaRDD(org.apache.spark.api.java.JavaRDD) ShardBoundary(org.broadinstitute.hellbender.engine.ShardBoundary) SimpleInterval(org.broadinstitute.hellbender.utils.SimpleInterval) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) Locatable(htsjdk.samtools.util.Locatable) BaseTest(org.broadinstitute.hellbender.utils.test.BaseTest) Test(org.testng.annotations.Test)

Aggregations

SAMSequenceDictionary (htsjdk.samtools.SAMSequenceDictionary)3 Collectors (java.util.stream.Collectors)3 JavaPairRDD (org.apache.spark.api.java.JavaPairRDD)3 JavaRDD (org.apache.spark.api.java.JavaRDD)3 JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)3 PairFlatMapFunction (org.apache.spark.api.java.function.PairFlatMapFunction)3 Shard (org.broadinstitute.hellbender.engine.Shard)3 ShardBoundary (org.broadinstitute.hellbender.engine.ShardBoundary)3 UserException (org.broadinstitute.hellbender.exceptions.UserException)3 SimpleInterval (org.broadinstitute.hellbender.utils.SimpleInterval)3 Tuple2 (scala.Tuple2)3 com.google.common.collect (com.google.common.collect)2 SAMSequenceRecord (htsjdk.samtools.SAMSequenceRecord)2 Locatable (htsjdk.samtools.util.Locatable)2 OverlapDetector (htsjdk.samtools.util.OverlapDetector)2 IOException (java.io.IOException)2 Serializable (java.io.Serializable)2 java.util (java.util)2 BaseTest (org.broadinstitute.hellbender.utils.test.BaseTest)2 Assert.assertEquals (org.testng.Assert.assertEquals)2