use of org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource in project gatk by broadinstitute.
the class ShuffleJoinReadsWithRefBases method addBases.
/**
* Joins each read of an RDD<GATKRead, T> with key's corresponding reference sequence.
*
* @param referenceDataflowSource The source of the reference sequence information
* @param keyedByRead The read-keyed RDD for which to extract reference sequence information
* @return The JavaPairRDD that contains each read along with the corresponding ReferenceBases object and the value
*/
public static <T> JavaPairRDD<GATKRead, Tuple2<T, ReferenceBases>> addBases(final ReferenceMultiSource referenceDataflowSource, final JavaPairRDD<GATKRead, T> keyedByRead) {
SerializableFunction<GATKRead, SimpleInterval> windowFunction = referenceDataflowSource.getReferenceWindowFunction();
JavaPairRDD<ReferenceShard, Tuple2<GATKRead, T>> shardRead = keyedByRead.mapToPair(pair -> {
ReferenceShard shard = ReferenceShard.getShardNumberFromInterval(windowFunction.apply(pair._1()));
return new Tuple2<>(shard, pair);
});
JavaPairRDD<ReferenceShard, Iterable<Tuple2<GATKRead, T>>> shardiRead = shardRead.groupByKey();
return shardiRead.flatMapToPair(in -> {
List<Tuple2<GATKRead, Tuple2<T, ReferenceBases>>> out = Lists.newArrayList();
Iterable<Tuple2<GATKRead, T>> iReads = in._2();
final List<SimpleInterval> readWindows = Utils.stream(iReads).map(pair -> windowFunction.apply(pair._1())).collect(Collectors.toList());
SimpleInterval interval = IntervalUtils.getSpanningInterval(readWindows);
ReferenceBases bases = referenceDataflowSource.getReferenceBases(null, interval);
for (Tuple2<GATKRead, T> p : iReads) {
final ReferenceBases subset = bases.getSubset(windowFunction.apply(p._1()));
out.add(new Tuple2<>(p._1(), new Tuple2<>(p._2(), subset)));
}
return out.iterator();
});
}
use of org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource in project gatk by broadinstitute.
the class BroadcastJoinReadsWithRefBases method addBases.
/**
* Joins each read of an RDD<GATKRead> with that read's corresponding reference sequence.
*
* @param referenceDataflowSource The source of the reference sequence information
* @param reads The reads for which to extract reference sequence information
* @return The JavaPairRDD that contains each read along with the corresponding ReferenceBases object
*/
public static JavaPairRDD<GATKRead, ReferenceBases> addBases(final ReferenceMultiSource referenceDataflowSource, final JavaRDD<GATKRead> reads) {
JavaSparkContext ctx = new JavaSparkContext(reads.context());
Broadcast<ReferenceMultiSource> bReferenceSource = ctx.broadcast(referenceDataflowSource);
return reads.mapToPair(read -> {
SimpleInterval interval = bReferenceSource.getValue().getReferenceWindowFunction().apply(read);
return new Tuple2<>(read, bReferenceSource.getValue().getReferenceBases(null, interval));
});
}
use of org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource in project gatk by broadinstitute.
the class AddContextDataToReadSparkUnitTest method addContextDataTest.
@Test(dataProvider = "bases", groups = "spark")
public void addContextDataTest(List<GATKRead> reads, List<GATKVariant> variantList, List<KV<GATKRead, ReadContextData>> expectedReadContextData, JoinStrategy joinStrategy) throws IOException {
JavaSparkContext ctx = SparkContextFactory.getTestSparkContext();
JavaRDD<GATKRead> rddReads = ctx.parallelize(reads);
JavaRDD<GATKVariant> rddVariants = ctx.parallelize(variantList);
ReferenceMultiSource mockSource = mock(ReferenceMultiSource.class, withSettings().serializable());
when(mockSource.getReferenceBases(any(PipelineOptions.class), any())).then(new ReferenceBasesAnswer());
when(mockSource.getReferenceWindowFunction()).thenReturn(ReferenceWindowFunctions.IDENTITY_FUNCTION);
SAMSequenceDictionary sd = new SAMSequenceDictionary(Lists.newArrayList(new SAMSequenceRecord("1", 100000), new SAMSequenceRecord("2", 100000)));
when(mockSource.getReferenceSequenceDictionary(null)).thenReturn(sd);
JavaPairRDD<GATKRead, ReadContextData> rddActual = AddContextDataToReadSpark.add(ctx, rddReads, mockSource, rddVariants, joinStrategy, sd, 10000, 1000);
Map<GATKRead, ReadContextData> actual = rddActual.collectAsMap();
Assert.assertEquals(actual.size(), expectedReadContextData.size());
for (KV<GATKRead, ReadContextData> kv : expectedReadContextData) {
ReadContextData readContextData = actual.get(kv.getKey());
Assert.assertNotNull(readContextData);
Assert.assertTrue(CollectionUtils.isEqualCollection(Lists.newArrayList(readContextData.getOverlappingVariants()), Lists.newArrayList(kv.getValue().getOverlappingVariants())));
SimpleInterval minimalInterval = kv.getValue().getOverlappingReferenceBases().getInterval();
ReferenceBases subset = readContextData.getOverlappingReferenceBases().getSubset(minimalInterval);
Assert.assertEquals(subset, kv.getValue().getOverlappingReferenceBases());
}
}
use of org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource in project gatk by broadinstitute.
the class FindBadGenomicKmersSparkUnitTest method miniRefTest.
@Test(groups = "spark")
public void miniRefTest() throws IOException {
final JavaSparkContext ctx = SparkContextFactory.getTestSparkContext();
final ReferenceMultiSource ref = new ReferenceMultiSource((PipelineOptions) null, REFERENCE_FILE_NAME, ReferenceWindowFunctions.IDENTITY_FUNCTION);
final SAMSequenceDictionary dict = ref.getReferenceSequenceDictionary(null);
if (dict == null)
throw new GATKException("No reference dictionary available.");
final Map<SVKmer, Long> kmerMap = new LinkedHashMap<>();
for (final SAMSequenceRecord rec : dict.getSequences()) {
final SimpleInterval interval = new SimpleInterval(rec.getSequenceName(), 1, rec.getSequenceLength());
final byte[] bases = ref.getReferenceBases(null, interval).getBases();
final SVKmerizer kmerizer = new SVKmerizer(bases, KMER_SIZE, new SVKmerLong());
while (kmerizer.hasNext()) {
final SVKmer kmer = kmerizer.next().canonical(KMER_SIZE);
final Long currentCount = kmerMap.getOrDefault(kmer, 0L);
kmerMap.put(kmer, currentCount + 1);
}
}
final Iterator<Map.Entry<SVKmer, Long>> kmerIterator = kmerMap.entrySet().iterator();
while (kmerIterator.hasNext()) {
if (kmerIterator.next().getValue() <= FindBadGenomicKmersSpark.MAX_KMER_FREQ)
kmerIterator.remove();
}
final List<SVKmer> badKmers = FindBadGenomicKmersSpark.findBadGenomicKmers(ctx, KMER_SIZE, Integer.MAX_VALUE, ref, null, null);
final Set<SVKmer> badKmerSet = new HashSet<>(badKmers);
Assert.assertEquals(badKmers.size(), badKmerSet.size());
Assert.assertEquals(badKmerSet, kmerMap.keySet());
}
use of org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource in project gatk by broadinstitute.
the class BaseRecalibratorSparkSharded method runPipeline.
@Override
protected void runPipeline(JavaSparkContext ctx) {
if (readArguments.getReadFilesNames().size() != 1) {
throw new UserException("Sorry, we only support a single reads input for now.");
}
final String bam = readArguments.getReadFilesNames().get(0);
final String referenceURL = referenceArguments.getReferenceFileName();
auth = getAuthHolder();
final ReferenceMultiSource rds = new ReferenceMultiSource(auth, referenceURL, BaseRecalibrationEngine.BQSR_REFERENCE_WINDOW_FUNCTION);
SAMFileHeader readsHeader = new ReadsSparkSource(ctx, readArguments.getReadValidationStringency()).getHeader(bam, referenceURL);
final SAMSequenceDictionary readsDictionary = readsHeader.getSequenceDictionary();
final SAMSequenceDictionary refDictionary = rds.getReferenceSequenceDictionary(readsDictionary);
final ReadFilter readFilterToApply = ReadFilter.fromList(BaseRecalibrator.getStandardBQSRReadFilterList(), readsHeader);
SequenceDictionaryUtils.validateDictionaries("reference", refDictionary, "reads", readsDictionary);
Broadcast<SAMFileHeader> readsHeaderBcast = ctx.broadcast(readsHeader);
Broadcast<SAMSequenceDictionary> refDictionaryBcast = ctx.broadcast(refDictionary);
List<SimpleInterval> intervals = intervalArgumentCollection.intervalsSpecified() ? intervalArgumentCollection.getIntervals(readsHeader.getSequenceDictionary()) : IntervalUtils.getAllIntervalsForReference(readsHeader.getSequenceDictionary());
List<String> localVariants = knownVariants;
localVariants = hackilyCopyFromGCSIfNecessary(localVariants);
List<GATKVariant> variants = VariantsSource.getVariantsList(localVariants);
// get reads, reference, variants
JavaRDD<ContextShard> readsWithContext = AddContextDataToReadSparkOptimized.add(ctx, intervals, bam, variants, readFilterToApply, rds);
// run BaseRecalibratorEngine.
BaseRecalibratorEngineSparkWrapper recal = new BaseRecalibratorEngineSparkWrapper(readsHeaderBcast, refDictionaryBcast, bqsrArgs);
JavaRDD<RecalibrationTables> tables = readsWithContext.mapPartitions(s -> recal.apply(s));
final RecalibrationTables emptyRecalibrationTable = new RecalibrationTables(new StandardCovariateList(bqsrArgs, readsHeader));
final RecalibrationTables table = tables.treeAggregate(emptyRecalibrationTable, RecalibrationTables::inPlaceCombine, RecalibrationTables::inPlaceCombine, Math.max(1, (int) (Math.log(tables.partitions().size()) / Math.log(2))));
BaseRecalibrationEngine.finalizeRecalibrationTables(table);
try {
BaseRecalibratorEngineSparkWrapper.saveTextualReport(outputTablesPath, readsHeader, table, bqsrArgs, auth);
} catch (IOException e) {
throw new UserException.CouldNotCreateOutputFile(new File(outputTablesPath), e);
}
}
Aggregations