Search in sources :

Example 51 with Variance

use of org.apache.commons.math3.stat.descriptive.moment.Variance in project gatk-protected by broadinstitute.

the class CoverageModelParameters method adaptModelToReadCountCollection.

/**
     * This method "adapts" a model to a read count collection in the following sense:
     *
     *     - removes targets that are not included in the model from the read counts collection
     *     - removes targets that are in the read count collection from the model
     *     - rearranges model targets in the same order as read count collection targets
     *
     * The modifications are not done in-place and the original input parameters remain intact.
     *
     * @param model a model
     * @param readCounts a read count collection
     * @return a pair of model and read count collection
     */
public static ImmutablePair<CoverageModelParameters, ReadCountCollection> adaptModelToReadCountCollection(@Nonnull final CoverageModelParameters model, @Nonnull final ReadCountCollection readCounts, @Nonnull final Logger logger) {
    logger.info("Adapting model to read counts...");
    Utils.nonNull(model, "The model parameters must be non-null");
    Utils.nonNull(readCounts, "The read count collection must be non-null");
    Utils.nonNull(logger, "The logger must be non-null");
    final List<Target> modelTargetList = model.getTargetList();
    final List<Target> readCountsTargetList = readCounts.targets();
    final Set<Target> mutualTargetSet = Sets.intersection(new HashSet<>(modelTargetList), new HashSet<>(readCountsTargetList));
    final List<Target> mutualTargetList = readCountsTargetList.stream().filter(mutualTargetSet::contains).collect(Collectors.toList());
    logger.info("Number of mutual targets: " + mutualTargetList.size());
    Utils.validateArg(mutualTargetList.size() > 0, "The intersection between model targets and targets from read count" + " collection is empty. Please check there the model is compatible with the given read count" + " collection.");
    if (modelTargetList.size() > mutualTargetList.size()) {
        logger.info("The following targets dropped from the model: " + Sets.difference(new HashSet<>(modelTargetList), mutualTargetSet).stream().map(Target::getName).collect(Collectors.joining(", ", "[", "]")));
    }
    if (readCountsTargetList.size() > mutualTargetList.size()) {
        logger.info("The following targets dropped from read counts: " + Sets.difference(new HashSet<>(readCountsTargetList), mutualTargetSet).stream().map(Target::getName).collect(Collectors.joining(", ", "[", "]")));
    }
    /* the targets in {@code subsetReadCounts} follow the original order of targets in {@code readCounts} */
    final ReadCountCollection subsetReadCounts = readCounts.subsetTargets(mutualTargetSet);
    /* fetch original model parameters */
    final INDArray originalModelTargetMeanBias = model.getTargetMeanLogBias();
    final INDArray originalModelTargetUnexplainedVariance = model.getTargetUnexplainedVariance();
    final INDArray originalModelMeanBiasCovariates = model.getMeanBiasCovariates();
    /* re-arrange targets, mean log bias, and target-specific unexplained variance */
    final Map<Target, Integer> modelTargetsToIndexMap = IntStream.range(0, modelTargetList.size()).mapToObj(ti -> ImmutablePair.of(modelTargetList.get(ti), ti)).collect(Collectors.toMap(Pair<Target, Integer>::getLeft, Pair<Target, Integer>::getRight));
    final int[] newTargetIndicesInOriginalModel = mutualTargetList.stream().mapToInt(modelTargetsToIndexMap::get).toArray();
    final INDArray newModelTargetMeanBias = Nd4j.create(new int[] { 1, mutualTargetList.size() });
    final INDArray newModelTargetUnexplainedVariance = Nd4j.create(new int[] { 1, mutualTargetList.size() });
    IntStream.range(0, mutualTargetList.size()).forEach(ti -> {
        newModelTargetMeanBias.put(0, ti, originalModelTargetMeanBias.getDouble(0, newTargetIndicesInOriginalModel[ti]));
        newModelTargetUnexplainedVariance.put(0, ti, originalModelTargetUnexplainedVariance.getDouble(0, newTargetIndicesInOriginalModel[ti]));
    });
    /* if model has bias covariates and/or ARD, re-arrange mean/var of bias covariates as well */
    final INDArray newModelMeanBiasCovariates;
    if (model.isBiasCovariatesEnabled()) {
        newModelMeanBiasCovariates = Nd4j.create(new int[] { mutualTargetList.size(), model.getNumLatents() });
        IntStream.range(0, mutualTargetList.size()).forEach(ti -> {
            newModelMeanBiasCovariates.get(NDArrayIndex.point(ti), NDArrayIndex.all()).assign(originalModelMeanBiasCovariates.get(NDArrayIndex.point(newTargetIndicesInOriginalModel[ti]), NDArrayIndex.all()));
        });
    } else {
        newModelMeanBiasCovariates = null;
    }
    return ImmutablePair.of(new CoverageModelParameters(mutualTargetList, newModelTargetMeanBias, newModelTargetUnexplainedVariance, newModelMeanBiasCovariates, model.getBiasCovariateARDCoefficients()), subsetReadCounts);
}
Also used : IntStream(java.util.stream.IntStream) java.util(java.util) NDArrayIndex(org.nd4j.linalg.indexing.NDArrayIndex) Nd4jIOUtils(org.broadinstitute.hellbender.tools.coveragemodel.nd4jutils.Nd4jIOUtils) Nd4j(org.nd4j.linalg.factory.Nd4j) Collectors(java.util.stream.Collectors) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) ParamUtils(org.broadinstitute.hellbender.utils.param.ParamUtils) Sets(com.google.cloud.dataflow.sdk.repackaged.com.google.common.collect.Sets) Logger(org.apache.logging.log4j.Logger) ReadCountCollection(org.broadinstitute.hellbender.tools.exome.ReadCountCollection) Pair(org.apache.commons.lang3.tuple.Pair) UserException(org.broadinstitute.hellbender.exceptions.UserException) java.io(java.io) RandomGenerator(org.apache.commons.math3.random.RandomGenerator) RandomGeneratorFactory(org.apache.commons.math3.random.RandomGeneratorFactory) Target(org.broadinstitute.hellbender.tools.exome.Target) TargetTableReader(org.broadinstitute.hellbender.tools.exome.TargetTableReader) INDArray(org.nd4j.linalg.api.ndarray.INDArray) TargetWriter(org.broadinstitute.hellbender.tools.exome.TargetWriter) Utils(org.broadinstitute.hellbender.utils.Utils) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Target(org.broadinstitute.hellbender.tools.exome.Target) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ReadCountCollection(org.broadinstitute.hellbender.tools.exome.ReadCountCollection)

