Search in sources :

Example 66 with Tuple2

use of scala.Tuple2 in project gatk-protected by broadinstitute.

the class HaplotypeCallerSpark method createReadShards.

/**
     * Create an RDD of {@link Shard} from an RDD of {@link GATKRead}
     * @param shardBoundariesBroadcast  broadcast of an {@link OverlapDetector} loaded with the intervals that should be used for creating ReadShards
     * @param reads Rdd of {@link GATKRead}
     * @return a Rdd of reads grouped into potentially overlapping shards
     */
private static JavaRDD<Shard<GATKRead>> createReadShards(final Broadcast<OverlapDetector<ShardBoundary>> shardBoundariesBroadcast, final JavaRDD<GATKRead> reads) {
    final JavaPairRDD<ShardBoundary, GATKRead> paired = reads.flatMapToPair(read -> {
        final Collection<ShardBoundary> overlappingShards = shardBoundariesBroadcast.value().getOverlaps(read);
        return overlappingShards.stream().map(key -> new Tuple2<>(key, read)).iterator();
    });
    final JavaPairRDD<ShardBoundary, Iterable<GATKRead>> shardsWithReads = paired.groupByKey();
    return shardsWithReads.map(shard -> new SparkReadShard(shard._1(), shard._2()));
}
Also used : CommandLineProgramProperties(org.broadinstitute.barclay.argparser.CommandLineProgramProperties) SparkProgramGroup(org.broadinstitute.hellbender.cmdline.programgroups.SparkProgramGroup) ReferenceMultiSource(org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource) Argument(org.broadinstitute.barclay.argparser.Argument) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) Advanced(org.broadinstitute.barclay.argparser.Advanced) org.broadinstitute.hellbender.cmdline(org.broadinstitute.hellbender.cmdline) ArgumentCollection(org.broadinstitute.barclay.argparser.ArgumentCollection) GATKRead(org.broadinstitute.hellbender.utils.read.GATKRead) SAMFileHeader(htsjdk.samtools.SAMFileHeader) GATKException(org.broadinstitute.hellbender.exceptions.GATKException) Function(java.util.function.Function) ReferenceSequenceFile(htsjdk.samtools.reference.ReferenceSequenceFile) IntervalUtils(org.broadinstitute.hellbender.utils.IntervalUtils) ReferenceBases(org.broadinstitute.hellbender.utils.reference.ReferenceBases) HaplotypeCallerArgumentCollection(org.broadinstitute.hellbender.tools.walkers.haplotypecaller.HaplotypeCallerArgumentCollection) SparkReadShard(org.broadinstitute.hellbender.engine.spark.SparkReadShard) StreamSupport(java.util.stream.StreamSupport) JavaRDD(org.apache.spark.api.java.JavaRDD) FlatMapFunction(org.apache.spark.api.java.function.FlatMapFunction) org.broadinstitute.barclay.argparser(org.broadinstitute.barclay.argparser) Broadcast(org.apache.spark.broadcast.Broadcast) OverlapDetector(htsjdk.samtools.util.OverlapDetector) Iterator(java.util.Iterator) SAMSequenceDictionary(htsjdk.samtools.SAMSequenceDictionary) Collection(java.util.Collection) GATKSparkTool(org.broadinstitute.hellbender.engine.spark.GATKSparkTool) IOException(java.io.IOException) ReadFilter(org.broadinstitute.hellbender.engine.filters.ReadFilter) Tuple2(scala.Tuple2) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) HaplotypeCaller(org.broadinstitute.hellbender.tools.walkers.haplotypecaller.HaplotypeCaller) SimpleInterval(org.broadinstitute.hellbender.utils.SimpleInterval) Collectors(java.util.stream.Collectors) HaplotypeCallerEngine(org.broadinstitute.hellbender.tools.walkers.haplotypecaller.HaplotypeCallerEngine) Serializable(java.io.Serializable) org.broadinstitute.hellbender.engine(org.broadinstitute.hellbender.engine) List(java.util.List) Stream(java.util.stream.Stream) UserException(org.broadinstitute.hellbender.exceptions.UserException) VariantContextWriter(htsjdk.variant.variantcontext.writer.VariantContextWriter) VariantContext(htsjdk.variant.variantcontext.VariantContext) Utils(org.broadinstitute.hellbender.utils.Utils) ReferenceSequence(htsjdk.samtools.reference.ReferenceSequence) VisibleForTesting(com.google.common.annotations.VisibleForTesting) Collections(java.util.Collections) GATKRead(org.broadinstitute.hellbender.utils.read.GATKRead) SparkReadShard(org.broadinstitute.hellbender.engine.spark.SparkReadShard) Tuple2(scala.Tuple2)

