Search in sources :

Example 1 with UnivariateFunction

use of org.apache.commons.math3.analysis.UnivariateFunction in project GDSC-SMLM by aherbert.

the class JumpDistanceAnalysisTest method canIntegrateProbabilityToCumulativeWithSinglePopulation.

// Commented out as this test always passes
//@Test
public void canIntegrateProbabilityToCumulativeWithSinglePopulation() {
    JumpDistanceAnalysis jd = new JumpDistanceAnalysis();
    jd.setMinD(0);
    jd.setMinFraction(0);
    SimpsonIntegrator si = new SimpsonIntegrator(1e-3, 1e-8, 2, SimpsonIntegrator.SIMPSON_MAX_ITERATIONS_COUNT);
    for (double d : D) {
        final double[] params = new double[] { d };
        final JumpDistanceFunction fp = jd.new JumpDistanceFunction(null, d);
        JumpDistanceCumulFunction fc = jd.new JumpDistanceCumulFunction(null, null, d);
        double x = d / 8;
        UnivariateFunction func = new UnivariateFunction() {

            public double value(double x) {
                return fp.evaluate(x, params);
            }
        };
        for (int i = 1; i < 10; i++, x *= 2) {
            double e = fc.evaluate(x, params);
            // Integrate
            double o = si.integrate(10000, func, 0, x);
            //log("Integrate d=%.1f : x=%.1f, e=%f, o=%f, iter=%d, eval=%d\n", d, x, e, o, si.getIterations(),
            //		si.getEvaluations());
            Assert.assertEquals("Failed to integrate", e, o, e * 1e-2);
        }
    }
}
Also used : SimpsonIntegrator(org.apache.commons.math3.analysis.integration.SimpsonIntegrator) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) JumpDistanceCumulFunction(gdsc.smlm.fitting.JumpDistanceAnalysis.JumpDistanceCumulFunction) MixedJumpDistanceCumulFunction(gdsc.smlm.fitting.JumpDistanceAnalysis.MixedJumpDistanceCumulFunction) JumpDistanceFunction(gdsc.smlm.fitting.JumpDistanceAnalysis.JumpDistanceFunction) MixedJumpDistanceFunction(gdsc.smlm.fitting.JumpDistanceAnalysis.MixedJumpDistanceFunction)

Example 2 with UnivariateFunction

use of org.apache.commons.math3.analysis.UnivariateFunction in project GDSC-SMLM by aherbert.

the class PoissonGammaGaussianFunction method likelihood.

/**
	 * Compute the likelihood
	 * 
	 * @param o
	 *            The observed count
	 * @param e
	 *            The expected count
	 * @return The likelihood
	 */