Example 52 with Variance

use of org.apache.commons.math3.stat.descriptive.moment.Variance in project gatk-protected by broadinstitute.

the class CoverageModelEMWorkspace method initializeWorkerPosteriors.

@UpdatesRDD
private void initializeWorkerPosteriors() {
    /* make a local copy for lambda capture (these are small, so no broadcasting is necessary) */
    final INDArray sampleMeanLogReadDepths = this.sampleMeanLogReadDepths;
    final INDArray sampleVarLogReadDepths = this.sampleVarLogReadDepths;
    final INDArray sampleUnexplainedVariance = this.sampleUnexplainedVariance;
    final INDArray sampleBiasLatentPosteriorFirstMoments = this.sampleBiasLatentPosteriorFirstMoments;
    final INDArray sampleBiasLatentPosteriorSecondMoments = this.sampleBiasLatentPosteriorSecondMoments;
    /* calculate copy ratio prior expectations */
    logger.info("Calculating copy ratio priors on the driver node...");
    final List<CopyRatioExpectations> copyRatioPriorExpectationsList = sampleIndexStream().mapToObj(si -> copyRatioExpectationsCalculator.getCopyRatioPriorExpectations(CopyRatioCallingMetadata.builder().sampleName(processedSampleNameList.get(si)).sampleSexGenotypeData(processedSampleSexGenotypeData.get(si)).sampleAverageMappingErrorProbability(config.getMappingErrorRate()).build(), processedTargetList)).collect(Collectors.toList());
    /* update log chain posterior expectation */
    sampleLogChainPosteriors.assign(Nd4j.create(copyRatioPriorExpectationsList.stream().mapToDouble(CopyRatioExpectations::getLogChainPosteriorProbability).toArray(), new int[] { numSamples, 1 }));
    /* push per-target copy ratio expectations to workers */
    final List<Tuple2<LinearlySpacedIndexBlock, Tuple2<INDArray, INDArray>>> copyRatioPriorsList = new ArrayList<>();
    for (final LinearlySpacedIndexBlock tb : targetBlocks) {
        final double[] logCopyRatioPriorMeansBlock = IntStream.range(0, tb.getNumElements()).mapToObj(rti -> copyRatioPriorExpectationsList.stream().mapToDouble(cre -> cre.getLogCopyRatioMeans()[rti + tb.getBegIndex()]).toArray()).flatMapToDouble(Arrays::stream).toArray();
        final double[] logCopyRatioPriorVariancesBlock = IntStream.range(0, tb.getNumElements()).mapToObj(rti -> copyRatioPriorExpectationsList.stream().mapToDouble(cre -> cre.getLogCopyRatioVariances()[rti + tb.getBegIndex()]).toArray()).flatMapToDouble(Arrays::stream).toArray();
        /* we do not need to take care of log copy ratio means and variances on masked targets here.
             * potential NaNs will be rectified in the compute blocks by calling the method
             * {@link CoverageModelEMComputeBlock#cloneWithInitializedData} */
        copyRatioPriorsList.add(new Tuple2<>(tb, new Tuple2<INDArray, INDArray>(Nd4j.create(logCopyRatioPriorMeansBlock, new int[] { numSamples, tb.getNumElements() }, 'f'), Nd4j.create(logCopyRatioPriorVariancesBlock, new int[] { numSamples, tb.getNumElements() }, 'f'))));
    }
    /* push to compute blocks */
    logger.info("Pushing posteriors to worker(s)...");
    /* copy ratio priors */
    joinWithWorkersAndMap(copyRatioPriorsList, p -> p._1.cloneWithUpdateCopyRatioPriors(p._2._1, p._2._2));
    /* read depth and sample-specific unexplained variance */
    mapWorkers(cb -> cb.cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.log_d_s, sampleMeanLogReadDepths).cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.var_log_d_s, sampleVarLogReadDepths).cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.gamma_s, sampleUnexplainedVariance));
    /* if bias covariates are enabled, initialize E[z] and E[z z^T] as well */
    if (biasCovariatesEnabled) {
        mapWorkers(cb -> cb.cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.z_sl, sampleBiasLatentPosteriorFirstMoments).cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.zz_sll, sampleBiasLatentPosteriorSecondMoments));
    }
}
Also used : ScalarProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.ScalarProducer) Function2(org.apache.spark.api.java.function.Function2) HMMSegmentProcessor(org.broadinstitute.hellbender.utils.hmm.segmentation.HMMSegmentProcessor) GermlinePloidyAnnotatedTargetCollection(org.broadinstitute.hellbender.tools.exome.sexgenotyper.GermlinePloidyAnnotatedTargetCollection) HiddenStateSegmentRecordWriter(org.broadinstitute.hellbender.utils.hmm.segmentation.HiddenStateSegmentRecordWriter) BiFunction(java.util.function.BiFunction) GATKException(org.broadinstitute.hellbender.exceptions.GATKException) SexGenotypeData(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeData) ParamUtils(org.broadinstitute.hellbender.utils.param.ParamUtils) CallStringProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.CallStringProducer) StorageLevel(org.apache.spark.storage.StorageLevel) SynchronizedUnivariateSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.SynchronizedUnivariateSolver) CopyRatioExpectationsCalculator(org.broadinstitute.hellbender.tools.coveragemodel.interfaces.CopyRatioExpectationsCalculator) UnivariateSolverSpecifications(org.broadinstitute.hellbender.tools.coveragemodel.math.UnivariateSolverSpecifications) IndexRange(org.broadinstitute.hellbender.utils.IndexRange) Broadcast(org.apache.spark.broadcast.Broadcast) ExitStatus(org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray.ExitStatus) SexGenotypeDataCollection(org.broadinstitute.hellbender.tools.exome.sexgenotyper.SexGenotypeDataCollection) HashPartitioner(org.apache.spark.HashPartitioner) Predicate(java.util.function.Predicate) GeneralLinearOperator(org.broadinstitute.hellbender.tools.coveragemodel.linalg.GeneralLinearOperator) Nd4j(org.nd4j.linalg.factory.Nd4j) INDArrayIndex(org.nd4j.linalg.indexing.INDArrayIndex) FastMath(org.apache.commons.math3.util.FastMath) org.broadinstitute.hellbender.tools.exome(org.broadinstitute.hellbender.tools.exome) Tuple2(scala.Tuple2) Collectors(java.util.stream.Collectors) Sets(com.google.common.collect.Sets) AbstractUnivariateSolver(org.apache.commons.math3.analysis.solvers.AbstractUnivariateSolver) FourierLinearOperatorNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.FourierLinearOperatorNDArray) Logger(org.apache.logging.log4j.Logger) Stream(java.util.stream.Stream) UserException(org.broadinstitute.hellbender.exceptions.UserException) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) Utils(org.broadinstitute.hellbender.utils.Utils) Function(org.apache.spark.api.java.function.Function) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer) IntStream(java.util.stream.IntStream) java.util(java.util) NDArrayIndex(org.nd4j.linalg.indexing.NDArrayIndex) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) AlleleMetadataProducer(org.broadinstitute.hellbender.utils.hmm.interfaces.AlleleMetadataProducer) EmissionCalculationStrategy(org.broadinstitute.hellbender.tools.coveragemodel.CoverageModelCopyRatioEmissionProbabilityCalculator.EmissionCalculationStrategy) RobustBrentSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver) IntervalUtils(org.broadinstitute.hellbender.utils.IntervalUtils) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) HiddenStateSegmentRecord(org.broadinstitute.hellbender.utils.hmm.segmentation.HiddenStateSegmentRecord) ImmutableTriple(org.apache.commons.lang3.tuple.ImmutableTriple) IterativeLinearSolverNDArray(org.broadinstitute.hellbender.tools.coveragemodel.linalg.IterativeLinearSolverNDArray) GATKProtectedMathUtils(org.broadinstitute.hellbender.utils.GATKProtectedMathUtils) Nd4jIOUtils(org.broadinstitute.hellbender.tools.coveragemodel.nd4jutils.Nd4jIOUtils) IOException(java.io.IOException) JavaPairRDD(org.apache.spark.api.java.JavaPairRDD) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) File(java.io.File) INDArray(org.nd4j.linalg.api.ndarray.INDArray) VisibleForTesting(com.google.common.annotations.VisibleForTesting) Transforms(org.nd4j.linalg.ops.transforms.Transforms) LogManager(org.apache.logging.log4j.LogManager) NoBracketingException(org.apache.commons.math3.exception.NoBracketingException) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Tuple2(scala.Tuple2)

