Search in sources :

Example 1 with IntervalsSkipList

use of org.broadinstitute.hellbender.utils.collections.IntervalsSkipList in project gatk by broadinstitute.

the class AddContextDataToReadSparkOptimized method fillVariants.

/**
     * Given a list of shards and a list of variants,
     * add each variant to every (shard+margin) that it overlaps.
     *
     * This happens immediately, at the caller.
     */
public static ArrayList<ContextShard> fillVariants(List<SimpleInterval> shardedIntervals, List<GATKVariant> variants, int margin) {
    IntervalsSkipList<GATKVariant> intervals = new IntervalsSkipList<>(variants);
    ArrayList<ContextShard> ret = new ArrayList<>();
    for (SimpleInterval s : shardedIntervals) {
        int start = Math.max(s.getStart() - margin, 1);
        int end = s.getEnd() + margin;
        // here it's OK if end is past the contig's boundary, there just won't be any variant there.
        SimpleInterval expandedInterval = new SimpleInterval(s.getContig(), start, end);
        // the next ContextShard has interval s because we want it to contain all reads that start in s.
        // We give it all variants that overlap the expanded interval in order to make sure we include
        // all the variants that overlap with the reads of interest.
        //
        // Graphically:
        // |------- s --------|
        //--------expandedInterval------------------|
        //            |-- a read starting in s --|
        //                           |--- a variant overlapping the read ---|
        //
        // Since the read's length is less than margin, we know that by including all the variants that overlap
        // with the expanded interval we are also including all the variants that overlap with all the reads in this shard.
        ret.add(new ContextShard(s).withVariants(intervals.getOverlapping(expandedInterval)));
    }
    return ret;
}
Also used : ContextShard(org.broadinstitute.hellbender.engine.ContextShard) GATKVariant(org.broadinstitute.hellbender.utils.variant.GATKVariant) IntervalsSkipList(org.broadinstitute.hellbender.utils.collections.IntervalsSkipList) ArrayList(java.util.ArrayList) SimpleInterval(org.broadinstitute.hellbender.utils.SimpleInterval)

Example 2 with IntervalsSkipList

use of org.broadinstitute.hellbender.utils.collections.IntervalsSkipList in project gatk by broadinstitute.

the class BroadcastJoinReadsWithVariants method join.

public static JavaPairRDD<GATKRead, Iterable<GATKVariant>> join(final JavaRDD<GATKRead> reads, final JavaRDD<GATKVariant> variants) {
    final JavaSparkContext ctx = new JavaSparkContext(reads.context());
    final IntervalsSkipList<GATKVariant> variantSkipList = new IntervalsSkipList<>(variants.collect());
    final Broadcast<IntervalsSkipList<GATKVariant>> variantsBroadcast = ctx.broadcast(variantSkipList);
    return reads.mapToPair(r -> {
        final IntervalsSkipList<GATKVariant> intervalsSkipList = variantsBroadcast.getValue();
        if (SimpleInterval.isValid(r.getContig(), r.getStart(), r.getEnd())) {
            return new Tuple2<>(r, intervalsSkipList.getOverlapping(new SimpleInterval(r)));
        } else {
            return new Tuple2<>(r, Collections.emptyList());
        }
    });
}
Also used : GATKVariant(org.broadinstitute.hellbender.utils.variant.GATKVariant) IntervalsSkipList(org.broadinstitute.hellbender.utils.collections.IntervalsSkipList) Tuple2(scala.Tuple2) SimpleInterval(org.broadinstitute.hellbender.utils.SimpleInterval) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext)

Example 3 with IntervalsSkipList

use of org.broadinstitute.hellbender.utils.collections.IntervalsSkipList in project gatk by broadinstitute.

the class AddContextDataToReadSpark method addUsingOverlapsPartitioning.

/**
     * Add context data ({@link ReadContextData}) to reads, using overlaps partitioning to avoid a shuffle.
     * @param ctx the Spark context
     * @param mappedReads the coordinate-sorted reads
     * @param referenceSource the reference source
     * @param variants the coordinate-sorted variants
     * @param sequenceDictionary the sequence dictionary for the reads
     * @param shardSize the maximum size of each shard, in bases
     * @param shardPadding amount of extra context around each shard, in bases
     * @return a RDD of read-context pairs, in coordinate-sorted order
     */
