Search in sources :

Example 21 with Broadcast

use of org.apache.spark.broadcast.Broadcast in project gatk by broadinstitute.

the class AssemblyRegionWalkerSpark method getAssemblyRegionsFunction.

private static FlatMapFunction<Shard<GATKRead>, AssemblyRegionWalkerContext> getAssemblyRegionsFunction(final Broadcast<ReferenceMultiSource> bReferenceSource, final Broadcast<FeatureManager> bFeatureManager, final SAMSequenceDictionary sequenceDictionary, final SAMFileHeader header, final AssemblyRegionEvaluator evaluator, final int minAssemblyRegionSize, final int maxAssemblyRegionSize, final int assemblyRegionPadding, final double activeProbThreshold, final int maxProbPropagationDistance) {
    return (FlatMapFunction<Shard<GATKRead>, AssemblyRegionWalkerContext>) shardedRead -> {
        SimpleInterval paddedInterval = shardedRead.getPaddedInterval();
        SimpleInterval assemblyRegionPaddedInterval = paddedInterval.expandWithinContig(assemblyRegionPadding, sequenceDictionary);
        ReferenceDataSource reference = bReferenceSource == null ? null : new ReferenceMemorySource(bReferenceSource.getValue().getReferenceBases(null, assemblyRegionPaddedInterval), sequenceDictionary);
        FeatureManager features = bFeatureManager == null ? null : bFeatureManager.getValue();
        ReferenceContext referenceContext = new ReferenceContext(reference, paddedInterval);
        FeatureContext featureContext = new FeatureContext(features, paddedInterval);
        final Iterable<AssemblyRegion> assemblyRegions = AssemblyRegion.createFromReadShard(shardedRead, header, referenceContext, featureContext, evaluator, minAssemblyRegionSize, maxAssemblyRegionSize, assemblyRegionPadding, activeProbThreshold, maxProbPropagationDistance);
        return StreamSupport.stream(assemblyRegions.spliterator(), false).map(assemblyRegion -> new AssemblyRegionWalkerContext(assemblyRegion, new ReferenceContext(reference, assemblyRegion.getExtendedSpan()), new FeatureContext(features, assemblyRegion.getExtendedSpan()))).iterator();
    };
}
Also used : GATKRead(org.broadinstitute.hellbender.utils.read.GATKRead) Broadcast(org.apache.spark.broadcast.Broadcast) ReferenceMultiSource(org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource) SAMSequenceDictionary(htsjdk.samtools.SAMSequenceDictionary) Argument(org.broadinstitute.barclay.argparser.Argument) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) Advanced(org.broadinstitute.barclay.argparser.Advanced) ReadFilter(org.broadinstitute.hellbender.engine.filters.ReadFilter) GATKRead(org.broadinstitute.hellbender.utils.read.GATKRead) SAMFileHeader(htsjdk.samtools.SAMFileHeader) SimpleInterval(org.broadinstitute.hellbender.utils.SimpleInterval) Collectors(java.util.stream.Collectors) ArrayList(java.util.ArrayList) org.broadinstitute.hellbender.engine(org.broadinstitute.hellbender.engine) List(java.util.List) IntervalUtils(org.broadinstitute.hellbender.utils.IntervalUtils) WellformedReadFilter(org.broadinstitute.hellbender.engine.filters.WellformedReadFilter) StreamSupport(java.util.stream.StreamSupport) JavaRDD(org.apache.spark.api.java.JavaRDD) FlatMapFunction(org.apache.spark.api.java.function.FlatMapFunction) ReadFilterLibrary(org.broadinstitute.hellbender.engine.filters.ReadFilterLibrary) FlatMapFunction(org.apache.spark.api.java.function.FlatMapFunction) SimpleInterval(org.broadinstitute.hellbender.utils.SimpleInterval)

Example 22 with Broadcast

use of org.apache.spark.broadcast.Broadcast in project gatk by broadinstitute.

the class LocusWalkerSpark method getAlignmentsFunction.

/**
     * Return a function that maps a {@link Shard} of reads into a tuple of alignments and their corresponding reference and features.
     * @param bReferenceSource the reference source broadcast
     * @param bFeatureManager the feature manager broadcast
     * @param sequenceDictionary the sequence dictionary for the reads
     * @param header the reads header
     * @param downsamplingInfo the downsampling method for the reads
     * @return a function that maps a {@link Shard} of reads into a tuple of alignments and their corresponding reference and features.
     */