Example 53 with Variance

use of org.apache.commons.math3.stat.descriptive.moment.Variance in project gatk by broadinstitute.

the class CoverageModelParameters method adaptModelToReadCountCollection.

/**
     * This method "adapts" a model to a read count collection in the following sense:
     *
     *     - removes targets that are not included in the model from the read counts collection
     *     - removes targets that are in the read count collection from the model
     *     - rearranges model targets in the same order as read count collection targets
     *
     * The modifications are not done in-place and the original input parameters remain intact.
     *
     * @param model a model
     * @param readCounts a read count collection
     * @return a pair of model and read count collection
     */
public static ImmutablePair<CoverageModelParameters, ReadCountCollection> adaptModelToReadCountCollection(@Nonnull final CoverageModelParameters model, @Nonnull final ReadCountCollection readCounts, @Nonnull final Logger logger) {
    logger.info("Adapting model to read counts...");
    Utils.nonNull(model, "The model parameters must be non-null");
    Utils.nonNull(readCounts, "The read count collection must be non-null");
    Utils.nonNull(logger, "The logger must be non-null");
    final List<Target> modelTargetList = model.getTargetList();
    final List<Target> readCountsTargetList = readCounts.targets();
    final Set<Target> mutualTargetSet = Sets.intersection(new HashSet<>(modelTargetList), new HashSet<>(readCountsTargetList));
    final List<Target> mutualTargetList = readCountsTargetList.stream().filter(mutualTargetSet::contains).collect(Collectors.toList());
    logger.info("Number of mutual targets: " + mutualTargetList.size());
    Utils.validateArg(mutualTargetList.size() > 0, "The intersection between model targets and targets from read count" + " collection is empty. Please check there the model is compatible with the given read count" + " collection.");
    if (modelTargetList.size() > mutualTargetList.size()) {
        logger.info("The following targets dropped from the model: " + Sets.difference(new HashSet<>(modelTargetList), mutualTargetSet).stream().map(Target::getName).collect(Collectors.joining(", ", "[", "]")));
    }
    if (readCountsTargetList.size() > mutualTargetList.size()) {
        logger.info("The following targets dropped from read counts: " + Sets.difference(new HashSet<>(readCountsTargetList), mutualTargetSet).stream().map(Target::getName).collect(Collectors.joining(", ", "[", "]")));
    }
    /* the targets in {@code subsetReadCounts} follow the original order of targets in {@code readCounts} */
    final ReadCountCollection subsetReadCounts = readCounts.subsetTargets(mutualTargetSet);
    /* fetch original model parameters */
    final INDArray originalModelTargetMeanBias = model.getTargetMeanLogBias();
    final INDArray originalModelTargetUnexplainedVariance = model.getTargetUnexplainedVariance();
    final INDArray originalModelMeanBiasCovariates = model.getMeanBiasCovariates();
    /* re-arrange targets, mean log bias, and target-specific unexplained variance */
    final Map<Target, Integer> modelTargetsToIndexMap = IntStream.range(0, modelTargetList.size()).mapToObj(ti -> ImmutablePair.of(modelTargetList.get(ti), ti)).collect(Collectors.toMap(Pair<Target, Integer>::getLeft, Pair<Target, Integer>::getRight));
    final int[] newTargetIndicesInOriginalModel = mutualTargetList.stream().mapToInt(modelTargetsToIndexMap::get).toArray();
    final INDArray newModelTargetMeanBias = Nd4j.create(new int[] { 1, mutualTargetList.size() });
    final INDArray newModelTargetUnexplainedVariance = Nd4j.create(new int[] { 1, mutualTargetList.size() });
    IntStream.range(0, mutualTargetList.size()).forEach(ti -> {
        newModelTargetMeanBias.put(0, ti, originalModelTargetMeanBias.getDouble(0, newTargetIndicesInOriginalModel[ti]));
        newModelTargetUnexplainedVariance.put(0, ti, originalModelTargetUnexplainedVariance.getDouble(0, newTargetIndicesInOriginalModel[ti]));
    });
    /* if model has bias covariates and/or ARD, re-arrange mean/var of bias covariates as well */
    final INDArray newModelMeanBiasCovariates;
    if (model.isBiasCovariatesEnabled()) {
        newModelMeanBiasCovariates = Nd4j.create(new int[] { mutualTargetList.size(), model.getNumLatents() });
        IntStream.range(0, mutualTargetList.size()).forEach(ti -> {
            newModelMeanBiasCovariates.get(NDArrayIndex.point(ti), NDArrayIndex.all()).assign(originalModelMeanBiasCovariates.get(NDArrayIndex.point(newTargetIndicesInOriginalModel[ti]), NDArrayIndex.all()));
        });
    } else {
        newModelMeanBiasCovariates = null;
    }
    return ImmutablePair.of(new CoverageModelParameters(mutualTargetList, newModelTargetMeanBias, newModelTargetUnexplainedVariance, newModelMeanBiasCovariates, model.getBiasCovariateARDCoefficients()), subsetReadCounts);
}
Also used : IntStream(java.util.stream.IntStream) java.util(java.util) NDArrayIndex(org.nd4j.linalg.indexing.NDArrayIndex) Nd4jIOUtils(org.broadinstitute.hellbender.tools.coveragemodel.nd4jutils.Nd4jIOUtils) Nd4j(org.nd4j.linalg.factory.Nd4j) Collectors(java.util.stream.Collectors) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) ParamUtils(org.broadinstitute.hellbender.utils.param.ParamUtils) Sets(com.google.cloud.dataflow.sdk.repackaged.com.google.common.collect.Sets) Logger(org.apache.logging.log4j.Logger) ReadCountCollection(org.broadinstitute.hellbender.tools.exome.ReadCountCollection) Pair(org.apache.commons.lang3.tuple.Pair) UserException(org.broadinstitute.hellbender.exceptions.UserException) java.io(java.io) RandomGenerator(org.apache.commons.math3.random.RandomGenerator) RandomGeneratorFactory(org.apache.commons.math3.random.RandomGeneratorFactory) Target(org.broadinstitute.hellbender.tools.exome.Target) TargetTableReader(org.broadinstitute.hellbender.tools.exome.TargetTableReader) INDArray(org.nd4j.linalg.api.ndarray.INDArray) TargetWriter(org.broadinstitute.hellbender.tools.exome.TargetWriter) Utils(org.broadinstitute.hellbender.utils.Utils) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Target(org.broadinstitute.hellbender.tools.exome.Target) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ReadCountCollection(org.broadinstitute.hellbender.tools.exome.ReadCountCollection)