public double likelihood(final double o, final double e) {
    // Use the same variables as the Python code
    final double cij = o;
    // convert to photons
    final double eta = alpha * e;
    if (sigma == 0) {
        // No convolution with a Gaussian. Simply evaluate for a Poisson-Gamma distribution.
        final double p;
        // Any observed count above zero
        if (cij > 0.0) {
            // The observed count converted to photons
            final double nij = alpha * cij;
            // The limit on eta * nij is therefore (709/2)^2 = 125670.25
            if (eta * nij > 10000) {
                // Approximate Bessel function i1(x) when using large x:
                // i1(x) ~ exp(x)/sqrt(2*pi*x)
                // However the entire equation is logged (creating transform),
                // evaluated then raised to e to prevent overflow error on 
                // large exp(x)
                final double transform = 0.5 * Math.log(alpha * eta / cij) - nij - eta + 2 * Math.sqrt(eta * nij) - Math.log(twoSqrtPi * Math.pow(eta * nij, 0.25));
                p = FastMath.exp(transform);
            } else {
                // Second part of equation 135
                p = Math.sqrt(alpha * eta / cij) * FastMath.exp(-nij - eta) * Bessel.I1(2 * Math.sqrt(eta * nij));
            }
        } else if (cij == 0.0) {
            p = FastMath.exp(-eta);
        } else {
            p = 0;
        }
        return (p > minimumProbability) ? p : minimumProbability;
    } else if (useApproximation) {
        return mortensenApproximation(cij, eta);
    } else {
        // This code is the full evaluation of equation 7 from the supplementary information  
        // of the paper Chao, et al (2013) Nature Methods 10, 335-338.
        // It is the full evaluation of a Poisson-Gamma-Gaussian convolution PMF. 
        // Read noise
        final double sk = sigma;
        // Gain
        final double g = 1.0 / alpha;
        // Observed pixel value
        final double z = o;
        // Expected number of photons
        final double vk = eta;
        // Compute the integral to infinity of:
        // exp( -((z-u)/(sqrt(2)*s)).^2 - u/g ) * sqrt(vk*u/g) .* besseli(1, 2 * sqrt(vk*u/g)) ./ u;
        // vk / g
        final double vk_g = vk * alpha;
        final double sqrt2sigma = Math.sqrt(2) * sk;
        // Specify the function to integrate
        UnivariateFunction f = new UnivariateFunction() {

            public double value(double u) {
                return eval(sqrt2sigma, z, vk_g, g, u);
            }
        };
        // Integrate to infinity is not necessary. The convolution of the function with the 
        // Gaussian should be adequately sampled using a nxSD around the maximum.
        // Find a bracket containing the maximum
        double lower, upper;
        double maxU = Math.max(1, o);
        double rLower = maxU;
        double rUpper = maxU + 1;
        double f1 = f.value(rLower);
        double f2 = f.value(rUpper);
        // Calculate the simple integral and the range
        double sum = f1 + f2;
        boolean searchUp = f2 > f1;
        if (searchUp) {
            while (f2 > f1) {
                f1 = f2;
                rUpper += 1;
                f2 = f.value(rUpper);
                sum += f2;
            }
            maxU = rUpper - 1;
        } else {
            // Ensure that u stays above zero
            while (f1 > f2 && rLower > 1) {
                f2 = f1;
                rLower -= 1;
                f1 = f.value(rLower);
                sum += f1;
            }
            maxU = (rLower > 1) ? rLower + 1 : rLower;
        }
        lower = Math.max(0, maxU - 5 * sk);
        upper = maxU + 5 * sk;
        if (useSimpleIntegration && lower > 0) {
            // remaining points in the range
            for (double u = rLower - 1; u >= lower; u -= 1) {
                sum += f.value(u);
            }
            for (double u = rUpper + 1; u <= upper; u += 1) {
                sum += f.value(u);
            }
        } else {
            // Use Legendre-Gauss integrator
            try {
                final double relativeAccuracy = 1e-4;
                final double absoluteAccuracy = 1e-8;
                final int minimalIterationCount = 3;
                final int maximalIterationCount = 32;
                final int integrationPoints = 16;
                // Use an integrator that does not use the boundary since u=0 is undefined (divide by zero)
                UnivariateIntegrator i = new IterativeLegendreGaussIntegrator(integrationPoints, relativeAccuracy, absoluteAccuracy, minimalIterationCount, maximalIterationCount);
                sum = i.integrate(2000, f, lower, upper);
            } catch (TooManyEvaluationsException ex) {
                return mortensenApproximation(cij, eta);
            }
        }
        // Compute the final probability
        //final double 
        f1 = z / sqrt2sigma;
        final double p = (FastMath.exp(-vk) / (sqrt2pi * sk)) * (FastMath.exp(-(f1 * f1)) + sum);
        return (p > minimumProbability) ? p : minimumProbability;
    }
}
Also used : IterativeLegendreGaussIntegrator(org.apache.commons.math3.analysis.integration.IterativeLegendreGaussIntegrator) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) UnivariateIntegrator(org.apache.commons.math3.analysis.integration.UnivariateIntegrator)

Example 3 with UnivariateFunction

use of org.apache.commons.math3.analysis.UnivariateFunction in project GDSC-SMLM by aherbert.

the class FIRE method findMin.

private UnivariatePointValuePair findMin(UnivariatePointValuePair current, UnivariateOptimizer o, UnivariateFunction f, double qValue, double factor) {
    try {
        BracketFinder bracket = new BracketFinder();
        bracket.search(f, GoalType.MINIMIZE, qValue * factor, qValue / factor);
        UnivariatePointValuePair next = o.optimize(GoalType.MINIMIZE, new MaxEval(3000), new SearchInterval(bracket.getLo(), bracket.getHi(), bracket.getMid()), new UnivariateObjectiveFunction(f));
        if (next == null)
            return current;
        //System.out.printf("LineMin [%.1f]  %f = %f\n", factor, next.getPoint(), next.getValue());
        if (current != null)
            return (next.getValue() < current.getValue()) ? next : current;
        return next;
    } catch (Exception e) {
        return current;
    }
}
Also used : MaxEval(org.apache.commons.math3.optim.MaxEval) SearchInterval(org.apache.commons.math3.optim.univariate.SearchInterval) UnivariateObjectiveFunction(org.apache.commons.math3.optim.univariate.UnivariateObjectiveFunction) BracketFinder(org.apache.commons.math3.optim.univariate.BracketFinder) UnivariatePointValuePair(org.apache.commons.math3.optim.univariate.UnivariatePointValuePair) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException)

