use of org.apache.commons.math3.stat.descriptive.moment.Variance in project gatk-protected by broadinstitute.
the class CoverageModelEMWorkspace method updateTargetUnexplainedVarianceIsotropic.
/**
* M-step update of unexplained variance in the isotropic mode
*
* @return a {@link SubroutineSignal} object containing "error_norm" and "iterations" fields
*/
@UpdatesRDD
@CachesRDD
private SubroutineSignal updateTargetUnexplainedVarianceIsotropic() {
mapWorkers(cb -> cb.cloneWithUpdatedCachesByTag(CoverageModelEMComputeBlock.CoverageModelICGCacheTag.M_STEP_PSI));
cacheWorkers("after M-step update of isotropic unexplained variance initialization");
final double oldIsotropicTargetSpecificVariance = fetchFromWorkers(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.Psi_t, 1).meanNumber().doubleValue();
final UnivariateFunction objFunc = psi -> mapWorkersAndReduce(cb -> cb.calculateSampleTargetSummedTargetSpecificVarianceObjectiveFunction(psi), (a, b) -> a + b);
final UnivariateFunction meritFunc = psi -> mapWorkersAndReduce(cb -> cb.calculateSampleTargetSummedTargetSpecificVarianceMeritFunction(psi), (a, b) -> a + b);
final RobustBrentSolver solver = new RobustBrentSolver(config.getTargetSpecificVarianceRelativeTolerance(), config.getTargetSpecificVarianceAbsoluteTolerance(), CoverageModelGlobalConstants.DEFAULT_FUNCTION_EVALUATION_ACCURACY, meritFunc, config.getTargetSpecificVarianceSolverNumBisections(), config.getTargetSpecificVarianceSolverRefinementDepth());
double newIsotropicTargetSpecificVariance;
try {
newIsotropicTargetSpecificVariance = solver.solve(config.getTargetSpecificVarianceMaxIterations(), objFunc, 0, config.getTargetSpecificVarianceUpperLimit());
} catch (NoBracketingException e) {
logger.warn("Root of M-step optimality equation for isotropic unexplained variance could be bracketed");
newIsotropicTargetSpecificVariance = oldIsotropicTargetSpecificVariance;
} catch (TooManyEvaluationsException e) {
logger.warn("Too many evaluations -- increase the number of root-finding iterations for the M-step update" + " of unexplained variance");
newIsotropicTargetSpecificVariance = oldIsotropicTargetSpecificVariance;
}
/* update the compute block(s) */
final double errNormInfinity = FastMath.abs(newIsotropicTargetSpecificVariance - oldIsotropicTargetSpecificVariance);
final int maxIterations = solver.getEvaluations();
final double finalizedNewIsotropicTargetSpecificVariance = newIsotropicTargetSpecificVariance;
mapWorkers(cb -> cb.cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.Psi_t, Nd4j.zeros(1, cb.getTargetSpaceBlock().getNumElements()).addi(finalizedNewIsotropicTargetSpecificVariance)));
return SubroutineSignal.builder().put(StandardSubroutineSignals.RESIDUAL_ERROR_NORM, errNormInfinity).put(StandardSubroutineSignals.ITERATIONS, maxIterations).build();
}
use of org.apache.commons.math3.stat.descriptive.moment.Variance in project gatk-protected by broadinstitute.
the class CoverageModelEMWorkspace method updateSampleUnexplainedVariance.
/**
* E-step update of the sample-specific unexplained variance
*
* @return a {@link SubroutineSignal} containing the update size (key: "error_norm") and the average
* number of function evaluations per sample (key: "iterations")
*/
@EvaluatesRDD
@UpdatesRDD
@CachesRDD
public SubroutineSignal updateSampleUnexplainedVariance() {
mapWorkers(cb -> cb.cloneWithUpdatedCachesByTag(CoverageModelEMComputeBlock.CoverageModelICGCacheTag.E_STEP_GAMMA));
cacheWorkers("after E-step for sample unexplained variance initialization");
/* create a compound objective function for simultaneous multi-sample queries */
final java.util.function.Function<Map<Integer, Double>, Map<Integer, Double>> objFunc = arg -> {
if (arg.isEmpty()) {
return Collections.emptyMap();
}
final int[] sampleIndices = arg.keySet().stream().mapToInt(i -> i).toArray();
final INDArray gammaValues = Nd4j.create(Arrays.stream(sampleIndices).mapToDouble(arg::get).toArray(), new int[] { sampleIndices.length, 1 });
final INDArray eval = mapWorkersAndReduce(cb -> cb.calculateSampleSpecificVarianceObjectiveFunctionMultiSample(sampleIndices, gammaValues), INDArray::add);
final Map<Integer, Double> output = new HashMap<>();
IntStream.range(0, sampleIndices.length).forEach(evalIdx -> output.put(sampleIndices[evalIdx], eval.getDouble(evalIdx)));
return output;
};
final java.util.function.Function<UnivariateSolverSpecifications, AbstractUnivariateSolver> solverFactory = spec -> new RobustBrentSolver(spec.getRelativeAccuracy(), spec.getAbsoluteAccuracy(), spec.getFunctionValueAccuracy(), null, config.getSampleSpecificVarianceSolverNumBisections(), config.getSampleSpecificVarianceSolverRefinementDepth());
/* instantiate a synchronized multi-sample root finder and add jobs */
final SynchronizedUnivariateSolver syncSolver = new SynchronizedUnivariateSolver(objFunc, solverFactory, numSamples);
IntStream.range(0, numSamples).forEach(si -> {
final double x0 = 0.5 * config.getSampleSpecificVarianceUpperLimit();
syncSolver.add(si, 0, config.getSampleSpecificVarianceUpperLimit(), x0, config.getSampleSpecificVarianceAbsoluteTolerance(), config.getSampleSpecificVarianceRelativeTolerance(), config.getSampleSpecificVarianceMaximumIterations());
});
/* solve and collect statistics */
final INDArray newSampleUnexplainedVariance = Nd4j.create(numSamples, 1);
final List<Integer> numberOfEvaluations = new ArrayList<>(numSamples);
try {
final Map<Integer, SynchronizedUnivariateSolver.UnivariateSolverSummary> newSampleSpecificVarianceMap = syncSolver.solve();
newSampleSpecificVarianceMap.entrySet().forEach(entry -> {
final int sampleIndex = entry.getKey();
final SynchronizedUnivariateSolver.UnivariateSolverSummary summary = entry.getValue();
double val = 0;
switch(summary.status) {
case SUCCESS:
val = summary.x;
break;
case TOO_MANY_EVALUATIONS:
logger.warn("Could not locate the root of gamma -- increase the maximum number of" + "function evaluations");
break;
}
newSampleUnexplainedVariance.put(sampleIndex, 0, val);
numberOfEvaluations.add(summary.evaluations);
});
} catch (final InterruptedException ex) {
throw new RuntimeException("The update of sample unexplained variance was interrupted -- can not continue");
}
/* admix */
final INDArray newSampleUnexplainedVarianceAdmixed = newSampleUnexplainedVariance.mul(config.getMeanFieldAdmixingRatio()).addi(sampleUnexplainedVariance.mul(1 - config.getMeanFieldAdmixingRatio()));
/* calculate the error */
final double errorNormInfinity = CoverageModelEMWorkspaceMathUtils.getINDArrayNormInfinity(newSampleUnexplainedVarianceAdmixed.sub(sampleUnexplainedVariance));
/* update local copy */
sampleUnexplainedVariance.assign(newSampleUnexplainedVarianceAdmixed);
/* push to workers */
pushToWorkers(newSampleUnexplainedVarianceAdmixed, (arr, cb) -> cb.cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.gamma_s, newSampleUnexplainedVarianceAdmixed));
final int iterations = (int) (numberOfEvaluations.stream().mapToDouble(d -> d).sum() / numSamples);
return SubroutineSignal.builder().put(StandardSubroutineSignals.RESIDUAL_ERROR_NORM, errorNormInfinity).put(StandardSubroutineSignals.ITERATIONS, iterations).build();
}
use of org.apache.commons.math3.stat.descriptive.moment.Variance in project gatk-protected by broadinstitute.
the class CoverageModelEMComputeBlock method cloneWithUpdatedTargetUnexplainedVarianceTargetResolved.
/**
* Performs the M-step for target-specific unexplained variance and clones the compute block
* with the updated value.
*
* @param maxIters maximum number of iterations
* @param psiUpperLimit upper limit for the unexplained variance
* @param absTol absolute error tolerance (used in root finding)
* @param relTol relative error tolerance (used in root finding)
* @param numBisections number of bisections (used in root finding)
* @param refinementDepth depth of search (used in root finding)
*
* @return a new instance of {@link CoverageModelEMComputeBlock}
*/
@QueriesICG
public CoverageModelEMComputeBlock cloneWithUpdatedTargetUnexplainedVarianceTargetResolved(final int maxIters, final double psiUpperLimit, final double absTol, final double relTol, final int numBisections, final int refinementDepth, final int numThreads) {
Utils.validateArg(maxIters > 0, "At least one iteration is required");
Utils.validateArg(psiUpperLimit >= 0, "The upper limit must be non-negative");
Utils.validateArg(absTol >= 0, "The absolute error tolerance must be non-negative");
Utils.validateArg(relTol >= 0, "The relative error tolerance must be non-negative");
Utils.validateArg(numBisections >= 0, "The number of bisections must be non-negative");
Utils.validateArg(refinementDepth >= 0, "The refinement depth must be non-negative");
Utils.validateArg(numThreads > 0, "Number of execution threads must be positive");
/* fetch the required caches */
final INDArray Psi_t = getINDArrayFromCache(CoverageModelICGCacheNode.Psi_t);
final INDArray M_st = getINDArrayFromCache(CoverageModelICGCacheNode.M_st);
final INDArray Sigma_st = getINDArrayFromCache(CoverageModelICGCacheNode.Sigma_st);
final INDArray gamma_s = getINDArrayFromCache(CoverageModelICGCacheNode.gamma_s);
final INDArray B_st = getINDArrayFromCache(CoverageModelICGCacheNode.B_st);
final ForkJoinPool forkJoinPool = new ForkJoinPool(numThreads);
final List<ImmutablePair<Double, Integer>> res;
try {
res = forkJoinPool.submit(() -> {
return IntStream.range(0, numTargets).parallel().mapToObj(ti -> {
final UnivariateFunction objFunc = psi -> calculateTargetSpecificVarianceSolverObjectiveFunction(ti, psi, M_st, Sigma_st, gamma_s, B_st);
final UnivariateFunction meritFunc = psi -> calculateTargetSpecificVarianceSolverMeritFunction(ti, psi, M_st, Sigma_st, gamma_s, B_st);
final RobustBrentSolver solver = new RobustBrentSolver(relTol, absTol, CoverageModelGlobalConstants.DEFAULT_FUNCTION_EVALUATION_ACCURACY, meritFunc, numBisections, refinementDepth);
double newPsi;
try {
newPsi = solver.solve(maxIters, objFunc, 0, psiUpperLimit);
} catch (NoBracketingException | TooManyEvaluationsException e) {
newPsi = Psi_t.getDouble(ti);
}
return new ImmutablePair<>(newPsi, solver.getEvaluations());
}).collect(Collectors.toList());
}).get();
} catch (InterruptedException | ExecutionException ex) {
throw new RuntimeException("Failure in concurrent update of target-specific unexplained variance");
}
final INDArray newPsi_t = Nd4j.create(res.stream().mapToDouble(p -> p.left).toArray(), Psi_t.shape());
final int maxIterations = Collections.max(res.stream().mapToInt(p -> p.right).boxed().collect(Collectors.toList()));
final double errNormInfinity = CoverageModelEMWorkspaceMathUtils.getINDArrayNormInfinity(newPsi_t.sub(Psi_t));
return cloneWithUpdatedPrimitiveAndSignal(CoverageModelICGCacheNode.Psi_t, newPsi_t, SubroutineSignal.builder().put(StandardSubroutineSignals.RESIDUAL_ERROR_NORM, errNormInfinity).put(StandardSubroutineSignals.ITERATIONS, maxIterations).build());
}
use of org.apache.commons.math3.stat.descriptive.moment.Variance in project gatk-protected by broadinstitute.
the class CoverageModelCopyRatioEmissionProbabilityCalculator method logLikelihoodPoisson.
/**
* Calculate emission log probability directly using Poisson distribution
*
* TODO github/gatk-protected issue #854
*
* @implNote This is a naive implementation where the variance of log bias ($\Psi$) is
* ignored altogether. This routine must be improved by performing a few-point numerical
* integration of:
*
* \int_{-\infty}^{+\infty} db Poisson(n | \lambda = d*c*exp(b) + eps*d)
* * Normal(b | \mu, \Psi)
*
* @param emissionData an instance of {@link CoverageModelCopyRatioEmissionData}
* @param copyRatio copy ratio on which the emission probability is conditioned on
* @return a double value
*/
private double logLikelihoodPoisson(@Nonnull CoverageModelCopyRatioEmissionData emissionData, double copyRatio) {
final double multBias = FastMath.exp(emissionData.getMu());
final double readDepth = emissionData.getCopyRatioCallingMetadata().getSampleCoverageDepth();
final double err = emissionData.getMappingErrorProbability();
final double poissonMean = readDepth * ((1 - err) * copyRatio * multBias + err);
return new PoissonDistribution(null, poissonMean, PoissonDistribution.DEFAULT_EPSILON, PoissonDistribution.DEFAULT_MAX_ITERATIONS).logProbability(emissionData.getReadCount());
}
use of org.apache.commons.math3.stat.descriptive.moment.Variance in project gatk-protected by broadinstitute.
the class GATKProtectedMathUtils method columnStdDevs.
/**
* Calculates the standard deviation per column from a matrix.
* @param matrix the input matrix.
* @return never {@code null}, an array with as many positions as columns in {@code matrix}.
* @throws IllegalArgumentException if {@code matrix} is {@code null}.
*/
public static double[] columnStdDevs(final RealMatrix matrix) {
Utils.nonNull(matrix);
final Variance varianceEvaluator = new Variance();
return IntStream.range(0, matrix.getColumnDimension()).mapToDouble(c -> Math.sqrt(varianceEvaluator.evaluate(matrix.getColumn(c)))).toArray();
}
Aggregations