use of org.apache.spark.api.java.JavaRDD in project gatk by broadinstitute.
the class LocusWalkerSpark method getAlignments.
/**
* Loads alignments and the corresponding reference and features into a {@link JavaRDD} for the intervals specified.
*
* If no intervals were specified, returns all the alignments.
*
* @return all alignments as a {@link JavaRDD}, bounded by intervals if specified.
*/
public JavaRDD<LocusWalkerContext> getAlignments(JavaSparkContext ctx) {
SAMSequenceDictionary sequenceDictionary = getBestAvailableSequenceDictionary();
List<SimpleInterval> intervals = hasIntervals() ? getIntervals() : IntervalUtils.getAllIntervalsForReference(sequenceDictionary);
final List<ShardBoundary> intervalShards = intervals.stream().flatMap(interval -> Shard.divideIntervalIntoShards(interval, readShardSize, readShardPadding, sequenceDictionary).stream()).collect(Collectors.toList());
int maxLocatableSize = Math.min(readShardSize, readShardPadding);
JavaRDD<Shard<GATKRead>> shardedReads = SparkSharder.shard(ctx, getReads(), GATKRead.class, sequenceDictionary, intervalShards, maxLocatableSize, shuffle);
Broadcast<ReferenceMultiSource> bReferenceSource = hasReference() ? ctx.broadcast(getReference()) : null;
Broadcast<FeatureManager> bFeatureManager = features == null ? null : ctx.broadcast(features);
return shardedReads.flatMap(getAlignmentsFunction(bReferenceSource, bFeatureManager, sequenceDictionary, getHeaderForReads(), getDownsamplingInfo()));
}
use of org.apache.spark.api.java.JavaRDD in project gatk by broadinstitute.
the class VariantWalkerSpark method getVariants.
/**
* Loads variants and the corresponding reads, reference and features into a {@link JavaRDD} for the intervals specified.
* FOr the current implementation the reads context will always be empty.
*
* If no intervals were specified, returns all the variants.
*
* @return all variants as a {@link JavaRDD}, bounded by intervals if specified.
*/
public JavaRDD<VariantWalkerContext> getVariants(JavaSparkContext ctx) {
SAMSequenceDictionary sequenceDictionary = getBestAvailableSequenceDictionary();
List<SimpleInterval> intervals = hasIntervals() ? getIntervals() : IntervalUtils.getAllIntervalsForReference(sequenceDictionary);
// use unpadded shards (padding is only needed for reference bases)
final List<ShardBoundary> intervalShards = intervals.stream().flatMap(interval -> Shard.divideIntervalIntoShards(interval, variantShardSize, 0, sequenceDictionary).stream()).collect(Collectors.toList());
JavaRDD<VariantContext> variants = variantsSource.getParallelVariantContexts(drivingVariantFile, getIntervals());
VariantFilter variantFilter = makeVariantFilter();
variants = variants.filter(variantFilter::test);
JavaRDD<Shard<VariantContext>> shardedVariants = SparkSharder.shard(ctx, variants, VariantContext.class, sequenceDictionary, intervalShards, variantShardSize, shuffle);
Broadcast<ReferenceMultiSource> bReferenceSource = hasReference() ? ctx.broadcast(getReference()) : null;
Broadcast<FeatureManager> bFeatureManager = features == null ? null : ctx.broadcast(features);
return shardedVariants.flatMap(getVariantsFunction(bReferenceSource, bFeatureManager, sequenceDictionary, variantShardPadding));
}
use of org.apache.spark.api.java.JavaRDD in project gatk by broadinstitute.
the class FindBadGenomicKmersSpark method processRefRDD.
/**
* Do a map/reduce on an RDD of genomic sequences:
* Kmerize, mapping to a pair <kmer,1>, reduce by summing values by key, filter out <kmer,N> where
* N <= MAX_KMER_FREQ, and collect the high frequency kmers back in the driver.
*/
@VisibleForTesting
static List<SVKmer> processRefRDD(final int kSize, final int maxDUSTScore, final int maxKmerFreq, final JavaRDD<byte[]> refRDD) {
final int nPartitions = refRDD.getNumPartitions();
final int hashSize = 2 * REF_RECORDS_PER_PARTITION;
final int arrayCap = REF_RECORDS_PER_PARTITION / 100;
return refRDD.mapPartitions(seqItr -> {
final HopscotchMap<SVKmer, Integer, KmerAndCount> kmerCounts = new HopscotchMap<>(hashSize);
while (seqItr.hasNext()) {
final byte[] seq = seqItr.next();
SVDUSTFilteredKmerizer.stream(seq, kSize, maxDUSTScore, new SVKmerLong()).map(kmer -> kmer.canonical(kSize)).forEach(kmer -> {
final KmerAndCount entry = kmerCounts.find(kmer);
if (entry == null)
kmerCounts.add(new KmerAndCount((SVKmerLong) kmer));
else
entry.bumpCount();
});
}
return kmerCounts.iterator();
}).mapToPair(entry -> new Tuple2<>(entry.getKey(), entry.getValue())).partitionBy(new HashPartitioner(nPartitions)).mapPartitions(pairItr -> {
final HopscotchMap<SVKmer, Integer, KmerAndCount> kmerCounts = new HopscotchMap<>(hashSize);
while (pairItr.hasNext()) {
final Tuple2<SVKmer, Integer> pair = pairItr.next();
final SVKmer kmer = pair._1();
final int count = pair._2();
KmerAndCount entry = kmerCounts.find(kmer);
if (entry == null)
kmerCounts.add(new KmerAndCount((SVKmerLong) kmer, count));
else
entry.bumpCount(count);
}
final List<SVKmer> highFreqKmers = new ArrayList<>(arrayCap);
for (KmerAndCount kmerAndCount : kmerCounts) {
if (kmerAndCount.grabCount() > maxKmerFreq)
highFreqKmers.add(kmerAndCount.getKey());
}
return highFreqKmers.iterator();
}).collect();
}
use of org.apache.spark.api.java.JavaRDD in project gatk by broadinstitute.
the class FindBreakpointEvidenceSpark method getKmerIntervals.
/** find kmers for each interval */
@VisibleForTesting
static Tuple2<List<AlignedAssemblyOrExcuse>, List<KmerAndInterval>> getKmerIntervals(final Params params, final JavaSparkContext ctx, final HopscotchUniqueMultiMap<String, Integer, QNameAndInterval> qNamesMultiMap, final int nIntervals, final Set<SVKmer> kmerKillSet, final JavaRDD<GATKRead> reads, final Locations locations) {
final Broadcast<Set<SVKmer>> broadcastKmerKillSet = ctx.broadcast(kmerKillSet);
final Broadcast<HopscotchUniqueMultiMap<String, Integer, QNameAndInterval>> broadcastQNameAndIntervalsMultiMap = ctx.broadcast(qNamesMultiMap);
// given a set of template names with interval IDs and a kill set of ubiquitous kmers,
// produce a set of interesting kmers for each interval ID
final int kmersPerPartitionGuess = params.cleanerKmersPerPartitionGuess;
final int minKmers = params.cleanerMinKmerCount;
final int maxKmers = params.cleanerMaxKmerCount;
final int maxIntervals = params.cleanerMaxIntervals;
final int kSize = params.kSize;
final int maxDUSTScore = params.maxDUSTScore;
final List<KmerAndInterval> kmerIntervals = reads.mapPartitionsToPair(readItr -> new MapPartitioner<>(readItr, new QNameKmerizer(broadcastQNameAndIntervalsMultiMap.value(), broadcastKmerKillSet.value(), kSize, maxDUSTScore)).iterator(), false).reduceByKey(Integer::sum).mapPartitions(itr -> new KmerCleaner(itr, kmersPerPartitionGuess, minKmers, maxKmers, maxIntervals).iterator()).collect();
broadcastQNameAndIntervalsMultiMap.destroy();
broadcastKmerKillSet.destroy();
final int[] intervalKmerCounts = new int[nIntervals];
for (final KmerAndInterval kmerAndInterval : kmerIntervals) {
intervalKmerCounts[kmerAndInterval.getIntervalId()] += 1;
}
final Set<Integer> intervalsToKill = new HashSet<>();
final List<AlignedAssemblyOrExcuse> intervalDispositions = new ArrayList<>();
for (int idx = 0; idx != nIntervals; ++idx) {
if (intervalKmerCounts[idx] < params.minKmersPerInterval) {
intervalsToKill.add(idx);
intervalDispositions.add(new AlignedAssemblyOrExcuse(idx, "FASTQ not written -- too few kmers"));
}
}
qNamesMultiMap.removeIf(qNameAndInterval -> intervalsToKill.contains(qNameAndInterval.getIntervalId()));
final List<KmerAndInterval> filteredKmerIntervals = kmerIntervals.stream().filter(kmerAndInterval -> !intervalsToKill.contains(kmerAndInterval.getIntervalId())).collect(SVUtils.arrayListCollector(kmerIntervals.size()));
// record the kmers with their interval IDs
if (locations.kmerFile != null) {
try (final OutputStreamWriter writer = new OutputStreamWriter(new BufferedOutputStream(BucketUtils.createFile(locations.kmerFile)))) {
for (final KmerAndInterval kmerAndInterval : filteredKmerIntervals) {
writer.write(kmerAndInterval.toString(kSize) + " " + kmerAndInterval.getIntervalId() + "\n");
}
} catch (final IOException ioe) {
throw new GATKException("Can't write kmer intervals file " + locations.kmerFile, ioe);
}
}
return new Tuple2<>(intervalDispositions, filteredKmerIntervals);
}
use of org.apache.spark.api.java.JavaRDD in project gatk by broadinstitute.
the class MarkDuplicatesSparkUtils method generateMetrics.
static JavaPairRDD<String, DuplicationMetrics> generateMetrics(final SAMFileHeader header, final JavaRDD<GATKRead> reads) {
return reads.filter(read -> !read.isSecondaryAlignment() && !read.isSupplementaryAlignment()).mapToPair(read -> {
final String library = LibraryIdGenerator.getLibraryName(header, read.getReadGroup());
DuplicationMetrics metrics = new DuplicationMetrics();
metrics.LIBRARY = library;
if (read.isUnmapped()) {
++metrics.UNMAPPED_READS;
} else if (!read.isPaired() || read.mateIsUnmapped()) {
++metrics.UNPAIRED_READS_EXAMINED;
} else {
++metrics.READ_PAIRS_EXAMINED;
}
if (read.isDuplicate()) {
if (!read.isPaired() || read.mateIsUnmapped()) {
++metrics.UNPAIRED_READ_DUPLICATES;
} else {
++metrics.READ_PAIR_DUPLICATES;
}
}
if (read.hasAttribute(OPTICAL_DUPLICATE_TOTAL_ATTRIBUTE_NAME)) {
metrics.READ_PAIR_OPTICAL_DUPLICATES += read.getAttributeAsInteger(OPTICAL_DUPLICATE_TOTAL_ATTRIBUTE_NAME);
}
return new Tuple2<>(library, metrics);
}).foldByKey(new DuplicationMetrics(), (metricsSum, m) -> {
if (metricsSum.LIBRARY == null) {
metricsSum.LIBRARY = m.LIBRARY;
}
// This should never happen, as we grouped by key using library as the key.
if (!metricsSum.LIBRARY.equals(m.LIBRARY)) {
throw new GATKException("Two different libraries encountered while summing metrics: " + metricsSum.LIBRARY + " and " + m.LIBRARY);
}
metricsSum.UNMAPPED_READS += m.UNMAPPED_READS;
metricsSum.UNPAIRED_READS_EXAMINED += m.UNPAIRED_READS_EXAMINED;
metricsSum.READ_PAIRS_EXAMINED += m.READ_PAIRS_EXAMINED;
metricsSum.UNPAIRED_READ_DUPLICATES += m.UNPAIRED_READ_DUPLICATES;
metricsSum.READ_PAIR_DUPLICATES += m.READ_PAIR_DUPLICATES;
metricsSum.READ_PAIR_OPTICAL_DUPLICATES += m.READ_PAIR_OPTICAL_DUPLICATES;
return metricsSum;
}).mapValues(metrics -> {
DuplicationMetrics copy = metrics.copy();
copy.READ_PAIRS_EXAMINED = metrics.READ_PAIRS_EXAMINED / 2;
copy.READ_PAIR_DUPLICATES = metrics.READ_PAIR_DUPLICATES / 2;
copy.calculateDerivedMetrics();
if (copy.ESTIMATED_LIBRARY_SIZE == null) {
copy.ESTIMATED_LIBRARY_SIZE = 0L;
}
return copy;
});
}
Aggregations