private static FlatMapFunction<Shard<GATKRead>, LocusWalkerContext> getAlignmentsFunction(Broadcast<ReferenceMultiSource> bReferenceSource, Broadcast<FeatureManager> bFeatureManager, SAMSequenceDictionary sequenceDictionary, SAMFileHeader header, LIBSDownsamplingInfo downsamplingInfo) {
    return (FlatMapFunction<Shard<GATKRead>, LocusWalkerContext>) shardedRead -> {
        SimpleInterval interval = shardedRead.getInterval();
        SimpleInterval paddedInterval = shardedRead.getPaddedInterval();
        Iterator<GATKRead> readIterator = shardedRead.iterator();
        ReferenceDataSource reference = bReferenceSource == null ? null : new ReferenceMemorySource(bReferenceSource.getValue().getReferenceBases(null, paddedInterval), sequenceDictionary);
        FeatureManager fm = bFeatureManager == null ? null : bFeatureManager.getValue();
        final Set<String> samples = header.getReadGroups().stream().map(SAMReadGroupRecord::getSample).collect(Collectors.toSet());
        LocusIteratorByState libs = new LocusIteratorByState(readIterator, downsamplingInfo, false, samples, header, true, false);
        IntervalOverlappingIterator<AlignmentContext> alignmentContexts = new IntervalOverlappingIterator<>(libs, ImmutableList.of(interval), sequenceDictionary);
        final Spliterator<AlignmentContext> alignmentContextSpliterator = Spliterators.spliteratorUnknownSize(alignmentContexts, 0);
        return StreamSupport.stream(alignmentContextSpliterator, false).map(alignmentContext -> {
            final SimpleInterval alignmentInterval = new SimpleInterval(alignmentContext);
            return new LocusWalkerContext(alignmentContext, new ReferenceContext(reference, alignmentInterval), new FeatureContext(fm, alignmentInterval));
        }).iterator();
    };
}
Also used : GATKRead(org.broadinstitute.hellbender.utils.read.GATKRead) Broadcast(org.apache.spark.broadcast.Broadcast) java.util(java.util) IntervalOverlappingIterator(org.broadinstitute.hellbender.utils.iterators.IntervalOverlappingIterator) ReferenceMultiSource(org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource) SAMSequenceDictionary(htsjdk.samtools.SAMSequenceDictionary) Argument(org.broadinstitute.barclay.argparser.Argument) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) LocusIteratorByState(org.broadinstitute.hellbender.utils.locusiterator.LocusIteratorByState) GATKRead(org.broadinstitute.hellbender.utils.read.GATKRead) SAMFileHeader(htsjdk.samtools.SAMFileHeader) SimpleInterval(org.broadinstitute.hellbender.utils.SimpleInterval) Collectors(java.util.stream.Collectors) org.broadinstitute.hellbender.engine(org.broadinstitute.hellbender.engine) SAMReadGroupRecord(htsjdk.samtools.SAMReadGroupRecord) IntervalUtils(org.broadinstitute.hellbender.utils.IntervalUtils) ImmutableList(com.google.common.collect.ImmutableList) StreamSupport(java.util.stream.StreamSupport) LIBSDownsamplingInfo(org.broadinstitute.hellbender.utils.locusiterator.LIBSDownsamplingInfo) JavaRDD(org.apache.spark.api.java.JavaRDD) FlatMapFunction(org.apache.spark.api.java.function.FlatMapFunction) CommandLineException(org.broadinstitute.barclay.argparser.CommandLineException) IntervalOverlappingIterator(org.broadinstitute.hellbender.utils.iterators.IntervalOverlappingIterator) LocusIteratorByState(org.broadinstitute.hellbender.utils.locusiterator.LocusIteratorByState) FlatMapFunction(org.apache.spark.api.java.function.FlatMapFunction) SimpleInterval(org.broadinstitute.hellbender.utils.SimpleInterval)

Example 23 with Broadcast

use of org.apache.spark.broadcast.Broadcast in project gatk by broadinstitute.

the class ReadWalkerSpark method getReadsFunction.