Example 4 with UnivariateFunction

use of org.apache.commons.math3.analysis.UnivariateFunction in project GDSC-SMLM by aherbert.

the class BoundedNonLinearConjugateGradientOptimizer method doOptimize.

/** {@inheritDoc} */
@Override
protected PointValuePair doOptimize() {
    final ConvergenceChecker<PointValuePair> checker = getConvergenceChecker();
    final double[] point = getStartPoint();
    final GoalType goal = getGoalType();
    final int n = point.length;
    sign = (goal == GoalType.MINIMIZE) ? -1 : 1;
    double[] unbounded = point.clone();
    applyBounds(point);
    double[] r = computeObjectiveGradient(point);
    checkGradients(r, unbounded);
    if (goal == GoalType.MINIMIZE) {
        for (int i = 0; i < n; i++) {
            r[i] = -r[i];
        }
    }
    // Initial search direction.
    double[] steepestDescent = preconditioner.precondition(point, r);
    double[] searchDirection = steepestDescent.clone();
    double delta = 0;
    for (int i = 0; i < n; ++i) {
        delta += r[i] * searchDirection[i];
    }
    // Used for non-gradient based line search
    LineSearch line = null;
    double rel = 1e-6;
    double abs = 1e-10;
    if (getConvergenceChecker() instanceof SimpleValueChecker) {
        rel = ((SimpleValueChecker) getConvergenceChecker()).getRelativeThreshold();
        abs = ((SimpleValueChecker) getConvergenceChecker()).getRelativeThreshold();
    }
    line = new LineSearch(Math.sqrt(rel), Math.sqrt(abs));
    PointValuePair current = null;
    int maxEval = getMaxEvaluations();
    while (true) {
        incrementIterationCount();
        final double objective = computeObjectiveValue(point);
        PointValuePair previous = current;
        current = new PointValuePair(point, objective);
        if (previous != null && checker.converged(getIterations(), previous, current)) {
            // We have found an optimum.
            return current;
        }
        double step;
        if (useGradientLineSearch) {
            // Classic code using the gradient function for the line search:
            // Find the optimal step in the search direction.
            final UnivariateFunction lsf = new LineSearchFunction(point, searchDirection);
            final double uB;
            try {
                uB = findUpperBound(lsf, 0, initialStep);
                // Check if the bracket found a minimum. Otherwise just move to the new point.
                if (noBracket)
                    step = uB;
                else {
                    // XXX Last parameters is set to a value close to zero in order to
                    // work around the divergence problem in the "testCircleFitting"
                    // unit test (see MATH-439).
                    //System.out.printf("Bracket %f - %f - %f\n", 0., 1e-15, uB);
                    step = solver.solve(maxEval, lsf, 0, uB, 1e-15);
                    // Subtract used up evaluations.
                    maxEval -= solver.getEvaluations();
                }
            } catch (MathIllegalStateException e) {
                //System.out.printf("Failed to bracket %s @ %s\n", Arrays.toString(point), Arrays.toString(searchDirection));
                // Line search without gradient (as per Powell optimiser)
                final UnivariatePointValuePair optimum = line.search(point, searchDirection);
                step = optimum.getPoint();
            //throw e;
            }
        } else {
            // Line search without gradient (as per Powell optimiser)
            final UnivariatePointValuePair optimum = line.search(point, searchDirection);
            step = optimum.getPoint();
        }
        //System.out.printf("Step = %f x %s\n", step, Arrays.toString(searchDirection));
        for (int i = 0; i < point.length; ++i) {
            point[i] += step * searchDirection[i];
        }
        unbounded = point.clone();
        applyBounds(point);
        r = computeObjectiveGradient(point);
        checkGradients(r, unbounded);
        if (goal == GoalType.MINIMIZE) {
            for (int i = 0; i < n; ++i) {
                r[i] = -r[i];
            }
        }
        // Compute beta.
        final double deltaOld = delta;
        final double[] newSteepestDescent = preconditioner.precondition(point, r);
        delta = 0;
        for (int i = 0; i < n; ++i) {
            delta += r[i] * newSteepestDescent[i];
        }
        if (delta == 0)
            return new PointValuePair(point, computeObjectiveValue(point));
        final double beta;
        switch(updateFormula) {
            case FLETCHER_REEVES:
                beta = delta / deltaOld;
                break;
            case POLAK_RIBIERE:
                double deltaMid = 0;
                for (int i = 0; i < r.length; ++i) {
                    deltaMid += r[i] * steepestDescent[i];
                }
                beta = (delta - deltaMid) / deltaOld;
                break;
            default:
                // Should never happen.
                throw new MathInternalError();
        }
        steepestDescent = newSteepestDescent;
        // Compute conjugate search direction.
        if (getIterations() % n == 0 || beta < 0) {
            // Break conjugation: reset search direction.
            searchDirection = steepestDescent.clone();
        } else {
            // Compute new conjugate search direction.
            for (int i = 0; i < n; ++i) {
                searchDirection[i] = steepestDescent[i] + beta * searchDirection[i];
            }
        }
        // The gradient has already been adjusted for the search direction
        checkGradients(searchDirection, unbounded, -sign);
    }
}
Also used : UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) GoalType(org.apache.commons.math3.optim.nonlinear.scalar.GoalType) UnivariatePointValuePair(org.apache.commons.math3.optim.univariate.UnivariatePointValuePair) SimpleValueChecker(org.apache.commons.math3.optim.SimpleValueChecker) MathIllegalStateException(org.apache.commons.math3.exception.MathIllegalStateException) PointValuePair(org.apache.commons.math3.optim.PointValuePair) UnivariatePointValuePair(org.apache.commons.math3.optim.univariate.UnivariatePointValuePair) MathInternalError(org.apache.commons.math3.exception.MathInternalError)