private static JavaPairRDD<GATKRead, ReadContextData> addUsingOverlapsPartitioning(final JavaSparkContext ctx, final JavaRDD<GATKRead> mappedReads, final ReferenceMultiSource referenceSource, final JavaRDD<GATKVariant> variants, final SAMSequenceDictionary sequenceDictionary, final int shardSize, final int shardPadding) {
    final List<SimpleInterval> intervals = IntervalUtils.getAllIntervalsForReference(sequenceDictionary);
    // use unpadded shards (padding is only needed for reference bases)
    final List<ShardBoundary> intervalShards = intervals.stream().flatMap(interval -> Shard.divideIntervalIntoShards(interval, shardSize, 0, sequenceDictionary).stream()).collect(Collectors.toList());
    final Broadcast<ReferenceMultiSource> bReferenceSource = ctx.broadcast(referenceSource);
    final IntervalsSkipList<GATKVariant> variantSkipList = new IntervalsSkipList<>(variants.collect());
    final Broadcast<IntervalsSkipList<GATKVariant>> variantsBroadcast = ctx.broadcast(variantSkipList);
    int maxLocatableSize = Math.min(shardSize, shardPadding);
    JavaRDD<Shard<GATKRead>> shardedReads = SparkSharder.shard(ctx, mappedReads, GATKRead.class, sequenceDictionary, intervalShards, maxLocatableSize);
    return shardedReads.flatMapToPair(new PairFlatMapFunction<Shard<GATKRead>, GATKRead, ReadContextData>() {

        private static final long serialVersionUID = 1L;

        @Override
        public Iterator<Tuple2<GATKRead, ReadContextData>> call(Shard<GATKRead> shard) throws Exception {
            // get reference bases for this shard (padded)
            SimpleInterval paddedInterval = shard.getInterval().expandWithinContig(shardPadding, sequenceDictionary);
            ReferenceBases referenceBases = bReferenceSource.getValue().getReferenceBases(null, paddedInterval);
            final IntervalsSkipList<GATKVariant> intervalsSkipList = variantsBroadcast.getValue();
            Iterator<Tuple2<GATKRead, ReadContextData>> transform = Iterators.transform(shard.iterator(), new Function<GATKRead, Tuple2<GATKRead, ReadContextData>>() {

                @Nullable
                @Override
                public Tuple2<GATKRead, ReadContextData> apply(@Nullable GATKRead r) {
                    List<GATKVariant> overlappingVariants;
                    if (SimpleInterval.isValid(r.getContig(), r.getStart(), r.getEnd())) {
                        overlappingVariants = intervalsSkipList.getOverlapping(new SimpleInterval(r));
                    } else {
                        //Sometimes we have reads that do not form valid intervals (reads that do not consume any ref bases, eg CIGAR 61S90I
                        //In those cases, we'll just say that nothing overlaps the read
                        overlappingVariants = Collections.emptyList();
                    }
                    return new Tuple2<>(r, new ReadContextData(referenceBases, overlappingVariants));
                }
            });
            // only include reads that start in the shard
            return Iterators.filter(transform, r -> r._1().getStart() >= shard.getStart() && r._1().getStart() <= shard.getEnd());
        }
    });
}
Also used : PairFlatMapFunction(org.apache.spark.api.java.function.PairFlatMapFunction) ReferenceMultiSource(org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) GATKRead(org.broadinstitute.hellbender.utils.read.GATKRead) Iterators(com.google.common.collect.Iterators) IntervalUtils(org.broadinstitute.hellbender.utils.IntervalUtils) ReferenceBases(org.broadinstitute.hellbender.utils.reference.ReferenceBases) ReadContextData(org.broadinstitute.hellbender.engine.ReadContextData) JavaRDD(org.apache.spark.api.java.JavaRDD) Nullable(javax.annotation.Nullable) Broadcast(org.apache.spark.broadcast.Broadcast) IntervalsSkipList(org.broadinstitute.hellbender.utils.collections.IntervalsSkipList) Function(com.google.common.base.Function) Iterator(java.util.Iterator) SAMSequenceDictionary(htsjdk.samtools.SAMSequenceDictionary) GATKVariant(org.broadinstitute.hellbender.utils.variant.GATKVariant) Tuple2(scala.Tuple2) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) SimpleInterval(org.broadinstitute.hellbender.utils.SimpleInterval) Collectors(java.util.stream.Collectors) Shard(org.broadinstitute.hellbender.engine.Shard) List(java.util.List) UserException(org.broadinstitute.hellbender.exceptions.UserException) ShardBoundary(org.broadinstitute.hellbender.engine.ShardBoundary) Collections(java.util.Collections) ReadFilterLibrary(org.broadinstitute.hellbender.engine.filters.ReadFilterLibrary) GATKRead(org.broadinstitute.hellbender.utils.read.GATKRead) ShardBoundary(org.broadinstitute.hellbender.engine.ShardBoundary) PairFlatMapFunction(org.apache.spark.api.java.function.PairFlatMapFunction) Function(com.google.common.base.Function) IntervalsSkipList(org.broadinstitute.hellbender.utils.collections.IntervalsSkipList) Iterator(java.util.Iterator) SimpleInterval(org.broadinstitute.hellbender.utils.SimpleInterval) GATKVariant(org.broadinstitute.hellbender.utils.variant.GATKVariant) ReferenceMultiSource(org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource) UserException(org.broadinstitute.hellbender.exceptions.UserException) ReadContextData(org.broadinstitute.hellbender.engine.ReadContextData) ReferenceBases(org.broadinstitute.hellbender.utils.reference.ReferenceBases) Tuple2(scala.Tuple2) Shard(org.broadinstitute.hellbender.engine.Shard) Nullable(javax.annotation.Nullable)

Aggregations

SimpleInterval (org.broadinstitute.hellbender.utils.SimpleInterval)3 IntervalsSkipList (org.broadinstitute.hellbender.utils.collections.IntervalsSkipList)3 GATKVariant (org.broadinstitute.hellbender.utils.variant.GATKVariant)3 JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)2 Tuple2 (scala.Tuple2)2 Function (com.google.common.base.Function)1 Iterators (com.google.common.collect.Iterators)1 SAMSequenceDictionary (htsjdk.samtools.SAMSequenceDictionary)1 ArrayList (java.util.ArrayList)1 Collections (java.util.Collections)1 Iterator (java.util.Iterator)1 List (java.util.List)1 Collectors (java.util.stream.Collectors)1 Nullable (javax.annotation.Nullable)1 JavaPairRDD (org.apache.spark.api.java.JavaPairRDD)1 JavaRDD (org.apache.spark.api.java.JavaRDD)1 PairFlatMapFunction (org.apache.spark.api.java.function.PairFlatMapFunction)1 Broadcast (org.apache.spark.broadcast.Broadcast)1 ContextShard (org.broadinstitute.hellbender.engine.ContextShard)1 ReadContextData (org.broadinstitute.hellbender.engine.ReadContextData)1