use of org.apache.spark.broadcast.Broadcast in project gatk by broadinstitute.
the class AssemblyRegionWalkerSpark method getAssemblyRegionsFunction.
private static FlatMapFunction<Shard<GATKRead>, AssemblyRegionWalkerContext> getAssemblyRegionsFunction(final Broadcast<ReferenceMultiSource> bReferenceSource, final Broadcast<FeatureManager> bFeatureManager, final SAMSequenceDictionary sequenceDictionary, final SAMFileHeader header, final AssemblyRegionEvaluator evaluator, final int minAssemblyRegionSize, final int maxAssemblyRegionSize, final int assemblyRegionPadding, final double activeProbThreshold, final int maxProbPropagationDistance) {
return (FlatMapFunction<Shard<GATKRead>, AssemblyRegionWalkerContext>) shardedRead -> {
SimpleInterval paddedInterval = shardedRead.getPaddedInterval();
SimpleInterval assemblyRegionPaddedInterval = paddedInterval.expandWithinContig(assemblyRegionPadding, sequenceDictionary);
ReferenceDataSource reference = bReferenceSource == null ? null : new ReferenceMemorySource(bReferenceSource.getValue().getReferenceBases(null, assemblyRegionPaddedInterval), sequenceDictionary);
FeatureManager features = bFeatureManager == null ? null : bFeatureManager.getValue();
ReferenceContext referenceContext = new ReferenceContext(reference, paddedInterval);
FeatureContext featureContext = new FeatureContext(features, paddedInterval);
final Iterable<AssemblyRegion> assemblyRegions = AssemblyRegion.createFromReadShard(shardedRead, header, referenceContext, featureContext, evaluator, minAssemblyRegionSize, maxAssemblyRegionSize, assemblyRegionPadding, activeProbThreshold, maxProbPropagationDistance);
return StreamSupport.stream(assemblyRegions.spliterator(), false).map(assemblyRegion -> new AssemblyRegionWalkerContext(assemblyRegion, new ReferenceContext(reference, assemblyRegion.getExtendedSpan()), new FeatureContext(features, assemblyRegion.getExtendedSpan()))).iterator();
};
}
use of org.apache.spark.broadcast.Broadcast in project gatk by broadinstitute.
the class LocusWalkerSpark method getAlignmentsFunction.
/**
* Return a function that maps a {@link Shard} of reads into a tuple of alignments and their corresponding reference and features.
* @param bReferenceSource the reference source broadcast
* @param bFeatureManager the feature manager broadcast
* @param sequenceDictionary the sequence dictionary for the reads
* @param header the reads header
* @param downsamplingInfo the downsampling method for the reads
* @return a function that maps a {@link Shard} of reads into a tuple of alignments and their corresponding reference and features.
*/
private static FlatMapFunction<Shard<GATKRead>, LocusWalkerContext> getAlignmentsFunction(Broadcast<ReferenceMultiSource> bReferenceSource, Broadcast<FeatureManager> bFeatureManager, SAMSequenceDictionary sequenceDictionary, SAMFileHeader header, LIBSDownsamplingInfo downsamplingInfo) {
return (FlatMapFunction<Shard<GATKRead>, LocusWalkerContext>) shardedRead -> {
SimpleInterval interval = shardedRead.getInterval();
SimpleInterval paddedInterval = shardedRead.getPaddedInterval();
Iterator<GATKRead> readIterator = shardedRead.iterator();
ReferenceDataSource reference = bReferenceSource == null ? null : new ReferenceMemorySource(bReferenceSource.getValue().getReferenceBases(null, paddedInterval), sequenceDictionary);
FeatureManager fm = bFeatureManager == null ? null : bFeatureManager.getValue();
final Set<String> samples = header.getReadGroups().stream().map(SAMReadGroupRecord::getSample).collect(Collectors.toSet());
LocusIteratorByState libs = new LocusIteratorByState(readIterator, downsamplingInfo, false, samples, header, true, false);
IntervalOverlappingIterator<AlignmentContext> alignmentContexts = new IntervalOverlappingIterator<>(libs, ImmutableList.of(interval), sequenceDictionary);
final Spliterator<AlignmentContext> alignmentContextSpliterator = Spliterators.spliteratorUnknownSize(alignmentContexts, 0);
return StreamSupport.stream(alignmentContextSpliterator, false).map(alignmentContext -> {
final SimpleInterval alignmentInterval = new SimpleInterval(alignmentContext);
return new LocusWalkerContext(alignmentContext, new ReferenceContext(reference, alignmentInterval), new FeatureContext(fm, alignmentInterval));
}).iterator();
};
}
use of org.apache.spark.broadcast.Broadcast in project gatk by broadinstitute.
the class ReadWalkerSpark method getReadsFunction.
private static FlatMapFunction<Shard<GATKRead>, ReadWalkerContext> getReadsFunction(Broadcast<ReferenceMultiSource> bReferenceSource, Broadcast<FeatureManager> bFeatureManager, SAMSequenceDictionary sequenceDictionary, int readShardPadding) {
return (FlatMapFunction<Shard<GATKRead>, ReadWalkerContext>) shard -> {
SimpleInterval paddedInterval = shard.getInterval().expandWithinContig(readShardPadding, sequenceDictionary);
ReferenceDataSource reference = bReferenceSource == null ? null : new ReferenceMemorySource(bReferenceSource.getValue().getReferenceBases(null, paddedInterval), sequenceDictionary);
FeatureManager features = bFeatureManager == null ? null : bFeatureManager.getValue();
return StreamSupport.stream(shard.spliterator(), false).map(r -> {
final SimpleInterval readInterval = getReadInterval(r);
return new ReadWalkerContext(r, new ReferenceContext(reference, readInterval), new FeatureContext(features, readInterval));
}).iterator();
};
}
use of org.apache.spark.broadcast.Broadcast in project gatk by broadinstitute.
the class CompareDuplicatesSpark method runTool.
@Override
protected void runTool(final JavaSparkContext ctx) {
JavaRDD<GATKRead> firstReads = filteredReads(getReads(), readArguments.getReadFilesNames().get(0));
ReadsSparkSource readsSource2 = new ReadsSparkSource(ctx, readArguments.getReadValidationStringency());
JavaRDD<GATKRead> secondReads = filteredReads(readsSource2.getParallelReads(input2, null, getIntervals(), bamPartitionSplitSize), input2);
// Start by verifying that we have same number of reads and duplicates in each BAM.
long firstBamSize = firstReads.count();
long secondBamSize = secondReads.count();
if (firstBamSize != secondBamSize) {
throw new UserException("input bams have different numbers of mapped reads: " + firstBamSize + "," + secondBamSize);
}
System.out.println("processing bams with " + firstBamSize + " mapped reads");
long firstDupesCount = firstReads.filter(GATKRead::isDuplicate).count();
long secondDupesCount = secondReads.filter(GATKRead::isDuplicate).count();
if (firstDupesCount != secondDupesCount) {
System.out.println("BAMs have different number of total duplicates: " + firstDupesCount + "," + secondDupesCount);
}
System.out.println("first and second: " + firstDupesCount + "," + secondDupesCount);
Broadcast<SAMFileHeader> bHeader = ctx.broadcast(getHeaderForReads());
// Group the reads of each BAM by MarkDuplicates key, then pair up the the reads for each BAM.
JavaPairRDD<String, GATKRead> firstKeyed = firstReads.mapToPair(read -> new Tuple2<>(ReadsKey.keyForFragment(bHeader.getValue(), read), read));
JavaPairRDD<String, GATKRead> secondKeyed = secondReads.mapToPair(read -> new Tuple2<>(ReadsKey.keyForFragment(bHeader.getValue(), read), read));
JavaPairRDD<String, Tuple2<Iterable<GATKRead>, Iterable<GATKRead>>> cogroup = firstKeyed.cogroup(secondKeyed, getRecommendedNumReducers());
// Produces an RDD of MatchTypes, e.g., EQUAL, DIFFERENT_REPRESENTATIVE_READ, etc. per MarkDuplicates key,
// which is approximately start position x strand.
JavaRDD<MatchType> tagged = cogroup.map(v1 -> {
SAMFileHeader header = bHeader.getValue();
Iterable<GATKRead> iFirstReads = v1._2()._1();
Iterable<GATKRead> iSecondReads = v1._2()._2();
return getDupes(iFirstReads, iSecondReads, header);
});
// TODO: We should also produce examples of reads that don't match to make debugging easier (#1263).
Map<MatchType, Integer> tagCountMap = tagged.mapToPair(v1 -> new Tuple2<>(v1, 1)).reduceByKey((v1, v2) -> v1 + v2).collectAsMap();
if (tagCountMap.get(MatchType.SIZE_UNEQUAL) != null) {
throw new UserException("The number of reads by the MarkDuplicates key were unequal, indicating that the BAMs are not the same");
}
if (tagCountMap.get(MatchType.READ_MISMATCH) != null) {
throw new UserException("The reads grouped by the MarkDuplicates key were not the same, indicating that the BAMs are not the same");
}
if (printSummary) {
MatchType[] values = MatchType.values();
Set<MatchType> matchTypes = Sets.newLinkedHashSet(Sets.newHashSet(values));
System.out.println("##############################");
matchTypes.forEach(s -> System.out.println(s + ": " + tagCountMap.getOrDefault(s, 0)));
}
if (throwOnDiff) {
for (MatchType s : MatchType.values()) {
if (s != MatchType.EQUAL) {
if (tagCountMap.get(s) != null)
throw new UserException("found difference between the two BAMs: " + s + " with count " + tagCountMap.get(s));
}
}
}
}
use of org.apache.spark.broadcast.Broadcast in project pyramid by cheng-li.
the class SparkCBMOptimizer method updateBinaryClassifiers.
private void updateBinaryClassifiers() {
if (logger.isDebugEnabled()) {
logger.debug("start updateBinaryClassifiers");
}
Classifier.ProbabilityEstimator[][] localBinaryClassifiers = cbm.binaryClassifiers;
double[][] localGammasT = gammasT;
Broadcast<MultiLabelClfDataSet> localDataSetBroadcast = dataSetBroadCast;
Broadcast<double[][][]> localTargetsBroadcast = targetDisBroadCast;
double localVariance = priorVarianceBinary;
List<BinaryTask> binaryTaskList = new ArrayList<>();
for (int k = 0; k < cbm.numComponents; k++) {
for (int l = 0; l < cbm.numLabels; l++) {
LogisticRegression logisticRegression = (LogisticRegression) localBinaryClassifiers[k][l];
double[] weights = localGammasT[k];
binaryTaskList.add(new BinaryTask(k, l, logisticRegression, weights));
}
}
JavaRDD<BinaryTask> binaryTaskRDD = sparkContext.parallelize(binaryTaskList, binaryTaskList.size());
List<BinaryTaskResult> results = binaryTaskRDD.map(binaryTask -> {
int labelIndex = binaryTask.classIndex;
return updateBinaryLogisticRegression(binaryTask.componentIndex, binaryTask.classIndex, binaryTask.logisticRegression, localDataSetBroadcast.value(), binaryTask.weights, localTargetsBroadcast.value()[labelIndex], localVariance);
}).collect();
for (BinaryTaskResult result : results) {
cbm.binaryClassifiers[result.componentIndex][result.classIndex] = result.binaryClassifier;
}
// IntStream.range(0, cbm.numComponents).forEach(this::updateBinaryClassifiers);
if (logger.isDebugEnabled()) {
logger.debug("finish updateBinaryClassifiers");
}
}
Aggregations