private static FlatMapFunction<Shard<GATKRead>, ReadWalkerContext> getReadsFunction(Broadcast<ReferenceMultiSource> bReferenceSource, Broadcast<FeatureManager> bFeatureManager, SAMSequenceDictionary sequenceDictionary, int readShardPadding) {
    return (FlatMapFunction<Shard<GATKRead>, ReadWalkerContext>) shard -> {
        SimpleInterval paddedInterval = shard.getInterval().expandWithinContig(readShardPadding, sequenceDictionary);
        ReferenceDataSource reference = bReferenceSource == null ? null : new ReferenceMemorySource(bReferenceSource.getValue().getReferenceBases(null, paddedInterval), sequenceDictionary);
        FeatureManager features = bFeatureManager == null ? null : bFeatureManager.getValue();
        return StreamSupport.stream(shard.spliterator(), false).map(r -> {
            final SimpleInterval readInterval = getReadInterval(r);
            return new ReadWalkerContext(r, new ReferenceContext(reference, readInterval), new FeatureContext(features, readInterval));
        }).iterator();
    };
}
Also used : GATKRead(org.broadinstitute.hellbender.utils.read.GATKRead) Broadcast(org.apache.spark.broadcast.Broadcast) ReferenceMultiSource(org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource) SAMSequenceDictionary(htsjdk.samtools.SAMSequenceDictionary) Argument(org.broadinstitute.barclay.argparser.Argument) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) GATKRead(org.broadinstitute.hellbender.utils.read.GATKRead) SimpleInterval(org.broadinstitute.hellbender.utils.SimpleInterval) Collectors(java.util.stream.Collectors) org.broadinstitute.hellbender.engine(org.broadinstitute.hellbender.engine) List(java.util.List) IntervalUtils(org.broadinstitute.hellbender.utils.IntervalUtils) StreamSupport(java.util.stream.StreamSupport) JavaRDD(org.apache.spark.api.java.JavaRDD) FlatMapFunction(org.apache.spark.api.java.function.FlatMapFunction) FlatMapFunction(org.apache.spark.api.java.function.FlatMapFunction) SimpleInterval(org.broadinstitute.hellbender.utils.SimpleInterval)

Example 24 with Broadcast

use of org.apache.spark.broadcast.Broadcast in project gatk by broadinstitute.

the class CompareDuplicatesSpark method runTool.

