use of org.apache.spark.broadcast.Broadcast in project gatk by broadinstitute.
the class VariantWalkerSpark method getVariants.
/**
* Loads variants and the corresponding reads, reference and features into a {@link JavaRDD} for the intervals specified.
* FOr the current implementation the reads context will always be empty.
*
* If no intervals were specified, returns all the variants.
*
* @return all variants as a {@link JavaRDD}, bounded by intervals if specified.
*/
public JavaRDD<VariantWalkerContext> getVariants(JavaSparkContext ctx) {
SAMSequenceDictionary sequenceDictionary = getBestAvailableSequenceDictionary();
List<SimpleInterval> intervals = hasIntervals() ? getIntervals() : IntervalUtils.getAllIntervalsForReference(sequenceDictionary);
// use unpadded shards (padding is only needed for reference bases)
final List<ShardBoundary> intervalShards = intervals.stream().flatMap(interval -> Shard.divideIntervalIntoShards(interval, variantShardSize, 0, sequenceDictionary).stream()).collect(Collectors.toList());
JavaRDD<VariantContext> variants = variantsSource.getParallelVariantContexts(drivingVariantFile, getIntervals());
VariantFilter variantFilter = makeVariantFilter();
variants = variants.filter(variantFilter::test);
JavaRDD<Shard<VariantContext>> shardedVariants = SparkSharder.shard(ctx, variants, VariantContext.class, sequenceDictionary, intervalShards, variantShardSize, shuffle);
Broadcast<ReferenceMultiSource> bReferenceSource = hasReference() ? ctx.broadcast(getReference()) : null;
Broadcast<FeatureManager> bFeatureManager = features == null ? null : ctx.broadcast(features);
return shardedVariants.flatMap(getVariantsFunction(bReferenceSource, bFeatureManager, sequenceDictionary, variantShardPadding));
}
use of org.apache.spark.broadcast.Broadcast in project gatk-protected by broadinstitute.
the class CoverageModelEMWorkspace method initializeWorkersWithPCA.
/**
* Initialize model parameters by performing PCA.
*/
@EvaluatesRDD
@UpdatesRDD
@CachesRDD
private void initializeWorkersWithPCA() {
logger.info("Initializing model parameters using PCA...");
/* initially, set m_t, Psi_t and W_tl to zero to get an estimate of the read depth */
final int numLatents = config.getNumLatents();
mapWorkers(cb -> cb.cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.m_t, Nd4j.zeros(new int[] { 1, cb.getTargetSpaceBlock().getNumElements() })).cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.Psi_t, Nd4j.zeros(new int[] { 1, cb.getTargetSpaceBlock().getNumElements() })));
if (biasCovariatesEnabled) {
mapWorkers(cb -> cb.cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.W_tl, Nd4j.zeros(new int[] { cb.getTargetSpaceBlock().getNumElements(), numLatents })));
}
/* update read depth without taking into account correction from bias covariates */
updateReadDepthPosteriorExpectations(1.0, true);
/* fetch sample covariance matrix */
final int minPCAInitializationReadCount = config.getMinPCAInitializationReadCount();
mapWorkers(cb -> cb.cloneWithPCAInitializationData(minPCAInitializationReadCount, Integer.MAX_VALUE));
cacheWorkers("PCA initialization");
final INDArray targetCovarianceMatrix = mapWorkersAndReduce(CoverageModelEMComputeBlock::calculateTargetCovarianceMatrixForPCAInitialization, INDArray::add);
/* perform eigen-decomposition on the target covariance matrix */
final ImmutablePair<INDArray, INDArray> targetCovarianceEigensystem = CoverageModelEMWorkspaceMathUtils.eig(targetCovarianceMatrix, false, logger);
/* the eigenvalues of sample covariance matrix can be immediately inferred by scaling */
final INDArray sampleCovarianceEigenvalues = targetCovarianceEigensystem.getLeft().div(numSamples);
/* estimate the isotropic unexplained variance -- see Bishop 12.46 */
final int residualDim = numTargets - numLatents;
final double isotropicVariance = sampleCovarianceEigenvalues.get(NDArrayIndex.interval(numLatents, numSamples)).sumNumber().doubleValue() / residualDim;
logger.info(String.format("PCA estimate of isotropic unexplained variance: %f", isotropicVariance));
/* estimate bias factors -- see Bishop 12.45 */
final INDArray scaleFactors = Transforms.sqrt(sampleCovarianceEigenvalues.get(NDArrayIndex.interval(0, numLatents)).sub(isotropicVariance), false);
final INDArray biasCovariatesPCA = Nd4j.create(new int[] { numTargets, numLatents });
for (int li = 0; li < numLatents; li++) {
final INDArray v = targetCovarianceEigensystem.getRight().getColumn(li);
/* calculate [Delta_PCA_st]^T v */
/* note: we do not need to broadcast vec since it is small and lambda capture is just fine */
final INDArray unnormedBiasCovariate = CoverageModelSparkUtils.assembleINDArrayBlocksFromCollection(mapWorkersAndCollect(cb -> ImmutablePair.of(cb.getTargetSpaceBlock(), cb.getINDArrayFromCache(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.Delta_PCA_st).transpose().mmul(v))), 0);
final double norm = unnormedBiasCovariate.norm1Number().doubleValue();
final INDArray normedBiasCovariate = unnormedBiasCovariate.divi(norm).muli(scaleFactors.getDouble(li));
biasCovariatesPCA.getColumn(li).assign(normedBiasCovariate);
}
if (ardEnabled) {
/* a better estimate of ARD coefficients */
biasCovariatesARDCoefficients.assign(Nd4j.zeros(new int[] { 1, numLatents }).addi(config.getInitialARDPrecisionRelativeToNoise() / isotropicVariance));
}
final CoverageModelParameters modelParamsFromPCA = new CoverageModelParameters(processedTargetList, Nd4j.zeros(new int[] { 1, numTargets }), Nd4j.zeros(new int[] { 1, numTargets }).addi(isotropicVariance), biasCovariatesPCA, biasCovariatesARDCoefficients);
/* clear PCA initialization data from workers */
mapWorkers(CoverageModelEMComputeBlock::cloneWithRemovedPCAInitializationData);
/* push model parameters to workers */
initializeWorkersWithGivenModel(modelParamsFromPCA);
/* update bias latent posterior expectations without admixing */
updateBiasLatentPosteriorExpectations(1.0);
}
use of org.apache.spark.broadcast.Broadcast in project incubator-systemml by apache.
the class SparkExecutionContext method getBroadcastForVariable.
/**
* TODO So far we only create broadcast variables but never destroy
* them. This is a memory leak which might lead to executor out-of-memory.
* However, in order to handle this, we need to keep track when broadcast
* variables are no longer required.
*
* @param varname variable name
* @return wrapper for broadcast variables
* @throws DMLRuntimeException if DMLRuntimeException occurs
*/
@SuppressWarnings("unchecked")
public PartitionedBroadcast<MatrixBlock> getBroadcastForVariable(String varname) throws DMLRuntimeException {
long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
MatrixObject mo = getMatrixObject(varname);
PartitionedBroadcast<MatrixBlock> bret = null;
//reuse existing broadcast handle
if (mo.getBroadcastHandle() != null && mo.getBroadcastHandle().isValid()) {
bret = mo.getBroadcastHandle().getBroadcast();
}
//create new broadcast handle (never created, evicted)
if (bret == null) {
//account for overwritten invalid broadcast (e.g., evicted)
if (mo.getBroadcastHandle() != null)
CacheableData.addBroadcastSize(-mo.getBroadcastHandle().getSize());
//obtain meta data for matrix
int brlen = (int) mo.getNumRowsPerBlock();
int bclen = (int) mo.getNumColumnsPerBlock();
//create partitioned matrix block and release memory consumed by input
MatrixBlock mb = mo.acquireRead();
PartitionedBlock<MatrixBlock> pmb = new PartitionedBlock<MatrixBlock>(mb, brlen, bclen);
mo.release();
//determine coarse-grained partitioning
int numPerPart = PartitionedBroadcast.computeBlocksPerPartition(mo.getNumRows(), mo.getNumColumns(), brlen, bclen);
int numParts = (int) Math.ceil((double) pmb.getNumRowBlocks() * pmb.getNumColumnBlocks() / numPerPart);
Broadcast<PartitionedBlock<MatrixBlock>>[] ret = new Broadcast[numParts];
//create coarse-grained partitioned broadcasts
if (numParts > 1) {
for (int i = 0; i < numParts; i++) {
int offset = i * numPerPart;
int numBlks = Math.min(numPerPart, pmb.getNumRowBlocks() * pmb.getNumColumnBlocks() - offset);
PartitionedBlock<MatrixBlock> tmp = pmb.createPartition(offset, numBlks, new MatrixBlock());
ret[i] = getSparkContext().broadcast(tmp);
}
} else {
//single partition
ret[0] = getSparkContext().broadcast(pmb);
}
bret = new PartitionedBroadcast<MatrixBlock>(ret);
BroadcastObject<MatrixBlock> bchandle = new BroadcastObject<MatrixBlock>(bret, varname, OptimizerUtils.estimatePartitionedSizeExactSparsity(mo.getMatrixCharacteristics()));
mo.setBroadcastHandle(bchandle);
CacheableData.addBroadcastSize(bchandle.getSize());
}
if (DMLScript.STATISTICS) {
Statistics.accSparkBroadCastTime(System.nanoTime() - t0);
Statistics.incSparkBroadcastCount(1);
}
return bret;
}
use of org.apache.spark.broadcast.Broadcast in project incubator-systemml by apache.
the class SparkExecutionContext method getBroadcastForFrameVariable.
@SuppressWarnings("unchecked")
public PartitionedBroadcast<FrameBlock> getBroadcastForFrameVariable(String varname) throws DMLRuntimeException {
long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
FrameObject fo = getFrameObject(varname);
PartitionedBroadcast<FrameBlock> bret = null;
//reuse existing broadcast handle
if (fo.getBroadcastHandle() != null && fo.getBroadcastHandle().isValid()) {
bret = fo.getBroadcastHandle().getBroadcast();
}
//create new broadcast handle (never created, evicted)
if (bret == null) {
//account for overwritten invalid broadcast (e.g., evicted)
if (fo.getBroadcastHandle() != null)
CacheableData.addBroadcastSize(-fo.getBroadcastHandle().getSize());
//obtain meta data for frame
int bclen = (int) fo.getNumColumns();
int brlen = OptimizerUtils.getDefaultFrameSize();
//create partitioned frame block and release memory consumed by input
FrameBlock mb = fo.acquireRead();
PartitionedBlock<FrameBlock> pmb = new PartitionedBlock<FrameBlock>(mb, brlen, bclen);
fo.release();
//determine coarse-grained partitioning
int numPerPart = PartitionedBroadcast.computeBlocksPerPartition(fo.getNumRows(), fo.getNumColumns(), brlen, bclen);
int numParts = (int) Math.ceil((double) pmb.getNumRowBlocks() * pmb.getNumColumnBlocks() / numPerPart);
Broadcast<PartitionedBlock<FrameBlock>>[] ret = new Broadcast[numParts];
//create coarse-grained partitioned broadcasts
if (numParts > 1) {
for (int i = 0; i < numParts; i++) {
int offset = i * numPerPart;
int numBlks = Math.min(numPerPart, pmb.getNumRowBlocks() * pmb.getNumColumnBlocks() - offset);
PartitionedBlock<FrameBlock> tmp = pmb.createPartition(offset, numBlks, new FrameBlock());
ret[i] = getSparkContext().broadcast(tmp);
}
} else {
//single partition
ret[0] = getSparkContext().broadcast(pmb);
}
bret = new PartitionedBroadcast<FrameBlock>(ret);
BroadcastObject<FrameBlock> bchandle = new BroadcastObject<FrameBlock>(bret, varname, OptimizerUtils.estimatePartitionedSizeExactSparsity(fo.getMatrixCharacteristics()));
fo.setBroadcastHandle(bchandle);
CacheableData.addBroadcastSize(bchandle.getSize());
}
if (DMLScript.STATISTICS) {
Statistics.accSparkBroadCastTime(System.nanoTime() - t0);
Statistics.incSparkBroadcastCount(1);
}
return bret;
}
use of org.apache.spark.broadcast.Broadcast in project gatk by broadinstitute.
the class HaplotypeCallerSpark method createReadShards.
/**
* Create an RDD of {@link Shard} from an RDD of {@link GATKRead}
* @param shardBoundariesBroadcast broadcast of an {@link OverlapDetector} loaded with the intervals that should be used for creating ReadShards
* @param reads Rdd of {@link GATKRead}
* @return a Rdd of reads grouped into potentially overlapping shards
*/
private static JavaRDD<Shard<GATKRead>> createReadShards(final Broadcast<OverlapDetector<ShardBoundary>> shardBoundariesBroadcast, final JavaRDD<GATKRead> reads) {
final JavaPairRDD<ShardBoundary, GATKRead> paired = reads.flatMapToPair(read -> {
final Collection<ShardBoundary> overlappingShards = shardBoundariesBroadcast.value().getOverlaps(read);
return overlappingShards.stream().map(key -> new Tuple2<>(key, read)).iterator();
});
final JavaPairRDD<ShardBoundary, Iterable<GATKRead>> shardsWithReads = paired.groupByKey();
return shardsWithReads.map(shard -> new SparkReadShard(shard._1(), shard._2()));
}
Aggregations