Search in sources :

Example 16 with Max

use of org.apache.commons.math3.stat.descriptive.rank.Max in project gatk by broadinstitute.

the class ReadCountCollectionUtilsUnitTest method testRemoveTargetsWithTooManyZeros.

@Test(dataProvider = "tooManyZerosData")
public void testRemoveTargetsWithTooManyZeros(final ReadCountCollection readCount) {
    final RealMatrix counts = readCount.counts();
    final int[] numberOfZeros = IntStream.range(0, counts.getRowDimension()).map(i -> (int) DoubleStream.of(counts.getRow(i)).filter(d -> d == 0.0).count()).toArray();
    final int maximumNumberOfZeros = IntStream.of(numberOfZeros).max().getAsInt();
    for (int maxZeros = 0; maxZeros < maximumNumberOfZeros; maxZeros++) {
        final int maxZerosThres = maxZeros;
        final int expectedRemainingCount = (int) IntStream.of(numberOfZeros).filter(i -> i <= maxZerosThres).count();
        if (expectedRemainingCount == 0) {
            try {
                ReadCountCollectionUtils.removeTargetsWithTooManyZeros(readCount, maxZeros, false, NULL_LOGGER);
            } catch (final UserException.BadInput ex) {
                // expected.
                continue;
            }
            Assert.fail("expects an exception");
        }
        final ReadCountCollection rc = ReadCountCollectionUtils.removeTargetsWithTooManyZeros(readCount, maxZeros, false, NULL_LOGGER);
        Assert.assertEquals(rc.targets().size(), expectedRemainingCount);
        int nextIndex = 0;
        for (int i = 0; i < readCount.targets().size(); i++) {
            final Target target = readCount.targets().get(i);
            final int newIndex = rc.targets().indexOf(target);
            if (numberOfZeros[i] <= maxZeros) {
                Assert.assertTrue(newIndex >= 0, " " + numberOfZeros[i] + " " + maxZeros);
                Assert.assertEquals(newIndex, nextIndex++);
            } else {
                Assert.assertEquals(newIndex, -1);
            }
        }
        Assert.assertEquals(nextIndex, expectedRemainingCount);
    }
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) DataProvider(org.testng.annotations.DataProvider) Level(org.apache.logging.log4j.Level) Test(org.testng.annotations.Test) Random(java.util.Random) ArrayList(java.util.ArrayList) Message(org.apache.logging.log4j.message.Message) Assert(org.testng.Assert) Median(org.apache.commons.math3.stat.descriptive.rank.Median) Marker(org.apache.logging.log4j.Marker) AbstractLogger(org.apache.logging.log4j.spi.AbstractLogger) PrintWriter(java.io.PrintWriter) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix) IOException(java.io.IOException) SimpleInterval(org.broadinstitute.hellbender.utils.SimpleInterval) Collectors(java.util.stream.Collectors) File(java.io.File) DoubleStream(java.util.stream.DoubleStream) List(java.util.List) Percentile(org.apache.commons.math3.stat.descriptive.rank.Percentile) Logger(org.apache.logging.log4j.Logger) Stream(java.util.stream.Stream) UserException(org.broadinstitute.hellbender.exceptions.UserException) RealMatrix(org.apache.commons.math3.linear.RealMatrix) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix) RealMatrix(org.apache.commons.math3.linear.RealMatrix) UserException(org.broadinstitute.hellbender.exceptions.UserException) Test(org.testng.annotations.Test)

Example 17 with Max

use of org.apache.commons.math3.stat.descriptive.rank.Max in project gatk by broadinstitute.

the class ReadCountCollectionUtilsUnitTest method testRemoveColumnsWithTooManyZeros.