Example 54 with Variance

use of org.apache.commons.math3.stat.descriptive.moment.Variance in project gatk by broadinstitute.

the class CopyRatioModellerUnitTest method testRunMCMCOnCopyRatioSegmentedGenome.

/**
     * Tests Bayesian inference of the copy-ratio model via MCMC.
     * <p>
     *     Recovery of input values for the variance and outlier-probability global parameters is checked.
     *     In particular, the true input value of the variance must fall within
     *     {@link CopyRatioModellerUnitTest#MULTIPLES_OF_SD_THRESHOLD}
     *     standard deviations of the posterior mean and the standard deviation of the posterior must agree
     *     with the analytic value to within a relative error of
     *     {@link CopyRatioModellerUnitTest#RELATIVE_ERROR_THRESHOLD} for 250 samples
     *     (after 250 burn-in samples have been discarded).  Similar criteria are applied
     *     to the recovery of the true input value for the outlier probability.
     * </p>
     * <p>
     *     Furthermore, the number of truth values for the segment-level means falling outside confidence intervals of
     *     1-sigma, 2-sigma, and 3-sigma given by the posteriors in each segment should be roughly consistent with
     *     a normal distribution (i.e., ~32, ~5, and ~0, respectively; we allow for errors of
     *     {@link CopyRatioModellerUnitTest#DELTA_NUMBER_OF_MEANS_ALLOWED_OUTSIDE_1_SIGMA},
     *     {@link CopyRatioModellerUnitTest#DELTA_NUMBER_OF_MEANS_ALLOWED_OUTSIDE_2_SIGMA}, and
     *     {@link CopyRatioModellerUnitTest#DELTA_NUMBER_OF_MEANS_ALLOWED_OUTSIDE_3_SIGMA}, respectively).
     *     The mean of the standard deviations of the posteriors for the segment-level means should also be
     *     recovered to within a relative error of {@link CopyRatioModellerUnitTest#RELATIVE_ERROR_THRESHOLD}.
     * </p>
     * <p>
     *     Finally, the recovered values for the latent outlier-indicator parameters should agree with those used to
     *     generate the data.  For each indicator, the recovered value (i.e., outlier or non-outlier) is taken to be
     *     that given by the majority of posterior samples.  We require that at least
     *     {@link CopyRatioModellerUnitTest#FRACTION_OF_OUTLIER_INDICATORS_CORRECT_THRESHOLD}
     *     of the 10000 indicators are recovered correctly.
     * </p>
     * <p>
     *     With these specifications, this unit test is not overly brittle (i.e., it should pass for a large majority
     *     of randomly generated data sets), but it is still brittle enough to check for correctness of the sampling
     *     (for example, specifying a sufficiently incorrect likelihood will cause the test to fail).
     * </p>
     */