Example 67 with Tuple2

use of scala.Tuple2 in project gatk by broadinstitute.

the class CoverageModelEMWorkspace method updateCopyRatioPosteriorExpectationsSpark.

/**
     * The Spark implementation of the E-step update of copy ratio posteriors
     *
     * @return a {@link SubroutineSignal} containing the update size
     */
@EvaluatesRDD
@UpdatesRDD
@CachesRDD
private SubroutineSignal updateCopyRatioPosteriorExpectationsSpark(final double admixingRatio) {
    /* local final member variables for lambda capture */
    final List<LinearlySpacedIndexBlock> targetBlocks = new ArrayList<>();
    targetBlocks.addAll(this.targetBlocks);
    final List<Target> targetList = new ArrayList<>();
    targetList.addAll(processedTargetList);
    final List<String> sampleNameList = new ArrayList<>();
    sampleNameList.addAll(processedSampleNameList);
    final List<SexGenotypeData> sampleSexGenotypeData = new ArrayList<>();
    sampleSexGenotypeData.addAll(processedSampleSexGenotypeData);
    final int numTargetBlocks = targetBlocks.size();
    final CopyRatioExpectationsCalculator<CoverageModelCopyRatioEmissionData, STATE> calculator = this.copyRatioExpectationsCalculator;
    final INDArray sampleReadDepths = Transforms.exp(sampleMeanLogReadDepths, true);
    /* make an RDD of copy ratio posterior expectations */
    final JavaPairRDD<Integer, CopyRatioExpectations> copyRatioPosteriorExpectationsPairRDD = /* fetch copy ratio emission data from workers */
    fetchCopyRatioEmissionDataSpark().mapPartitionsToPair(it -> {
        final List<Tuple2<Integer, CopyRatioExpectations>> newPartitionData = new ArrayList<>();
        while (it.hasNext()) {
            final Tuple2<Integer, List<CoverageModelCopyRatioEmissionData>> prevDatum = it.next();
            final int si = prevDatum._1;
            final CopyRatioCallingMetadata copyRatioCallingMetadata = CopyRatioCallingMetadata.builder().sampleName(sampleNameList.get(si)).sampleSexGenotypeData(sampleSexGenotypeData.get(si)).sampleCoverageDepth(sampleReadDepths.getDouble(si)).emissionCalculationStrategy(EmissionCalculationStrategy.HYBRID_POISSON_GAUSSIAN).build();
            newPartitionData.add(new Tuple2<>(prevDatum._1, calculator.getCopyRatioPosteriorExpectations(copyRatioCallingMetadata, targetList, prevDatum._2)));
        }
        return newPartitionData.iterator();
    }, true);
    /* we need to do two things to copyRatioPosteriorExpectationsPairRDD; so we cache it */
    /* step 1. update log chain posterior expectation on the driver node */
    final double[] newSampleLogChainPosteriors = copyRatioPosteriorExpectationsPairRDD.mapValues(CopyRatioExpectations::getLogChainPosteriorProbability).collect().stream().sorted(Comparator.comparingInt(t -> t._1)).mapToDouble(t -> t._2).toArray();
    sampleLogChainPosteriors.assign(Nd4j.create(newSampleLogChainPosteriors, new int[] { numSamples, 1 }));
    /* step 2. repartition in target space */
    final JavaPairRDD<LinearlySpacedIndexBlock, ImmutablePair<INDArray, INDArray>> blockifiedCopyRatioPosteriorResultsPairRDD = copyRatioPosteriorExpectationsPairRDD.flatMapToPair(dat -> targetBlocks.stream().map(tb -> new Tuple2<>(tb, new Tuple2<>(dat._1, ImmutablePair.of(dat._2.getLogCopyRatioMeans(tb), dat._2.getLogCopyRatioVariances(tb))))).iterator()).combineByKey(/* recipe to create an singleton list */
    Collections::singletonList, /* recipe to add an element to the list */
    (list, element) -> Stream.concat(list.stream(), Stream.of(element)).collect(Collectors.toList()), /* recipe to concatenate two lists */
    (list1, list2) -> Stream.concat(list1.stream(), list2.stream()).collect(Collectors.toList()), /* repartition with respect to target space blocks */
    new HashPartitioner(numTargetBlocks)).mapValues(list -> list.stream().sorted(Comparator.comparingInt(t -> t._1)).map(p -> p._2).map(t -> ImmutablePair.of(Nd4j.create(t.left), Nd4j.create(t.right))).collect(Collectors.toList())).mapValues(CoverageModelEMWorkspace::stackCopyRatioPosteriorDataForAllSamples);
    /* we do not need copy ratio expectations anymore */
    copyRatioPosteriorExpectationsPairRDD.unpersist();
    /* step 3. merge with computeRDD and update */
    computeRDD = computeRDD.join(blockifiedCopyRatioPosteriorResultsPairRDD).mapValues(t -> t._1.cloneWithUpdatedCopyRatioPosteriors(t._2.left, t._2.right, admixingRatio));
    cacheWorkers("after E-step for copy ratio update");
    /* collect subroutine signals */
    final List<SubroutineSignal> sigs = mapWorkersAndCollect(CoverageModelEMComputeBlock::getLatestMStepSignal);
    final double errorNormInfinity = Collections.max(sigs.stream().map(sig -> sig.<Double>get(StandardSubroutineSignals.RESIDUAL_ERROR_NORM)).collect(Collectors.toList()));
    return SubroutineSignal.builder().put(StandardSubroutineSignals.RESIDUAL_ERROR_NORM, errorNormInfinity).build();
}
Also used : ScalarProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.ScalarProducer) Function2(org.apache.spark.api.java.function.Function2) HMMSegmentProcessor(org.broadinstitute.hellbender.utils.hmm.segmentation.HMMSegmentProcessor) GermlinePloidyAnnotatedTargetCollection(org.broadinstitute.hellbender.tools.exome.sexgenotyper.GermlinePloidyAnnotatedTargetCollection) HiddenStateSegmentRecordWriter(org.broadinstitute.hellbender.utils.hmm.segmentation.HiddenStateSegmentRecordWriter) BiFunction(java.util.function.BiFunction) GATKException(org.broadinstitute.hellbender.exceptions.GATKException) SexGenotypeData(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeData) ParamUtils(org.broadinstitute.hellbender.utils.param.ParamUtils) CallStringProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.CallStringProducer) StorageLevel(org.apache.spark.storage.StorageLevel) SynchronizedUnivariateSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.SynchronizedUnivariateSolver) CopyRatioExpectationsCalculator(org.broadinstitute.hellbender.tools.coveragemodel.interfaces.CopyRatioExpectationsCalculator) UnivariateSolverSpecifications(org.broadinstitute.hellbender.tools.coveragemodel.math.UnivariateSolverSpecifications) IndexRange(org.broadinstitute.hellbender.utils.IndexRange) Broadcast(org.apache.spark.broadcast.Broadcast) ExitStatus(org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray.ExitStatus) SexGenotypeDataCollection(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeDataCollection) HashPartitioner(org.apache.spark.HashPartitioner) Predicate(java.util.function.Predicate) GeneralLinearOperator(org.broadinstitute.hellbender.tools.coveragemodel.linalg.GeneralLinearOperator) Nd4j(org.nd4j.linalg.factory.Nd4j) INDArrayIndex(org.nd4j.linalg.indexing.INDArrayIndex) FastMath(org.apache.commons.math3.util.FastMath) org.broadinstitute.hellbender.tools.exome(org.broadinstitute.hellbender.tools.exome) Tuple2(scala.Tuple2) Collectors(java.util.stream.Collectors) Sets(com.google.common.collect.Sets) AbstractUnivariateSolver(org.apache.commons.math3.analysis.solvers.AbstractUnivariateSolver) FourierLinearOperatorNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.FourierLinearOperatorNDArray) Logger(org.apache.logging.log4j.Logger) Stream(java.util.stream.Stream) UserException(org.broadinstitute.hellbender.exceptions.UserException) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) Utils(org.broadinstitute.hellbender.utils.Utils) Function(org.apache.spark.api.java.function.Function) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer) IntStream(java.util.stream.IntStream) java.util(java.util) NDArrayIndex(org.nd4j.linalg.indexing.NDArrayIndex) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) AlleleMetadataProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.AlleleMetadataProducer) EmissionCalculationStrategy(org.broadinstitute.hellbender.tools.coveragemodel.CoverageModelCopyRatioEmissionProbabilityCalculator.EmissionCalculationStrategy) RobustBrentSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver) IntervalUtils(org.broadinstitute.hellbender.utils.IntervalUtils) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) HiddenStateSegmentRecord(org.broadinstitute.hellbender.utils.hmm.segmentation.HiddenStateSegmentRecord) ImmutableTriple(org.apache.commons.lang3.tuple.ImmutableTriple) IterativeLinearSolverNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray) GATKProtectedMathUtils(org.broadinstitute.hellbender.utils.GATKProtectedMathUtils) Nd4jIOUtils(org.broadinstitute.hellbender.tools.coveragemodel.nd4jutils.Nd4jIOUtils) IOException(java.io.IOException) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) File(java.io.File) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VisibleForTesting(com.google.common.annotations.VisibleForTesting) Transforms(org.nd4j.linalg.ops.transforms.Transforms) LogManager(org.apache.logging.log4j.LogManager) NoBracketingException(org.apache.commons.math3.exception.NoBracketingException) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) Tuple2(scala.Tuple2) SexGenotypeData(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeData) HashPartitioner(org.apache.spark.HashPartitioner)

