use of org.apache.spark.broadcast.Broadcast in project pyramid by cheng-li.
the class SparkCBMOptimizer method updateBinaryClassifiers.
private void updateBinaryClassifiers() {
if (logger.isDebugEnabled()) {
logger.debug("start updateBinaryClassifiers");
}
Classifier.ProbabilityEstimator[][] localBinaryClassifiers = cbm.binaryClassifiers;
double[][] localGammasT = gammasT;
Broadcast<MultiLabelClfDataSet> localDataSetBroadcast = dataSetBroadCast;
Broadcast<double[][][]> localTargetsBroadcast = targetDisBroadCast;
double localVariance = priorVarianceBinary;
List<BinaryTask> binaryTaskList = new ArrayList<>();
for (int k = 0; k < cbm.numComponents; k++) {
for (int l = 0; l < cbm.numLabels; l++) {
LogisticRegression logisticRegression = (LogisticRegression) localBinaryClassifiers[k][l];
double[] weights = localGammasT[k];
binaryTaskList.add(new BinaryTask(k, l, logisticRegression, weights));
}
}
JavaRDD<BinaryTask> binaryTaskRDD = sparkContext.parallelize(binaryTaskList, binaryTaskList.size());
List<BinaryTaskResult> results = binaryTaskRDD.map(binaryTask -> {
int labelIndex = binaryTask.classIndex;
return updateBinaryLogisticRegression(binaryTask.componentIndex, binaryTask.classIndex, binaryTask.logisticRegression, localDataSetBroadcast.value(), binaryTask.weights, localTargetsBroadcast.value()[labelIndex], localVariance);
}).collect();
for (BinaryTaskResult result : results) {
cbm.binaryClassifiers[result.componentIndex][result.classIndex] = result.binaryClassifier;
}
// IntStream.range(0, cbm.numComponents).forEach(this::updateBinaryClassifiers);
if (logger.isDebugEnabled()) {
logger.debug("finish updateBinaryClassifiers");
}
}
use of org.apache.spark.broadcast.Broadcast in project gatk by broadinstitute.
the class HaplotypeCallerSpark method createReadShards.
/**
* Create an RDD of {@link Shard} from an RDD of {@link GATKRead}
* @param shardBoundariesBroadcast broadcast of an {@link OverlapDetector} loaded with the intervals that should be used for creating ReadShards
* @param reads Rdd of {@link GATKRead}
* @return a Rdd of reads grouped into potentially overlapping shards
*/
private static JavaRDD<Shard<GATKRead>> createReadShards(final Broadcast<OverlapDetector<ShardBoundary>> shardBoundariesBroadcast, final JavaRDD<GATKRead> reads) {
final JavaPairRDD<ShardBoundary, GATKRead> paired = reads.flatMapToPair(read -> {
final Collection<ShardBoundary> overlappingShards = shardBoundariesBroadcast.value().getOverlaps(read);
return overlappingShards.stream().map(key -> new Tuple2<>(key, read)).iterator();
});
final JavaPairRDD<ShardBoundary, Iterable<GATKRead>> shardsWithReads = paired.groupByKey();
return shardsWithReads.map(shard -> new SparkReadShard(shard._1(), shard._2()));
}
use of org.apache.spark.broadcast.Broadcast in project gatk by broadinstitute.
the class FindBreakpointEvidenceSpark method handleAssemblies.
/**
* Transform all the reads for a supplied set of template names in each interval into FASTQ records
* for each interval, and do something with the list of FASTQ records for each interval (like write it to a file).
*/
@VisibleForTesting
static List<AlignedAssemblyOrExcuse> handleAssemblies(final JavaSparkContext ctx, final HopscotchUniqueMultiMap<String, Integer, QNameAndInterval> qNamesMultiMap, final JavaRDD<GATKRead> reads, final int nIntervals, final boolean includeMappingLocation, final boolean dumpFASTQs, final LocalAssemblyHandler localAssemblyHandler) {
final Broadcast<HopscotchUniqueMultiMap<String, Integer, QNameAndInterval>> broadcastQNamesMultiMap = ctx.broadcast(qNamesMultiMap);
final List<AlignedAssemblyOrExcuse> intervalDispositions = reads.mapPartitionsToPair(readItr -> new ReadsForQNamesFinder(broadcastQNamesMultiMap.value(), nIntervals, includeMappingLocation, dumpFASTQs).call(readItr).iterator(), false).combineByKey(x -> x, FindBreakpointEvidenceSpark::combineLists, FindBreakpointEvidenceSpark::combineLists, new HashPartitioner(nIntervals), false, null).map(localAssemblyHandler::apply).collect();
broadcastQNamesMultiMap.destroy();
BwaMemIndexSingleton.closeAllDistributedInstances(ctx);
return intervalDispositions;
}
use of org.apache.spark.broadcast.Broadcast in project gatk-protected by broadinstitute.
the class CoverageModelWLinearOperatorSpark method operate.
@Override
public INDArray operate(@Nonnull final INDArray W_tl) throws DimensionMismatchException {
if (W_tl.rank() != 2 || W_tl.shape()[0] != numTargets || W_tl.shape()[1] != numLatents)
throw new DimensionMismatchException(W_tl.length(), numTargets * numLatents);
/* Z F W */
final long startTimeZFW = System.nanoTime();
final INDArray Z_F_W_tl = Nd4j.create(numTargets, numLatents);
IntStream.range(0, numLatents).parallel().forEach(li -> Z_F_W_tl.get(NDArrayIndex.all(), NDArrayIndex.point(li)).assign(F_tt.operate(W_tl.get(NDArrayIndex.all(), NDArrayIndex.point(li)))));
Z_F_W_tl.assign(Nd4j.gemm(Z_F_W_tl, Z_ll, false, false));
final long endTimeZFW = System.nanoTime();
/* perform a broadcast hash join */
final long startTimeQW = System.nanoTime();
final Map<LinearlySpacedIndexBlock, INDArray> W_tl_map = CoverageModelSparkUtils.partitionINDArrayToMap(targetSpaceBlocks, W_tl);
final Broadcast<Map<LinearlySpacedIndexBlock, INDArray>> W_tl_bc = ctx.broadcast(W_tl_map);
final INDArray Q_W_tl = CoverageModelSparkUtils.assembleINDArrayBlocksFromRDD(computeRDD.mapValues(cb -> {
final INDArray W_tl_chunk = W_tl_bc.value().get(cb.getTargetSpaceBlock());
final INDArray Q_tll_chunk = cb.getINDArrayFromCache(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.Q_tll);
final Collection<INDArray> W_Q_chunk = IntStream.range(0, cb.getTargetSpaceBlock().getNumElements()).parallel().mapToObj(ti -> Q_tll_chunk.get(NDArrayIndex.point(ti)).mmul(W_tl_chunk.get(NDArrayIndex.point(ti)).transpose())).collect(Collectors.toList());
return Nd4j.vstack(W_Q_chunk);
}), 0);
W_tl_bc.destroy();
// final JavaPairRDD<LinearlySpacedIndexBlock, INDArray> W_tl_RDD = CoverageModelSparkUtils.rddFromINDArray(W_tl,
// targetSpaceBlocks, ctx, true);
// final INDArray Q_W_tl = CoverageModelSparkUtils.assembleINDArrayBlocks(
// computeRDD.join(W_tl_RDD).mapValues(p -> {
// final CoverageModelEMComputeBlock cb = p._1;
// final INDArray W_tl_chunk = p._2;
// final INDArray Q_tll_chunk = cb.getINDArrayFromCache("Q_tll");
// return Nd4j.vstack(IntStream.range(0, cb.getTargetSpaceBlock().getNumElements()).parallel()
// .mapToObj(ti -> Q_tll_chunk.get(NDArrayIndex.point(ti)).mmul(W_tl_chunk.get(NDArrayIndex.point(ti)).transpose()))
// .collect(Collectors.toList()));
// }), false);
// W_tl_RDD.unpersist();
final long endTimeQW = System.nanoTime();
logger.debug("Local [Z] [F] [W] timing: " + (endTimeZFW - startTimeZFW) / 1000000 + " ms");
logger.debug("Spark [Q] [W] timing: " + (endTimeQW - startTimeQW) / 1000000 + " ms");
return Q_W_tl.addi(Z_F_W_tl);
}
use of org.apache.spark.broadcast.Broadcast in project gatk-protected by broadinstitute.
the class CoverageModelWPreconditionerSpark method operate.
@Override
public INDArray operate(@Nonnull final INDArray W_tl) throws DimensionMismatchException {
if (W_tl.rank() != 2 || W_tl.shape()[0] != numTargets || W_tl.shape()[1] != numLatents) {
throw new DimensionMismatchException(W_tl.length(), numTargets * numLatents);
}
long startTimeRFFT = System.nanoTime();
/* forward rfft */
final INDArray W_kl = Nd4j.create(fftSize, numLatents);
IntStream.range(0, numLatents).parallel().forEach(li -> W_kl.get(NDArrayIndex.all(), NDArrayIndex.point(li)).assign(Nd4j.create(F_tt.getForwardFFT(W_tl.get(NDArrayIndex.all(), NDArrayIndex.point(li))), new int[] { fftSize, 1 })));
long endTimeRFFT = System.nanoTime();
/* apply the preconditioner in the Fourier space */
long startTimePrecond = System.nanoTime();
final Map<LinearlySpacedIndexBlock, INDArray> W_kl_map = CoverageModelSparkUtils.partitionINDArrayToMap(fourierSpaceBlocks, W_kl);
final Broadcast<Map<LinearlySpacedIndexBlock, INDArray>> W_kl_bc = ctx.broadcast(W_kl_map);
final JavaPairRDD<LinearlySpacedIndexBlock, INDArray> preconditionedWRDD = linOpPairRDD.mapToPair(p -> {
final INDArray W_kl_chuck = W_kl_bc.value().get(p._1);
final INDArray linOp_chunk = p._2;
final int blockSize = linOp_chunk.shape()[0];
final List<INDArray> linOpWList = IntStream.range(0, blockSize).parallel().mapToObj(k -> CoverageModelEMWorkspaceMathUtils.linsolve(linOp_chunk.get(NDArrayIndex.point(k)), W_kl_chuck.get(NDArrayIndex.point(k)))).collect(Collectors.toList());
return new Tuple2<>(p._1, Nd4j.vstack(linOpWList));
});
W_kl.assign(CoverageModelSparkUtils.assembleINDArrayBlocksFromRDD(preconditionedWRDD, 0));
W_kl_bc.destroy();
// final JavaPairRDD<LinearlySpacedIndexBlock, INDArray> W_kl_RDD = CoverageModelSparkUtils.rddFromINDArray(W_kl,
// fourierSpaceBlocks, ctx, true);
// W_kl.assign(CoverageModelSparkUtils.assembleINDArrayBlocks(linOpPairRDD.join((W_kl_RDD))
// .mapValues(p -> {
// final INDArray linOp = p._1;
// final INDArray W = p._2;
// final int blockSize = linOp.shape()[0];
// final List<INDArray> linOpWList = IntStream.range(0, blockSize).parallel().mapToObj(k ->
// CoverageModelEMWorkspaceMathUtils.linsolve(linOp.get(NDArrayIndex.point(k)),
// W.get(NDArrayIndex.point(k))))
// .collect(Collectors.toList());
// return Nd4j.vstack(linOpWList);
// }), false));
// W_kl_RDD.unpersist();
long endTimePrecond = System.nanoTime();
/* irfft */
long startTimeIRFFT = System.nanoTime();
final INDArray res = Nd4j.create(numTargets, numLatents);
IntStream.range(0, numLatents).parallel().forEach(li -> res.get(NDArrayIndex.all(), NDArrayIndex.point(li)).assign(F_tt.getInverseFFT(W_kl.get(NDArrayIndex.all(), NDArrayIndex.point(li)))));
long endTimeIRFFT = System.nanoTime();
logger.debug("Local FFT timing: " + (endTimeRFFT - startTimeRFFT + endTimeIRFFT - startTimeIRFFT) / 1000000 + " ms");
logger.debug("Spark preconditioner application timing: " + (endTimePrecond - startTimePrecond) / 1000000 + " ms");
return res;
}
Aggregations