Search in sources :

Example 26 with Variance

use of org.apache.commons.math3.stat.descriptive.moment.Variance in project gatk-protected by broadinstitute.

the class HDF5PCACoveragePoNCreationUtils method calculateTargetVariances.

/**
     * Determine the variance for each target in the PoN (panel targets).
     *
     * @return      array of doubles where each double corresponds to a target in the PoN (panel targets)
     */
private static double[] calculateTargetVariances(final ReadCountCollection normalizedCounts, final List<String> panelTargetNames, final ReductionResult reduction, final JavaSparkContext ctx) {
    Utils.nonNull(panelTargetNames);
    Utils.nonNull(normalizedCounts);
    Utils.nonNull(reduction);
    final PCATangentNormalizationResult allNormals = PCATangentNormalizationUtils.tangentNormalizeNormalsInPoN(normalizedCounts, panelTargetNames, reduction.getReducedCounts(), reduction.getReducedPseudoInverse(), ctx);
    final RealMatrix allSampleProjectedTargets = allNormals.getTangentNormalized().counts();
    return MatrixSummaryUtils.getRowVariances(allSampleProjectedTargets);
}
Also used : RealMatrix(org.apache.commons.math3.linear.RealMatrix)

Example 27 with Variance

use of org.apache.commons.math3.stat.descriptive.moment.Variance in project gatk by broadinstitute.

the class CoverageModelEMComputeBlock method cloneWithUpdatedTargetUnexplainedVarianceTargetResolved.

/**
     * Performs the M-step for target-specific unexplained variance and clones the compute block
     * with the updated value.
     *
     * @param maxIters maximum number of iterations
     * @param psiUpperLimit upper limit for the unexplained variance
     * @param absTol absolute error tolerance (used in root finding)
     * @param relTol relative error tolerance (used in root finding)
     * @param numBisections number of bisections (used in root finding)
     * @param refinementDepth depth of search (used in root finding)
     *
     * @return a new instance of {@link CoverageModelEMComputeBlock}
     */
@QueriesICG
public CoverageModelEMComputeBlock cloneWithUpdatedTargetUnexplainedVarianceTargetResolved(final int maxIters, final double psiUpperLimit, final double absTol, final double relTol, final int numBisections, final int refinementDepth, final int numThreads) {
    Utils.validateArg(maxIters > 0, "At least one iteration is required");
    Utils.validateArg(psiUpperLimit >= 0, "The upper limit must be non-negative");
    Utils.validateArg(absTol >= 0, "The absolute error tolerance must be non-negative");
    Utils.validateArg(relTol >= 0, "The relative error tolerance must be non-negative");
    Utils.validateArg(numBisections >= 0, "The number of bisections must be non-negative");
    Utils.validateArg(refinementDepth >= 0, "The refinement depth must be non-negative");
    Utils.validateArg(numThreads > 0, "Number of execution threads must be positive");
    /* fetch the required caches */
    final INDArray Psi_t = getINDArrayFromCache(CoverageModelICGCacheNode.Psi_t);
    final INDArray M_st = getINDArrayFromCache(CoverageModelICGCacheNode.M_st);
    final INDArray Sigma_st = getINDArrayFromCache(CoverageModelICGCacheNode.Sigma_st);
    final INDArray gamma_s = getINDArrayFromCache(CoverageModelICGCacheNode.gamma_s);
    final INDArray B_st = getINDArrayFromCache(CoverageModelICGCacheNode.B_st);
    final ForkJoinPool forkJoinPool = new ForkJoinPool(numThreads);
    final List<ImmutablePair<Double, Integer>> res;
    try {
        res = forkJoinPool.submit(() -> {
            return IntStream.range(0, numTargets).parallel().mapToObj(ti -> {
                final UnivariateFunction objFunc = psi -> calculateTargetSpecificVarianceSolverObjectiveFunction(ti, psi, M_st, Sigma_st, gamma_s, B_st);
                final UnivariateFunction meritFunc = psi -> calculateTargetSpecificVarianceSolverMeritFunction(ti, psi, M_st, Sigma_st, gamma_s, B_st);
                final RobustBrentSolver solver = new RobustBrentSolver(relTol, absTol, CoverageModelGlobalConstants.DEFAULT_FUNCTION_EVALUATION_ACCURACY, meritFunc, numBisections, refinementDepth);
                double newPsi;
                try {
                    newPsi = solver.solve(maxIters, objFunc, 0, psiUpperLimit);
                } catch (NoBracketingException | TooManyEvaluationsException e) {
                    newPsi = Psi_t.getDouble(ti);
                }
                return new ImmutablePair<>(newPsi, solver.getEvaluations());
            }).collect(Collectors.toList());
        }).get();
    } catch (InterruptedException | ExecutionException ex) {
        throw new RuntimeException("Failure in concurrent update of target-specific unexplained variance");
    }
    final INDArray newPsi_t = Nd4j.create(res.stream().mapToDouble(p -> p.left).toArray(), Psi_t.shape());
    final int maxIterations = Collections.max(res.stream().mapToInt(p -> p.right).boxed().collect(Collectors.toList()));
    final double errNormInfinity = CoverageModelEMWorkspaceMathUtils.getINDArrayNormInfinity(newPsi_t.sub(Psi_t));
    return cloneWithUpdatedPrimitiveAndSignal(CoverageModelICGCacheNode.Psi_t, newPsi_t, SubroutineSignal.builder().put(StandardSubroutineSignals.RESIDUAL_ERROR_NORM, errNormInfinity).put(StandardSubroutineSignals.ITERATIONS, maxIterations).build());
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) NDArrayIndex(org.nd4j.linalg.indexing.NDArrayIndex) ParamUtils(org.broadinstitute.hellbender.utils.param.ParamUtils) RobustBrentSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver) Map(java.util.Map) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) ImmutableTriple(org.apache.commons.lang3.tuple.ImmutableTriple) Nd4j(org.nd4j.linalg.factory.Nd4j) FastMath(org.apache.commons.math3.util.FastMath) Collectors(java.util.stream.Collectors) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) Serializable(java.io.Serializable) ExecutionException(java.util.concurrent.ExecutionException) List(java.util.List) org.broadinstitute.hellbender.tools.coveragemodel.cachemanager(org.broadinstitute.hellbender.tools.coveragemodel.cachemanager) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ForkJoinPool(java.util.concurrent.ForkJoinPool) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) Utils(org.broadinstitute.hellbender.utils.Utils) VisibleForTesting(com.google.common.annotations.VisibleForTesting) Transforms(org.nd4j.linalg.ops.transforms.Transforms) Collections(java.util.Collections) NoBracketingException(org.apache.commons.math3.exception.NoBracketingException) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) RobustBrentSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) ExecutionException(java.util.concurrent.ExecutionException) ForkJoinPool(java.util.concurrent.ForkJoinPool)

