Search in sources :

Example 1 with EmissionCalculationStrategy

use of org.broadinstitute.hellbender.tools.coveragemodel.CoverageModelCopyRatioEmissionProbabilityCalculator.EmissionCalculationStrategy in project gatk by broadinstitute.

the class CoverageModelEMWorkspace method getCopyRatioSegmentsSpark.

/**
     * Fetch copy ratio segments from compute blocks (Spark implementation)
     *
     * @return a list of {@link CopyRatioHMMResults}
     */
private List<List<HiddenStateSegmentRecord<STATE, Target>>> getCopyRatioSegmentsSpark() {
    /* local final member variables for lambda capture */
    final List<Target> processedTargetList = new ArrayList<>();
    processedTargetList.addAll(this.processedTargetList);
    final List<SexGenotypeData> processedSampleSexGenotypeData = new ArrayList<>();
    processedSampleSexGenotypeData.addAll(this.processedSampleSexGenotypeData);
    final List<String> processedSampleNameList = new ArrayList<>();
    processedSampleNameList.addAll(this.processedSampleNameList);
    final INDArray sampleReadDepths = Transforms.exp(sampleMeanLogReadDepths, true);
    final CopyRatioExpectationsCalculator<CoverageModelCopyRatioEmissionData, STATE> copyRatioExpectationsCalculator = this.copyRatioExpectationsCalculator;
    final BiFunction<SexGenotypeData, Target, STATE> referenceStateFactory = this.referenceStateFactory;
    return fetchCopyRatioEmissionDataSpark().mapPartitionsToPair(it -> {
        final List<Tuple2<Integer, CopyRatioHMMResults<CoverageModelCopyRatioEmissionData, STATE>>> newPartitionData = new ArrayList<>();
        while (it.hasNext()) {
            final Tuple2<Integer, List<CoverageModelCopyRatioEmissionData>> prevDatum = it.next();
            final int sampleIndex = prevDatum._1;
            final CopyRatioCallingMetadata copyRatioCallingMetadata = CopyRatioCallingMetadata.builder().sampleName(processedSampleNameList.get(sampleIndex)).sampleSexGenotypeData(processedSampleSexGenotypeData.get(sampleIndex)).sampleCoverageDepth(sampleReadDepths.getDouble(sampleIndex)).emissionCalculationStrategy(EmissionCalculationStrategy.HYBRID_POISSON_GAUSSIAN).build();
            newPartitionData.add(new Tuple2<>(sampleIndex, copyRatioExpectationsCalculator.getCopyRatioHMMResults(copyRatioCallingMetadata, processedTargetList, prevDatum._2)));
        }
        return newPartitionData.iterator();
    }, true).mapPartitionsToPair(it -> {
        final List<Tuple2<Integer, List<HiddenStateSegmentRecord<STATE, Target>>>> newPartitionData = new ArrayList<>();
        while (it.hasNext()) {
            final Tuple2<Integer, CopyRatioHMMResults<CoverageModelCopyRatioEmissionData, STATE>> prevDatum = it.next();
            final int sampleIndex = prevDatum._1;
            final CopyRatioHMMResults<CoverageModelCopyRatioEmissionData, STATE> result = prevDatum._2;
            final HMMSegmentProcessor<CoverageModelCopyRatioEmissionData, STATE, Target> processor = new HMMSegmentProcessor<>(Collections.singletonList(result.getMetaData().getSampleName()), Collections.singletonList(result.getMetaData().getSampleSexGenotypeData()), referenceStateFactory, Collections.singletonList(new HashedListTargetCollection<>(processedTargetList)), Collections.singletonList(result.getForwardBackwardResult()), Collections.singletonList(result.getViterbiResult()));
            newPartitionData.add(new Tuple2<>(sampleIndex, processor.getSegmentsAsList()));
        }
        return newPartitionData.iterator();
    }).collect().stream().sorted(Comparator.comparingInt(t -> t._1)).map(t -> t._2).collect(Collectors.toList());
}
Also used : ScalarProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.ScalarProducer) Function2(org.apache.spark.api.java.function.Function2) HMMSegmentProcessor(org.broadinstitute.hellbender.utils.hmm.segmentation.HMMSegmentProcessor) GermlinePloidyAnnotatedTargetCollection(org.broadinstitute.hellbender.tools.exome.sexgenotyper.GermlinePloidyAnnotatedTargetCollection) HiddenStateSegmentRecordWriter(org.broadinstitute.hellbender.utils.hmm.segmentation.HiddenStateSegmentRecordWriter) BiFunction(java.util.function.BiFunction) GATKException(org.broadinstitute.hellbender.exceptions.GATKException) SexGenotypeData(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeData) ParamUtils(org.broadinstitute.hellbender.utils.param.ParamUtils) CallStringProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.CallStringProducer) StorageLevel(org.apache.spark.storage.StorageLevel) SynchronizedUnivariateSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.SynchronizedUnivariateSolver) CopyRatioExpectationsCalculator(org.broadinstitute.hellbender.tools.coveragemodel.interfaces.CopyRatioExpectationsCalculator) UnivariateSolverSpecifications(org.broadinstitute.hellbender.tools.coveragemodel.math.UnivariateSolverSpecifications) IndexRange(org.broadinstitute.hellbender.utils.IndexRange) Broadcast(org.apache.spark.broadcast.Broadcast) ExitStatus(org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray.ExitStatus) SexGenotypeDataCollection(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeDataCollection) HashPartitioner(org.apache.spark.HashPartitioner) Predicate(java.util.function.Predicate) GeneralLinearOperator(org.broadinstitute.hellbender.tools.coveragemodel.linalg.GeneralLinearOperator) Nd4j(org.nd4j.linalg.factory.Nd4j) INDArrayIndex(org.nd4j.linalg.indexing.INDArrayIndex) FastMath(org.apache.commons.math3.util.FastMath) org.broadinstitute.hellbender.tools.exome(org.broadinstitute.hellbender.tools.exome) Tuple2(scala.Tuple2) Collectors(java.util.stream.Collectors) Sets(com.google.common.collect.Sets) AbstractUnivariateSolver(org.apache.commons.math3.analysis.solvers.AbstractUnivariateSolver) FourierLinearOperatorNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.FourierLinearOperatorNDArray) Logger(org.apache.logging.log4j.Logger) Stream(java.util.stream.Stream) UserException(org.broadinstitute.hellbender.exceptions.UserException) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) Utils(org.broadinstitute.hellbender.utils.Utils) Function(org.apache.spark.api.java.function.Function) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer) IntStream(java.util.stream.IntStream) java.util(java.util) NDArrayIndex(org.nd4j.linalg.indexing.NDArrayIndex) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) AlleleMetadataProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.AlleleMetadataProducer) EmissionCalculationStrategy(org.broadinstitute.hellbender.tools.coveragemodel.CoverageModelCopyRatioEmissionProbabilityCalculator.EmissionCalculationStrategy) RobustBrentSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver) IntervalUtils(org.broadinstitute.hellbender.utils.IntervalUtils) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) HiddenStateSegmentRecord(org.broadinstitute.hellbender.utils.hmm.segmentation.HiddenStateSegmentRecord) ImmutableTriple(org.apache.commons.lang3.tuple.ImmutableTriple) IterativeLinearSolverNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray) GATKProtectedMathUtils(org.broadinstitute.hellbender.utils.GATKProtectedMathUtils) Nd4jIOUtils(org.broadinstitute.hellbender.tools.coveragemodel.nd4jutils.Nd4jIOUtils) IOException(java.io.IOException) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) File(java.io.File) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VisibleForTesting(com.google.common.annotations.VisibleForTesting) Transforms(org.nd4j.linalg.ops.transforms.Transforms) LogManager(org.apache.logging.log4j.LogManager) NoBracketingException(org.apache.commons.math3.exception.NoBracketingException) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Tuple2(scala.Tuple2) SexGenotypeData(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeData) HMMSegmentProcessor(org.broadinstitute.hellbender.utils.hmm.segmentation.HMMSegmentProcessor)

