Search in sources :

Example 56 with JavaRDD

use of org.apache.spark.api.java.JavaRDD in project gatk by broadinstitute.

the class QualityYieldMetricsCollectorSpark method collectMetrics.

/**
     * Do the actual metrics collection on the provided RDD.
     * @param filteredReads The reads to be analyzed for this collector.
     * @param samHeader The SAMFileHeader associated with the reads in the input RDD.
     */
@Override
public void collectMetrics(final JavaRDD<GATKRead> filteredReads, final SAMFileHeader samHeader) {
    final QualityYieldMetrics metrics = filteredReads.aggregate(new QualityYieldMetrics().setUseOriginalQualities(args.useOriginalQualities), (hgp, read) -> hgp.addRead(read), (hgp1, hgp2) -> hgp1.combine(hgp2)).finish();
    metricsFile.addMetric(metrics);
}
Also used : Header(htsjdk.samtools.metrics.Header) QualityYieldMetrics(org.broadinstitute.hellbender.metrics.QualityYieldMetrics) ReadFilter(org.broadinstitute.hellbender.engine.filters.ReadFilter) MetricsFile(htsjdk.samtools.metrics.MetricsFile) GATKRead(org.broadinstitute.hellbender.utils.read.GATKRead) SAMFileHeader(htsjdk.samtools.SAMFileHeader) AuthHolder(org.broadinstitute.hellbender.engine.AuthHolder) Serializable(java.io.Serializable) MetricsUtils(org.broadinstitute.hellbender.metrics.MetricsUtils) List(java.util.List) QualityYieldMetricsArgumentCollection(org.broadinstitute.hellbender.metrics.QualityYieldMetricsArgumentCollection) JavaRDD(org.apache.spark.api.java.JavaRDD) ReadFilterLibrary(org.broadinstitute.hellbender.engine.filters.ReadFilterLibrary) QualityYieldMetrics(org.broadinstitute.hellbender.metrics.QualityYieldMetrics)

Example 57 with JavaRDD

use of org.apache.spark.api.java.JavaRDD in project gatk by broadinstitute.

the class AddContextDataToReadSpark method addUsingOverlapsPartitioning.

/**
     * Add context data ({@link ReadContextData}) to reads, using overlaps partitioning to avoid a shuffle.
     * @param ctx the Spark context
     * @param mappedReads the coordinate-sorted reads
     * @param referenceSource the reference source
     * @param variants the coordinate-sorted variants
     * @param sequenceDictionary the sequence dictionary for the reads
     * @param shardSize the maximum size of each shard, in bases
     * @param shardPadding amount of extra context around each shard, in bases
     * @return a RDD of read-context pairs, in coordinate-sorted order
     */