Example 28 with Variance

use of org.apache.commons.math3.stat.descriptive.moment.Variance in project gatk by broadinstitute.

the class GATKProtectedMathUtils method rowStdDevs.

/**
     * Calculates the standard deviation per row from a matrix.
     * @param matrix the input matrix.
     * @return never {@code null}, an array with as many positions as rows in {@code matrix}.
     * @throws IllegalArgumentException if {@code matrix} is {@code null}.
     */
public static double[] rowStdDevs(final RealMatrix matrix) {
    Utils.nonNull(matrix);
    final Variance varianceEvaluator = new Variance();
    return IntStream.range(0, matrix.getRowDimension()).mapToDouble(r -> Math.sqrt(varianceEvaluator.evaluate(matrix.getRow(r)))).toArray();
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) Pair(org.apache.commons.math3.util.Pair) MathArrays(org.apache.commons.math3.util.MathArrays) Collection(java.util.Collection) FastMath(org.apache.commons.math3.util.FastMath) EnumeratedDistribution(org.apache.commons.math3.distribution.EnumeratedDistribution) GATKException(org.broadinstitute.hellbender.exceptions.GATKException) FourierLinearOperator(org.broadinstitute.hellbender.tools.coveragemodel.linalg.FourierLinearOperator) Function(java.util.function.Function) Collectors(java.util.stream.Collectors) ParamUtils(org.broadinstitute.hellbender.utils.param.ParamUtils) Serializable(java.io.Serializable) List(java.util.List) RandomGenerator(org.apache.commons.math3.random.RandomGenerator) Median(org.apache.commons.math3.stat.descriptive.rank.Median) Variance(org.apache.commons.math3.stat.descriptive.moment.Variance) RealMatrix(org.apache.commons.math3.linear.RealMatrix) Nonnull(javax.annotation.Nonnull) Collections(java.util.Collections) Variance(org.apache.commons.math3.stat.descriptive.moment.Variance)

Example 29 with Variance

use of org.apache.commons.math3.stat.descriptive.moment.Variance in project gatk-protected by broadinstitute.

the class SliceSamplerUnitTest method testSliceSamplingOfMonotonicBetaDistribution.

/**
     * Test slice sampling of a monotonic beta distribution as an example of sampling of a bounded random variable.
     * Checks that input mean and variance are recovered by 10000 samples to a relative error of 0.5% and 2%,
     * respectively.
     */