Example 2 with EmissionCalculationStrategy

use of org.broadinstitute.hellbender.tools.coveragemodel.CoverageModelCopyRatioEmissionProbabilityCalculator.EmissionCalculationStrategy in project gatk-protected by broadinstitute.

the class CoverageModelEMWorkspace method updateCopyRatioPosteriorExpectationsLocal.

/**
     * Local implementation of the E-step update of copy ratio posteriors
     *
     * @return a {@link SubroutineSignal} containing the update size (key: "error_norm")
     */
public SubroutineSignal updateCopyRatioPosteriorExpectationsLocal(final double admixingRatio) {
    /* step 1. fetch copy ratio emission data */
    final List<List<CoverageModelCopyRatioEmissionData>> copyRatioEmissionData = fetchCopyRatioEmissionDataLocal();
    /* step 2. run the forward-backward algorithm and calculate copy ratio posteriors */
    final INDArray sampleReadDepths = Transforms.exp(sampleMeanLogReadDepths, true);
    final List<CopyRatioExpectations> copyRatioPosteriorResults = sampleIndexStream().parallel().mapToObj(si -> copyRatioExpectationsCalculator.getCopyRatioPosteriorExpectations(CopyRatioCallingMetadata.builder().sampleName(processedSampleNameList.get(si)).sampleSexGenotypeData(processedSampleSexGenotypeData.get(si)).sampleCoverageDepth(sampleReadDepths.getDouble(si)).emissionCalculationStrategy(EmissionCalculationStrategy.HYBRID_POISSON_GAUSSIAN).build(), processedTargetList, copyRatioEmissionData.get(si))).collect(Collectors.toList());
    /* update log chain posterior expectation */
    sampleLogChainPosteriors.assign(Nd4j.create(copyRatioPosteriorResults.stream().mapToDouble(CopyRatioExpectations::getLogChainPosteriorProbability).toArray(), new int[] { numSamples, 1 }));
    /* sent the results back to workers */
    final ImmutablePair<INDArray, INDArray> copyRatioPosteriorDataPair = convertCopyRatioLatentPosteriorExpectationsToNDArray(copyRatioPosteriorResults);
    final INDArray log_c_st = copyRatioPosteriorDataPair.left;
    final INDArray var_log_c_st = copyRatioPosteriorDataPair.right;
    /* partition the pair of (log_c_st, var_log_c_st), sent the result to workers via broadcast-hash-map */
    pushToWorkers(mapINDArrayPairToBlocks(log_c_st.transpose(), var_log_c_st.transpose()), (p, cb) -> cb.cloneWithUpdatedCopyRatioPosteriors(p.get(cb.getTargetSpaceBlock()).left.transpose(), p.get(cb.getTargetSpaceBlock()).right.transpose(), admixingRatio));
    cacheWorkers("after E-step update of copy ratio posteriors");
    /* collect subroutine signals */
    final List<SubroutineSignal> sigs = mapWorkersAndCollect(CoverageModelEMComputeBlock::getLatestMStepSignal);
    final double errorNormInfinity = Collections.max(sigs.stream().map(sig -> sig.<Double>get(StandardSubroutineSignals.RESIDUAL_ERROR_NORM)).collect(Collectors.toList()));
    return SubroutineSignal.builder().put(StandardSubroutineSignals.RESIDUAL_ERROR_NORM, errorNormInfinity).build();
}
Also used : ScalarProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.ScalarProducer) Function2(org.apache.spark.api.java.function.Function2) HMMSegmentProcessor(org.broadinstitute.hellbender.utils.hmm.segmentation.HMMSegmentProcessor) GermlinePloidyAnnotatedTargetCollection(org.broadinstitute.hellbender.tools.exome.sexgenotyper.GermlinePloidyAnnotatedTargetCollection) HiddenStateSegmentRecordWriter(org.broadinstitute.hellbender.utils.hmm.segmentation.HiddenStateSegmentRecordWriter) BiFunction(java.util.function.BiFunction) GATKException(org.broadinstitute.hellbender.exceptions.GATKException) SexGenotypeData(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeData) ParamUtils(org.broadinstitute.hellbender.utils.param.ParamUtils) CallStringProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.CallStringProducer) StorageLevel(org.apache.spark.storage.StorageLevel) SynchronizedUnivariateSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.SynchronizedUnivariateSolver) CopyRatioExpectationsCalculator(org.broadinstitute.hellbender.tools.coveragemodel.interfaces.CopyRatioExpectationsCalculator) UnivariateSolverSpecifications(org.broadinstitute.hellbender.tools.coveragemodel.math.UnivariateSolverSpecifications) IndexRange(org.broadinstitute.hellbender.utils.IndexRange) Broadcast(org.apache.spark.broadcast.Broadcast) ExitStatus(org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray.ExitStatus) SexGenotypeDataCollection(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeDataCollection) HashPartitioner(org.apache.spark.HashPartitioner) Predicate(java.util.function.Predicate) GeneralLinearOperator(org.broadinstitute.hellbender.tools.coveragemodel.linalg.GeneralLinearOperator) Nd4j(org.nd4j.linalg.factory.Nd4j) INDArrayIndex(org.nd4j.linalg.indexing.INDArrayIndex) FastMath(org.apache.commons.math3.util.FastMath) org.broadinstitute.hellbender.tools.exome(org.broadinstitute.hellbender.tools.exome) Tuple2(scala.Tuple2) Collectors(java.util.stream.Collectors) Sets(com.google.common.collect.Sets) AbstractUnivariateSolver(org.apache.commons.math3.analysis.solvers.AbstractUnivariateSolver) FourierLinearOperatorNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.FourierLinearOperatorNDArray) Logger(org.apache.logging.log4j.Logger) Stream(java.util.stream.Stream) UserException(org.broadinstitute.hellbender.exceptions.UserException) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) Utils(org.broadinstitute.hellbender.utils.Utils) Function(org.apache.spark.api.java.function.Function) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer) IntStream(java.util.stream.IntStream) java.util(java.util) NDArrayIndex(org.nd4j.linalg.indexing.NDArrayIndex) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) AlleleMetadataProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.AlleleMetadataProducer) EmissionCalculationStrategy(org.broadinstitute.hellbender.tools.coveragemodel.CoverageModelCopyRatioEmissionProbabilityCalculator.EmissionCalculationStrategy) RobustBrentSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver) IntervalUtils(org.broadinstitute.hellbender.utils.IntervalUtils) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) HiddenStateSegmentRecord(org.broadinstitute.hellbender.utils.hmm.segmentation.HiddenStateSegmentRecord) ImmutableTriple(org.apache.commons.lang3.tuple.ImmutableTriple) IterativeLinearSolverNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray) GATKProtectedMathUtils(org.broadinstitute.hellbender.utils.GATKProtectedMathUtils) Nd4jIOUtils(org.broadinstitute.hellbender.tools.coveragemodel.nd4jutils.Nd4jIOUtils) IOException(java.io.IOException) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) File(java.io.File) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VisibleForTesting(com.google.common.annotations.VisibleForTesting) Transforms(org.nd4j.linalg.ops.transforms.Transforms) LogManager(org.apache.logging.log4j.LogManager) NoBracketingException(org.apache.commons.math3.exception.NoBracketingException) INDArray(org.nd4j.linalg.api.ndarray.INDArray)

