use of org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource in project gatk by broadinstitute.
the class LocusWalkerSpark method getAlignmentsFunction.
/**
* Return a function that maps a {@link Shard} of reads into a tuple of alignments and their corresponding reference and features.
* @param bReferenceSource the reference source broadcast
* @param bFeatureManager the feature manager broadcast
* @param sequenceDictionary the sequence dictionary for the reads
* @param header the reads header
* @param downsamplingInfo the downsampling method for the reads
* @return a function that maps a {@link Shard} of reads into a tuple of alignments and their corresponding reference and features.
*/
private static FlatMapFunction<Shard<GATKRead>, LocusWalkerContext> getAlignmentsFunction(Broadcast<ReferenceMultiSource> bReferenceSource, Broadcast<FeatureManager> bFeatureManager, SAMSequenceDictionary sequenceDictionary, SAMFileHeader header, LIBSDownsamplingInfo downsamplingInfo) {
return (FlatMapFunction<Shard<GATKRead>, LocusWalkerContext>) shardedRead -> {
SimpleInterval interval = shardedRead.getInterval();
SimpleInterval paddedInterval = shardedRead.getPaddedInterval();
Iterator<GATKRead> readIterator = shardedRead.iterator();
ReferenceDataSource reference = bReferenceSource == null ? null : new ReferenceMemorySource(bReferenceSource.getValue().getReferenceBases(null, paddedInterval), sequenceDictionary);
FeatureManager fm = bFeatureManager == null ? null : bFeatureManager.getValue();
final Set<String> samples = header.getReadGroups().stream().map(SAMReadGroupRecord::getSample).collect(Collectors.toSet());
LocusIteratorByState libs = new LocusIteratorByState(readIterator, downsamplingInfo, false, samples, header, true, false);
IntervalOverlappingIterator<AlignmentContext> alignmentContexts = new IntervalOverlappingIterator<>(libs, ImmutableList.of(interval), sequenceDictionary);
final Spliterator<AlignmentContext> alignmentContextSpliterator = Spliterators.spliteratorUnknownSize(alignmentContexts, 0);
return StreamSupport.stream(alignmentContextSpliterator, false).map(alignmentContext -> {
final SimpleInterval alignmentInterval = new SimpleInterval(alignmentContext);
return new LocusWalkerContext(alignmentContext, new ReferenceContext(reference, alignmentInterval), new FeatureContext(fm, alignmentInterval));
}).iterator();
};
}
use of org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource in project gatk by broadinstitute.
the class ReadWalkerSpark method getReadsFunction.
private static FlatMapFunction<Shard<GATKRead>, ReadWalkerContext> getReadsFunction(Broadcast<ReferenceMultiSource> bReferenceSource, Broadcast<FeatureManager> bFeatureManager, SAMSequenceDictionary sequenceDictionary, int readShardPadding) {
return (FlatMapFunction<Shard<GATKRead>, ReadWalkerContext>) shard -> {
SimpleInterval paddedInterval = shard.getInterval().expandWithinContig(readShardPadding, sequenceDictionary);
ReferenceDataSource reference = bReferenceSource == null ? null : new ReferenceMemorySource(bReferenceSource.getValue().getReferenceBases(null, paddedInterval), sequenceDictionary);
FeatureManager features = bFeatureManager == null ? null : bFeatureManager.getValue();
return StreamSupport.stream(shard.spliterator(), false).map(r -> {
final SimpleInterval readInterval = getReadInterval(r);
return new ReadWalkerContext(r, new ReferenceContext(reference, readInterval), new FeatureContext(features, readInterval));
}).iterator();
};
}
use of org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource in project gatk-protected by broadinstitute.
the class HaplotypeCallerSparkIntegrationTest method testReferenceMultiSourceIsSerializable.
@Test
public void testReferenceMultiSourceIsSerializable() {
final ReferenceMultiSource args = new ReferenceMultiSource((PipelineOptions) null, BaseTest.b37_2bit_reference_20_21, ReferenceWindowFunctions.IDENTITY_FUNCTION);
SparkTestUtils.roundTripInKryo(args, ReferenceMultiSource.class, SparkContextFactory.getTestSparkContext().getConf());
}
use of org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource in project gatk by broadinstitute.
the class JoinReadsWithRefBasesSparkUnitTest method refBasesShuffleTest.
@Test(dataProvider = "bases", groups = "spark")
public void refBasesShuffleTest(List<GATKRead> reads, List<KV<GATKRead, ReferenceBases>> kvReadRefBases, List<SimpleInterval> intervals) throws IOException {
JavaSparkContext ctx = SparkContextFactory.getTestSparkContext();
JavaRDD<GATKRead> rddReads = ctx.parallelize(reads);
ReferenceMultiSource mockSource = mock(ReferenceMultiSource.class, withSettings().serializable());
for (SimpleInterval i : intervals) {
when(mockSource.getReferenceBases(any(PipelineOptions.class), eq(i))).thenReturn(FakeReferenceSource.bases(i));
}
when(mockSource.getReferenceWindowFunction()).thenReturn(ReferenceWindowFunctions.IDENTITY_FUNCTION);
JavaPairRDD<GATKRead, ReferenceBases> rddResult = ShuffleJoinReadsWithRefBases.addBases(mockSource, rddReads);
Map<GATKRead, ReferenceBases> result = rddResult.collectAsMap();
for (KV<GATKRead, ReferenceBases> kv : kvReadRefBases) {
ReferenceBases referenceBases = result.get(kv.getKey());
Assert.assertNotNull(referenceBases);
Assert.assertEquals(kv.getValue(), referenceBases);
}
}
use of org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource in project gatk by broadinstitute.
the class JoinReadsWithRefBasesSparkUnitTest method refBasesBroadcastTest.
@Test(dataProvider = "bases", groups = "spark")
public void refBasesBroadcastTest(List<GATKRead> reads, List<KV<GATKRead, ReferenceBases>> kvReadRefBases, List<SimpleInterval> intervals) throws IOException {
JavaSparkContext ctx = SparkContextFactory.getTestSparkContext();
JavaRDD<GATKRead> rddReads = ctx.parallelize(reads);
ReferenceMultiSource mockSource = mock(ReferenceMultiSource.class, withSettings().serializable());
for (SimpleInterval i : intervals) {
when(mockSource.getReferenceBases(any(PipelineOptions.class), eq(i))).thenReturn(FakeReferenceSource.bases(i));
}
when(mockSource.getReferenceWindowFunction()).thenReturn(ReferenceWindowFunctions.IDENTITY_FUNCTION);
JavaPairRDD<GATKRead, ReferenceBases> rddResult = BroadcastJoinReadsWithRefBases.addBases(mockSource, rddReads);
Map<GATKRead, ReferenceBases> result = rddResult.collectAsMap();
for (KV<GATKRead, ReferenceBases> kv : kvReadRefBases) {
ReferenceBases referenceBases = result.get(kv.getKey());
Assert.assertNotNull(referenceBases);
Assert.assertEquals(kv.getValue(), referenceBases);
}
}
Aggregations