@Override
protected void runTool(final JavaSparkContext ctx) {
    JavaRDD<GATKRead> firstReads = filteredReads(getReads(), readArguments.getReadFilesNames().get(0));
    ReadsSparkSource readsSource2 = new ReadsSparkSource(ctx, readArguments.getReadValidationStringency());
    JavaRDD<GATKRead> secondReads = filteredReads(readsSource2.getParallelReads(input2, null, getIntervals(), bamPartitionSplitSize), input2);
    // Start by verifying that we have same number of reads and duplicates in each BAM.
    long firstBamSize = firstReads.count();
    long secondBamSize = secondReads.count();
    if (firstBamSize != secondBamSize) {
        throw new UserException("input bams have different numbers of mapped reads: " + firstBamSize + "," + secondBamSize);
    }
    System.out.println("processing bams with " + firstBamSize + " mapped reads");
    long firstDupesCount = firstReads.filter(GATKRead::isDuplicate).count();
    long secondDupesCount = secondReads.filter(GATKRead::isDuplicate).count();
    if (firstDupesCount != secondDupesCount) {
        System.out.println("BAMs have different number of total duplicates: " + firstDupesCount + "," + secondDupesCount);
    }
    System.out.println("first and second: " + firstDupesCount + "," + secondDupesCount);
    Broadcast<SAMFileHeader> bHeader = ctx.broadcast(getHeaderForReads());
    // Group the reads of each BAM by MarkDuplicates key, then pair up the the reads for each BAM.
    JavaPairRDD<String, GATKRead> firstKeyed = firstReads.mapToPair(read -> new Tuple2<>(ReadsKey.keyForFragment(bHeader.getValue(), read), read));
    JavaPairRDD<String, GATKRead> secondKeyed = secondReads.mapToPair(read -> new Tuple2<>(ReadsKey.keyForFragment(bHeader.getValue(), read), read));
    JavaPairRDD<String, Tuple2<Iterable<GATKRead>, Iterable<GATKRead>>> cogroup = firstKeyed.cogroup(secondKeyed, getRecommendedNumReducers());
    // Produces an RDD of MatchTypes, e.g., EQUAL, DIFFERENT_REPRESENTATIVE_READ, etc. per MarkDuplicates key,
    // which is approximately start position x strand.
    JavaRDD<MatchType> tagged = cogroup.map(v1 -> {
        SAMFileHeader header = bHeader.getValue();
        Iterable<GATKRead> iFirstReads = v1._2()._1();
        Iterable<GATKRead> iSecondReads = v1._2()._2();
        return getDupes(iFirstReads, iSecondReads, header);
    });
    // TODO: We should also produce examples of reads that don't match to make debugging easier (#1263).
    Map<MatchType, Integer> tagCountMap = tagged.mapToPair(v1 -> new Tuple2<>(v1, 1)).reduceByKey((v1, v2) -> v1 + v2).collectAsMap();
    if (tagCountMap.get(MatchType.SIZE_UNEQUAL) != null) {
        throw new UserException("The number of reads by the MarkDuplicates key were unequal, indicating that the BAMs are not the same");
    }
    if (tagCountMap.get(MatchType.READ_MISMATCH) != null) {
        throw new UserException("The reads grouped by the MarkDuplicates key were not the same, indicating that the BAMs are not the same");
    }
    if (printSummary) {
        MatchType[] values = MatchType.values();
        Set<MatchType> matchTypes = Sets.newLinkedHashSet(Sets.newHashSet(values));
        System.out.println("##############################");
        matchTypes.forEach(s -> System.out.println(s + ": " + tagCountMap.getOrDefault(s, 0)));
    }
    if (throwOnDiff) {
        for (MatchType s : MatchType.values()) {
            if (s != MatchType.EQUAL) {
                if (tagCountMap.get(s) != null)
                    throw new UserException("found difference between the two BAMs: " + s + " with count " + tagCountMap.get(s));
            }
        }
    }
}
Also used : GATKRead(org.broadinstitute.hellbender.utils.read.GATKRead) Broadcast(org.apache.spark.broadcast.Broadcast) CommandLineProgramProperties(org.broadinstitute.barclay.argparser.CommandLineProgramProperties) TestSparkProgramGroup(org.broadinstitute.hellbender.cmdline.programgroups.TestSparkProgramGroup) Argument(org.broadinstitute.barclay.argparser.Argument) ReadsKey(org.broadinstitute.hellbender.utils.read.markduplicates.ReadsKey) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) GATKSparkTool(org.broadinstitute.hellbender.engine.spark.GATKSparkTool) ReadCoordinateComparator(org.broadinstitute.hellbender.utils.read.ReadCoordinateComparator) Set(java.util.Set) GATKRead(org.broadinstitute.hellbender.utils.read.GATKRead) Tuple2(scala.Tuple2) SAMFileHeader(htsjdk.samtools.SAMFileHeader) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) GATKException(org.broadinstitute.hellbender.exceptions.GATKException) Sets(com.google.common.collect.Sets) ReadUtils(org.broadinstitute.hellbender.utils.read.ReadUtils) List(java.util.List) Lists(com.google.common.collect.Lists) ReadsSparkSource(org.broadinstitute.hellbender.engine.spark.datasources.ReadsSparkSource) UserException(org.broadinstitute.hellbender.exceptions.UserException) Map(java.util.Map) Function(org.apache.spark.api.java.function.Function) JavaRDD(org.apache.spark.api.java.JavaRDD) ReadsSparkSource(org.broadinstitute.hellbender.engine.spark.datasources.ReadsSparkSource) Tuple2(scala.Tuple2) UserException(org.broadinstitute.hellbender.exceptions.UserException) SAMFileHeader(htsjdk.samtools.SAMFileHeader)

Example 25 with Broadcast

use of org.apache.spark.broadcast.Broadcast in project pyramid by cheng-li.

the class SparkCBMOptimizer method updateBinaryClassifiers.