private static JavaPairRDD<GATKRead, ReadContextData> addUsingOverlapsPartitioning(final JavaSparkContext ctx, final JavaRDD<GATKRead> mappedReads, final ReferenceMultiSource referenceSource, final JavaRDD<GATKVariant> variants, final SAMSequenceDictionary sequenceDictionary, final int shardSize, final int shardPadding) {
    final List<SimpleInterval> intervals = IntervalUtils.getAllIntervalsForReference(sequenceDictionary);
    // use unpadded shards (padding is only needed for reference bases)
    final List<ShardBoundary> intervalShards = intervals.stream().flatMap(interval -> Shard.divideIntervalIntoShards(interval, shardSize, 0, sequenceDictionary).stream()).collect(Collectors.toList());
    final Broadcast<ReferenceMultiSource> bReferenceSource = ctx.broadcast(referenceSource);
    final IntervalsSkipList<GATKVariant> variantSkipList = new IntervalsSkipList<>(variants.collect());
    final Broadcast<IntervalsSkipList<GATKVariant>> variantsBroadcast = ctx.broadcast(variantSkipList);
    int maxLocatableSize = Math.min(shardSize, shardPadding);
    JavaRDD<Shard<GATKRead>> shardedReads = SparkSharder.shard(ctx, mappedReads, GATKRead.class, sequenceDictionary, intervalShards, maxLocatableSize);
    return shardedReads.flatMapToPair(new PairFlatMapFunction<Shard<GATKRead>, GATKRead, ReadContextData>() {

        private static final long serialVersionUID = 1L;

        @Override
        public Iterator<Tuple2<GATKRead, ReadContextData>> call(Shard<GATKRead> shard) throws Exception {
            // get reference bases for this shard (padded)
            SimpleInterval paddedInterval = shard.getInterval().expandWithinContig(shardPadding, sequenceDictionary);
            ReferenceBases referenceBases = bReferenceSource.getValue().getReferenceBases(null, paddedInterval);
            final IntervalsSkipList<GATKVariant> intervalsSkipList = variantsBroadcast.getValue();
            Iterator<Tuple2<GATKRead, ReadContextData>> transform = Iterators.transform(shard.iterator(), new Function<GATKRead, Tuple2<GATKRead, ReadContextData>>() {

                @Nullable
                @Override
                public Tuple2<GATKRead, ReadContextData> apply(@Nullable GATKRead r) {
                    List<GATKVariant> overlappingVariants;
                    if (SimpleInterval.isValid(r.getContig(), r.getStart(), r.getEnd())) {
                        overlappingVariants = intervalsSkipList.getOverlapping(new SimpleInterval(r));
                    } else {
                        //Sometimes we have reads that do not form valid intervals (reads that do not consume any ref bases, eg CIGAR 61S90I
                        //In those cases, we'll just say that nothing overlaps the read
                        overlappingVariants = Collections.emptyList();
                    }
                    return new Tuple2<>(r, new ReadContextData(referenceBases, overlappingVariants));
                }
            });
            // only include reads that start in the shard
            return Iterators.filter(transform, r -> r._1().getStart() >= shard.getStart() && r._1().getStart() <= shard.getEnd());
        }
    });
}
Also used : PairFlatMapFunction(org.apache.spark.api.java.function.PairFlatMapFunction) ReferenceMultiSource(org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) GATKRead(org.broadinstitute.hellbender.utils.read.GATKRead) Iterators(com.google.common.collect.Iterators) IntervalUtils(org.broadinstitute.hellbender.utils.IntervalUtils) ReferenceBases(org.broadinstitute.hellbender.utils.reference.ReferenceBases) ReadContextData(org.broadinstitute.hellbender.engine.ReadContextData) JavaRDD(org.apache.spark.api.java.JavaRDD) Nullable(javax.annotation.Nullable) Broadcast(org.apache.spark.broadcast.Broadcast) IntervalsSkipList(org.broadinstitute.hellbender.utils.collections.IntervalsSkipList) Function(com.google.common.base.Function) Iterator(java.util.Iterator) SAMSequenceDictionary(htsjdk.samtools.SAMSequenceDictionary) GATKVariant(org.broadinstitute.hellbender.utils.variant.GATKVariant) Tuple2(scala.Tuple2) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) SimpleInterval(org.broadinstitute.hellbender.utils.SimpleInterval) Collectors(java.util.stream.Collectors) Shard(org.broadinstitute.hellbender.engine.Shard) List(java.util.List) UserException(org.broadinstitute.hellbender.exceptions.UserException) ShardBoundary(org.broadinstitute.hellbender.engine.ShardBoundary) Collections(java.util.Collections) ReadFilterLibrary(org.broadinstitute.hellbender.engine.filters.ReadFilterLibrary) GATKRead(org.broadinstitute.hellbender.utils.read.GATKRead) ShardBoundary(org.broadinstitute.hellbender.engine.ShardBoundary) PairFlatMapFunction(org.apache.spark.api.java.function.PairFlatMapFunction) Function(com.google.common.base.Function) IntervalsSkipList(org.broadinstitute.hellbender.utils.collections.IntervalsSkipList) Iterator(java.util.Iterator) SimpleInterval(org.broadinstitute.hellbender.utils.SimpleInterval) GATKVariant(org.broadinstitute.hellbender.utils.variant.GATKVariant) ReferenceMultiSource(org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource) UserException(org.broadinstitute.hellbender.exceptions.UserException) ReadContextData(org.broadinstitute.hellbender.engine.ReadContextData) ReferenceBases(org.broadinstitute.hellbender.utils.reference.ReferenceBases) Tuple2(scala.Tuple2) Shard(org.broadinstitute.hellbender.engine.Shard) Nullable(javax.annotation.Nullable)

Example 58 with JavaRDD

use of org.apache.spark.api.java.JavaRDD in project gatk by broadinstitute.

the class IntervalWalkerSpark method getIntervals.

/**
     * Loads intervals and the corresponding reads, reference and features into a {@link JavaRDD}.
     *
     * @return all intervals as a {@link JavaRDD}.
     */