@Test
public void testRunMCMCOnCopyRatioSegmentedGenome() throws IOException {
    final JavaSparkContext ctx = SparkContextFactory.getTestSparkContext();
    LoggingUtils.setLoggingLevel(Log.LogLevel.INFO);
    //load data (coverages and number of targets in each segment)
    final ReadCountCollection coverage = ReadCountCollectionUtils.parse(COVERAGES_FILE);
    //Genome with no SNPs
    final Genome genome = new Genome(coverage, Collections.emptyList());
    final SegmentedGenome segmentedGenome = new SegmentedGenome(SEGMENT_FILE, genome);
    //run MCMC
    final CopyRatioModeller modeller = new CopyRatioModeller(segmentedGenome);
    modeller.fitMCMC(NUM_SAMPLES, NUM_BURN_IN);
    //check statistics of global-parameter posterior samples (i.e., posterior mode and standard deviation)
    final Map<CopyRatioParameter, PosteriorSummary> globalParameterPosteriorSummaries = modeller.getGlobalParameterPosteriorSummaries(CREDIBLE_INTERVAL_ALPHA, ctx);
    final PosteriorSummary variancePosteriorSummary = globalParameterPosteriorSummaries.get(CopyRatioParameter.VARIANCE);
    final double variancePosteriorCenter = variancePosteriorSummary.getCenter();
    final double variancePosteriorStandardDeviation = (variancePosteriorSummary.getUpper() - variancePosteriorSummary.getLower()) / 2;
    Assert.assertEquals(Math.abs(variancePosteriorCenter - VARIANCE_TRUTH), 0., MULTIPLES_OF_SD_THRESHOLD * VARIANCE_POSTERIOR_STANDARD_DEVIATION_TRUTH);
    Assert.assertEquals(relativeError(variancePosteriorStandardDeviation, VARIANCE_POSTERIOR_STANDARD_DEVIATION_TRUTH), 0., RELATIVE_ERROR_THRESHOLD);
    final PosteriorSummary outlierProbabilityPosteriorSummary = globalParameterPosteriorSummaries.get(CopyRatioParameter.OUTLIER_PROBABILITY);
    final double outlierProbabilityPosteriorCenter = outlierProbabilityPosteriorSummary.getCenter();
    final double outlierProbabilityPosteriorStandardDeviation = (outlierProbabilityPosteriorSummary.getUpper() - outlierProbabilityPosteriorSummary.getLower()) / 2;
    Assert.assertEquals(Math.abs(outlierProbabilityPosteriorCenter - OUTLIER_PROBABILITY_TRUTH), 0., MULTIPLES_OF_SD_THRESHOLD * OUTLIER_PROBABILITY_POSTERIOR_STANDARD_DEVIATION_TRUTH);
    Assert.assertEquals(relativeError(outlierProbabilityPosteriorStandardDeviation, OUTLIER_PROBABILITY_POSTERIOR_STANDARD_DEVIATION_TRUTH), 0., RELATIVE_ERROR_THRESHOLD);
    //check statistics of segment-mean posterior samples (i.e., posterior means and standard deviations)
    final List<Double> meansTruth = loadList(MEANS_TRUTH_FILE, Double::parseDouble);
    int numMeansOutsideOneSigma = 0;
    int numMeansOutsideTwoSigma = 0;
    int numMeansOutsideThreeSigma = 0;
    final int numSegments = meansTruth.size();
    //segment-mean posteriors are expected to be Gaussian, so PosteriorSummary for
    // {@link CopyRatioModellerUnitTest#CREDIBLE_INTERVAL_ALPHA}=0.32 is
    //(posterior mean, posterior mean - posterior standard devation, posterior mean + posterior standard deviation)
    final List<PosteriorSummary> meanPosteriorSummaries = modeller.getSegmentMeansPosteriorSummaries(CREDIBLE_INTERVAL_ALPHA, ctx);
    final double[] meanPosteriorStandardDeviations = new double[numSegments];
    for (int segment = 0; segment < numSegments; segment++) {
        final double meanPosteriorCenter = meanPosteriorSummaries.get(segment).getCenter();
        final double meanPosteriorStandardDeviation = (meanPosteriorSummaries.get(segment).getUpper() - meanPosteriorSummaries.get(segment).getLower()) / 2.;
        meanPosteriorStandardDeviations[segment] = meanPosteriorStandardDeviation;
        final double absoluteDifferenceFromTruth = Math.abs(meanPosteriorCenter - meansTruth.get(segment));
        if (absoluteDifferenceFromTruth > meanPosteriorStandardDeviation) {
            numMeansOutsideOneSigma++;
        }
        if (absoluteDifferenceFromTruth > 2 * meanPosteriorStandardDeviation) {
            numMeansOutsideTwoSigma++;
        }
        if (absoluteDifferenceFromTruth > 3 * meanPosteriorStandardDeviation) {
            numMeansOutsideThreeSigma++;
        }
    }
    final double meanPosteriorStandardDeviationsMean = new Mean().evaluate(meanPosteriorStandardDeviations);
    Assert.assertEquals(numMeansOutsideOneSigma, 100 - 68, DELTA_NUMBER_OF_MEANS_ALLOWED_OUTSIDE_1_SIGMA);
    Assert.assertEquals(numMeansOutsideTwoSigma, 100 - 95, DELTA_NUMBER_OF_MEANS_ALLOWED_OUTSIDE_2_SIGMA);
    Assert.assertTrue(numMeansOutsideThreeSigma <= DELTA_NUMBER_OF_MEANS_ALLOWED_OUTSIDE_3_SIGMA);
    Assert.assertEquals(relativeError(meanPosteriorStandardDeviationsMean, MEAN_POSTERIOR_STANDARD_DEVIATION_MEAN_TRUTH), 0., RELATIVE_ERROR_THRESHOLD);
    //check accuracy of latent outlier-indicator posterior samples
    final List<CopyRatioState.OutlierIndicators> outlierIndicatorSamples = modeller.getOutlierIndicatorsSamples();
    int numIndicatorsCorrect = 0;
    final int numIndicatorSamples = outlierIndicatorSamples.size();
    final List<Integer> outlierIndicatorsTruthAsInt = loadList(OUTLIER_INDICATORS_TRUTH_FILE, Integer::parseInt);
    final List<Boolean> outlierIndicatorsTruth = outlierIndicatorsTruthAsInt.stream().map(i -> i == 1).collect(Collectors.toList());
    for (int target = 0; target < coverage.targets().size(); target++) {
        int numSamplesOutliers = 0;
        for (final CopyRatioState.OutlierIndicators sample : outlierIndicatorSamples) {
            if (sample.get(target)) {
                numSamplesOutliers++;
            }
        }
        //take predicted state of indicator to be given by the majority of samples
        if ((numSamplesOutliers >= numIndicatorSamples / 2.) == outlierIndicatorsTruth.get(target)) {
            numIndicatorsCorrect++;
        }
    }
    final double fractionOfOutlierIndicatorsCorrect = (double) numIndicatorsCorrect / coverage.targets().size();
    Assert.assertTrue(fractionOfOutlierIndicatorsCorrect >= FRACTION_OF_OUTLIER_INDICATORS_CORRECT_THRESHOLD);
}
Also used : BaseTest(org.broadinstitute.hellbender.utils.test.BaseTest) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) Genome(org.broadinstitute.hellbender.tools.exome.Genome) FileUtils(org.apache.commons.io.FileUtils) Test(org.testng.annotations.Test) IOException(java.io.IOException) Function(java.util.function.Function) Collectors(java.util.stream.Collectors) File(java.io.File) Mean(org.apache.commons.math3.stat.descriptive.moment.Mean) List(java.util.List) Log(htsjdk.samtools.util.Log) ReadCountCollection(org.broadinstitute.hellbender.tools.exome.ReadCountCollection) UserException(org.broadinstitute.hellbender.exceptions.UserException) Assert(org.testng.Assert) PosteriorSummary(org.broadinstitute.hellbender.utils.mcmc.PosteriorSummary) ReadCountCollectionUtils(org.broadinstitute.hellbender.tools.exome.ReadCountCollectionUtils) Map(java.util.Map) SparkContextFactory(org.broadinstitute.hellbender.engine.spark.SparkContextFactory) SegmentedGenome(org.broadinstitute.hellbender.tools.exome.SegmentedGenome) LoggingUtils(org.broadinstitute.hellbender.utils.LoggingUtils) Collections(java.util.Collections) Mean(org.apache.commons.math3.stat.descriptive.moment.Mean) ReadCountCollection(org.broadinstitute.hellbender.tools.exome.ReadCountCollection) PosteriorSummary(org.broadinstitute.hellbender.utils.mcmc.PosteriorSummary) SegmentedGenome(org.broadinstitute.hellbender.tools.exome.SegmentedGenome) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) Genome(org.broadinstitute.hellbender.tools.exome.Genome) SegmentedGenome(org.broadinstitute.hellbender.tools.exome.SegmentedGenome) BaseTest(org.broadinstitute.hellbender.utils.test.BaseTest) Test(org.testng.annotations.Test)

