use of org.apache.spark.api.java.JavaPairRDD in project gatk by broadinstitute.
the class CoverageModelEMWorkspace method instantiateWorkers.
/**
* Instantiate compute block(s). If Spark is disabled, a single {@link CoverageModelEMComputeBlock} is
* instantiated. Otherwise, a {@link JavaPairRDD} of compute nodes will be created.
*/
private void instantiateWorkers() {
if (sparkContextIsAvailable) {
/* initialize the RDD */
logger.info("Initializing an RDD of compute blocks");
computeRDD = ctx.parallelizePairs(targetBlockStream().map(tb -> new Tuple2<>(tb, new CoverageModelEMComputeBlock(tb, numSamples, numLatents, ardEnabled))).collect(Collectors.toList()), numTargetBlocks).partitionBy(new HashPartitioner(numTargetBlocks)).cache();
} else {
logger.info("Initializing a local compute block");
localComputeBlock = new CoverageModelEMComputeBlock(targetBlocks.get(0), numSamples, numLatents, ardEnabled);
}
prevCheckpointedComputeRDD = null;
cacheCallCounter = 0;
}
use of org.apache.spark.api.java.JavaPairRDD in project gatk by broadinstitute.
the class SparkSharderUnitTest method testContigBoundary.
@Test
public void testContigBoundary() throws IOException {
JavaSparkContext ctx = SparkContextFactory.getTestSparkContext();
// Consider the following reads (divided into four partitions), and intervals.
// This test counts the number of reads that overlap each interval.
// 1 2
// 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7
// ---------------------------------------------------------
// Reads in partition 0
// [-----] chr 1
// [-----] chr 1
// [-----] chr 1
// [-----] chr 2
// [-----] chr 2
// ---------------------------------------------------------
// Per-partition read extents
// [-----------------] chr 1
// [-------] chr 2
// ---------------------------------------------------------
// Intervals
// [-----] chr 1
// [---------] chr 1
// [-----------------------] chr 2
// ---------------------------------------------------------
// 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7
JavaRDD<TestRead> reads = ctx.parallelize(ImmutableList.of(new TestRead("1", 1, 3), new TestRead("1", 5, 7), new TestRead("1", 7, 9), new TestRead("2", 1, 3), new TestRead("2", 2, 4)), 1);
List<SimpleInterval> intervals = ImmutableList.of(new SimpleInterval("1", 2, 4), new SimpleInterval("1", 8, 12), new SimpleInterval("2", 1, 12));
List<ShardBoundary> shardBoundaries = intervals.stream().map(si -> new ShardBoundary(si, si)).collect(Collectors.toList());
ImmutableMap<SimpleInterval, Integer> expectedReadsPerInterval = ImmutableMap.of(intervals.get(0), 1, intervals.get(1), 1, intervals.get(2), 2);
JavaPairRDD<Locatable, Integer> readsPerInterval = SparkSharder.shard(ctx, reads, TestRead.class, sequenceDictionary, shardBoundaries, STANDARD_READ_LENGTH, false).flatMapToPair(new CountOverlappingReadsFunction());
assertEquals(readsPerInterval.collectAsMap(), expectedReadsPerInterval);
JavaPairRDD<Locatable, Integer> readsPerIntervalShuffle = SparkSharder.shard(ctx, reads, TestRead.class, sequenceDictionary, shardBoundaries, STANDARD_READ_LENGTH, true).flatMapToPair(new CountOverlappingReadsFunction());
assertEquals(readsPerIntervalShuffle.collectAsMap(), expectedReadsPerInterval);
}
use of org.apache.spark.api.java.JavaPairRDD in project gatk by broadinstitute.
the class JoinReadsWithVariantsSparkUnitTest method pairReadsAndVariantsTest.
@Test(dataProvider = "pairedReadsAndVariants", groups = "spark")
public void pairReadsAndVariantsTest(List<GATKRead> reads, List<GATKVariant> variantList, List<KV<GATKRead, Iterable<GATKVariant>>> kvReadiVariant, JoinStrategy joinStrategy) {
JavaSparkContext ctx = SparkContextFactory.getTestSparkContext();
JavaRDD<GATKRead> rddReads = ctx.parallelize(reads);
JavaRDD<GATKVariant> rddVariants = ctx.parallelize(variantList);
JavaPairRDD<GATKRead, Iterable<GATKVariant>> actual = joinStrategy == JoinStrategy.SHUFFLE ? ShuffleJoinReadsWithVariants.join(rddReads, rddVariants) : BroadcastJoinReadsWithVariants.join(rddReads, rddVariants);
Map<GATKRead, Iterable<GATKVariant>> gatkReadIterableMap = actual.collectAsMap();
Assert.assertEquals(gatkReadIterableMap.size(), kvReadiVariant.size());
for (KV<GATKRead, Iterable<GATKVariant>> kv : kvReadiVariant) {
List<GATKVariant> variants = Lists.newArrayList(gatkReadIterableMap.get(kv.getKey()));
Assert.assertTrue(variants.stream().noneMatch(v -> v == null));
HashSet<GATKVariant> hashVariants = new LinkedHashSet<>(variants);
final Iterable<GATKVariant> iVariants = kv.getValue();
HashSet<GATKVariant> expectedHashVariants = Sets.newLinkedHashSet(iVariants);
Assert.assertEquals(hashVariants, expectedHashVariants);
}
}
use of org.apache.spark.api.java.JavaPairRDD in project gatk-protected by broadinstitute.
the class CoverageModelWLinearOperatorSpark method operate.
@Override
public INDArray operate(@Nonnull final INDArray W_tl) throws DimensionMismatchException {
if (W_tl.rank() != 2 || W_tl.shape()[0] != numTargets || W_tl.shape()[1] != numLatents)
throw new DimensionMismatchException(W_tl.length(), numTargets * numLatents);
/* Z F W */
final long startTimeZFW = System.nanoTime();
final INDArray Z_F_W_tl = Nd4j.create(numTargets, numLatents);
IntStream.range(0, numLatents).parallel().forEach(li -> Z_F_W_tl.get(NDArrayIndex.all(), NDArrayIndex.point(li)).assign(F_tt.operate(W_tl.get(NDArrayIndex.all(), NDArrayIndex.point(li)))));
Z_F_W_tl.assign(Nd4j.gemm(Z_F_W_tl, Z_ll, false, false));
final long endTimeZFW = System.nanoTime();
/* perform a broadcast hash join */
final long startTimeQW = System.nanoTime();
final Map<LinearlySpacedIndexBlock, INDArray> W_tl_map = CoverageModelSparkUtils.partitionINDArrayToMap(targetSpaceBlocks, W_tl);
final Broadcast<Map<LinearlySpacedIndexBlock, INDArray>> W_tl_bc = ctx.broadcast(W_tl_map);
final INDArray Q_W_tl = CoverageModelSparkUtils.assembleINDArrayBlocksFromRDD(computeRDD.mapValues(cb -> {
final INDArray W_tl_chunk = W_tl_bc.value().get(cb.getTargetSpaceBlock());
final INDArray Q_tll_chunk = cb.getINDArrayFromCache(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.Q_tll);
final Collection<INDArray> W_Q_chunk = IntStream.range(0, cb.getTargetSpaceBlock().getNumElements()).parallel().mapToObj(ti -> Q_tll_chunk.get(NDArrayIndex.point(ti)).mmul(W_tl_chunk.get(NDArrayIndex.point(ti)).transpose())).collect(Collectors.toList());
return Nd4j.vstack(W_Q_chunk);
}), 0);
W_tl_bc.destroy();
// final JavaPairRDD<LinearlySpacedIndexBlock, INDArray> W_tl_RDD = CoverageModelSparkUtils.rddFromINDArray(W_tl,
// targetSpaceBlocks, ctx, true);
// final INDArray Q_W_tl = CoverageModelSparkUtils.assembleINDArrayBlocks(
// computeRDD.join(W_tl_RDD).mapValues(p -> {
// final CoverageModelEMComputeBlock cb = p._1;
// final INDArray W_tl_chunk = p._2;
// final INDArray Q_tll_chunk = cb.getINDArrayFromCache("Q_tll");
// return Nd4j.vstack(IntStream.range(0, cb.getTargetSpaceBlock().getNumElements()).parallel()
// .mapToObj(ti -> Q_tll_chunk.get(NDArrayIndex.point(ti)).mmul(W_tl_chunk.get(NDArrayIndex.point(ti)).transpose()))
// .collect(Collectors.toList()));
// }), false);
// W_tl_RDD.unpersist();
final long endTimeQW = System.nanoTime();
logger.debug("Local [Z] [F] [W] timing: " + (endTimeZFW - startTimeZFW) / 1000000 + " ms");
logger.debug("Spark [Q] [W] timing: " + (endTimeQW - startTimeQW) / 1000000 + " ms");
return Q_W_tl.addi(Z_F_W_tl);
}
use of org.apache.spark.api.java.JavaPairRDD in project gatk-protected by broadinstitute.
the class CoverageModelWPreconditionerSpark method operate.
@Override
public INDArray operate(@Nonnull final INDArray W_tl) throws DimensionMismatchException {
if (W_tl.rank() != 2 || W_tl.shape()[0] != numTargets || W_tl.shape()[1] != numLatents) {
throw new DimensionMismatchException(W_tl.length(), numTargets * numLatents);
}
long startTimeRFFT = System.nanoTime();
/* forward rfft */
final INDArray W_kl = Nd4j.create(fftSize, numLatents);
IntStream.range(0, numLatents).parallel().forEach(li -> W_kl.get(NDArrayIndex.all(), NDArrayIndex.point(li)).assign(Nd4j.create(F_tt.getForwardFFT(W_tl.get(NDArrayIndex.all(), NDArrayIndex.point(li))), new int[] { fftSize, 1 })));
long endTimeRFFT = System.nanoTime();
/* apply the preconditioner in the Fourier space */
long startTimePrecond = System.nanoTime();
final Map<LinearlySpacedIndexBlock, INDArray> W_kl_map = CoverageModelSparkUtils.partitionINDArrayToMap(fourierSpaceBlocks, W_kl);
final Broadcast<Map<LinearlySpacedIndexBlock, INDArray>> W_kl_bc = ctx.broadcast(W_kl_map);
final JavaPairRDD<LinearlySpacedIndexBlock, INDArray> preconditionedWRDD = linOpPairRDD.mapToPair(p -> {
final INDArray W_kl_chuck = W_kl_bc.value().get(p._1);
final INDArray linOp_chunk = p._2;
final int blockSize = linOp_chunk.shape()[0];
final List<INDArray> linOpWList = IntStream.range(0, blockSize).parallel().mapToObj(k -> CoverageModelEMWorkspaceMathUtils.linsolve(linOp_chunk.get(NDArrayIndex.point(k)), W_kl_chuck.get(NDArrayIndex.point(k)))).collect(Collectors.toList());
return new Tuple2<>(p._1, Nd4j.vstack(linOpWList));
});
W_kl.assign(CoverageModelSparkUtils.assembleINDArrayBlocksFromRDD(preconditionedWRDD, 0));
W_kl_bc.destroy();
// final JavaPairRDD<LinearlySpacedIndexBlock, INDArray> W_kl_RDD = CoverageModelSparkUtils.rddFromINDArray(W_kl,
// fourierSpaceBlocks, ctx, true);
// W_kl.assign(CoverageModelSparkUtils.assembleINDArrayBlocks(linOpPairRDD.join((W_kl_RDD))
// .mapValues(p -> {
// final INDArray linOp = p._1;
// final INDArray W = p._2;
// final int blockSize = linOp.shape()[0];
// final List<INDArray> linOpWList = IntStream.range(0, blockSize).parallel().mapToObj(k ->
// CoverageModelEMWorkspaceMathUtils.linsolve(linOp.get(NDArrayIndex.point(k)),
// W.get(NDArrayIndex.point(k))))
// .collect(Collectors.toList());
// return Nd4j.vstack(linOpWList);
// }), false));
// W_kl_RDD.unpersist();
long endTimePrecond = System.nanoTime();
/* irfft */
long startTimeIRFFT = System.nanoTime();
final INDArray res = Nd4j.create(numTargets, numLatents);
IntStream.range(0, numLatents).parallel().forEach(li -> res.get(NDArrayIndex.all(), NDArrayIndex.point(li)).assign(F_tt.getInverseFFT(W_kl.get(NDArrayIndex.all(), NDArrayIndex.point(li)))));
long endTimeIRFFT = System.nanoTime();
logger.debug("Local FFT timing: " + (endTimeRFFT - startTimeRFFT + endTimeIRFFT - startTimeIRFFT) / 1000000 + " ms");
logger.debug("Spark preconditioner application timing: " + (endTimePrecond - startTimePrecond) / 1000000 + " ms");
return res;
}
Aggregations