Example 5 with UnivariateFunction

use of org.apache.commons.math3.analysis.UnivariateFunction in project gatk by broadinstitute.

the class CoverageModelEMComputeBlock method cloneWithUpdatedTargetUnexplainedVarianceTargetResolved.

/**
     * Performs the M-step for target-specific unexplained variance and clones the compute block
     * with the updated value.
     *
     * @param maxIters maximum number of iterations
     * @param psiUpperLimit upper limit for the unexplained variance
     * @param absTol absolute error tolerance (used in root finding)
     * @param relTol relative error tolerance (used in root finding)
     * @param numBisections number of bisections (used in root finding)
     * @param refinementDepth depth of search (used in root finding)
     *
     * @return a new instance of {@link CoverageModelEMComputeBlock}
     */
@QueriesICG
public CoverageModelEMComputeBlock cloneWithUpdatedTargetUnexplainedVarianceTargetResolved(final int maxIters, final double psiUpperLimit, final double absTol, final double relTol, final int numBisections, final int refinementDepth, final int numThreads) {
    Utils.validateArg(maxIters > 0, "At least one iteration is required");
    Utils.validateArg(psiUpperLimit >= 0, "The upper limit must be non-negative");
    Utils.validateArg(absTol >= 0, "The absolute error tolerance must be non-negative");
    Utils.validateArg(relTol >= 0, "The relative error tolerance must be non-negative");
    Utils.validateArg(numBisections >= 0, "The number of bisections must be non-negative");
    Utils.validateArg(refinementDepth >= 0, "The refinement depth must be non-negative");
    Utils.validateArg(numThreads > 0, "Number of execution threads must be positive");
    /* fetch the required caches */
    final INDArray Psi_t = getINDArrayFromCache(CoverageModelICGCacheNode.Psi_t);
    final INDArray M_st = getINDArrayFromCache(CoverageModelICGCacheNode.M_st);
    final INDArray Sigma_st = getINDArrayFromCache(CoverageModelICGCacheNode.Sigma_st);
    final INDArray gamma_s = getINDArrayFromCache(CoverageModelICGCacheNode.gamma_s);
    final INDArray B_st = getINDArrayFromCache(CoverageModelICGCacheNode.B_st);
    final ForkJoinPool forkJoinPool = new ForkJoinPool(numThreads);
    final List<ImmutablePair<Double, Integer>> res;
    try {
        res = forkJoinPool.submit(() -> {
            return IntStream.range(0, numTargets).parallel().mapToObj(ti -> {
                final UnivariateFunction objFunc = psi -> calculateTargetSpecificVarianceSolverObjectiveFunction(ti, psi, M_st, Sigma_st, gamma_s, B_st);
                final UnivariateFunction meritFunc = psi -> calculateTargetSpecificVarianceSolverMeritFunction(ti, psi, M_st, Sigma_st, gamma_s, B_st);
                final RobustBrentSolver solver = new RobustBrentSolver(relTol, absTol, CoverageModelGlobalConstants.DEFAULT_FUNCTION_EVALUATION_ACCURACY, meritFunc, numBisections, refinementDepth);
                double newPsi;
                try {
                    newPsi = solver.solve(maxIters, objFunc, 0, psiUpperLimit);
                } catch (NoBracketingException | TooManyEvaluationsException e) {
                    newPsi = Psi_t.getDouble(ti);
                }
                return new ImmutablePair<>(newPsi, solver.getEvaluations());
            }).collect(Collectors.toList());
        }).get();
    } catch (InterruptedException | ExecutionException ex) {
        throw new RuntimeException("Failure in concurrent update of target-specific unexplained variance");
    }
    final INDArray newPsi_t = Nd4j.create(res.stream().mapToDouble(p -> p.left).toArray(), Psi_t.shape());
    final int maxIterations = Collections.max(res.stream().mapToInt(p -> p.right).boxed().collect(Collectors.toList()));
    final double errNormInfinity = CoverageModelEMWorkspaceMathUtils.getINDArrayNormInfinity(newPsi_t.sub(Psi_t));
    return cloneWithUpdatedPrimitiveAndSignal(CoverageModelICGCacheNode.Psi_t, newPsi_t, SubroutineSignal.builder().put(StandardSubroutineSignals.RESIDUAL_ERROR_NORM, errNormInfinity).put(StandardSubroutineSignals.ITERATIONS, maxIterations).build());
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) NDArrayIndex(org.nd4j.linalg.indexing.NDArrayIndex) ParamUtils(org.broadinstitute.hellbender.utils.param.ParamUtils) RobustBrentSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver) Map(java.util.Map) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) ImmutableTriple(org.apache.commons.lang3.tuple.ImmutableTriple) Nd4j(org.nd4j.linalg.factory.Nd4j) FastMath(org.apache.commons.math3.util.FastMath) Collectors(java.util.stream.Collectors) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) Serializable(java.io.Serializable) ExecutionException(java.util.concurrent.ExecutionException) List(java.util.List) org.broadinstitute.hellbender.tools.coveragemodel.cachemanager(org.broadinstitute.hellbender.tools.coveragemodel.cachemanager) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ForkJoinPool(java.util.concurrent.ForkJoinPool) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) Utils(org.broadinstitute.hellbender.utils.Utils) VisibleForTesting(com.google.common.annotations.VisibleForTesting) Transforms(org.nd4j.linalg.ops.transforms.Transforms) Collections(java.util.Collections) NoBracketingException(org.apache.commons.math3.exception.NoBracketingException) UnivariateFunction(org.apache.commons.math3.analysis.UnivariateFunction) RobustBrentSolver(org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair) ExecutionException(java.util.concurrent.ExecutionException) ForkJoinPool(java.util.concurrent.ForkJoinPool)