private void updateBinaryClassifiers() {
    if (logger.isDebugEnabled()) {
        logger.debug("start updateBinaryClassifiers");
    }
    Classifier.ProbabilityEstimator[][] localBinaryClassifiers = cbm.binaryClassifiers;
    double[][] localGammasT = gammasT;
    Broadcast<MultiLabelClfDataSet> localDataSetBroadcast = dataSetBroadCast;
    Broadcast<double[][][]> localTargetsBroadcast = targetDisBroadCast;
    double localVariance = priorVarianceBinary;
    List<BinaryTask> binaryTaskList = new ArrayList<>();
    for (int k = 0; k < cbm.numComponents; k++) {
        for (int l = 0; l < cbm.numLabels; l++) {
            LogisticRegression logisticRegression = (LogisticRegression) localBinaryClassifiers[k][l];
            double[] weights = localGammasT[k];
            binaryTaskList.add(new BinaryTask(k, l, logisticRegression, weights));
        }
    }
    JavaRDD<BinaryTask> binaryTaskRDD = sparkContext.parallelize(binaryTaskList, binaryTaskList.size());
    List<BinaryTaskResult> results = binaryTaskRDD.map(binaryTask -> {
        int labelIndex = binaryTask.classIndex;
        return updateBinaryLogisticRegression(binaryTask.componentIndex, binaryTask.classIndex, binaryTask.logisticRegression, localDataSetBroadcast.value(), binaryTask.weights, localTargetsBroadcast.value()[labelIndex], localVariance);
    }).collect();
    for (BinaryTaskResult result : results) {
        cbm.binaryClassifiers[result.componentIndex][result.classIndex] = result.binaryClassifier;
    }
    //        IntStream.range(0, cbm.numComponents).forEach(this::updateBinaryClassifiers);
    if (logger.isDebugEnabled()) {
        logger.debug("finish updateBinaryClassifiers");
    }
}
Also used : IntStream(java.util.stream.IntStream) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) ArrayList(java.util.ArrayList) Classifier(edu.neu.ccs.pyramid.classification.Classifier) Terminator(edu.neu.ccs.pyramid.optimization.Terminator) LogisticRegression(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression) RidgeLogisticOptimizer(edu.neu.ccs.pyramid.classification.logistic_regression.RidgeLogisticOptimizer) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) RegTreeFactory(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory) LKBoost(edu.neu.ccs.pyramid.classification.lkboost.LKBoost) LogisticLoss(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticLoss) JavaRDD(org.apache.spark.api.java.JavaRDD) Broadcast(org.apache.spark.broadcast.Broadcast) LKBoostOptimizer(edu.neu.ccs.pyramid.classification.lkboost.LKBoostOptimizer) RegTreeConfig(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig) Serializable(java.io.Serializable) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet) KLDivergence(edu.neu.ccs.pyramid.eval.KLDivergence) List(java.util.List) Logger(org.apache.logging.log4j.Logger) ElasticNetLogisticTrainer(edu.neu.ccs.pyramid.classification.logistic_regression.ElasticNetLogisticTrainer) Entropy(edu.neu.ccs.pyramid.eval.Entropy) Vector(org.apache.mahout.math.Vector) LogManager(org.apache.logging.log4j.LogManager) LKBOutputCalculator(edu.neu.ccs.pyramid.classification.lkboost.LKBOutputCalculator) ArrayList(java.util.ArrayList) Classifier(edu.neu.ccs.pyramid.classification.Classifier) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet) LogisticRegression(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression)

Aggregations

Broadcast (org.apache.spark.broadcast.Broadcast)25 JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)23 Collectors (java.util.stream.Collectors)21 List (java.util.List)15 JavaRDD (org.apache.spark.api.java.JavaRDD)15 IntervalUtils (org.broadinstitute.hellbender.utils.IntervalUtils)15 Argument (org.broadinstitute.barclay.argparser.Argument)13 JavaPairRDD (org.apache.spark.api.java.JavaPairRDD)12 Tuple2 (scala.Tuple2)12 SAMSequenceDictionary (htsjdk.samtools.SAMSequenceDictionary)11 IntStream (java.util.stream.IntStream)11 LogManager (org.apache.logging.log4j.LogManager)11 Logger (org.apache.logging.log4j.Logger)11 ReferenceMultiSource (org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource)11 SimpleInterval (org.broadinstitute.hellbender.utils.SimpleInterval)11 StreamSupport (java.util.stream.StreamSupport)10 org.broadinstitute.hellbender.engine (org.broadinstitute.hellbender.engine)10 UserException (org.broadinstitute.hellbender.exceptions.UserException)10 FlatMapFunction (org.apache.spark.api.java.function.FlatMapFunction)9 GATKException (org.broadinstitute.hellbender.exceptions.GATKException)9