public JavaRDD<IntervalWalkerContext> getIntervals(JavaSparkContext ctx) {
    SAMSequenceDictionary sequenceDictionary = getBestAvailableSequenceDictionary();
    // don't shard the intervals themselves, since we want each interval to be processed by a single task
    final List<ShardBoundary> intervalShardBoundaries = getIntervals().stream().map(i -> new ShardBoundary(i, i)).collect(Collectors.toList());
    JavaRDD<Shard<GATKRead>> shardedReads = SparkSharder.shard(ctx, getReads(), GATKRead.class, sequenceDictionary, intervalShardBoundaries, Integer.MAX_VALUE, shuffle);
    Broadcast<ReferenceMultiSource> bReferenceSource = hasReference() ? ctx.broadcast(getReference()) : null;
    Broadcast<FeatureManager> bFeatureManager = features == null ? null : ctx.broadcast(features);
    return shardedReads.map(getIntervalsFunction(bReferenceSource, bFeatureManager, sequenceDictionary, intervalShardPadding));
}
Also used : Broadcast(org.apache.spark.broadcast.Broadcast) Iterator(java.util.Iterator) ReferenceMultiSource(org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource) SAMSequenceDictionary(htsjdk.samtools.SAMSequenceDictionary) Argument(org.broadinstitute.barclay.argparser.Argument) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) GATKRead(org.broadinstitute.hellbender.utils.read.GATKRead) SimpleInterval(org.broadinstitute.hellbender.utils.SimpleInterval) Collectors(java.util.stream.Collectors) org.broadinstitute.hellbender.engine(org.broadinstitute.hellbender.engine) List(java.util.List) IntervalUtils(org.broadinstitute.hellbender.utils.IntervalUtils) StreamSupport(java.util.stream.StreamSupport) JavaRDD(org.apache.spark.api.java.JavaRDD) ReferenceMultiSource(org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource) SAMSequenceDictionary(htsjdk.samtools.SAMSequenceDictionary)

Example 59 with JavaRDD

use of org.apache.spark.api.java.JavaRDD in project gatk by broadinstitute.

the class CompareDuplicatesSpark method runTool.

@Override
protected void runTool(final JavaSparkContext ctx) {
    JavaRDD<GATKRead> firstReads = filteredReads(getReads(), readArguments.getReadFilesNames().get(0));
    ReadsSparkSource readsSource2 = new ReadsSparkSource(ctx, readArguments.getReadValidationStringency());
    JavaRDD<GATKRead> secondReads = filteredReads(readsSource2.getParallelReads(input2, null, getIntervals(), bamPartitionSplitSize), input2);
    // Start by verifying that we have same number of reads and duplicates in each BAM.
    long firstBamSize = firstReads.count();
    long secondBamSize = secondReads.count();
    if (firstBamSize != secondBamSize) {
        throw new UserException("input bams have different numbers of mapped reads: " + firstBamSize + "," + secondBamSize);
    }
    System.out.println("processing bams with " + firstBamSize + " mapped reads");
    long firstDupesCount = firstReads.filter(GATKRead::isDuplicate).count();
    long secondDupesCount = secondReads.filter(GATKRead::isDuplicate).count();
    if (firstDupesCount != secondDupesCount) {
        System.out.println("BAMs have different number of total duplicates: " + firstDupesCount + "," + secondDupesCount);
    }
    System.out.println("first and second: " + firstDupesCount + "," + secondDupesCount);
    Broadcast<SAMFileHeader> bHeader = ctx.broadcast(getHeaderForReads());
    // Group the reads of each BAM by MarkDuplicates key, then pair up the the reads for each BAM.
    JavaPairRDD<String, GATKRead> firstKeyed = firstReads.mapToPair(read -> new Tuple2<>(ReadsKey.keyForFragment(bHeader.getValue(), read), read));
    JavaPairRDD<String, GATKRead> secondKeyed = secondReads.mapToPair(read -> new Tuple2<>(ReadsKey.keyForFragment(bHeader.getValue(), read), read));
    JavaPairRDD<String, Tuple2<Iterable<GATKRead>, Iterable<GATKRead>>> cogroup = firstKeyed.cogroup(secondKeyed, getRecommendedNumReducers());
    // Produces an RDD of MatchTypes, e.g., EQUAL, DIFFERENT_REPRESENTATIVE_READ, etc. per MarkDuplicates key,
    // which is approximately start position x strand.
    JavaRDD<MatchType> tagged = cogroup.map(v1 -> {
        SAMFileHeader header = bHeader.getValue();
        Iterable<GATKRead> iFirstReads = v1._2()._1();
        Iterable<GATKRead> iSecondReads = v1._2()._2();
        return getDupes(iFirstReads, iSecondReads, header);
    });
    // TODO: We should also produce examples of reads that don't match to make debugging easier (#1263).
    Map<MatchType, Integer> tagCountMap = tagged.mapToPair(v1 -> new Tuple2<>(v1, 1)).reduceByKey((v1, v2) -> v1 + v2).collectAsMap();
    if (tagCountMap.get(MatchType.SIZE_UNEQUAL) != null) {
        throw new UserException("The number of reads by the MarkDuplicates key were unequal, indicating that the BAMs are not the same");
    }
    if (tagCountMap.get(MatchType.READ_MISMATCH) != null) {
        throw new UserException("The reads grouped by the MarkDuplicates key were not the same, indicating that the BAMs are not the same");
    }
    if (printSummary) {
        MatchType[] values = MatchType.values();
        Set<MatchType> matchTypes = Sets.newLinkedHashSet(Sets.newHashSet(values));
        System.out.println("##############################");
        matchTypes.forEach(s -> System.out.println(s + ": " + tagCountMap.getOrDefault(s, 0)));
    }
    if (throwOnDiff) {
        for (MatchType s : MatchType.values()) {
            if (s != MatchType.EQUAL) {
                if (tagCountMap.get(s) != null)
                    throw new UserException("found difference between the two BAMs: " + s + " with count " + tagCountMap.get(s));
            }
        }
    }
}
Also used : GATKRead(org.broadinstitute.hellbender.utils.read.GATKRead) Broadcast(org.apache.spark.broadcast.Broadcast) CommandLineProgramProperties(org.broadinstitute.barclay.argparser.CommandLineProgramProperties) TestSparkProgramGroup(org.broadinstitute.hellbender.cmdline.programgroups.TestSparkProgramGroup) Argument(org.broadinstitute.barclay.argparser.Argument) ReadsKey(org.broadinstitute.hellbender.utils.read.markduplicates.ReadsKey) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) GATKSparkTool(org.broadinstitute.hellbender.engine.spark.GATKSparkTool) ReadCoordinateComparator(org.broadinstitute.hellbender.utils.read.ReadCoordinateComparator) Set(java.util.Set) GATKRead(org.broadinstitute.hellbender.utils.read.GATKRead) Tuple2(scala.Tuple2) SAMFileHeader(htsjdk.samtools.SAMFileHeader) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) GATKException(org.broadinstitute.hellbender.exceptions.GATKException) Sets(com.google.common.collect.Sets) ReadUtils(org.broadinstitute.hellbender.utils.read.ReadUtils) List(java.util.List) Lists(com.google.common.collect.Lists) ReadsSparkSource(org.broadinstitute.hellbender.engine.spark.datasources.ReadsSparkSource) UserException(org.broadinstitute.hellbender.exceptions.UserException) Map(java.util.Map) Function(org.apache.spark.api.java.function.Function) JavaRDD(org.apache.spark.api.java.JavaRDD) ReadsSparkSource(org.broadinstitute.hellbender.engine.spark.datasources.ReadsSparkSource) Tuple2(scala.Tuple2) UserException(org.broadinstitute.hellbender.exceptions.UserException) SAMFileHeader(htsjdk.samtools.SAMFileHeader)

