use of org.apache.spark.broadcast.Broadcast in project gatk by broadinstitute.
the class CoverageModelEMWorkspace method initializeWorkersWithPCA.
/**
* Initialize model parameters by performing PCA.
*/
@EvaluatesRDD
@UpdatesRDD
@CachesRDD
private void initializeWorkersWithPCA() {
logger.info("Initializing model parameters using PCA...");
/* initially, set m_t, Psi_t and W_tl to zero to get an estimate of the read depth */
final int numLatents = config.getNumLatents();
mapWorkers(cb -> cb.cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.m_t, Nd4j.zeros(new int[] { 1, cb.getTargetSpaceBlock().getNumElements() })).cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.Psi_t, Nd4j.zeros(new int[] { 1, cb.getTargetSpaceBlock().getNumElements() })));
if (biasCovariatesEnabled) {
mapWorkers(cb -> cb.cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.W_tl, Nd4j.zeros(new int[] { cb.getTargetSpaceBlock().getNumElements(), numLatents })));
}
/* update read depth without taking into account correction from bias covariates */
updateReadDepthPosteriorExpectations(1.0, true);
/* fetch sample covariance matrix */
final int minPCAInitializationReadCount = config.getMinPCAInitializationReadCount();
mapWorkers(cb -> cb.cloneWithPCAInitializationData(minPCAInitializationReadCount, Integer.MAX_VALUE));
cacheWorkers("PCA initialization");
final INDArray targetCovarianceMatrix = mapWorkersAndReduce(CoverageModelEMComputeBlock::calculateTargetCovarianceMatrixForPCAInitialization, INDArray::add);
/* perform eigen-decomposition on the target covariance matrix */
final ImmutablePair<INDArray, INDArray> targetCovarianceEigensystem = CoverageModelEMWorkspaceMathUtils.eig(targetCovarianceMatrix, false, logger);
/* the eigenvalues of sample covariance matrix can be immediately inferred by scaling */
final INDArray sampleCovarianceEigenvalues = targetCovarianceEigensystem.getLeft().div(numSamples);
/* estimate the isotropic unexplained variance -- see Bishop 12.46 */
final int residualDim = numTargets - numLatents;
final double isotropicVariance = sampleCovarianceEigenvalues.get(NDArrayIndex.interval(numLatents, numSamples)).sumNumber().doubleValue() / residualDim;
logger.info(String.format("PCA estimate of isotropic unexplained variance: %f", isotropicVariance));
/* estimate bias factors -- see Bishop 12.45 */
final INDArray scaleFactors = Transforms.sqrt(sampleCovarianceEigenvalues.get(NDArrayIndex.interval(0, numLatents)).sub(isotropicVariance), false);
final INDArray biasCovariatesPCA = Nd4j.create(new int[] { numTargets, numLatents });
for (int li = 0; li < numLatents; li++) {
final INDArray v = targetCovarianceEigensystem.getRight().getColumn(li);
/* calculate [Delta_PCA_st]^T v */
/* note: we do not need to broadcast vec since it is small and lambda capture is just fine */
final INDArray unnormedBiasCovariate = CoverageModelSparkUtils.assembleINDArrayBlocksFromCollection(mapWorkersAndCollect(cb -> ImmutablePair.of(cb.getTargetSpaceBlock(), cb.getINDArrayFromCache(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.Delta_PCA_st).transpose().mmul(v))), 0);
final double norm = unnormedBiasCovariate.norm1Number().doubleValue();
final INDArray normedBiasCovariate = unnormedBiasCovariate.divi(norm).muli(scaleFactors.getDouble(li));
biasCovariatesPCA.getColumn(li).assign(normedBiasCovariate);
}
if (ardEnabled) {
/* a better estimate of ARD coefficients */
biasCovariatesARDCoefficients.assign(Nd4j.zeros(new int[] { 1, numLatents }).addi(config.getInitialARDPrecisionRelativeToNoise() / isotropicVariance));
}
final CoverageModelParameters modelParamsFromPCA = new CoverageModelParameters(processedTargetList, Nd4j.zeros(new int[] { 1, numTargets }), Nd4j.zeros(new int[] { 1, numTargets }).addi(isotropicVariance), biasCovariatesPCA, biasCovariatesARDCoefficients);
/* clear PCA initialization data from workers */
mapWorkers(CoverageModelEMComputeBlock::cloneWithRemovedPCAInitializationData);
/* push model parameters to workers */
initializeWorkersWithGivenModel(modelParamsFromPCA);
/* update bias latent posterior expectations without admixing */
updateBiasLatentPosteriorExpectations(1.0);
}
use of org.apache.spark.broadcast.Broadcast in project gatk by broadinstitute.
the class CoverageModelWPreconditionerSpark method operate.
@Override
public INDArray operate(@Nonnull final INDArray W_tl) throws DimensionMismatchException {
if (W_tl.rank() != 2 || W_tl.shape()[0] != numTargets || W_tl.shape()[1] != numLatents) {
throw new DimensionMismatchException(W_tl.length(), numTargets * numLatents);
}
long startTimeRFFT = System.nanoTime();
/* forward rfft */
final INDArray W_kl = Nd4j.create(fftSize, numLatents);
IntStream.range(0, numLatents).parallel().forEach(li -> W_kl.get(NDArrayIndex.all(), NDArrayIndex.point(li)).assign(Nd4j.create(F_tt.getForwardFFT(W_tl.get(NDArrayIndex.all(), NDArrayIndex.point(li))), new int[] { fftSize, 1 })));
long endTimeRFFT = System.nanoTime();
/* apply the preconditioner in the Fourier space */
long startTimePrecond = System.nanoTime();
final Map<LinearlySpacedIndexBlock, INDArray> W_kl_map = CoverageModelSparkUtils.partitionINDArrayToMap(fourierSpaceBlocks, W_kl);
final Broadcast<Map<LinearlySpacedIndexBlock, INDArray>> W_kl_bc = ctx.broadcast(W_kl_map);
final JavaPairRDD<LinearlySpacedIndexBlock, INDArray> preconditionedWRDD = linOpPairRDD.mapToPair(p -> {
final INDArray W_kl_chuck = W_kl_bc.value().get(p._1);
final INDArray linOp_chunk = p._2;
final int blockSize = linOp_chunk.shape()[0];
final List<INDArray> linOpWList = IntStream.range(0, blockSize).parallel().mapToObj(k -> CoverageModelEMWorkspaceMathUtils.linsolve(linOp_chunk.get(NDArrayIndex.point(k)), W_kl_chuck.get(NDArrayIndex.point(k)))).collect(Collectors.toList());
return new Tuple2<>(p._1, Nd4j.vstack(linOpWList));
});
W_kl.assign(CoverageModelSparkUtils.assembleINDArrayBlocksFromRDD(preconditionedWRDD, 0));
W_kl_bc.destroy();
// final JavaPairRDD<LinearlySpacedIndexBlock, INDArray> W_kl_RDD = CoverageModelSparkUtils.rddFromINDArray(W_kl,
// fourierSpaceBlocks, ctx, true);
// W_kl.assign(CoverageModelSparkUtils.assembleINDArrayBlocks(linOpPairRDD.join((W_kl_RDD))
// .mapValues(p -> {
// final INDArray linOp = p._1;
// final INDArray W = p._2;
// final int blockSize = linOp.shape()[0];
// final List<INDArray> linOpWList = IntStream.range(0, blockSize).parallel().mapToObj(k ->
// CoverageModelEMWorkspaceMathUtils.linsolve(linOp.get(NDArrayIndex.point(k)),
// W.get(NDArrayIndex.point(k))))
// .collect(Collectors.toList());
// return Nd4j.vstack(linOpWList);
// }), false));
// W_kl_RDD.unpersist();
long endTimePrecond = System.nanoTime();
/* irfft */
long startTimeIRFFT = System.nanoTime();
final INDArray res = Nd4j.create(numTargets, numLatents);
IntStream.range(0, numLatents).parallel().forEach(li -> res.get(NDArrayIndex.all(), NDArrayIndex.point(li)).assign(F_tt.getInverseFFT(W_kl.get(NDArrayIndex.all(), NDArrayIndex.point(li)))));
long endTimeIRFFT = System.nanoTime();
logger.debug("Local FFT timing: " + (endTimeRFFT - startTimeRFFT + endTimeIRFFT - startTimeIRFFT) / 1000000 + " ms");
logger.debug("Spark preconditioner application timing: " + (endTimePrecond - startTimePrecond) / 1000000 + " ms");
return res;
}
use of org.apache.spark.broadcast.Broadcast in project gatk by broadinstitute.
the class CoverageModelWLinearOperatorSpark method operate.
@Override
public INDArray operate(@Nonnull final INDArray W_tl) throws DimensionMismatchException {
if (W_tl.rank() != 2 || W_tl.shape()[0] != numTargets || W_tl.shape()[1] != numLatents)
throw new DimensionMismatchException(W_tl.length(), numTargets * numLatents);
/* Z F W */
final long startTimeZFW = System.nanoTime();
final INDArray Z_F_W_tl = Nd4j.create(numTargets, numLatents);
IntStream.range(0, numLatents).parallel().forEach(li -> Z_F_W_tl.get(NDArrayIndex.all(), NDArrayIndex.point(li)).assign(F_tt.operate(W_tl.get(NDArrayIndex.all(), NDArrayIndex.point(li)))));
Z_F_W_tl.assign(Nd4j.gemm(Z_F_W_tl, Z_ll, false, false));
final long endTimeZFW = System.nanoTime();
/* perform a broadcast hash join */
final long startTimeQW = System.nanoTime();
final Map<LinearlySpacedIndexBlock, INDArray> W_tl_map = CoverageModelSparkUtils.partitionINDArrayToMap(targetSpaceBlocks, W_tl);
final Broadcast<Map<LinearlySpacedIndexBlock, INDArray>> W_tl_bc = ctx.broadcast(W_tl_map);
final INDArray Q_W_tl = CoverageModelSparkUtils.assembleINDArrayBlocksFromRDD(computeRDD.mapValues(cb -> {
final INDArray W_tl_chunk = W_tl_bc.value().get(cb.getTargetSpaceBlock());
final INDArray Q_tll_chunk = cb.getINDArrayFromCache(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.Q_tll);
final Collection<INDArray> W_Q_chunk = IntStream.range(0, cb.getTargetSpaceBlock().getNumElements()).parallel().mapToObj(ti -> Q_tll_chunk.get(NDArrayIndex.point(ti)).mmul(W_tl_chunk.get(NDArrayIndex.point(ti)).transpose())).collect(Collectors.toList());
return Nd4j.vstack(W_Q_chunk);
}), 0);
W_tl_bc.destroy();
// final JavaPairRDD<LinearlySpacedIndexBlock, INDArray> W_tl_RDD = CoverageModelSparkUtils.rddFromINDArray(W_tl,
// targetSpaceBlocks, ctx, true);
// final INDArray Q_W_tl = CoverageModelSparkUtils.assembleINDArrayBlocks(
// computeRDD.join(W_tl_RDD).mapValues(p -> {
// final CoverageModelEMComputeBlock cb = p._1;
// final INDArray W_tl_chunk = p._2;
// final INDArray Q_tll_chunk = cb.getINDArrayFromCache("Q_tll");
// return Nd4j.vstack(IntStream.range(0, cb.getTargetSpaceBlock().getNumElements()).parallel()
// .mapToObj(ti -> Q_tll_chunk.get(NDArrayIndex.point(ti)).mmul(W_tl_chunk.get(NDArrayIndex.point(ti)).transpose()))
// .collect(Collectors.toList()));
// }), false);
// W_tl_RDD.unpersist();
final long endTimeQW = System.nanoTime();
logger.debug("Local [Z] [F] [W] timing: " + (endTimeZFW - startTimeZFW) / 1000000 + " ms");
logger.debug("Spark [Q] [W] timing: " + (endTimeQW - startTimeQW) / 1000000 + " ms");
return Q_W_tl.addi(Z_F_W_tl);
}
use of org.apache.spark.broadcast.Broadcast in project gatk by broadinstitute.
the class LocusWalkerSpark method getAlignments.
/**
* Loads alignments and the corresponding reference and features into a {@link JavaRDD} for the intervals specified.
*
* If no intervals were specified, returns all the alignments.
*
* @return all alignments as a {@link JavaRDD}, bounded by intervals if specified.
*/
public JavaRDD<LocusWalkerContext> getAlignments(JavaSparkContext ctx) {
SAMSequenceDictionary sequenceDictionary = getBestAvailableSequenceDictionary();
List<SimpleInterval> intervals = hasIntervals() ? getIntervals() : IntervalUtils.getAllIntervalsForReference(sequenceDictionary);
final List<ShardBoundary> intervalShards = intervals.stream().flatMap(interval -> Shard.divideIntervalIntoShards(interval, readShardSize, readShardPadding, sequenceDictionary).stream()).collect(Collectors.toList());
int maxLocatableSize = Math.min(readShardSize, readShardPadding);
JavaRDD<Shard<GATKRead>> shardedReads = SparkSharder.shard(ctx, getReads(), GATKRead.class, sequenceDictionary, intervalShards, maxLocatableSize, shuffle);
Broadcast<ReferenceMultiSource> bReferenceSource = hasReference() ? ctx.broadcast(getReference()) : null;
Broadcast<FeatureManager> bFeatureManager = features == null ? null : ctx.broadcast(features);
return shardedReads.flatMap(getAlignmentsFunction(bReferenceSource, bFeatureManager, sequenceDictionary, getHeaderForReads(), getDownsamplingInfo()));
}
use of org.apache.spark.broadcast.Broadcast in project gatk by broadinstitute.
the class VariantWalkerSpark method getVariantsFunction.
private static FlatMapFunction<Shard<VariantContext>, VariantWalkerContext> getVariantsFunction(final Broadcast<ReferenceMultiSource> bReferenceSource, final Broadcast<FeatureManager> bFeatureManager, final SAMSequenceDictionary sequenceDictionary, final int variantShardPadding) {
return (FlatMapFunction<Shard<VariantContext>, VariantWalkerContext>) shard -> {
SimpleInterval paddedInterval = shard.getInterval().expandWithinContig(variantShardPadding, sequenceDictionary);
ReferenceDataSource reference = bReferenceSource == null ? null : new ReferenceMemorySource(bReferenceSource.getValue().getReferenceBases(null, paddedInterval), sequenceDictionary);
FeatureManager features = bFeatureManager == null ? null : bFeatureManager.getValue();
return StreamSupport.stream(shard.spliterator(), false).filter(v -> v.getStart() >= shard.getStart() && v.getStart() <= shard.getEnd()).map(v -> {
final SimpleInterval variantInterval = new SimpleInterval(v);
return new VariantWalkerContext(v, new ReadsContext(), new ReferenceContext(reference, variantInterval), new FeatureContext(features, variantInterval));
}).iterator();
};
}
Aggregations