@Test(dataProvider = "tooManyZerosData")
public void testRemoveColumnsWithTooManyZeros(final ReadCountCollection readCount) {
    final RealMatrix counts = readCount.counts();
    final int[] numberOfZeros = IntStream.range(0, counts.getColumnDimension()).map(i -> (int) DoubleStream.of(counts.getColumn(i)).filter(d -> d == 0.0).count()).toArray();
    final int maximumNumberOfZeros = IntStream.of(numberOfZeros).max().getAsInt();
    for (int maxZeros = 0; maxZeros < maximumNumberOfZeros; maxZeros++) {
        final int maxZerosThres = maxZeros;
        final int expectedRemainingCount = (int) IntStream.of(numberOfZeros).filter(i -> i <= maxZerosThres).count();
        if (expectedRemainingCount == 0) {
            try {
                ReadCountCollectionUtils.removeColumnsWithTooManyZeros(readCount, maxZeros, false, NULL_LOGGER);
            } catch (final UserException.BadInput ex) {
                // expected.
                continue;
            }
            Assert.fail("expects an exception");
        }
        final ReadCountCollection rc = ReadCountCollectionUtils.removeColumnsWithTooManyZeros(readCount, maxZeros, false, NULL_LOGGER);
        Assert.assertEquals(rc.columnNames().size(), expectedRemainingCount);
        final int[] newIndices = new int[expectedRemainingCount];
        int nextIndex = 0;
        for (int i = 0; i < readCount.columnNames().size(); i++) {
            final String name = readCount.columnNames().get(i);
            final int newIndex = rc.columnNames().indexOf(name);
            if (numberOfZeros[i] <= maxZeros) {
                Assert.assertTrue(newIndex >= 0);
                newIndices[nextIndex++] = i;
            } else {
                Assert.assertEquals(newIndex, -1);
            }
        }
        Assert.assertEquals(nextIndex, expectedRemainingCount);
        for (int i = 1; i < newIndices.length; i++) {
            Assert.assertTrue(newIndices[i - 1] < newIndices[i]);
        }
    }
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) DataProvider(org.testng.annotations.DataProvider) Level(org.apache.logging.log4j.Level) Test(org.testng.annotations.Test) Random(java.util.Random) ArrayList(java.util.ArrayList) Message(org.apache.logging.log4j.message.Message) Assert(org.testng.Assert) Median(org.apache.commons.math3.stat.descriptive.rank.Median) Marker(org.apache.logging.log4j.Marker) AbstractLogger(org.apache.logging.log4j.spi.AbstractLogger) PrintWriter(java.io.PrintWriter) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix) IOException(java.io.IOException) SimpleInterval(org.broadinstitute.hellbender.utils.SimpleInterval) Collectors(java.util.stream.Collectors) File(java.io.File) DoubleStream(java.util.stream.DoubleStream) List(java.util.List) Percentile(org.apache.commons.math3.stat.descriptive.rank.Percentile) Logger(org.apache.logging.log4j.Logger) Stream(java.util.stream.Stream) UserException(org.broadinstitute.hellbender.exceptions.UserException) RealMatrix(org.apache.commons.math3.linear.RealMatrix) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix) RealMatrix(org.apache.commons.math3.linear.RealMatrix) UserException(org.broadinstitute.hellbender.exceptions.UserException) Test(org.testng.annotations.Test)

Example 18 with Max

use of org.apache.commons.math3.stat.descriptive.rank.Max in project gatk-protected by broadinstitute.

the class PosteriorSummaryUtils method calculatePosteriorMode.

/**
     * Given a list of posterior samples, returns an estimate of the posterior mode (using
     * mllib kernel density estimation in {@link KernelDensity} and {@link BrentOptimizer}).
     * Note that estimate may be poor if number of samples is small (resulting in poor kernel density estimation),
     * or if posterior is not unimodal (or is sufficiently pathological otherwise). If the samples contain
     * {@link Double#NaN}, {@link Double#NaN} will be returned.
     * @param samples   posterior samples, cannot be {@code null} and number of samples must be greater than 0
     * @param ctx       {@link JavaSparkContext} used by {@link KernelDensity} for mllib kernel density estimation
     */