Example 60 with JavaRDD

use of org.apache.spark.api.java.JavaRDD in project gatk by broadinstitute.

the class InsertSizeMetricsCollectorSparkUnitTest method test.

@Test(dataProvider = "metricsfiles", groups = "spark")
public void test(final String fileName, final String referenceName, final boolean allLevels, final String expectedResultsFile) throws IOException {
    final String inputPath = new File(TEST_DATA_DIR, fileName).getAbsolutePath();
    final String referencePath = referenceName != null ? new File(referenceName).getAbsolutePath() : null;
    final File outfile = BaseTest.createTempFile("test", ".insert_size_metrics");
    JavaSparkContext ctx = SparkContextFactory.getTestSparkContext();
    ReadsSparkSource readSource = new ReadsSparkSource(ctx, ValidationStringency.DEFAULT_STRINGENCY);
    SAMFileHeader samHeader = readSource.getHeader(inputPath, referencePath);
    JavaRDD<GATKRead> rddParallelReads = readSource.getParallelReads(inputPath, referencePath);
    InsertSizeMetricsArgumentCollection isArgs = new InsertSizeMetricsArgumentCollection();
    isArgs.output = outfile.getAbsolutePath();
    if (allLevels) {
        isArgs.metricAccumulationLevel.accumulationLevels = new HashSet<>();
        isArgs.metricAccumulationLevel.accumulationLevels.add(MetricAccumulationLevel.ALL_READS);
        isArgs.metricAccumulationLevel.accumulationLevels.add(MetricAccumulationLevel.SAMPLE);
        isArgs.metricAccumulationLevel.accumulationLevels.add(MetricAccumulationLevel.LIBRARY);
        isArgs.metricAccumulationLevel.accumulationLevels.add(MetricAccumulationLevel.READ_GROUP);
    }
    InsertSizeMetricsCollectorSpark isSpark = new InsertSizeMetricsCollectorSpark();
    isSpark.initialize(isArgs, samHeader, null);
    // Since we're bypassing the framework in order to force this test to run on multiple partitions, we
    // need to make the read filter manually since we don't have the plugin descriptor to do it for us; so
    // remove the (default) FirstOfPairReadFilter filter and add in the SECOND_IN_PAIR manually since thats
    // required for our tests to pass
    List<ReadFilter> readFilters = isSpark.getDefaultReadFilters();
    readFilters.stream().filter(f -> !f.getClass().getSimpleName().equals(ReadFilterLibrary.FirstOfPairReadFilter.class.getSimpleName()));
    ReadFilter rf = ReadFilter.fromList(readFilters, samHeader);
    // Force the input RDD to be split into two partitions to ensure that the
    // reduce/combiners run
    rddParallelReads = rddParallelReads.repartition(2);
    isSpark.collectMetrics(rddParallelReads.filter(r -> rf.test(r)), samHeader);
    isSpark.saveMetrics(fileName, null);
    IntegrationTestSpec.assertEqualTextFiles(outfile, new File(TEST_DATA_DIR, expectedResultsFile), "#");
}
Also used : GATKRead(org.broadinstitute.hellbender.utils.read.GATKRead) DataProvider(org.testng.annotations.DataProvider) BaseTest(org.broadinstitute.hellbender.utils.test.BaseTest) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) Test(org.testng.annotations.Test) IOException(java.io.IOException) ReadFilter(org.broadinstitute.hellbender.engine.filters.ReadFilter) IntegrationTestSpec(org.broadinstitute.hellbender.utils.test.IntegrationTestSpec) GATKRead(org.broadinstitute.hellbender.utils.read.GATKRead) SAMFileHeader(htsjdk.samtools.SAMFileHeader) ValidationStringency(htsjdk.samtools.ValidationStringency) CommandLineProgramTest(org.broadinstitute.hellbender.CommandLineProgramTest) File(java.io.File) HashSet(java.util.HashSet) List(java.util.List) ReadsSparkSource(org.broadinstitute.hellbender.engine.spark.datasources.ReadsSparkSource) InsertSizeMetricsArgumentCollection(org.broadinstitute.hellbender.metrics.InsertSizeMetricsArgumentCollection) MetricAccumulationLevel(org.broadinstitute.hellbender.metrics.MetricAccumulationLevel) SparkContextFactory(org.broadinstitute.hellbender.engine.spark.SparkContextFactory) JavaRDD(org.apache.spark.api.java.JavaRDD) ReadFilterLibrary(org.broadinstitute.hellbender.engine.filters.ReadFilterLibrary) ReadsSparkSource(org.broadinstitute.hellbender.engine.spark.datasources.ReadsSparkSource) InsertSizeMetricsArgumentCollection(org.broadinstitute.hellbender.metrics.InsertSizeMetricsArgumentCollection) ReadFilter(org.broadinstitute.hellbender.engine.filters.ReadFilter) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) SAMFileHeader(htsjdk.samtools.SAMFileHeader) File(java.io.File) BaseTest(org.broadinstitute.hellbender.utils.test.BaseTest) Test(org.testng.annotations.Test) CommandLineProgramTest(org.broadinstitute.hellbender.CommandLineProgramTest)

Aggregations

JavaRDD (org.apache.spark.api.java.JavaRDD)63 JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)33 List (java.util.List)24 GATKRead (org.broadinstitute.hellbender.utils.read.GATKRead)24 Collectors (java.util.stream.Collectors)20 JavaPairRDD (org.apache.spark.api.java.JavaPairRDD)20 Tuple2 (scala.Tuple2)20 Argument (org.broadinstitute.barclay.argparser.Argument)17 Broadcast (org.apache.spark.broadcast.Broadcast)15 SimpleInterval (org.broadinstitute.hellbender.utils.SimpleInterval)15 SAMFileHeader (htsjdk.samtools.SAMFileHeader)14 SAMSequenceDictionary (htsjdk.samtools.SAMSequenceDictionary)14 IOException (java.io.IOException)14 UserException (org.broadinstitute.hellbender.exceptions.UserException)14 CommandLineProgramProperties (org.broadinstitute.barclay.argparser.CommandLineProgramProperties)13 GATKSparkTool (org.broadinstitute.hellbender.engine.spark.GATKSparkTool)13 Serializable (java.io.Serializable)12 IntervalUtils (org.broadinstitute.hellbender.utils.IntervalUtils)12 java.util (java.util)11 ArrayList (java.util.ArrayList)11