Example 68 with Tuple2

use of scala.Tuple2 in project gatk by broadinstitute.

the class CoverageModelWPreconditionerSpark method operate.

@Override
public INDArray operate(@Nonnull final INDArray W_tl) throws DimensionMismatchException {
    if (W_tl.rank() != 2 || W_tl.shape()[0] != numTargets || W_tl.shape()[1] != numLatents) {
        throw new DimensionMismatchException(W_tl.length(), numTargets * numLatents);
    }
    long startTimeRFFT = System.nanoTime();
    /* forward rfft */
    final INDArray W_kl = Nd4j.create(fftSize, numLatents);
    IntStream.range(0, numLatents).parallel().forEach(li -> W_kl.get(NDArrayIndex.all(), NDArrayIndex.point(li)).assign(Nd4j.create(F_tt.getForwardFFT(W_tl.get(NDArrayIndex.all(), NDArrayIndex.point(li))), new int[] { fftSize, 1 })));
    long endTimeRFFT = System.nanoTime();
    /* apply the preconditioner in the Fourier space */
    long startTimePrecond = System.nanoTime();
    final Map<LinearlySpacedIndexBlock, INDArray> W_kl_map = CoverageModelSparkUtils.partitionINDArrayToMap(fourierSpaceBlocks, W_kl);
    final Broadcast<Map<LinearlySpacedIndexBlock, INDArray>> W_kl_bc = ctx.broadcast(W_kl_map);
    final JavaPairRDD<LinearlySpacedIndexBlock, INDArray> preconditionedWRDD = linOpPairRDD.mapToPair(p -> {
        final INDArray W_kl_chuck = W_kl_bc.value().get(p._1);
        final INDArray linOp_chunk = p._2;
        final int blockSize = linOp_chunk.shape()[0];
        final List<INDArray> linOpWList = IntStream.range(0, blockSize).parallel().mapToObj(k -> CoverageModelEMWorkspaceMathUtils.linsolve(linOp_chunk.get(NDArrayIndex.point(k)), W_kl_chuck.get(NDArrayIndex.point(k)))).collect(Collectors.toList());
        return new Tuple2<>(p._1, Nd4j.vstack(linOpWList));
    });
    W_kl.assign(CoverageModelSparkUtils.assembleINDArrayBlocksFromRDD(preconditionedWRDD, 0));
    W_kl_bc.destroy();
    //        final JavaPairRDD<LinearlySpacedIndexBlock, INDArray> W_kl_RDD = CoverageModelSparkUtils.rddFromINDArray(W_kl,
    //                fourierSpaceBlocks, ctx, true);
    //        W_kl.assign(CoverageModelSparkUtils.assembleINDArrayBlocks(linOpPairRDD.join((W_kl_RDD))
    //                .mapValues(p -> {
    //                    final INDArray linOp = p._1;
    //                    final INDArray W = p._2;
    //                    final int blockSize = linOp.shape()[0];
    //                    final List<INDArray> linOpWList = IntStream.range(0, blockSize).parallel().mapToObj(k ->
    //                            CoverageModelEMWorkspaceMathUtils.linsolve(linOp.get(NDArrayIndex.point(k)),
    //                                    W.get(NDArrayIndex.point(k))))
    //                            .collect(Collectors.toList());
    //                    return Nd4j.vstack(linOpWList);
    //                }), false));
    //        W_kl_RDD.unpersist();
    long endTimePrecond = System.nanoTime();
    /* irfft */
    long startTimeIRFFT = System.nanoTime();
    final INDArray res = Nd4j.create(numTargets, numLatents);
    IntStream.range(0, numLatents).parallel().forEach(li -> res.get(NDArrayIndex.all(), NDArrayIndex.point(li)).assign(F_tt.getInverseFFT(W_kl.get(NDArrayIndex.all(), NDArrayIndex.point(li)))));
    long endTimeIRFFT = System.nanoTime();
    logger.debug("Local FFT timing: " + (endTimeRFFT - startTimeRFFT + endTimeIRFFT - startTimeIRFFT) / 1000000 + " ms");
    logger.debug("Spark preconditioner application timing: " + (endTimePrecond - startTimePrecond) / 1000000 + " ms");
    return res;
}
Also used : Broadcast(org.apache.spark.broadcast.Broadcast) IntStream(java.util.stream.IntStream) NDArrayIndex(org.nd4j.linalg.indexing.NDArrayIndex) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) GeneralLinearOperator(org.broadinstitute.hellbender.tools.coveragemodel.linalg.GeneralLinearOperator) Nd4j(org.nd4j.linalg.factory.Nd4j) DimensionMismatchException(org.apache.commons.math3.exception.DimensionMismatchException) Tuple2(scala.Tuple2) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) Collectors(java.util.stream.Collectors) ParamUtils(org.broadinstitute.hellbender.utils.param.ParamUtils) FourierLinearOperatorNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.FourierLinearOperatorNDArray) List(java.util.List) Logger(org.apache.logging.log4j.Logger) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Map(java.util.Map) LogManager(org.apache.logging.log4j.LogManager) Nonnull(javax.annotation.Nonnull) DimensionMismatchException(org.apache.commons.math3.exception.DimensionMismatchException) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Tuple2(scala.Tuple2) Map(java.util.Map)