Aggregations

UnivariateFunction (org.apache.commons.math3.analysis.UnivariateFunction)17 TooManyEvaluationsException (org.apache.commons.math3.exception.TooManyEvaluationsException)7 SimpsonIntegrator (org.apache.commons.math3.analysis.integration.SimpsonIntegrator)6 NoBracketingException (org.apache.commons.math3.exception.NoBracketingException)6 FastMath (org.apache.commons.math3.util.FastMath)6 VisibleForTesting (com.google.common.annotations.VisibleForTesting)4 List (java.util.List)4 Collectors (java.util.stream.Collectors)4 IntStream (java.util.stream.IntStream)4 Nonnull (javax.annotation.Nonnull)4 Nullable (javax.annotation.Nullable)4 ImmutablePair (org.apache.commons.lang3.tuple.ImmutablePair)4 ImmutableTriple (org.apache.commons.lang3.tuple.ImmutableTriple)4 UnivariateIntegrator (org.apache.commons.math3.analysis.integration.UnivariateIntegrator)4 RobustBrentSolver (org.broadinstitute.hellbender.tools.coveragemodel.math.RobustBrentSolver)4 Utils (org.broadinstitute.hellbender.utils.Utils)4 ParamUtils (org.broadinstitute.hellbender.utils.param.ParamUtils)4 INDArray (org.nd4j.linalg.api.ndarray.INDArray)4 Nd4j (org.nd4j.linalg.factory.Nd4j)4 NDArrayIndex (org.nd4j.linalg.indexing.NDArrayIndex)4