Example 3 with EmissionCalculationStrategy

use of org.broadinstitute.hellbender.tools.coveragemodel.CoverageModelCopyRatioEmissionProbabilityCalculator.EmissionCalculationStrategy in project gatk by broadinstitute.

the class CoverageModelEMWorkspace method updateCopyRatioPosteriorExpectationsSpark.

/**
     * The Spark implementation of the E-step update of copy ratio posteriors
     *
     * @return a {@link SubroutineSignal} containing the update size
     */
@EvaluatesRDD
@UpdatesRDD
@CachesRDD
private SubroutineSignal updateCopyRatioPosteriorExpectationsSpark(final double admixingRatio) {
    /* local final member variables for lambda capture */
    final List<LinearlySpacedIndexBlock> targetBlocks = new ArrayList<>();
    targetBlocks.addAll(this.targetBlocks);
    final List<Target> targetList = new ArrayList<>();
    targetList.addAll(processedTargetList);
    final List<String> sampleNameList = new ArrayList<>();
    sampleNameList.addAll(processedSampleNameList);
    final List<SexGenotypeData> sampleSexGenotypeData = new ArrayList<>();
    sampleSexGenotypeData.addAll(processedSampleSexGenotypeData);
    final int numTargetBlocks = targetBlocks.size();
    final CopyRatioExpectationsCalculator<CoverageModelCopyRatioEmissionData, STATE> calculator = this.copyRatioExpectationsCalculator;
    final INDArray sampleReadDepths = Transforms.exp(sampleMeanLogReadDepths, true);
    /* make an RDD of copy ratio posterior expectations */
    final JavaPairRDD<Integer, CopyRatioExpectations> copyRatioPosteriorExpectationsPairRDD = /* fetch copy ratio emission data from workers */
    fetchCopyRatioEmissionDataSpark().mapPartitionsToPair(it -> {
        final List<Tuple2<Integer, CopyRatioExpectations>> newPartitionData = new ArrayList<>();
        while (it.hasNext()) {
            final Tuple2<Integer, List<CoverageModelCopyRatioEmissionData>> prevDatum = it.next();
            final int si = prevDatum._1;
            final CopyRatioCallingMetadata copyRatioCallingMetadata = CopyRatioCallingMetadata.builder().sampleName(sampleNameList.get(si)).sampleSexGenotypeData(sampleSexGenotypeData.get(si)).sampleCoverageDepth(sampleReadDepths.getDouble(si)).emissionCalculationStrategy(EmissionCalculationStrategy.HYBRID_POISSON_GAUSSIAN).build();
            newPartitionData.add(new Tuple2<>(prevDatum._1, calculator.getCopyRatioPosteriorExpectations(copyRatioCallingMetadata, targetList, prevDatum._2)));
        }
        return newPartitionData.iterator();
    }, true);
    /* we need to do two things to copyRatioPosteriorExpectationsPairRDD; so we cache it */
    /* step 1. update log chain posterior expectation on the driver node */
    final double[] newSampleLogChainPosteriors = copyRatioPosteriorExpectationsPairRDD.mapValues(CopyRatioExpectations::getLogChainPosteriorProbability).collect().stream().sorted(Comparator.comparingInt(t -> t._1)).mapToDouble(t -> t._2).toArray();
    sampleLogChainPosteriors.assign(Nd4j.create(newSampleLogChainPosteriors, new int[] { numSamples, 1 }));
    /* step 2. repartition in target space */
    final JavaPairRDD<LinearlySpacedIndexBlock, ImmutablePair<INDArray, INDArray>> blockifiedCopyRatioPosteriorResultsPairRDD = copyRatioPosteriorExpectationsPairRDD.flatMapToPair(dat -> targetBlocks.stream().map(tb -> new Tuple2<>(tb, new Tuple2<>(dat._1, ImmutablePair.of(dat._2.getLogCopyRatioMeans(tb), dat._2.getLogCopyRatioVariances(tb))))).iterator()).combineByKey(/* recipe to create an singleton list */
    Collections::singletonList, /* recipe to add an element to the list */
    (list, element) -> Stream.concat(list.stream(), Stream.of(element)).collect(Collectors.toList()), /* recipe to concatenate two lists */
    (list1, list2) -> Stream.concat(list1.stream(), list2.stream()).collect(Collectors.toList()), /* repartition with respect to target space blocks */
    new HashPartitioner(numTargetBlocks)).mapValues(list -> list.stream().sorted(Comparator.comparingInt(t -> t._1)).map(p -> p._2).map(t -> ImmutablePair.of(Nd4j.create(t.left), Nd4j.create(t.right))).collect(Collectors.toList())).mapValues(CoverageModelEMWorkspace::stackCopyRatioPosteriorDataForAllSamples);
    /* we do not need copy ratio expectations anymore */
    copyRatioPosteriorExpectationsPairRDD.unpersist();
    /* step 3. merge with computeRDD and update */
    computeRDD = computeRDD.join(blockifiedCopyRatioPosteriorResultsPairRDD).mapValues(t -> t._1.cloneWithUpdatedCopyRatioPosteriors(t._2.left, t._2.right, admixingRatio));
    cacheWorkers("after E-step for copy ratio update");
    /* collect subroutine signals */
    final List<SubroutineSignal> sigs = mapWorkersAndCollect(CoverageModelEMComputeBlock::getLatestMStepSignal);
    final double errorNormInfinity = Collections.max(sigs.stream().map(sig -> sig.<Double>get(StandardSubroutineSignals.RESIDUAL_ERROR_NORM)).collect(Collectors.toList()));
    return SubroutineSignal.builder().put(StandardSubroutineSignals.RESIDUAL_ERROR_NORM, errorNormInfinity).build();
}
Also used : ScalarProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.ScalarProducer) Function2(org.apache.spark.api.java.function.Function2) HMMSegmentProcessor(org.broadinstitute.hellbender.utils.hmm.segmentation.HMMSegmentProcessor) GermlinePloidyAnnotatedTargetCollection(org.broadinstitute.hellbender.tools.exome.sexgenotyper.GermlinePloidyAnnotatedTargetCollection) HiddenStateSegmentRecordWriter(org.broadinstitute.hellbender.utils.hmm.segmentation.HiddenStateSegmentRecordWriter) BiFunction(java.util.function.BiFunction) GATKException(org.broadinstitute.hellbender.exceptions.GATKException) SexGenotypeData(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeData) ParamUtils(org.broadinstitute.hellbender.utils.param.ParamUtils) CallStringProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.CallStringProducer) StorageLevel(org.apache.spark.storage.StorageLevel) SynchronizedUnivariateSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.SynchronizedUnivariateSolver) CopyRatioExpectationsCalculator(org.broadinstitute.hellbender.tools.coveragemodel.interfaces.CopyRatioExpectationsCalculator) UnivariateSolverSpecifications(org.broadinstitute.hellbender.tools.coveragemodel.math.UnivariateSolverSpecifications) IndexRange(org.broadinstitute.hellbender.utils.IndexRange) Broadcast(org.apache.spark.broadcast.Broadcast) ExitStatus(org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray.ExitStatus) SexGenotypeDataCollection(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeDataCollection) HashPartitioner(org.apache.spark.HashPartitioner) Predicate(java.util.function.Predicate) GeneralLinearOperator(org.broadinstitute.hellbender.tools.coveragemodel.linalg.GeneralLinearOperator) Nd4j(org.nd4j.linalg.factory.Nd4j) INDArrayIndex(org.nd4j.linalg.indexing.INDArrayIndex) FastMath(org.apache.commons.math3.util.FastMath) org.broadinstitute.hellbender.tools.exome(org.broadinstitute.hellbender.tools.exome) Tuple2(scala.Tuple2) Collectors(java.util.stream.Collectors) Sets(com.google.common.collect.Sets) AbstractUnivariateSolver(org.apache.commons.math3.analysis.solvers.AbstractUnivariateSolver) FourierLinearOperatorNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.FourierLinearOperatorNDArray) Logger(org.apache.logging.log4j.Logger) Stream(java.util.stream.Stream) UserException(org.broadinstitute.hellbender.exceptions.UserException) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) Utils(org.broadinstitute.hellbender.utils.Utils) Function(org.apache.spark.api.java.function.Function) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer) IntStream(java.util.stream.IntStream) java.util(java.util) NDArrayIndex(org.nd4j.linalg.indexing.NDArrayIndex) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) AlleleMetadataProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.AlleleMetadataProducer) EmissionCalculationStrategy(org.broadinstitute.hellbender.tools.coveragemodel.CoverageModelCopyRatioEmissionProbabilityCalculator.EmissionCalculationStrategy) RobustBrentSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver) IntervalUtils(org.broadinstitute.hellbender.utils.IntervalUtils) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) HiddenStateSegmentRecord(org.broadinstitute.hellbender.utils.hmm.segmentation.HiddenStateSegmentRecord) ImmutableTriple(org.apache.commons.lang3.tuple.ImmutableTriple) IterativeLinearSolverNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray) GATKProtectedMathUtils(org.broadinstitute.hellbender.utils.GATKProtectedMathUtils) Nd4jIOUtils(org.broadinstitute.hellbender.tools.coveragemodel.nd4jutils.Nd4jIOUtils) IOException(java.io.IOException) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) File(java.io.File) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VisibleForTesting(com.google.common.annotations.VisibleForTesting) Transforms(org.nd4j.linalg.ops.transforms.Transforms) LogManager(org.apache.logging.log4j.LogManager) NoBracketingException(org.apache.commons.math3.exception.NoBracketingException) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) Tuple2(scala.Tuple2) SexGenotypeData(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeData) HashPartitioner(org.apache.spark.HashPartitioner)