Example 69 with Tuple2

use of scala.Tuple2 in project gatk by broadinstitute.

the class CoverageModelEMWorkspace method initializeWorkerPosteriors.

@UpdatesRDD
private void initializeWorkerPosteriors() {
    /* make a local copy for lambda capture (these are small, so no broadcasting is necessary) */
    final INDArray sampleMeanLogReadDepths = this.sampleMeanLogReadDepths;
    final INDArray sampleVarLogReadDepths = this.sampleVarLogReadDepths;
    final INDArray sampleUnexplainedVariance = this.sampleUnexplainedVariance;
    final INDArray sampleBiasLatentPosteriorFirstMoments = this.sampleBiasLatentPosteriorFirstMoments;
    final INDArray sampleBiasLatentPosteriorSecondMoments = this.sampleBiasLatentPosteriorSecondMoments;
    /* calculate copy ratio prior expectations */
    logger.info("Calculating copy ratio priors on the driver node...");
    final List<CopyRatioExpectations> copyRatioPriorExpectationsList = sampleIndexStream().mapToObj(si -> copyRatioExpectationsCalculator.getCopyRatioPriorExpectations(CopyRatioCallingMetadata.builder().sampleName(processedSampleNameList.get(si)).sampleSexGenotypeData(processedSampleSexGenotypeData.get(si)).sampleAverageMappingErrorProbability(config.getMappingErrorRate()).build(), processedTargetList)).collect(Collectors.toList());
    /* update log chain posterior expectation */
    sampleLogChainPosteriors.assign(Nd4j.create(copyRatioPriorExpectationsList.stream().mapToDouble(CopyRatioExpectations::getLogChainPosteriorProbability).toArray(), new int[] { numSamples, 1 }));
    /* push per-target copy ratio expectations to workers */
    final List<Tuple2<LinearlySpacedIndexBlock, Tuple2<INDArray, INDArray>>> copyRatioPriorsList = new ArrayList<>();
    for (final LinearlySpacedIndexBlock tb : targetBlocks) {
        final double[] logCopyRatioPriorMeansBlock = IntStream.range(0, tb.getNumElements()).mapToObj(rti -> copyRatioPriorExpectationsList.stream().mapToDouble(cre -> cre.getLogCopyRatioMeans()[rti + tb.getBegIndex()]).toArray()).flatMapToDouble(Arrays::stream).toArray();
        final double[] logCopyRatioPriorVariancesBlock = IntStream.range(0, tb.getNumElements()).mapToObj(rti -> copyRatioPriorExpectationsList.stream().mapToDouble(cre -> cre.getLogCopyRatioVariances()[rti + tb.getBegIndex()]).toArray()).flatMapToDouble(Arrays::stream).toArray();
        /* we do not need to take care of log copy ratio means and variances on masked targets here.
             * potential NaNs will be rectified in the compute blocks by calling the method
             * {@link CoverageModelEMComputeBlock#cloneWithInitializedData} */
        copyRatioPriorsList.add(new Tuple2<>(tb, new Tuple2<INDArray, INDArray>(Nd4j.create(logCopyRatioPriorMeansBlock, new int[] { numSamples, tb.getNumElements() }, 'f'), Nd4j.create(logCopyRatioPriorVariancesBlock, new int[] { numSamples, tb.getNumElements() }, 'f'))));
    }
    /* push to compute blocks */
    logger.info("Pushing posteriors to worker(s)...");
    /* copy ratio priors */
    joinWithWorkersAndMap(copyRatioPriorsList, p -> p._1.cloneWithUpdateCopyRatioPriors(p._2._1, p._2._2));
    /* read depth and sample-specific unexplained variance */
    mapWorkers(cb -> cb.cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.log_d_s, sampleMeanLogReadDepths).cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.var_log_d_s, sampleVarLogReadDepths).cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.gamma_s, sampleUnexplainedVariance));
    /* if bias covariates are enabled, initialize E[z] and E[z z^T] as well */
    if (biasCovariatesEnabled) {
        mapWorkers(cb -> cb.cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.z_sl, sampleBiasLatentPosteriorFirstMoments).cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.zz_sll, sampleBiasLatentPosteriorSecondMoments));
    }
}
Also used : ScalarProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.ScalarProducer) Function2(org.apache.spark.api.java.function.Function2) HMMSegmentProcessor(org.broadinstitute.hellbender.utils.hmm.segmentation.HMMSegmentProcessor) GermlinePloidyAnnotatedTargetCollection(org.broadinstitute.hellbender.tools.exome.sexgenotyper.GermlinePloidyAnnotatedTargetCollection) HiddenStateSegmentRecordWriter(org.broadinstitute.hellbender.utils.hmm.segmentation.HiddenStateSegmentRecordWriter) BiFunction(java.util.function.BiFunction) GATKException(org.broadinstitute.hellbender.exceptions.GATKException) SexGenotypeData(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeData) ParamUtils(org.broadinstitute.hellbender.utils.param.ParamUtils) CallStringProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.CallStringProducer) StorageLevel(org.apache.spark.storage.StorageLevel) SynchronizedUnivariateSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.SynchronizedUnivariateSolver) CopyRatioExpectationsCalculator(org.broadinstitute.hellbender.tools.coveragemodel.interfaces.CopyRatioExpectationsCalculator) UnivariateSolverSpecifications(org.broadinstitute.hellbender.tools.coveragemodel.math.UnivariateSolverSpecifications) IndexRange(org.broadinstitute.hellbender.utils.IndexRange) Broadcast(org.apache.spark.broadcast.Broadcast) ExitStatus(org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray.ExitStatus) SexGenotypeDataCollection(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeDataCollection) HashPartitioner(org.apache.spark.HashPartitioner) Predicate(java.util.function.Predicate) GeneralLinearOperator(org.broadinstitute.hellbender.tools.coveragemodel.linalg.GeneralLinearOperator) Nd4j(org.nd4j.linalg.factory.Nd4j) INDArrayIndex(org.nd4j.linalg.indexing.INDArrayIndex) FastMath(org.apache.commons.math3.util.FastMath) org.broadinstitute.hellbender.tools.exome(org.broadinstitute.hellbender.tools.exome) Tuple2(scala.Tuple2) Collectors(java.util.stream.Collectors) Sets(com.google.common.collect.Sets) AbstractUnivariateSolver(org.apache.commons.math3.analysis.solvers.AbstractUnivariateSolver) FourierLinearOperatorNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.FourierLinearOperatorNDArray) Logger(org.apache.logging.log4j.Logger) Stream(java.util.stream.Stream) UserException(org.broadinstitute.hellbender.exceptions.UserException) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) Utils(org.broadinstitute.hellbender.utils.Utils) Function(org.apache.spark.api.java.function.Function) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer) IntStream(java.util.stream.IntStream) java.util(java.util) NDArrayIndex(org.nd4j.linalg.indexing.NDArrayIndex) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) AlleleMetadataProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.AlleleMetadataProducer) EmissionCalculationStrategy(org.broadinstitute.hellbender.tools.coveragemodel.CoverageModelCopyRatioEmissionProbabilityCalculator.EmissionCalculationStrategy) RobustBrentSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver) IntervalUtils(org.broadinstitute.hellbender.utils.IntervalUtils) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) HiddenStateSegmentRecord(org.broadinstitute.hellbender.utils.hmm.segmentation.HiddenStateSegmentRecord) ImmutableTriple(org.apache.commons.lang3.tuple.ImmutableTriple) IterativeLinearSolverNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray) GATKProtectedMathUtils(org.broadinstitute.hellbender.utils.GATKProtectedMathUtils) Nd4jIOUtils(org.broadinstitute.hellbender.tools.coveragemodel.nd4jutils.Nd4jIOUtils) IOException(java.io.IOException) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) File(java.io.File) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VisibleForTesting(com.google.common.annotations.VisibleForTesting) Transforms(org.nd4j.linalg.ops.transforms.Transforms) LogManager(org.apache.logging.log4j.LogManager) NoBracketingException(org.apache.commons.math3.exception.NoBracketingException) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Tuple2(scala.Tuple2)