public static double calculatePosteriorMode(final List<Double> samples, final JavaSparkContext ctx) {
    Utils.nonNull(samples);
    Utils.validateArg(samples.size() > 0, "Number of samples must be greater than zero.");
    //calculate sample min, max, mean, and standard deviation
    final double sampleMin = Collections.min(samples);
    final double sampleMax = Collections.max(samples);
    final double sampleMean = new Mean().evaluate(Doubles.toArray(samples));
    final double sampleStandardDeviation = new StandardDeviation().evaluate(Doubles.toArray(samples));
    //if samples are all the same or contain NaN, can simply return mean
    if (sampleStandardDeviation == 0. || Double.isNaN(sampleMean)) {
        return sampleMean;
    }
    //use Silverman's rule to set bandwidth for kernel density estimation from sample standard deviation
    //see https://en.wikipedia.org/wiki/Kernel_density_estimation#Practical_estimation_of_the_bandwidth
    final double bandwidth = SILVERMANS_RULE_CONSTANT * sampleStandardDeviation * Math.pow(samples.size(), SILVERMANS_RULE_EXPONENT);
    //use kernel density estimation to approximate posterior from samples
    final KernelDensity pdf = new KernelDensity().setSample(ctx.parallelize(samples, 1)).setBandwidth(bandwidth);
    //use Brent optimization to find mode (i.e., maximum) of kernel-density-estimated posterior
    final BrentOptimizer optimizer = new BrentOptimizer(RELATIVE_TOLERANCE, RELATIVE_TOLERANCE * (sampleMax - sampleMin));
    final UnivariateObjectiveFunction objective = new UnivariateObjectiveFunction(f -> pdf.estimate(new double[] { f })[0]);
    //search for mode within sample range, start near sample mean
    final SearchInterval searchInterval = new SearchInterval(sampleMin, sampleMax, sampleMean);
    return optimizer.optimize(objective, GoalType.MAXIMIZE, searchInterval, BRENT_MAX_EVAL).getPoint();
}
Also used : Mean(org.apache.commons.math3.stat.descriptive.moment.Mean) SearchInterval(org.apache.commons.math3.optim.univariate.SearchInterval) UnivariateObjectiveFunction(org.apache.commons.math3.optim.univariate.UnivariateObjectiveFunction) BrentOptimizer(org.apache.commons.math3.optim.univariate.BrentOptimizer) KernelDensity(org.apache.spark.mllib.stat.KernelDensity) StandardDeviation(org.apache.commons.math3.stat.descriptive.moment.StandardDeviation)

Example 19 with Max

use of org.apache.commons.math3.stat.descriptive.rank.Max in project gatk by broadinstitute.

the class CoverageModelParameters method generateRandomModel.

/**
     * Generates random coverage model parameters.
     *
     * @param targetList list of targets
     * @param numLatents number of latent variables
     * @param seed random seed
     * @param randomMeanLogBiasStandardDeviation std of mean log bias (mean is set to 0)
     * @param randomBiasCovariatesStandardDeviation std of bias covariates (mean is set to 0)
     * @param randomMaxUnexplainedVariance max value of unexplained variance (samples are taken from a uniform
     *                                     distribution [0, {@code randomMaxUnexplainedVariance}])
     * @param initialBiasCovariatesARDCoefficients initial row vector of ARD coefficients
     * @return an instance of {@link CoverageModelParameters}
     */
public static CoverageModelParameters generateRandomModel(final List<Target> targetList, final int numLatents, final long seed, final double randomMeanLogBiasStandardDeviation, final double randomBiasCovariatesStandardDeviation, final double randomMaxUnexplainedVariance, final INDArray initialBiasCovariatesARDCoefficients) {
    Utils.validateArg(numLatents >= 0, "Dimension of the bias space must be non-negative");
    Utils.validateArg(randomBiasCovariatesStandardDeviation >= 0, "Standard deviation of random bias covariates" + " must be non-negative");
    Utils.validateArg(randomMeanLogBiasStandardDeviation >= 0, "Standard deviation of random mean log bias" + " must be non-negative");
    Utils.validateArg(randomMaxUnexplainedVariance >= 0, "Max random unexplained variance must be non-negative");
    Utils.validateArg(initialBiasCovariatesARDCoefficients == null || numLatents > 0 && initialBiasCovariatesARDCoefficients.length() == numLatents, "If ARD is enabled, the dimension" + " of the bias latent space must be positive and match the length of ARD coeffecient vector");
    final boolean biasCovariatesEnabled = numLatents > 0;
    final int numTargets = targetList.size();
    final RandomGenerator rng = RandomGeneratorFactory.createRandomGenerator(new Random(seed));
    /* Gaussian random for mean log bias */
    final INDArray initialMeanLogBias = Nd4j.create(getNormalRandomNumbers(numTargets, 0, randomMeanLogBiasStandardDeviation, rng), new int[] { 1, numTargets });
    /* Uniform random for unexplained variance */
    final INDArray initialUnexplainedVariance = Nd4j.create(getUniformRandomNumbers(numTargets, 0, randomMaxUnexplainedVariance, rng), new int[] { 1, numTargets });
    final INDArray initialMeanBiasCovariates;
    if (biasCovariatesEnabled) {
        /* Gaussian random for bias covariates */
        initialMeanBiasCovariates = Nd4j.create(getNormalRandomNumbers(numTargets * numLatents, 0, randomBiasCovariatesStandardDeviation, rng), new int[] { numTargets, numLatents });
    } else {
        initialMeanBiasCovariates = null;
    }
    return new CoverageModelParameters(targetList, initialMeanLogBias, initialUnexplainedVariance, initialMeanBiasCovariates, initialBiasCovariatesARDCoefficients);
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) RandomGenerator(org.apache.commons.math3.random.RandomGenerator)