Example 4 with EmissionCalculationStrategy

use of org.broadinstitute.hellbender.tools.coveragemodel.CoverageModelCopyRatioEmissionProbabilityCalculator.EmissionCalculationStrategy in project gatk by broadinstitute.

the class CoverageModelEMWorkspace method updateCopyRatioPosteriorExpectationsLocal.

/**
     * Local implementation of the E-step update of copy ratio posteriors
     *
     * @return a {@link SubroutineSignal} containing the update size (key: "error_norm")
     */
public SubroutineSignal updateCopyRatioPosteriorExpectationsLocal(final double admixingRatio) {
    /* step 1. fetch copy ratio emission data */
    final List<List<CoverageModelCopyRatioEmissionData>> copyRatioEmissionData = fetchCopyRatioEmissionDataLocal();
    /* step 2. run the forward-backward algorithm and calculate copy ratio posteriors */
    final INDArray sampleReadDepths = Transforms.exp(sampleMeanLogReadDepths, true);
    final List<CopyRatioExpectations> copyRatioPosteriorResults = sampleIndexStream().parallel().mapToObj(si -> copyRatioExpectationsCalculator.getCopyRatioPosteriorExpectations(CopyRatioCallingMetadata.builder().sampleName(processedSampleNameList.get(si)).sampleSexGenotypeData(processedSampleSexGenotypeData.get(si)).sampleCoverageDepth(sampleReadDepths.getDouble(si)).emissionCalculationStrategy(EmissionCalculationStrategy.HYBRID_POISSON_GAUSSIAN).build(), processedTargetList, copyRatioEmissionData.get(si))).collect(Collectors.toList());
    /* update log chain posterior expectation */
    sampleLogChainPosteriors.assign(Nd4j.create(copyRatioPosteriorResults.stream().mapToDouble(CopyRatioExpectations::getLogChainPosteriorProbability).toArray(), new int[] { numSamples, 1 }));
    /* sent the results back to workers */
    final ImmutablePair<INDArray, INDArray> copyRatioPosteriorDataPair = convertCopyRatioLatentPosteriorExpectationsToNDArray(copyRatioPosteriorResults);
    final INDArray log_c_st = copyRatioPosteriorDataPair.left;
    final INDArray var_log_c_st = copyRatioPosteriorDataPair.right;
    /* partition the pair of (log_c_st, var_log_c_st), sent the result to workers via broadcast-hash-map */
    pushToWorkers(mapINDArrayPairToBlocks(log_c_st.transpose(), var_log_c_st.transpose()), (p, cb) -> cb.cloneWithUpdatedCopyRatioPosteriors(p.get(cb.getTargetSpaceBlock()).left.transpose(), p.get(cb.getTargetSpaceBlock()).right.transpose(), admixingRatio));
    cacheWorkers("after E-step update of copy ratio posteriors");
    /* collect subroutine signals */
    final List<SubroutineSignal> sigs = mapWorkersAndCollect(CoverageModelEMComputeBlock::getLatestMStepSignal);
    final double errorNormInfinity = Collections.max(sigs.stream().map(sig -> sig.<Double>get(StandardSubroutineSignals.RESIDUAL_ERROR_NORM)).collect(Collectors.toList()));
    return SubroutineSignal.builder().put(StandardSubroutineSignals.RESIDUAL_ERROR_NORM, errorNormInfinity).build();
}
Also used : ScalarProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.ScalarProducer) Function2(org.apache.spark.api.java.function.Function2) HMMSegmentProcessor(org.broadinstitute.hellbender.utils.hmm.segmentation.HMMSegmentProcessor) GermlinePloidyAnnotatedTargetCollection(org.broadinstitute.hellbender.tools.exome.sexgenotyper.GermlinePloidyAnnotatedTargetCollection) HiddenStateSegmentRecordWriter(org.broadinstitute.hellbender.utils.hmm.segmentation.HiddenStateSegmentRecordWriter) BiFunction(java.util.function.BiFunction) GATKException(org.broadinstitute.hellbender.exceptions.GATKException) SexGenotypeData(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeData) ParamUtils(org.broadinstitute.hellbender.utils.param.ParamUtils) CallStringProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.CallStringProducer) StorageLevel(org.apache.spark.storage.StorageLevel) SynchronizedUnivariateSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.SynchronizedUnivariateSolver) CopyRatioExpectationsCalculator(org.broadinstitute.hellbender.tools.coveragemodel.interfaces.CopyRatioExpectationsCalculator) UnivariateSolverSpecifications(org.broadinstitute.hellbender.tools.coveragemodel.math.UnivariateSolverSpecifications) IndexRange(org.broadinstitute.hellbender.utils.IndexRange) Broadcast(org.apache.spark.broadcast.Broadcast) ExitStatus(org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray.ExitStatus) SexGenotypeDataCollection(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeDataCollection) HashPartitioner(org.apache.spark.HashPartitioner) Predicate(java.util.function.Predicate) GeneralLinearOperator(org.broadinstitute.hellbender.tools.coveragemodel.linalg.GeneralLinearOperator) Nd4j(org.nd4j.linalg.factory.Nd4j) INDArrayIndex(org.nd4j.linalg.indexing.INDArrayIndex) FastMath(org.apache.commons.math3.util.FastMath) org.broadinstitute.hellbender.tools.exome(org.broadinstitute.hellbender.tools.exome) Tuple2(scala.Tuple2) Collectors(java.util.stream.Collectors) Sets(com.google.common.collect.Sets) AbstractUnivariateSolver(org.apache.commons.math3.analysis.solvers.AbstractUnivariateSolver) FourierLinearOperatorNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.FourierLinearOperatorNDArray) Logger(org.apache.logging.log4j.Logger) Stream(java.util.stream.Stream) UserException(org.broadinstitute.hellbender.exceptions.UserException) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) Utils(org.broadinstitute.hellbender.utils.Utils) Function(org.apache.spark.api.java.function.Function) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer) IntStream(java.util.stream.IntStream) java.util(java.util) NDArrayIndex(org.nd4j.linalg.indexing.NDArrayIndex) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) AlleleMetadataProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.AlleleMetadataProducer) EmissionCalculationStrategy(org.broadinstitute.hellbender.tools.coveragemodel.CoverageModelCopyRatioEmissionProbabilityCalculator.EmissionCalculationStrategy) RobustBrentSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver) IntervalUtils(org.broadinstitute.hellbender.utils.IntervalUtils) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) HiddenStateSegmentRecord(org.broadinstitute.hellbender.utils.hmm.segmentation.HiddenStateSegmentRecord) ImmutableTriple(org.apache.commons.lang3.tuple.ImmutableTriple) IterativeLinearSolverNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray) GATKProtectedMathUtils(org.broadinstitute.hellbender.utils.GATKProtectedMathUtils) Nd4jIOUtils(org.broadinstitute.hellbender.tools.coveragemodel.nd4jutils.Nd4jIOUtils) IOException(java.io.IOException) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) File(java.io.File) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VisibleForTesting(com.google.common.annotations.VisibleForTesting) Transforms(org.nd4j.linalg.ops.transforms.Transforms) LogManager(org.apache.logging.log4j.LogManager) NoBracketingException(org.apache.commons.math3.exception.NoBracketingException) INDArray(org.nd4j.linalg.api.ndarray.INDArray)