Example 70 with Tuple2

use of scala.Tuple2 in project gatk by broadinstitute.

the class BroadcastJoinReadsWithVariants method join.

public static JavaPairRDD<GATKRead, Iterable<GATKVariant>> join(final JavaRDD<GATKRead> reads, final JavaRDD<GATKVariant> variants) {
    final JavaSparkContext ctx = new JavaSparkContext(reads.context());
    final IntervalsSkipList<GATKVariant> variantSkipList = new IntervalsSkipList<>(variants.collect());
    final Broadcast<IntervalsSkipList<GATKVariant>> variantsBroadcast = ctx.broadcast(variantSkipList);
    return reads.mapToPair(r -> {
        final IntervalsSkipList<GATKVariant> intervalsSkipList = variantsBroadcast.getValue();
        if (SimpleInterval.isValid(r.getContig(), r.getStart(), r.getEnd())) {
            return new Tuple2<>(r, intervalsSkipList.getOverlapping(new SimpleInterval(r)));
        } else {
            return new Tuple2<>(r, Collections.emptyList());
        }
    });
}
Also used : GATKVariant(org.broadinstitute.hellbender.utils.variant.GATKVariant) IntervalsSkipList(org.broadinstitute.hellbender.utils.collections.IntervalsSkipList) Tuple2(scala.Tuple2) SimpleInterval(org.broadinstitute.hellbender.utils.SimpleInterval) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext)

Aggregations

Tuple2 (scala.Tuple2)183 JavaSparkContext (org.apache.spark.api.java.JavaSparkContext)57 ArrayList (java.util.ArrayList)44 IOException (java.io.IOException)32 Test (org.junit.Test)32 INDArray (org.nd4j.linalg.api.ndarray.INDArray)28 JavaPairRDD (org.apache.spark.api.java.JavaPairRDD)23 List (java.util.List)22 Function (org.apache.spark.api.java.function.Function)19 File (java.io.File)18 Collectors (java.util.stream.Collectors)18 MatrixBlock (org.apache.sysml.runtime.matrix.data.MatrixBlock)18 MatrixIndexes (org.apache.sysml.runtime.matrix.data.MatrixIndexes)18 GATKException (org.broadinstitute.hellbender.exceptions.GATKException)18 Configuration (org.apache.hadoop.conf.Configuration)17 UserException (org.broadinstitute.hellbender.exceptions.UserException)17 Broadcast (org.apache.spark.broadcast.Broadcast)16 SparkConf (org.apache.spark.SparkConf)15 JavaRDD (org.apache.spark.api.java.JavaRDD)15 VisibleForTesting (com.google.common.annotations.VisibleForTesting)14