use of org.apache.commons.math3.analysis.UnivariateFunction in project gatk-protected by broadinstitute.
the class CoverageModelEMComputeBlock method cloneWithUpdatedTargetUnexplainedVarianceTargetResolved.
/**
* Performs the M-step for target-specific unexplained variance and clones the compute block
* with the updated value.
*
* @param maxIters maximum number of iterations
* @param psiUpperLimit upper limit for the unexplained variance
* @param absTol absolute error tolerance (used in root finding)
* @param relTol relative error tolerance (used in root finding)
* @param numBisections number of bisections (used in root finding)
* @param refinementDepth depth of search (used in root finding)
*
* @return a new instance of {@link CoverageModelEMComputeBlock}
*/
@QueriesICG
public CoverageModelEMComputeBlock cloneWithUpdatedTargetUnexplainedVarianceTargetResolved(final int maxIters, final double psiUpperLimit, final double absTol, final double relTol, final int numBisections, final int refinementDepth, final int numThreads) {
Utils.validateArg(maxIters > 0, "At least one iteration is required");
Utils.validateArg(psiUpperLimit >= 0, "The upper limit must be non-negative");
Utils.validateArg(absTol >= 0, "The absolute error tolerance must be non-negative");
Utils.validateArg(relTol >= 0, "The relative error tolerance must be non-negative");
Utils.validateArg(numBisections >= 0, "The number of bisections must be non-negative");
Utils.validateArg(refinementDepth >= 0, "The refinement depth must be non-negative");
Utils.validateArg(numThreads > 0, "Number of execution threads must be positive");
/* fetch the required caches */
final INDArray Psi_t = getINDArrayFromCache(CoverageModelICGCacheNode.Psi_t);
final INDArray M_st = getINDArrayFromCache(CoverageModelICGCacheNode.M_st);
final INDArray Sigma_st = getINDArrayFromCache(CoverageModelICGCacheNode.Sigma_st);
final INDArray gamma_s = getINDArrayFromCache(CoverageModelICGCacheNode.gamma_s);
final INDArray B_st = getINDArrayFromCache(CoverageModelICGCacheNode.B_st);
final ForkJoinPool forkJoinPool = new ForkJoinPool(numThreads);
final List<ImmutablePair<Double, Integer>> res;
try {
res = forkJoinPool.submit(() -> {
return IntStream.range(0, numTargets).parallel().mapToObj(ti -> {
final UnivariateFunction objFunc = psi -> calculateTargetSpecificVarianceSolverObjectiveFunction(ti, psi, M_st, Sigma_st, gamma_s, B_st);
final UnivariateFunction meritFunc = psi -> calculateTargetSpecificVarianceSolverMeritFunction(ti, psi, M_st, Sigma_st, gamma_s, B_st);
final RobustBrentSolver solver = new RobustBrentSolver(relTol, absTol, CoverageModelGlobalConstants.DEFAULT_FUNCTION_EVALUATION_ACCURACY, meritFunc, numBisections, refinementDepth);
double newPsi;
try {
newPsi = solver.solve(maxIters, objFunc, 0, psiUpperLimit);
} catch (NoBracketingException | TooManyEvaluationsException e) {
newPsi = Psi_t.getDouble(ti);
}
return new ImmutablePair<>(newPsi, solver.getEvaluations());
}).collect(Collectors.toList());
}).get();
} catch (InterruptedException | ExecutionException ex) {
throw new RuntimeException("Failure in concurrent update of target-specific unexplained variance");
}
final INDArray newPsi_t = Nd4j.create(res.stream().mapToDouble(p -> p.left).toArray(), Psi_t.shape());
final int maxIterations = Collections.max(res.stream().mapToInt(p -> p.right).boxed().collect(Collectors.toList()));
final double errNormInfinity = CoverageModelEMWorkspaceMathUtils.getINDArrayNormInfinity(newPsi_t.sub(Psi_t));
return cloneWithUpdatedPrimitiveAndSignal(CoverageModelICGCacheNode.Psi_t, newPsi_t, SubroutineSignal.builder().put(StandardSubroutineSignals.RESIDUAL_ERROR_NORM, errNormInfinity).put(StandardSubroutineSignals.ITERATIONS, maxIterations).build());
}
use of org.apache.commons.math3.analysis.UnivariateFunction in project gatk by broadinstitute.
the class CoverageModelEMWorkspace method updateTargetUnexplainedVarianceIsotropic.
/**
* M-step update of unexplained variance in the isotropic mode
*
* @return a {@link SubroutineSignal} object containing "error_norm" and "iterations" fields
*/
@UpdatesRDD
@CachesRDD
private SubroutineSignal updateTargetUnexplainedVarianceIsotropic() {
mapWorkers(cb -> cb.cloneWithUpdatedCachesByTag(CoverageModelEMComputeBlock.CoverageModelICGCacheTag.M_STEP_PSI));
cacheWorkers("after M-step update of isotropic unexplained variance initialization");
final double oldIsotropicTargetSpecificVariance = fetchFromWorkers(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.Psi_t, 1).meanNumber().doubleValue();
final UnivariateFunction objFunc = psi -> mapWorkersAndReduce(cb -> cb.calculateSampleTargetSummedTargetSpecificVarianceObjectiveFunction(psi), (a, b) -> a + b);
final UnivariateFunction meritFunc = psi -> mapWorkersAndReduce(cb -> cb.calculateSampleTargetSummedTargetSpecificVarianceMeritFunction(psi), (a, b) -> a + b);
final RobustBrentSolver solver = new RobustBrentSolver(config.getTargetSpecificVarianceRelativeTolerance(), config.getTargetSpecificVarianceAbsoluteTolerance(), CoverageModelGlobalConstants.DEFAULT_FUNCTION_EVALUATION_ACCURACY, meritFunc, config.getTargetSpecificVarianceSolverNumBisections(), config.getTargetSpecificVarianceSolverRefinementDepth());
double newIsotropicTargetSpecificVariance;
try {
newIsotropicTargetSpecificVariance = solver.solve(config.getTargetSpecificVarianceMaxIterations(), objFunc, 0, config.getTargetSpecificVarianceUpperLimit());
} catch (NoBracketingException e) {
logger.warn("Root of M-step optimality equation for isotropic unexplained variance could be bracketed");
newIsotropicTargetSpecificVariance = oldIsotropicTargetSpecificVariance;
} catch (TooManyEvaluationsException e) {
logger.warn("Too many evaluations -- increase the number of root-finding iterations for the M-step update" + " of unexplained variance");
newIsotropicTargetSpecificVariance = oldIsotropicTargetSpecificVariance;
}
/* update the compute block(s) */
final double errNormInfinity = FastMath.abs(newIsotropicTargetSpecificVariance - oldIsotropicTargetSpecificVariance);
final int maxIterations = solver.getEvaluations();
final double finalizedNewIsotropicTargetSpecificVariance = newIsotropicTargetSpecificVariance;
mapWorkers(cb -> cb.cloneWithUpdatedPrimitive(CoverageModelEMComputeBlock.CoverageModelICGCacheNode.Psi_t, Nd4j.zeros(1, cb.getTargetSpaceBlock().getNumElements()).addi(finalizedNewIsotropicTargetSpecificVariance)));
return SubroutineSignal.builder().put(StandardSubroutineSignals.RESIDUAL_ERROR_NORM, errNormInfinity).put(StandardSubroutineSignals.ITERATIONS, maxIterations).build();
}
use of org.apache.commons.math3.analysis.UnivariateFunction in project gatk-protected by broadinstitute.
the class RobustBrentSolverUnitTest method simpleTest.
/**
* Test on a 4th degree polynomial with 4 real roots at x = 0, 1, 2, 3. This objective function is positive for
* large enough positive and negative values of its arguments. Therefore, the simple Brent solver complains that
* the search interval does not bracket a root. The robust Brent solver, however, subdivides the given search
* interval and finds a bracketing sub-interval.
*
* The "best" root according to the given merit function (set to the anti-derivative of the objective function)
* is in fact the one at x = 0. We require the robust solver to output x = 0, and the simple solver to fail.
*/
@Test
public void simpleTest() {
final UnivariateFunction objFunc = x -> 30 * x * (x - 1) * (x - 2) * (x - 3);
final UnivariateFunction meritFunc = x -> 6 * FastMath.pow(x, 5) - 45 * FastMath.pow(x, 4) + 110 * FastMath.pow(x, 3) - 90 * FastMath.pow(x, 2);
final RobustBrentSolver solverRobust = new RobustBrentSolver(DEF_REL_ACC, DEF_REL_ACC, DEF_F_ACC, meritFunc, 4, 1);
final BrentSolver solverSimple = new BrentSolver(DEF_REL_ACC, DEF_REL_ACC, DEF_F_ACC);
final double xRobust = solverRobust.solve(100, objFunc, -1, 4);
Assert.assertEquals(xRobust, 0, DEF_ABS_ACC);
boolean simpleSolverFails = false;
try {
/* this will fail */
solverSimple.solve(100, objFunc, -1, 4);
} catch (final NoBracketingException ex) {
simpleSolverFails = true;
}
Assert.assertTrue(simpleSolverFails);
}
use of org.apache.commons.math3.analysis.UnivariateFunction in project gatk by broadinstitute.
the class CoverageModelEMComputeBlock method cloneWithUpdatedTargetUnexplainedVarianceTargetResolved.
/**
* Performs the M-step for target-specific unexplained variance and clones the compute block
* with the updated value.
*
* @param maxIters maximum number of iterations
* @param psiUpperLimit upper limit for the unexplained variance
* @param absTol absolute error tolerance (used in root finding)
* @param relTol relative error tolerance (used in root finding)
* @param numBisections number of bisections (used in root finding)
* @param refinementDepth depth of search (used in root finding)
*
* @return a new instance of {@link CoverageModelEMComputeBlock}
*/
@QueriesICG
public CoverageModelEMComputeBlock cloneWithUpdatedTargetUnexplainedVarianceTargetResolved(final int maxIters, final double psiUpperLimit, final double absTol, final double relTol, final int numBisections, final int refinementDepth, final int numThreads) {
Utils.validateArg(maxIters > 0, "At least one iteration is required");
Utils.validateArg(psiUpperLimit >= 0, "The upper limit must be non-negative");
Utils.validateArg(absTol >= 0, "The absolute error tolerance must be non-negative");
Utils.validateArg(relTol >= 0, "The relative error tolerance must be non-negative");
Utils.validateArg(numBisections >= 0, "The number of bisections must be non-negative");
Utils.validateArg(refinementDepth >= 0, "The refinement depth must be non-negative");
Utils.validateArg(numThreads > 0, "Number of execution threads must be positive");
/* fetch the required caches */
final INDArray Psi_t = getINDArrayFromCache(CoverageModelICGCacheNode.Psi_t);
final INDArray M_st = getINDArrayFromCache(CoverageModelICGCacheNode.M_st);
final INDArray Sigma_st = getINDArrayFromCache(CoverageModelICGCacheNode.Sigma_st);
final INDArray gamma_s = getINDArrayFromCache(CoverageModelICGCacheNode.gamma_s);
final INDArray B_st = getINDArrayFromCache(CoverageModelICGCacheNode.B_st);
final ForkJoinPool forkJoinPool = new ForkJoinPool(numThreads);
final List<ImmutablePair<Double, Integer>> res;
try {
res = forkJoinPool.submit(() -> {
return IntStream.range(0, numTargets).parallel().mapToObj(ti -> {
final UnivariateFunction objFunc = psi -> calculateTargetSpecificVarianceSolverObjectiveFunction(ti, psi, M_st, Sigma_st, gamma_s, B_st);
final UnivariateFunction meritFunc = psi -> calculateTargetSpecificVarianceSolverMeritFunction(ti, psi, M_st, Sigma_st, gamma_s, B_st);
final RobustBrentSolver solver = new RobustBrentSolver(relTol, absTol, CoverageModelGlobalConstants.DEFAULT_FUNCTION_EVALUATION_ACCURACY, meritFunc, numBisections, refinementDepth);
double newPsi;
try {
newPsi = solver.solve(maxIters, objFunc, 0, psiUpperLimit);
} catch (NoBracketingException | TooManyEvaluationsException e) {
newPsi = Psi_t.getDouble(ti);
}
return new ImmutablePair<>(newPsi, solver.getEvaluations());
}).collect(Collectors.toList());
}).get();
} catch (InterruptedException | ExecutionException ex) {
throw new RuntimeException("Failure in concurrent update of target-specific unexplained variance");
}
final INDArray newPsi_t = Nd4j.create(res.stream().mapToDouble(p -> p.left).toArray(), Psi_t.shape());
final int maxIterations = Collections.max(res.stream().mapToInt(p -> p.right).boxed().collect(Collectors.toList()));
final double errNormInfinity = CoverageModelEMWorkspaceMathUtils.getINDArrayNormInfinity(newPsi_t.sub(Psi_t));
return cloneWithUpdatedPrimitiveAndSignal(CoverageModelICGCacheNode.Psi_t, newPsi_t, SubroutineSignal.builder().put(StandardSubroutineSignals.RESIDUAL_ERROR_NORM, errNormInfinity).put(StandardSubroutineSignals.ITERATIONS, maxIterations).build());
}
use of org.apache.commons.math3.analysis.UnivariateFunction in project GDSC-SMLM by aherbert.
the class PeakResult method computeI1.
/**
* Compute the function I1 using numerical integration. See Mortensen, et al (2010) Nature Methods 7, 377-383), SI
* equation 43.
*
* <pre>
* I1 = 1 + sum [ ln(t) / (1 + t/rho) ] dt
* = - sum [ t * ln(t) / (t + rho) ] dt
* </pre>
*
* Where sum is the integral between 0 and 1. In the case of rho=0 the function returns 1;
*
* @param rho
* @param integrationPoints
* the number of integration points for the LegendreGaussIntegrator
* @return the I1 value
*/
private static double computeI1(final double rho, int integrationPoints) {
if (rho == 0)
return 1;
final double relativeAccuracy = 1e-4;
final double absoluteAccuracy = 1e-8;
final int minimalIterationCount = 3;
final int maximalIterationCount = 32;
// Use an integrator that does not use the boundary since log(0) is undefined.
UnivariateIntegrator i = new IterativeLegendreGaussIntegrator(integrationPoints, relativeAccuracy, absoluteAccuracy, minimalIterationCount, maximalIterationCount);
// Specify the function to integrate
UnivariateFunction f = new UnivariateFunction() {
public double value(double x) {
return x * Math.log(x) / (x + rho);
}
};
final double i1 = -i.integrate(2000, f, 0, 1);
return i1;
}
Aggregations