use of org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource in project gatk by broadinstitute.
the class AddContextDataToReadSpark method addUsingOverlapsPartitioning.
/**
* Add context data ({@link ReadContextData}) to reads, using overlaps partitioning to avoid a shuffle.
* @param ctx the Spark context
* @param mappedReads the coordinate-sorted reads
* @param referenceSource the reference source
* @param variants the coordinate-sorted variants
* @param sequenceDictionary the sequence dictionary for the reads
* @param shardSize the maximum size of each shard, in bases
* @param shardPadding amount of extra context around each shard, in bases
* @return a RDD of read-context pairs, in coordinate-sorted order
*/
private static JavaPairRDD<GATKRead, ReadContextData> addUsingOverlapsPartitioning(final JavaSparkContext ctx, final JavaRDD<GATKRead> mappedReads, final ReferenceMultiSource referenceSource, final JavaRDD<GATKVariant> variants, final SAMSequenceDictionary sequenceDictionary, final int shardSize, final int shardPadding) {
final List<SimpleInterval> intervals = IntervalUtils.getAllIntervalsForReference(sequenceDictionary);
// use unpadded shards (padding is only needed for reference bases)
final List<ShardBoundary> intervalShards = intervals.stream().flatMap(interval -> Shard.divideIntervalIntoShards(interval, shardSize, 0, sequenceDictionary).stream()).collect(Collectors.toList());
final Broadcast<ReferenceMultiSource> bReferenceSource = ctx.broadcast(referenceSource);
final IntervalsSkipList<GATKVariant> variantSkipList = new IntervalsSkipList<>(variants.collect());
final Broadcast<IntervalsSkipList<GATKVariant>> variantsBroadcast = ctx.broadcast(variantSkipList);
int maxLocatableSize = Math.min(shardSize, shardPadding);
JavaRDD<Shard<GATKRead>> shardedReads = SparkSharder.shard(ctx, mappedReads, GATKRead.class, sequenceDictionary, intervalShards, maxLocatableSize);
return shardedReads.flatMapToPair(new PairFlatMapFunction<Shard<GATKRead>, GATKRead, ReadContextData>() {
private static final long serialVersionUID = 1L;
@Override
public Iterator<Tuple2<GATKRead, ReadContextData>> call(Shard<GATKRead> shard) throws Exception {
// get reference bases for this shard (padded)
SimpleInterval paddedInterval = shard.getInterval().expandWithinContig(shardPadding, sequenceDictionary);
ReferenceBases referenceBases = bReferenceSource.getValue().getReferenceBases(null, paddedInterval);
final IntervalsSkipList<GATKVariant> intervalsSkipList = variantsBroadcast.getValue();
Iterator<Tuple2<GATKRead, ReadContextData>> transform = Iterators.transform(shard.iterator(), new Function<GATKRead, Tuple2<GATKRead, ReadContextData>>() {
@Nullable
@Override
public Tuple2<GATKRead, ReadContextData> apply(@Nullable GATKRead r) {
List<GATKVariant> overlappingVariants;
if (SimpleInterval.isValid(r.getContig(), r.getStart(), r.getEnd())) {
overlappingVariants = intervalsSkipList.getOverlapping(new SimpleInterval(r));
} else {
//Sometimes we have reads that do not form valid intervals (reads that do not consume any ref bases, eg CIGAR 61S90I
//In those cases, we'll just say that nothing overlaps the read
overlappingVariants = Collections.emptyList();
}
return new Tuple2<>(r, new ReadContextData(referenceBases, overlappingVariants));
}
});
// only include reads that start in the shard
return Iterators.filter(transform, r -> r._1().getStart() >= shard.getStart() && r._1().getStart() <= shard.getEnd());
}
});
}
use of org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource in project gatk by broadinstitute.
the class IntervalWalkerSpark method getIntervals.
/**
* Loads intervals and the corresponding reads, reference and features into a {@link JavaRDD}.
*
* @return all intervals as a {@link JavaRDD}.
*/
public JavaRDD<IntervalWalkerContext> getIntervals(JavaSparkContext ctx) {
SAMSequenceDictionary sequenceDictionary = getBestAvailableSequenceDictionary();
// don't shard the intervals themselves, since we want each interval to be processed by a single task
final List<ShardBoundary> intervalShardBoundaries = getIntervals().stream().map(i -> new ShardBoundary(i, i)).collect(Collectors.toList());
JavaRDD<Shard<GATKRead>> shardedReads = SparkSharder.shard(ctx, getReads(), GATKRead.class, sequenceDictionary, intervalShardBoundaries, Integer.MAX_VALUE, shuffle);
Broadcast<ReferenceMultiSource> bReferenceSource = hasReference() ? ctx.broadcast(getReference()) : null;
Broadcast<FeatureManager> bFeatureManager = features == null ? null : ctx.broadcast(features);
return shardedReads.map(getIntervalsFunction(bReferenceSource, bFeatureManager, sequenceDictionary, intervalShardPadding));
}
use of org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource in project gatk by broadinstitute.
the class AssemblyRegionWalkerSpark method getAssemblyRegions.
/**
* Loads assembly regions and the corresponding reference and features into a {@link JavaRDD} for the intervals specified.
*
* If no intervals were specified, returns all the assembly regions.
*
* @return all assembly regions as a {@link JavaRDD}, bounded by intervals if specified.
*/
protected JavaRDD<AssemblyRegionWalkerContext> getAssemblyRegions(JavaSparkContext ctx) {
SAMSequenceDictionary sequenceDictionary = getBestAvailableSequenceDictionary();
JavaRDD<Shard<GATKRead>> shardedReads = SparkSharder.shard(ctx, getReads(), GATKRead.class, sequenceDictionary, intervalShards, readShardSize, shuffle);
Broadcast<ReferenceMultiSource> bReferenceSource = hasReference() ? ctx.broadcast(getReference()) : null;
Broadcast<FeatureManager> bFeatureManager = features == null ? null : ctx.broadcast(features);
return shardedReads.flatMap(getAssemblyRegionsFunction(bReferenceSource, bFeatureManager, sequenceDictionary, getHeaderForReads(), assemblyRegionEvaluator(), minAssemblyRegionSize, maxAssemblyRegionSize, assemblyRegionPadding, activeProbThreshold, maxProbPropagationDistance));
}
use of org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource in project gatk by broadinstitute.
the class AssemblyRegionWalkerSpark method getAssemblyRegionsFunction.
private static FlatMapFunction<Shard<GATKRead>, AssemblyRegionWalkerContext> getAssemblyRegionsFunction(final Broadcast<ReferenceMultiSource> bReferenceSource, final Broadcast<FeatureManager> bFeatureManager, final SAMSequenceDictionary sequenceDictionary, final SAMFileHeader header, final AssemblyRegionEvaluator evaluator, final int minAssemblyRegionSize, final int maxAssemblyRegionSize, final int assemblyRegionPadding, final double activeProbThreshold, final int maxProbPropagationDistance) {
return (FlatMapFunction<Shard<GATKRead>, AssemblyRegionWalkerContext>) shardedRead -> {
SimpleInterval paddedInterval = shardedRead.getPaddedInterval();
SimpleInterval assemblyRegionPaddedInterval = paddedInterval.expandWithinContig(assemblyRegionPadding, sequenceDictionary);
ReferenceDataSource reference = bReferenceSource == null ? null : new ReferenceMemorySource(bReferenceSource.getValue().getReferenceBases(null, assemblyRegionPaddedInterval), sequenceDictionary);
FeatureManager features = bFeatureManager == null ? null : bFeatureManager.getValue();
ReferenceContext referenceContext = new ReferenceContext(reference, paddedInterval);
FeatureContext featureContext = new FeatureContext(features, paddedInterval);
final Iterable<AssemblyRegion> assemblyRegions = AssemblyRegion.createFromReadShard(shardedRead, header, referenceContext, featureContext, evaluator, minAssemblyRegionSize, maxAssemblyRegionSize, assemblyRegionPadding, activeProbThreshold, maxProbPropagationDistance);
return StreamSupport.stream(assemblyRegions.spliterator(), false).map(assemblyRegion -> new AssemblyRegionWalkerContext(assemblyRegion, new ReferenceContext(reference, assemblyRegion.getExtendedSpan()), new FeatureContext(features, assemblyRegion.getExtendedSpan()))).iterator();
};
}
use of org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource in project gatk by broadinstitute.
the class BroadcastJoinReadsWithRefBases method addBases.
/**
* Joins each read of an RDD<GATKRead, T> with key's corresponding reference sequence.
*
* @param referenceDataflowSource The source of the reference sequence information
* @param keyedByRead The read-keyed RDD for which to extract reference sequence information
* @return The JavaPairRDD that contains each read along with the corresponding ReferenceBases object and the value
*/
public static <T> JavaPairRDD<GATKRead, Tuple2<T, ReferenceBases>> addBases(final ReferenceMultiSource referenceDataflowSource, final JavaPairRDD<GATKRead, T> keyedByRead) {
JavaSparkContext ctx = new JavaSparkContext(keyedByRead.context());
Broadcast<ReferenceMultiSource> bReferenceSource = ctx.broadcast(referenceDataflowSource);
return keyedByRead.mapToPair(pair -> {
SimpleInterval interval = bReferenceSource.getValue().getReferenceWindowFunction().apply(pair._1());
return new Tuple2<>(pair._1(), new Tuple2<>(pair._2(), bReferenceSource.getValue().getReferenceBases(null, interval)));
});
}
Aggregations