@Test
public void testSliceSamplingOfMonotonicBetaDistribution() {
    rng.setSeed(RANDOM_SEED);
    final double alpha = 10.;
    final double beta = 1.;
    final BetaDistribution betaDistribution = new BetaDistribution(alpha, beta);
    final Function<Double, Double> betaLogPDF = betaDistribution::logDensity;
    final double xInitial = 0.5;
    final double xMin = 0.;
    final double xMax = 1.;
    final double width = 0.1;
    final int numSamples = 10000;
    final SliceSampler betaSampler = new SliceSampler(rng, betaLogPDF, xMin, xMax, width);
    final double[] samples = Doubles.toArray(betaSampler.sample(xInitial, numSamples));
    final double mean = betaDistribution.getNumericalMean();
    final double variance = betaDistribution.getNumericalVariance();
    final double sampleMean = new Mean().evaluate(samples);
    final double sampleVariance = new Variance().evaluate(samples);
    Assert.assertEquals(relativeError(sampleMean, mean), 0., 0.005);
    Assert.assertEquals(relativeError(sampleVariance, variance), 0., 0.02);
}
Also used : BetaDistribution(org.apache.commons.math3.distribution.BetaDistribution) Mean(org.apache.commons.math3.stat.descriptive.moment.Mean) Variance(org.apache.commons.math3.stat.descriptive.moment.Variance) Test(org.testng.annotations.Test)

Example 30 with Variance

use of org.apache.commons.math3.stat.descriptive.moment.Variance in project gatk-protected by broadinstitute.

the class SliceSamplerUnitTest method testSliceSamplingOfPeakedBetaDistribution.

/**
     * Test slice sampling of a peaked beta distribution as an example of sampling of a bounded random variable.
     * Checks that input mean and variance are recovered by 10000 samples to a relative error of 0.5% and 2%,
     * respectively.
     */
@Test
public void testSliceSamplingOfPeakedBetaDistribution() {
    rng.setSeed(RANDOM_SEED);
    final double alpha = 10.;
    final double beta = 4.;
    final BetaDistribution betaDistribution = new BetaDistribution(alpha, beta);
    final Function<Double, Double> betaLogPDF = betaDistribution::logDensity;
    final double xInitial = 0.5;
    final double xMin = 0.;
    final double xMax = 1.;
    final double width = 0.1;
    final int numSamples = 10000;
    final SliceSampler betaSampler = new SliceSampler(rng, betaLogPDF, xMin, xMax, width);
    final double[] samples = Doubles.toArray(betaSampler.sample(xInitial, numSamples));
    final double mean = betaDistribution.getNumericalMean();
    final double variance = betaDistribution.getNumericalVariance();
    final double sampleMean = new Mean().evaluate(samples);
    final double sampleVariance = new Variance().evaluate(samples);
    Assert.assertEquals(relativeError(sampleMean, mean), 0., 0.005);
    Assert.assertEquals(relativeError(sampleVariance, variance), 0., 0.02);
}
Also used : BetaDistribution(org.apache.commons.math3.distribution.BetaDistribution) Mean(org.apache.commons.math3.stat.descriptive.moment.Mean) Variance(org.apache.commons.math3.stat.descriptive.moment.Variance) Test(org.testng.annotations.Test)

Aggregations

Collectors (java.util.stream.Collectors)24 IntStream (java.util.stream.IntStream)24 ParamUtils (org.broadinstitute.hellbender.utils.param.ParamUtils)22 Nonnull (javax.annotation.Nonnull)20 RandomGenerator (org.apache.commons.math3.random.RandomGenerator)18 Variance (org.apache.commons.math3.stat.descriptive.moment.Variance)18 List (java.util.List)16 FastMath (org.apache.commons.math3.util.FastMath)16 Utils (org.broadinstitute.hellbender.utils.Utils)16 INDArray (org.nd4j.linalg.api.ndarray.INDArray)16 Function (java.util.function.Function)15 Arrays (java.util.Arrays)14 Nullable (javax.annotation.Nullable)14 ImmutablePair (org.apache.commons.lang3.tuple.ImmutablePair)14 RealMatrix (org.apache.commons.math3.linear.RealMatrix)14 Logger (org.apache.logging.log4j.Logger)14 GATKException (org.broadinstitute.hellbender.exceptions.GATKException)14 UserException (org.broadinstitute.hellbender.exceptions.UserException)14 Nd4j (org.nd4j.linalg.factory.Nd4j)14 NDArrayIndex (org.nd4j.linalg.indexing.NDArrayIndex)14