use of org.apache.spark.api.java.JavaPairRDD in project gatk by broadinstitute.
the class CoverageModelWLinearOperatorSpark method operate.
@Override
public INDArray operate(@Nonnull final INDArray W_tl) throws DimensionMismatchException {
if (W_tl.rank() != 2 || W_tl.shape()[0] != numTargets || W_tl.shape()[1] != numLatents)
throw new DimensionMismatchException(W_tl.length(), numTargets * numLatents);
/* Z F W */
final long startTimeZFW = System.nanoTime();
final INDArray Z_F_W_tl = Nd4j.create(numTargets, numLatents);
IntStream.range(0, numLatents).parallel().forEach(li -> Z_F_W_tl.get(NDArrayIndex.all(), NDArrayIndex.point(li)).assign(F_tt.operate(W_tl.get(NDArrayIndex.all(), NDArrayIndex.point(li)))));
Z_F_W_tl.assign(Nd4j.gemm(Z_F_W_tl, Z_ll, false, false));
final long endTimeZFW = System.nanoTime();
/* perform a broadcast hash join */
final long startTimeQW = System.nanoTime();
final Map<LinearlySpacedIndexBlock, INDArray> W_tl_map = CoverageModelSparkUtils.partitionINDArrayToMap(targetSpaceBlocks, W_tl);
final Broadcast<Map<LinearlySpacedIndexBlock, INDArray>> W_tl_bc = ctx.broadcast(W_tl_map);
final INDArray Q_W_tl = CoverageModelSparkUtils.assembleINDArrayBlocksFromRDD(computeRDD.mapValues(cb -> {
final INDArray W_tl_chunk = W_tl_bc.value().get(cb.getTargetSpaceBlock());
final INDArray Q_tll_chunk = cb.getINDArrayFromCache(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.Q_tll);
final Collection<INDArray> W_Q_chunk = IntStream.range(0, cb.getTargetSpaceBlock().getNumElements()).parallel().mapToObj(ti -> Q_tll_chunk.get(NDArrayIndex.point(ti)).mmul(W_tl_chunk.get(NDArrayIndex.point(ti)).transpose())).collect(Collectors.toList());
return Nd4j.vstack(W_Q_chunk);
}), 0);
W_tl_bc.destroy();
// final JavaPairRDD<LinearlySpacedIndexBlock, INDArray> W_tl_RDD = CoverageModelSparkUtils.rddFromINDArray(W_tl,
// targetSpaceBlocks, ctx, true);
// final INDArray Q_W_tl = CoverageModelSparkUtils.assembleINDArrayBlocks(
// computeRDD.join(W_tl_RDD).mapValues(p -> {
// final CoverageModelEMComputeBlock cb = p._1;
// final INDArray W_tl_chunk = p._2;
// final INDArray Q_tll_chunk = cb.getINDArrayFromCache("Q_tll");
// return Nd4j.vstack(IntStream.range(0, cb.getTargetSpaceBlock().getNumElements()).parallel()
// .mapToObj(ti -> Q_tll_chunk.get(NDArrayIndex.point(ti)).mmul(W_tl_chunk.get(NDArrayIndex.point(ti)).transpose()))
// .collect(Collectors.toList()));
// }), false);
// W_tl_RDD.unpersist();
final long endTimeQW = System.nanoTime();
logger.debug("Local [Z] [F] [W] timing: " + (endTimeZFW - startTimeZFW) / 1000000 + " ms");
logger.debug("Spark [Q] [W] timing: " + (endTimeQW - startTimeQW) / 1000000 + " ms");
return Q_W_tl.addi(Z_F_W_tl);
}
use of org.apache.spark.api.java.JavaPairRDD in project gatk-protected by broadinstitute.
the class CoverageModelEMWorkspace method updateCopyRatioPosteriorExpectationsSpark.
/**
* The Spark implementation of the E-step update of copy ratio posteriors
*
* @return a {@link SubroutineSignal} containing the update size
*/
@EvaluatesRDD
@UpdatesRDD
@CachesRDD
private SubroutineSignal updateCopyRatioPosteriorExpectationsSpark(final double admixingRatio) {
/* local final member variables for lambda capture */
final List<LinearlySpacedIndexBlock> targetBlocks = new ArrayList<>();
targetBlocks.addAll(this.targetBlocks);
final List<Target> targetList = new ArrayList<>();
targetList.addAll(processedTargetList);
final List<String> sampleNameList = new ArrayList<>();
sampleNameList.addAll(processedSampleNameList);
final List<SexGenotypeData> sampleSexGenotypeData = new ArrayList<>();
sampleSexGenotypeData.addAll(processedSampleSexGenotypeData);
final int numTargetBlocks = targetBlocks.size();
final CopyRatioExpectationsCalculator<CoverageModelCopyRatioEmissionData, STATE> calculator = this.copyRatioExpectationsCalculator;
final INDArray sampleReadDepths = Transforms.exp(sampleMeanLogReadDepths, true);
/* make an RDD of copy ratio posterior expectations */
final JavaPairRDD<Integer, CopyRatioExpectations> copyRatioPosteriorExpectationsPairRDD = /* fetch copy ratio emission data from workers */
fetchCopyRatioEmissionDataSpark().mapPartitionsToPair(it -> {
final List<Tuple2<Integer, CopyRatioExpectations>> newPartitionData = new ArrayList<>();
while (it.hasNext()) {
final Tuple2<Integer, List<CoverageModelCopyRatioEmissionData>> prevDatum = it.next();
final int si = prevDatum._1;
final CopyRatioCallingMetadata copyRatioCallingMetadata = CopyRatioCallingMetadata.builder().sampleName(sampleNameList.get(si)).sampleSexGenotypeData(sampleSexGenotypeData.get(si)).sampleCoverageDepth(sampleReadDepths.getDouble(si)).emissionCalculationStrategy(EmissionCalculationStrategy.HYBRID_POISSON_GAUSSIAN).build();
newPartitionData.add(new Tuple2<>(prevDatum._1, calculator.getCopyRatioPosteriorExpectations(copyRatioCallingMetadata, targetList, prevDatum._2)));
}
return newPartitionData.iterator();
}, true);
/* we need to do two things to copyRatioPosteriorExpectationsPairRDD; so we cache it */
/* step 1. update log chain posterior expectation on the driver node */
final double[] newSampleLogChainPosteriors = copyRatioPosteriorExpectationsPairRDD.mapValues(CopyRatioExpectations::getLogChainPosteriorProbability).collect().stream().sorted(Comparator.comparingInt(t -> t._1)).mapToDouble(t -> t._2).toArray();
sampleLogChainPosteriors.assign(Nd4j.create(newSampleLogChainPosteriors, new int[] { numSamples, 1 }));
/* step 2. repartition in target space */
final JavaPairRDD<LinearlySpacedIndexBlock, ImmutablePair<INDArray, INDArray>> blockifiedCopyRatioPosteriorResultsPairRDD = copyRatioPosteriorExpectationsPairRDD.flatMapToPair(dat -> targetBlocks.stream().map(tb -> new Tuple2<>(tb, new Tuple2<>(dat._1, ImmutablePair.of(dat._2.getLogCopyRatioMeans(tb), dat._2.getLogCopyRatioVariances(tb))))).iterator()).combineByKey(/* recipe to create an singleton list */
Collections::singletonList, /* recipe to add an element to the list */
(list, element) -> Stream.concat(list.stream(), Stream.of(element)).collect(Collectors.toList()), /* recipe to concatenate two lists */
(list1, list2) -> Stream.concat(list1.stream(), list2.stream()).collect(Collectors.toList()), /* repartition with respect to target space blocks */
new HashPartitioner(numTargetBlocks)).mapValues(list -> list.stream().sorted(Comparator.comparingInt(t -> t._1)).map(p -> p._2).map(t -> ImmutablePair.of(Nd4j.create(t.left), Nd4j.create(t.right))).collect(Collectors.toList())).mapValues(CoverageModelEMWorkspace::stackCopyRatioPosteriorDataForAllSamples);
/* we do not need copy ratio expectations anymore */
copyRatioPosteriorExpectationsPairRDD.unpersist();
/* step 3. merge with computeRDD and update */
computeRDD = computeRDD.join(blockifiedCopyRatioPosteriorResultsPairRDD).mapValues(t -> t._1.cloneWithUpdatedCopyRatioPosteriors(t._2.left, t._2.right, admixingRatio));
cacheWorkers("after E-step for copy ratio update");
/* collect subroutine signals */
final List<SubroutineSignal> sigs = mapWorkersAndCollect(CoverageModelEMComputeBlock::getLatestMStepSignal);
final double errorNormInfinity = Collections.max(sigs.stream().map(sig -> sig.<Double>get(StandardSubroutineSignals.RESIDUAL_ERROR_NORM)).collect(Collectors.toList()));
return SubroutineSignal.builder().put(StandardSubroutineSignals.RESIDUAL_ERROR_NORM, errorNormInfinity).build();
}
use of org.apache.spark.api.java.JavaPairRDD in project gatk by broadinstitute.
the class MarkDuplicatesSparkUtils method generateMetrics.
static JavaPairRDD<String, DuplicationMetrics> generateMetrics(final SAMFileHeader header, final JavaRDD<GATKRead> reads) {
return reads.filter(read -> !read.isSecondaryAlignment() && !read.isSupplementaryAlignment()).mapToPair(read -> {
final String library = LibraryIdGenerator.getLibraryName(header, read.getReadGroup());
DuplicationMetrics metrics = new DuplicationMetrics();
metrics.LIBRARY = library;
if (read.isUnmapped()) {
++metrics.UNMAPPED_READS;
} else if (!read.isPaired() || read.mateIsUnmapped()) {
++metrics.UNPAIRED_READS_EXAMINED;
} else {
++metrics.READ_PAIRS_EXAMINED;
}
if (read.isDuplicate()) {
if (!read.isPaired() || read.mateIsUnmapped()) {
++metrics.UNPAIRED_READ_DUPLICATES;
} else {
++metrics.READ_PAIR_DUPLICATES;
}
}
if (read.hasAttribute(OPTICAL_DUPLICATE_TOTAL_ATTRIBUTE_NAME)) {
metrics.READ_PAIR_OPTICAL_DUPLICATES += read.getAttributeAsInteger(OPTICAL_DUPLICATE_TOTAL_ATTRIBUTE_NAME);
}
return new Tuple2<>(library, metrics);
}).foldByKey(new DuplicationMetrics(), (metricsSum, m) -> {
if (metricsSum.LIBRARY == null) {
metricsSum.LIBRARY = m.LIBRARY;
}
// This should never happen, as we grouped by key using library as the key.
if (!metricsSum.LIBRARY.equals(m.LIBRARY)) {
throw new GATKException("Two different libraries encountered while summing metrics: " + metricsSum.LIBRARY + " and " + m.LIBRARY);
}
metricsSum.UNMAPPED_READS += m.UNMAPPED_READS;
metricsSum.UNPAIRED_READS_EXAMINED += m.UNPAIRED_READS_EXAMINED;
metricsSum.READ_PAIRS_EXAMINED += m.READ_PAIRS_EXAMINED;
metricsSum.UNPAIRED_READ_DUPLICATES += m.UNPAIRED_READ_DUPLICATES;
metricsSum.READ_PAIR_DUPLICATES += m.READ_PAIR_DUPLICATES;
metricsSum.READ_PAIR_OPTICAL_DUPLICATES += m.READ_PAIR_OPTICAL_DUPLICATES;
return metricsSum;
}).mapValues(metrics -> {
DuplicationMetrics copy = metrics.copy();
copy.READ_PAIRS_EXAMINED = metrics.READ_PAIRS_EXAMINED / 2;
copy.READ_PAIR_DUPLICATES = metrics.READ_PAIR_DUPLICATES / 2;
copy.calculateDerivedMetrics();
if (copy.ESTIMATED_LIBRARY_SIZE == null) {
copy.ESTIMATED_LIBRARY_SIZE = 0L;
}
return copy;
});
}
use of org.apache.spark.api.java.JavaPairRDD in project gatk by broadinstitute.
the class SparkSharderUnitTest method testSingleContig.
@Test
public void testSingleContig() throws IOException {
JavaSparkContext ctx = SparkContextFactory.getTestSparkContext();
// Consider the following reads (divided into four partitions), and intervals.
// This test counts the number of reads that overlap each interval.
// 1 2
// 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7
// ---------------------------------------------------------
// Reads in partition 0
// [-----]
// [-----]
// [-----]
// ---------------------------------------------------------
// Reads in partition 1
// [-----]
// [-----]
// [-----]
// ---------------------------------------------------------
// Reads in partition 2
// [-----]
// [-----]
// [-----]
// ---------------------------------------------------------
// Reads in partition 3
// [-----]
// [-----]
// [-----]
// ---------------------------------------------------------
// Per-partition read extents
// [-----------------]
// [-----]
// [---------------]
// [---------------------]
// ---------------------------------------------------------
// Intervals
// [-----]
// [---------]
// [-----------------------]
//
// 1 2
// ---------------------------------------------------------
// 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7
JavaRDD<TestRead> reads = ctx.parallelize(ImmutableList.of(new TestRead(1, 3), new TestRead(5, 7), new TestRead(7, 9), new TestRead(7, 9), new TestRead(7, 9), new TestRead(7, 9), new TestRead(7, 9), new TestRead(11, 13), new TestRead(12, 14), new TestRead(17, 19), new TestRead(21, 23), new TestRead(25, 27)), 4);
List<SimpleInterval> intervals = ImmutableList.of(new SimpleInterval("1", 2, 4), new SimpleInterval("1", 8, 12), new SimpleInterval("1", 11, 22));
List<ShardBoundary> shardBoundaries = intervals.stream().map(si -> new ShardBoundary(si, si)).collect(Collectors.toList());
ImmutableMap<SimpleInterval, Integer> expectedReadsPerInterval = ImmutableMap.of(intervals.get(0), 1, intervals.get(1), 7, intervals.get(2), 4);
JavaPairRDD<Locatable, Integer> readsPerInterval = SparkSharder.shard(ctx, reads, TestRead.class, sequenceDictionary, shardBoundaries, STANDARD_READ_LENGTH, false).flatMapToPair(new CountOverlappingReadsFunction());
assertEquals(readsPerInterval.collectAsMap(), expectedReadsPerInterval);
JavaPairRDD<Locatable, Integer> readsPerIntervalShuffle = SparkSharder.shard(ctx, reads, TestRead.class, sequenceDictionary, shardBoundaries, STANDARD_READ_LENGTH, true).flatMapToPair(new CountOverlappingReadsFunction());
assertEquals(readsPerIntervalShuffle.collectAsMap(), expectedReadsPerInterval);
try {
// max read length less than actual causes exception
int maxReadLength = STANDARD_READ_LENGTH - 1;
SparkSharder.shard(ctx, reads, TestRead.class, sequenceDictionary, shardBoundaries, maxReadLength, true).flatMapToPair(new CountOverlappingReadsFunction()).collect();
} catch (Exception e) {
assertEquals(e.getCause().getClass(), UserException.class);
}
}
use of org.apache.spark.api.java.JavaPairRDD in project gatk by broadinstitute.
the class ShuffleJoinReadsWithRefBases method addBases.
/**
* Joins each read of an RDD<GATKRead, T> with key's corresponding reference sequence.
*
* @param referenceDataflowSource The source of the reference sequence information
* @param keyedByRead The read-keyed RDD for which to extract reference sequence information
* @return The JavaPairRDD that contains each read along with the corresponding ReferenceBases object and the value
*/
public static <T> JavaPairRDD<GATKRead, Tuple2<T, ReferenceBases>> addBases(final ReferenceMultiSource referenceDataflowSource, final JavaPairRDD<GATKRead, T> keyedByRead) {
SerializableFunction<GATKRead, SimpleInterval> windowFunction = referenceDataflowSource.getReferenceWindowFunction();
JavaPairRDD<ReferenceShard, Tuple2<GATKRead, T>> shardRead = keyedByRead.mapToPair(pair -> {
ReferenceShard shard = ReferenceShard.getShardNumberFromInterval(windowFunction.apply(pair._1()));
return new Tuple2<>(shard, pair);
});
JavaPairRDD<ReferenceShard, Iterable<Tuple2<GATKRead, T>>> shardiRead = shardRead.groupByKey();
return shardiRead.flatMapToPair(in -> {
List<Tuple2<GATKRead, Tuple2<T, ReferenceBases>>> out = Lists.newArrayList();
Iterable<Tuple2<GATKRead, T>> iReads = in._2();
final List<SimpleInterval> readWindows = Utils.stream(iReads).map(pair -> windowFunction.apply(pair._1())).collect(Collectors.toList());
SimpleInterval interval = IntervalUtils.getSpanningInterval(readWindows);
ReferenceBases bases = referenceDataflowSource.getReferenceBases(null, interval);
for (Tuple2<GATKRead, T> p : iReads) {
final ReferenceBases subset = bases.getSubset(windowFunction.apply(p._1()));
out.add(new Tuple2<>(p._1(), new Tuple2<>(p._2(), subset)));
}
return out.iterator();
});
}
Aggregations