Example 55 with Variance

use of org.apache.commons.math3.stat.descriptive.moment.Variance in project gatk-protected by broadinstitute.

the class MatrixSummaryUtils method getRowVariances.

/**
     * Return an array containing the variance for each row in the given matrix.
     * @param m Not {@code null}.  Size MxN.    If any entry is NaN, the corresponding rows will have a
     *          variance of NaN.
     * @return array of size M.  Never {@code null}  IF there is only one column (or only one entry
     */
public static double[] getRowVariances(final RealMatrix m) {
    Utils.nonNull(m, "Cannot calculate medians on a null matrix.");
    final StandardDeviation std = new StandardDeviation();
    return IntStream.range(0, m.getRowDimension()).boxed().mapToDouble(i -> Math.pow(std.evaluate(m.getRow(i)), 2)).toArray();
}
Also used : IntStream(java.util.stream.IntStream) Median(org.apache.commons.math3.stat.descriptive.rank.Median) StandardDeviation(org.apache.commons.math3.stat.descriptive.moment.StandardDeviation) RealMatrix(org.apache.commons.math3.linear.RealMatrix) StandardDeviation(org.apache.commons.math3.stat.descriptive.moment.StandardDeviation)

Aggregations

Collectors (java.util.stream.Collectors)24 IntStream (java.util.stream.IntStream)24 ParamUtils (org.broadinstitute.hellbender.utils.param.ParamUtils)22 Nonnull (javax.annotation.Nonnull)20 RandomGenerator (org.apache.commons.math3.random.RandomGenerator)18 Variance (org.apache.commons.math3.stat.descriptive.moment.Variance)18 List (java.util.List)16 FastMath (org.apache.commons.math3.util.FastMath)16 Utils (org.broadinstitute.hellbender.utils.Utils)16 INDArray (org.nd4j.linalg.api.ndarray.INDArray)16 Function (java.util.function.Function)15 Arrays (java.util.Arrays)14 Nullable (javax.annotation.Nullable)14 ImmutablePair (org.apache.commons.lang3.tuple.ImmutablePair)14 RealMatrix (org.apache.commons.math3.linear.RealMatrix)14 Logger (org.apache.logging.log4j.Logger)14 GATKException (org.broadinstitute.hellbender.exceptions.GATKException)14 UserException (org.broadinstitute.hellbender.exceptions.UserException)14 Nd4j (org.nd4j.linalg.factory.Nd4j)14 NDArrayIndex (org.nd4j.linalg.indexing.NDArrayIndex)14