use of org.apache.commons.math3.stat.descriptive.moment.Variance in project gatk-protected by broadinstitute.
the class HDF5PCACoveragePoNCreationUtils method calculateTargetVariances.
/**
* Determine the variance for each target in the PoN (panel targets).
*
* @return array of doubles where each double corresponds to a target in the PoN (panel targets)
*/
private static double[] calculateTargetVariances(final ReadCountCollection normalizedCounts, final List<String> panelTargetNames, final ReductionResult reduction, final JavaSparkContext ctx) {
Utils.nonNull(panelTargetNames);
Utils.nonNull(normalizedCounts);
Utils.nonNull(reduction);
final PCATangentNormalizationResult allNormals = PCATangentNormalizationUtils.tangentNormalizeNormalsInPoN(normalizedCounts, panelTargetNames, reduction.getReducedCounts(), reduction.getReducedPseudoInverse(), ctx);
final RealMatrix allSampleProjectedTargets = allNormals.getTangentNormalized().counts();
return MatrixSummaryUtils.getRowVariances(allSampleProjectedTargets);
}
use of org.apache.commons.math3.stat.descriptive.moment.Variance in project gatk by broadinstitute.
the class CoverageModelEMComputeBlock method cloneWithUpdatedTargetUnexplainedVarianceTargetResolved.
/**
* Performs the M-step for target-specific unexplained variance and clones the compute block
* with the updated value.
*
* @param maxIters maximum number of iterations
* @param psiUpperLimit upper limit for the unexplained variance
* @param absTol absolute error tolerance (used in root finding)
* @param relTol relative error tolerance (used in root finding)
* @param numBisections number of bisections (used in root finding)
* @param refinementDepth depth of search (used in root finding)
*
* @return a new instance of {@link CoverageModelEMComputeBlock}
*/
@QueriesICG
public CoverageModelEMComputeBlock cloneWithUpdatedTargetUnexplainedVarianceTargetResolved(final int maxIters, final double psiUpperLimit, final double absTol, final double relTol, final int numBisections, final int refinementDepth, final int numThreads) {
Utils.validateArg(maxIters > 0, "At least one iteration is required");
Utils.validateArg(psiUpperLimit >= 0, "The upper limit must be non-negative");
Utils.validateArg(absTol >= 0, "The absolute error tolerance must be non-negative");
Utils.validateArg(relTol >= 0, "The relative error tolerance must be non-negative");
Utils.validateArg(numBisections >= 0, "The number of bisections must be non-negative");
Utils.validateArg(refinementDepth >= 0, "The refinement depth must be non-negative");
Utils.validateArg(numThreads > 0, "Number of execution threads must be positive");
/* fetch the required caches */
final INDArray Psi_t = getINDArrayFromCache(CoverageModelICGCacheNode.Psi_t);
final INDArray M_st = getINDArrayFromCache(CoverageModelICGCacheNode.M_st);
final INDArray Sigma_st = getINDArrayFromCache(CoverageModelICGCacheNode.Sigma_st);
final INDArray gamma_s = getINDArrayFromCache(CoverageModelICGCacheNode.gamma_s);
final INDArray B_st = getINDArrayFromCache(CoverageModelICGCacheNode.B_st);
final ForkJoinPool forkJoinPool = new ForkJoinPool(numThreads);
final List<ImmutablePair<Double, Integer>> res;
try {
res = forkJoinPool.submit(() -> {
return IntStream.range(0, numTargets).parallel().mapToObj(ti -> {
final UnivariateFunction objFunc = psi -> calculateTargetSpecificVarianceSolverObjectiveFunction(ti, psi, M_st, Sigma_st, gamma_s, B_st);
final UnivariateFunction meritFunc = psi -> calculateTargetSpecificVarianceSolverMeritFunction(ti, psi, M_st, Sigma_st, gamma_s, B_st);
final RobustBrentSolver solver = new RobustBrentSolver(relTol, absTol, CoverageModelGlobalConstants.DEFAULT_FUNCTION_EVALUATION_ACCURACY, meritFunc, numBisections, refinementDepth);
double newPsi;
try {
newPsi = solver.solve(maxIters, objFunc, 0, psiUpperLimit);
} catch (NoBracketingException | TooManyEvaluationsException e) {
newPsi = Psi_t.getDouble(ti);
}
return new ImmutablePair<>(newPsi, solver.getEvaluations());
}).collect(Collectors.toList());
}).get();
} catch (InterruptedException | ExecutionException ex) {
throw new RuntimeException("Failure in concurrent update of target-specific unexplained variance");
}
final INDArray newPsi_t = Nd4j.create(res.stream().mapToDouble(p -> p.left).toArray(), Psi_t.shape());
final int maxIterations = Collections.max(res.stream().mapToInt(p -> p.right).boxed().collect(Collectors.toList()));
final double errNormInfinity = CoverageModelEMWorkspaceMathUtils.getINDArrayNormInfinity(newPsi_t.sub(Psi_t));
return cloneWithUpdatedPrimitiveAndSignal(CoverageModelICGCacheNode.Psi_t, newPsi_t, SubroutineSignal.builder().put(StandardSubroutineSignals.RESIDUAL_ERROR_NORM, errNormInfinity).put(StandardSubroutineSignals.ITERATIONS, maxIterations).build());
}
use of org.apache.commons.math3.stat.descriptive.moment.Variance in project gatk by broadinstitute.
the class GATKProtectedMathUtils method rowStdDevs.
/**
* Calculates the standard deviation per row from a matrix.
* @param matrix the input matrix.
* @return never {@code null}, an array with as many positions as rows in {@code matrix}.
* @throws IllegalArgumentException if {@code matrix} is {@code null}.
*/
public static double[] rowStdDevs(final RealMatrix matrix) {
Utils.nonNull(matrix);
final Variance varianceEvaluator = new Variance();
return IntStream.range(0, matrix.getRowDimension()).mapToDouble(r -> Math.sqrt(varianceEvaluator.evaluate(matrix.getRow(r)))).toArray();
}
use of org.apache.commons.math3.stat.descriptive.moment.Variance in project gatk-protected by broadinstitute.
the class SliceSamplerUnitTest method testSliceSamplingOfMonotonicBetaDistribution.
/**
* Test slice sampling of a monotonic beta distribution as an example of sampling of a bounded random variable.
* Checks that input mean and variance are recovered by 10000 samples to a relative error of 0.5% and 2%,
* respectively.
*/
@Test
public void testSliceSamplingOfMonotonicBetaDistribution() {
rng.setSeed(RANDOM_SEED);
final double alpha = 10.;
final double beta = 1.;
final BetaDistribution betaDistribution = new BetaDistribution(alpha, beta);
final Function<Double, Double> betaLogPDF = betaDistribution::logDensity;
final double xInitial = 0.5;
final double xMin = 0.;
final double xMax = 1.;
final double width = 0.1;
final int numSamples = 10000;
final SliceSampler betaSampler = new SliceSampler(rng, betaLogPDF, xMin, xMax, width);
final double[] samples = Doubles.toArray(betaSampler.sample(xInitial, numSamples));
final double mean = betaDistribution.getNumericalMean();
final double variance = betaDistribution.getNumericalVariance();
final double sampleMean = new Mean().evaluate(samples);
final double sampleVariance = new Variance().evaluate(samples);
Assert.assertEquals(relativeError(sampleMean, mean), 0., 0.005);
Assert.assertEquals(relativeError(sampleVariance, variance), 0., 0.02);
}
use of org.apache.commons.math3.stat.descriptive.moment.Variance in project gatk-protected by broadinstitute.
the class SliceSamplerUnitTest method testSliceSamplingOfPeakedBetaDistribution.
/**
* Test slice sampling of a peaked beta distribution as an example of sampling of a bounded random variable.
* Checks that input mean and variance are recovered by 10000 samples to a relative error of 0.5% and 2%,
* respectively.
*/
@Test
public void testSliceSamplingOfPeakedBetaDistribution() {
rng.setSeed(RANDOM_SEED);
final double alpha = 10.;
final double beta = 4.;
final BetaDistribution betaDistribution = new BetaDistribution(alpha, beta);
final Function<Double, Double> betaLogPDF = betaDistribution::logDensity;
final double xInitial = 0.5;
final double xMin = 0.;
final double xMax = 1.;
final double width = 0.1;
final int numSamples = 10000;
final SliceSampler betaSampler = new SliceSampler(rng, betaLogPDF, xMin, xMax, width);
final double[] samples = Doubles.toArray(betaSampler.sample(xInitial, numSamples));
final double mean = betaDistribution.getNumericalMean();
final double variance = betaDistribution.getNumericalVariance();
final double sampleMean = new Mean().evaluate(samples);
final double sampleVariance = new Variance().evaluate(samples);
Assert.assertEquals(relativeError(sampleMean, mean), 0., 0.005);
Assert.assertEquals(relativeError(sampleVariance, variance), 0., 0.02);
}
Aggregations