Example 5 with EmissionCalculationStrategy

use of org.broadinstitute.hellbender.tools.coveragemodel.CoverageModelCopyRatioEmissionProbabilityCalculator.EmissionCalculationStrategy in project gatk by broadinstitute.

the class CoverageModelEMWorkspace method getCopyRatioSegmentsLocal.

private List<List<HiddenStateSegmentRecord<STATE, Target>>> getCopyRatioSegmentsLocal() {
    final List<List<CoverageModelCopyRatioEmissionData>> copyRatioEmissionData = fetchCopyRatioEmissionDataLocal();
    final INDArray sampleReadDepths = Transforms.exp(sampleMeanLogReadDepths, true);
    return sampleIndexStream().mapToObj(si -> {
        final CopyRatioCallingMetadata metadata = CopyRatioCallingMetadata.builder().sampleName(processedSampleNameList.get(si)).sampleSexGenotypeData(processedSampleSexGenotypeData.get(si)).sampleCoverageDepth(sampleReadDepths.getDouble(si)).emissionCalculationStrategy(EmissionCalculationStrategy.HYBRID_POISSON_GAUSSIAN).build();
        return copyRatioExpectationsCalculator.getCopyRatioHMMResults(metadata, processedTargetList, copyRatioEmissionData.get(si));
    }).map(result -> {
        final HMMSegmentProcessor<CoverageModelCopyRatioEmissionData, STATE, Target> processor = new HMMSegmentProcessor<>(Collections.singletonList(result.getMetaData().getSampleName()), Collections.singletonList(result.getMetaData().getSampleSexGenotypeData()), referenceStateFactory, Collections.singletonList(new HashedListTargetCollection<>(processedTargetList)), Collections.singletonList(result.getForwardBackwardResult()), Collections.singletonList(result.getViterbiResult()));
        return processor.getSegmentsAsList();
    }).collect(Collectors.toList());
}
Also used : ScalarProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.ScalarProducer) Function2(org.apache.spark.api.java.function.Function2) HMMSegmentProcessor(org.broadinstitute.hellbender.utils.hmm.segmentation.HMMSegmentProcessor) GermlinePloidyAnnotatedTargetCollection(org.broadinstitute.hellbender.tools.exome.sexgenotyper.GermlinePloidyAnnotatedTargetCollection) HiddenStateSegmentRecordWriter(org.broadinstitute.hellbender.utils.hmm.segmentation.HiddenStateSegmentRecordWriter) BiFunction(java.util.function.BiFunction) GATKException(org.broadinstitute.hellbender.exceptions.GATKException) SexGenotypeData(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeData) ParamUtils(org.broadinstitute.hellbender.utils.param.ParamUtils) CallStringProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.CallStringProducer) StorageLevel(org.apache.spark.storage.StorageLevel) SynchronizedUnivariateSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.SynchronizedUnivariateSolver) CopyRatioExpectationsCalculator(org.broadinstitute.hellbender.tools.coveragemodel.interfaces.CopyRatioExpectationsCalculator) UnivariateSolverSpecifications(org.broadinstitute.hellbender.tools.coveragemodel.math.UnivariateSolverSpecifications) IndexRange(org.broadinstitute.hellbender.utils.IndexRange) Broadcast(org.apache.spark.broadcast.Broadcast) ExitStatus(org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray.ExitStatus) SexGenotypeDataCollection(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeDataCollection) HashPartitioner(org.apache.spark.HashPartitioner) Predicate(java.util.function.Predicate) GeneralLinearOperator(org.broadinstitute.hellbender.tools.coveragemodel.linalg.GeneralLinearOperator) Nd4j(org.nd4j.linalg.factory.Nd4j) INDArrayIndex(org.nd4j.linalg.indexing.INDArrayIndex) FastMath(org.apache.commons.math3.util.FastMath) org.broadinstitute.hellbender.tools.exome(org.broadinstitute.hellbender.tools.exome) Tuple2(scala.Tuple2) Collectors(java.util.stream.Collectors) Sets(com.google.common.collect.Sets) AbstractUnivariateSolver(org.apache.commons.math3.analysis.solvers.AbstractUnivariateSolver) FourierLinearOperatorNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.FourierLinearOperatorNDArray) Logger(org.apache.logging.log4j.Logger) Stream(java.util.stream.Stream) UserException(org.broadinstitute.hellbender.exceptions.UserException) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) Utils(org.broadinstitute.hellbender.utils.Utils) Function(org.apache.spark.api.java.function.Function) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer) IntStream(java.util.stream.IntStream) java.util(java.util) NDArrayIndex(org.nd4j.linalg.indexing.NDArrayIndex) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) AlleleMetadataProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.AlleleMetadataProducer) EmissionCalculationStrategy(org.broadinstitute.hellbender.tools.coveragemodel.CoverageModelCopyRatioEmissionProbabilityCalculator.EmissionCalculationStrategy) RobustBrentSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver) IntervalUtils(org.broadinstitute.hellbender.utils.IntervalUtils) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) HiddenStateSegmentRecord(org.broadinstitute.hellbender.utils.hmm.segmentation.HiddenStateSegmentRecord) ImmutableTriple(org.apache.commons.lang3.tuple.ImmutableTriple) IterativeLinearSolverNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray) GATKProtectedMathUtils(org.broadinstitute.hellbender.utils.GATKProtectedMathUtils) Nd4jIOUtils(org.broadinstitute.hellbender.tools.coveragemodel.nd4jutils.Nd4jIOUtils) IOException(java.io.IOException) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) File(java.io.File) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VisibleForTesting(com.google.common.annotations.VisibleForTesting) Transforms(org.nd4j.linalg.ops.transforms.Transforms) LogManager(org.apache.logging.log4j.LogManager) NoBracketingException(org.apache.commons.math3.exception.NoBracketingException) INDArray(org.nd4j.linalg.api.ndarray.INDArray) HMMSegmentProcessor(org.broadinstitute.hellbender.utils.hmm.segmentation.HMMSegmentProcessor)

Aggregations

VisibleForTesting (com.google.common.annotations.VisibleForTesting)8 Sets (com.google.common.collect.Sets)8 File (java.io.File)8 IOException (java.io.IOException)8 java.util (java.util)8 BiFunction (java.util.function.BiFunction)8 Predicate (java.util.function.Predicate)8 Collectors (java.util.stream.Collectors)8 IntStream (java.util.stream.IntStream)8 Stream (java.util.stream.Stream)8 Nonnull (javax.annotation.Nonnull)8 Nullable (javax.annotation.Nullable)8 ImmutablePair (org.apache.commons.lang3.tuple.ImmutablePair)8 ImmutableTriple (org.apache.commons.lang3.tuple.ImmutableTriple)8 UnivariateFunction (org.apache.commons.math3.analysis.UnivariateFunction)8 AbstractUnivariateSolver (org.apache.commons.math3.analysis.solvers.AbstractUnivariateSolver)8 NoBracketingException (org.apache.commons.math3.exception.NoBracketingException)8 TooManyEvaluationsException (org.apache.commons.math3.exception.TooManyEvaluationsException)8 FastMath (org.apache.commons.math3.util.FastMath)8 LogManager (org.apache.logging.log4j.LogManager)8