Example 20 with Max

use of org.apache.commons.math3.stat.descriptive.rank.Max in project gatk by broadinstitute.

the class AlleleFrequencyCalculator method getLog10PNonRef.

//TODO: this should be a class of static methods once the old AFCalculator is gone.
/**
     * Compute the probability of the alleles segregating given the genotype likelihoods of the samples in vc
     *
     * @param vc the VariantContext holding the alleles and sample information.  The VariantContext
     *           must have at least 1 alternative allele
     * @param refSnpIndelPseudocounts a total hack.  A length-3 vector containing Dirichlet prior pseudocounts to
     *                                be given to ref, alt SNP, and alt indel alleles.  Hack won't be necessary when we destroy the old AF calculators
     * @return result (for programming convenience)
     */
@Override
public AFCalculationResult getLog10PNonRef(final VariantContext vc, final int defaultPloidy, final int maximumAlternativeAlleles, final double[] refSnpIndelPseudocounts) {
    Utils.nonNull(vc, "VariantContext cannot be null");
    final int numAlleles = vc.getNAlleles();
    final List<Allele> alleles = vc.getAlleles();
    Utils.validateArg(numAlleles > 1, () -> "VariantContext has only a single reference allele, but getLog10PNonRef requires at least one at all " + vc);
    final double[] priorPseudocounts = alleles.stream().mapToDouble(a -> a.isReference() ? refPseudocount : (a.length() > 1 ? snpPseudocount : indelPseudocount)).toArray();
    double[] alleleCounts = new double[numAlleles];
    // log10(1/numAlleles)
    final double flatLog10AlleleFrequency = -MathUtils.log10(numAlleles);
    double[] log10AlleleFrequencies = new IndexRange(0, numAlleles).mapToDouble(n -> flatLog10AlleleFrequency);
    double alleleCountsMaximumDifference = Double.POSITIVE_INFINITY;
    while (alleleCountsMaximumDifference > THRESHOLD_FOR_ALLELE_COUNT_CONVERGENCE) {
        final double[] newAlleleCounts = effectiveAlleleCounts(vc, log10AlleleFrequencies);
        alleleCountsMaximumDifference = Arrays.stream(MathArrays.ebeSubtract(alleleCounts, newAlleleCounts)).map(Math::abs).max().getAsDouble();
        alleleCounts = newAlleleCounts;
        final double[] posteriorPseudocounts = MathArrays.ebeAdd(priorPseudocounts, alleleCounts);
        // first iteration uses flat prior in order to avoid local minimum where the prior + no pseudocounts gives such a low
        // effective allele frequency that it overwhelms the genotype likelihood of a real variant
        // basically, we want a chance to get non-zero pseudocounts before using a prior that's biased against a variant
        log10AlleleFrequencies = new Dirichlet(posteriorPseudocounts).log10MeanWeights();
    }
    double[] log10POfZeroCountsByAllele = new double[numAlleles];
    double log10PNoVariant = 0;
    for (final Genotype g : vc.getGenotypes()) {
        if (!g.hasLikelihoods()) {
            continue;
        }
        final int ploidy = g.getPloidy() == 0 ? defaultPloidy : g.getPloidy();
        final GenotypeLikelihoodCalculator glCalc = GL_CALCS.getInstance(ploidy, numAlleles);
        final double[] log10GenotypePosteriors = log10NormalizedGenotypePosteriors(g, glCalc, log10AlleleFrequencies);
        //the total probability
        log10PNoVariant += log10GenotypePosteriors[HOM_REF_GENOTYPE_INDEX];
        // per allele non-log space probabilities of zero counts for this sample
        // for each allele calculate the total probability of genotypes containing at least one copy of the allele
        final double[] log10ProbabilityOfNonZeroAltAlleles = new double[numAlleles];
        Arrays.fill(log10ProbabilityOfNonZeroAltAlleles, Double.NEGATIVE_INFINITY);
        for (int genotype = 0; genotype < glCalc.genotypeCount(); genotype++) {
            final double log10GenotypePosterior = log10GenotypePosteriors[genotype];
            glCalc.genotypeAlleleCountsAt(genotype).forEachAlleleIndexAndCount((alleleIndex, count) -> log10ProbabilityOfNonZeroAltAlleles[alleleIndex] = MathUtils.log10SumLog10(log10ProbabilityOfNonZeroAltAlleles[alleleIndex], log10GenotypePosterior));
        }
        for (int allele = 0; allele < numAlleles; allele++) {
            // if prob of non hom ref == 1 up to numerical precision, short-circuit to avoid NaN
            if (log10ProbabilityOfNonZeroAltAlleles[allele] >= 0) {
                log10POfZeroCountsByAllele[allele] = Double.NEGATIVE_INFINITY;
            } else {
                log10POfZeroCountsByAllele[allele] += MathUtils.log10OneMinusPow10(log10ProbabilityOfNonZeroAltAlleles[allele]);
            }
        }
    }
    // unfortunately AFCalculationResult expects integers for the MLE.  We really should emit the EM no-integer values
    // which are valuable (eg in CombineGVCFs) as the sufficient statistics of the Dirichlet posterior on allele frequencies
    final int[] integerAlleleCounts = Arrays.stream(alleleCounts).mapToInt(x -> (int) Math.round(x)).toArray();
    final int[] integerAltAlleleCounts = Arrays.copyOfRange(integerAlleleCounts, 1, numAlleles);
    //skip the ref allele (index 0)
    final Map<Allele, Double> log10PRefByAllele = IntStream.range(1, numAlleles).boxed().collect(Collectors.toMap(alleles::get, a -> log10POfZeroCountsByAllele[a]));
    // we compute posteriors here and don't have the same prior that AFCalculationResult expects.  Therefore, we
    // give it our posterior as its "likelihood" along with a flat dummy prior
    //TODO: HACK must be negative for AFCalcResult
    final double[] dummyFlatPrior = { -1e-10, -1e-10 };
    final double[] log10PosteriorOfNoVariantYesVariant = { log10PNoVariant, MathUtils.log10OneMinusPow10(log10PNoVariant) };
    return new AFCalculationResult(integerAltAlleleCounts, alleles, log10PosteriorOfNoVariantYesVariant, dummyFlatPrior, log10PRefByAllele);
}
Also used : Genotype(htsjdk.variant.variantcontext.Genotype) IntStream(java.util.stream.IntStream) Allele(htsjdk.variant.variantcontext.Allele) Arrays(java.util.Arrays) MathArrays(org.apache.commons.math3.util.MathArrays) Dirichlet(org.broadinstitute.hellbender.utils.Dirichlet) GenotypeAlleleCounts(org.broadinstitute.hellbender.tools.walkers.genotyper.GenotypeAlleleCounts) Collectors(java.util.stream.Collectors) GenotypeLikelihoodCalculator(org.broadinstitute.hellbender.tools.walkers.genotyper.GenotypeLikelihoodCalculator) GenotypeLikelihoodCalculators(org.broadinstitute.hellbender.tools.walkers.genotyper.GenotypeLikelihoodCalculators) List(java.util.List) MathUtils(org.broadinstitute.hellbender.utils.MathUtils) Map(java.util.Map) VariantContext(htsjdk.variant.variantcontext.VariantContext) Utils(org.broadinstitute.hellbender.utils.Utils) IndexRange(org.broadinstitute.hellbender.utils.IndexRange) Genotype(htsjdk.variant.variantcontext.Genotype) IndexRange(org.broadinstitute.hellbender.utils.IndexRange) Allele(htsjdk.variant.variantcontext.Allele) Dirichlet(org.broadinstitute.hellbender.utils.Dirichlet) GenotypeLikelihoodCalculator(org.broadinstitute.hellbender.tools.walkers.genotyper.GenotypeLikelihoodCalculator)

Aggregations

ArrayList (java.util.ArrayList)26 List (java.util.List)19 Collectors (java.util.stream.Collectors)13 DescriptiveStatistics (org.apache.commons.math3.stat.descriptive.DescriptiveStatistics)13 Arrays (java.util.Arrays)11 Map (java.util.Map)11 IntStream (java.util.stream.IntStream)10 RandomGenerator (org.apache.commons.math3.random.RandomGenerator)10 Array2DRowRealMatrix (org.apache.commons.math3.linear.Array2DRowRealMatrix)9 RealMatrix (org.apache.commons.math3.linear.RealMatrix)9 Plot2 (ij.gui.Plot2)8 File (java.io.File)8 IOException (java.io.IOException)8 TooManyEvaluationsException (org.apache.commons.math3.exception.TooManyEvaluationsException)7 Test (org.testng.annotations.Test)7 StoredDataStatistics (gdsc.core.utils.StoredDataStatistics)6 Collections (java.util.Collections)6 HashMap (java.util.HashMap)6 Random (java.util.Random)6 UnivariateFunction (org.apache.commons.math3.analysis.UnivariateFunction)6