Search in sources :

Example 1 with RandomGeneratorAdapter

use of uk.ac.sussex.gdsc.core.utils.rng.RandomGeneratorAdapter in project GDSC-SMLM by aherbert.

the class PcPalmFitting method runBoundedOptimiser.

private PointValuePair runBoundedOptimiser(double[] initialSolution, double[] lowerB, double[] upperB, SumOfSquaresModelFunction function) {
    // Create the functions to optimise
    final ObjectiveFunction objective = new ObjectiveFunction(new SumOfSquaresMultivariateFunction(function));
    final ObjectiveFunctionGradient gradient = new ObjectiveFunctionGradient(new SumOfSquaresMultivariateVectorFunction(function));
    final boolean debug = false;
    // Try a gradient optimiser since this will produce a deterministic solution
    PointValuePair optimum = null;
    boundedEvaluations = 0;
    final MaxEval maxEvaluations = new MaxEval(2000);
    MultivariateOptimizer opt = null;
    for (int iteration = 0; iteration <= settings.fitRestarts; iteration++) {
        try {
            final double relativeThreshold = 1e-6;
            opt = new BoundedNonLinearConjugateGradientOptimizer(BoundedNonLinearConjugateGradientOptimizer.Formula.FLETCHER_REEVES, new SimpleValueChecker(relativeThreshold, -1));
            optimum = opt.optimize(maxEvaluations, gradient, objective, GoalType.MINIMIZE, new InitialGuess((optimum == null) ? initialSolution : optimum.getPointRef()), new SimpleBounds(lowerB, upperB));
            if (debug) {
                System.out.printf("Bounded Iter %d = %g (%d)\n", iteration, optimum.getValue(), opt.getEvaluations());
            }
        } catch (final RuntimeException ex) {
            // No need to restart
            break;
        } finally {
            if (opt != null) {
                boundedEvaluations += opt.getEvaluations();
            }
        }
    }
    // Try a CMAES optimiser which is non-deterministic. To overcome this we perform restarts.
    // CMAESOptimiser based on Matlab code:
    // https://www.lri.fr/~hansen/cmaes.m
    // Take the defaults from the Matlab documentation
    final double stopFitness = 0;
    final boolean isActiveCma = true;
    final int diagonalOnly = 0;
    final int checkFeasableCount = 1;
    final RandomGenerator random = new RandomGeneratorAdapter(UniformRandomProviders.create());
    final boolean generateStatistics = false;
    final ConvergenceChecker<PointValuePair> checker = new SimpleValueChecker(1e-6, 1e-10);
    // The sigma determines the search range for the variables. It should be 1/3 of the initial
    // search region.
    final double[] range = new double[lowerB.length];
    for (int i = 0; i < lowerB.length; i++) {
        range[i] = (upperB[i] - lowerB[i]) / 3;
    }
    final OptimizationData sigma = new CMAESOptimizer.Sigma(range);
    final OptimizationData popSize = new CMAESOptimizer.PopulationSize((int) (4 + Math.floor(3 * Math.log(initialSolution.length))));
    final SimpleBounds bounds = new SimpleBounds(lowerB, upperB);
    opt = new CMAESOptimizer(maxEvaluations.getMaxEval(), stopFitness, isActiveCma, diagonalOnly, checkFeasableCount, random, generateStatistics, checker);
    // Restart the optimiser several times and take the best answer.
    for (int iteration = 0; iteration <= settings.fitRestarts; iteration++) {
        try {
            // Start from the initial solution
            final PointValuePair constrainedSolution = opt.optimize(new InitialGuess(initialSolution), objective, GoalType.MINIMIZE, bounds, sigma, popSize, maxEvaluations);
            if (debug) {
                System.out.printf("CMAES Iter %d initial = %g (%d)\n", iteration, constrainedSolution.getValue(), opt.getEvaluations());
            }
            boundedEvaluations += opt.getEvaluations();
            if (optimum == null || constrainedSolution.getValue() < optimum.getValue()) {
                optimum = constrainedSolution;
            }
        } catch (final TooManyEvaluationsException | TooManyIterationsException ex) {
        // Ignore
        } finally {
            boundedEvaluations += maxEvaluations.getMaxEval();
        }
        if (optimum == null) {
            continue;
        }
        try {
            // Also restart from the current optimum
            final PointValuePair constrainedSolution = opt.optimize(new InitialGuess(optimum.getPointRef()), objective, GoalType.MINIMIZE, bounds, sigma, popSize, maxEvaluations);
            if (debug) {
                System.out.printf("CMAES Iter %d restart = %g (%d)\n", iteration, constrainedSolution.getValue(), opt.getEvaluations());
            }
            if (constrainedSolution.getValue() < optimum.getValue()) {
                optimum = constrainedSolution;
            }
        } catch (final TooManyEvaluationsException | TooManyIterationsException ex) {
        // Ignore
        } finally {
            boundedEvaluations += maxEvaluations.getMaxEval();
        }
    }
    return optimum;
}
Also used : MultivariateOptimizer(org.apache.commons.math3.optim.nonlinear.scalar.MultivariateOptimizer) MaxEval(org.apache.commons.math3.optim.MaxEval) InitialGuess(org.apache.commons.math3.optim.InitialGuess) SimpleBounds(org.apache.commons.math3.optim.SimpleBounds) ObjectiveFunction(org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction) SimpleValueChecker(org.apache.commons.math3.optim.SimpleValueChecker) RandomGenerator(org.apache.commons.math3.random.RandomGenerator) PointValuePair(org.apache.commons.math3.optim.PointValuePair) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) BoundedNonLinearConjugateGradientOptimizer(uk.ac.sussex.gdsc.smlm.math3.optim.nonlinear.scalar.gradient.BoundedNonLinearConjugateGradientOptimizer) TooManyIterationsException(org.apache.commons.math3.exception.TooManyIterationsException) RandomGeneratorAdapter(uk.ac.sussex.gdsc.core.utils.rng.RandomGeneratorAdapter) CMAESOptimizer(org.apache.commons.math3.optim.nonlinear.scalar.noderiv.CMAESOptimizer) ObjectiveFunctionGradient(org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunctionGradient) OptimizationData(org.apache.commons.math3.optim.OptimizationData)

Example 2 with RandomGeneratorAdapter

use of uk.ac.sussex.gdsc.core.utils.rng.RandomGeneratorAdapter in project GDSC-SMLM by aherbert.

the class TrackPopulationAnalysisTest method canComputeExponentialLogLikelihood.

@Test
void canComputeExponentialLogLikelihood() {
    final RandomGenerator rng = new RandomGeneratorAdapter(RngUtils.createWithFixedSeed());
    final double delta = 1e-6;
    for (final double mean : new double[] { 2, 10 }) {
        // Create exponential data
        ExponentialDistribution ed = new ExponentialDistribution(rng, mean);
        final double[] values = ed.sample(1000);
        // Discretise
        final double factor = mean / 10;
        for (int i = 0; i < values.length; i++) {
            values[i] = Math.round(values[i] / factor) * factor;
        }
        // Histogram these
        final double[][] h = MathUtils.cumulativeHistogram(values, false);
        final ExponentialDataFunction ef = ExponentialDataFunction.fromCdf(h);
        // Test with different means
        for (final double du : new double[] { -0.1, 0, 0.1 }) {
            final double mu = mean + du;
            ed = new ExponentialDistribution(rng, mu);
            double ll = 0;
            for (final double x : values) {
                ll += ed.logDensity(x);
            }
            final double ll2 = ef.value(mu);
            Assertions.assertEquals(ll, ll2, 1e-10, "Log likelihood");
            // Check for gradient
            final double ll2a = ef.value(mu + delta);
            final double ll2b = ef.value(mu - delta);
            final double gradient = (ll2a - ll2b) / (2 * delta);
            Assertions.assertEquals(gradient, ef.gradient(mu), Math.abs(gradient) * 1e-4, "Gradient");
        }
    }
}
Also used : RandomGeneratorAdapter(uk.ac.sussex.gdsc.core.utils.rng.RandomGeneratorAdapter) ExponentialDataFunction(uk.ac.sussex.gdsc.smlm.ij.plugins.TrackPopulationAnalysis.ExponentialDataFunction) ExponentialDistribution(org.apache.commons.math3.distribution.ExponentialDistribution) RandomGenerator(org.apache.commons.math3.random.RandomGenerator) Test(org.junit.jupiter.api.Test)

Example 3 with RandomGeneratorAdapter

use of uk.ac.sussex.gdsc.core.utils.rng.RandomGeneratorAdapter in project GDSC-SMLM by aherbert.

the class BinomialFitter method fitBinomial.

/**
 * Fit the binomial distribution (n,p) to the cumulative histogram. Performs fitting assuming a
 * fixed n value and attempts to optimise p.
 *
 * @param histogram The input histogram
 * @param mean The histogram mean (used to estimate p). Calculated if NaN.
 * @param n The n to evaluate
 * @param zeroTruncated True if the model should ignore n=0 (zero-truncated binomial)
 * @return The best fit (n, p)
 * @throws IllegalArgumentException If any of the input data values are negative
 * @throws IllegalArgumentException If any fitting a zero truncated binomial and there are no
 *         values above zero
 */
public PointValuePair fitBinomial(double[] histogram, double mean, int n, boolean zeroTruncated) {
    if (Double.isNaN(mean)) {
        mean = getMean(histogram);
    }
    if (zeroTruncated && histogram[0] > 0) {
        log("Fitting zero-truncated histogram but there are zero values - " + "Renormalising to ignore zero");
        double cumul = 0;
        for (int i = 1; i < histogram.length; i++) {
            cumul += histogram[i];
        }
        if (cumul == 0) {
            throw new IllegalArgumentException("Fitting zero-truncated histogram but there are no non-zero values");
        }
        histogram[0] = 0;
        for (int i = 1; i < histogram.length; i++) {
            histogram[i] /= cumul;
        }
    }
    final int nFittedPoints = Math.min(histogram.length, n + 1) - ((zeroTruncated) ? 1 : 0);
    if (nFittedPoints < 1) {
        log("No points to fit (%d): Histogram.length = %d, n = %d, zero-truncated = %b", nFittedPoints, histogram.length, n, zeroTruncated);
        return null;
    }
    // The model is only fitting the probability p
    // For a binomial n*p = mean => p = mean/n
    final double[] initialSolution = new double[] { Math.min(mean / n, 1) };
    // Create the function
    final BinomialModelFunction function = new BinomialModelFunction(histogram, n, zeroTruncated);
    final double[] lB = new double[1];
    final double[] uB = new double[] { 1 };
    final SimpleBounds bounds = new SimpleBounds(lB, uB);
    // Fit
    // CMAESOptimizer or BOBYQAOptimizer support bounds
    // CMAESOptimiser based on Matlab code:
    // https://www.lri.fr/~hansen/cmaes.m
    // Take the defaults from the Matlab documentation
    final int maxIterations = 2000;
    final double stopFitness = 0;
    final boolean isActiveCma = true;
    final int diagonalOnly = 0;
    final int checkFeasableCount = 1;
    final RandomGenerator random = new RandomGeneratorAdapter(UniformRandomProviders.create());
    final boolean generateStatistics = false;
    final ConvergenceChecker<PointValuePair> checker = new SimpleValueChecker(1e-6, 1e-10);
    // The sigma determines the search range for the variables. It should be 1/3 of the initial
    // search region.
    final OptimizationData sigma = new CMAESOptimizer.Sigma(new double[] { (uB[0] - lB[0]) / 3 });
    final OptimizationData popSize = new CMAESOptimizer.PopulationSize((int) (4 + Math.floor(3 * Math.log(2))));
    try {
        PointValuePair solution = null;
        boolean noRefit = maximumLikelihood;
        if (n == 1 && zeroTruncated) {
            // No need to fit
            solution = new PointValuePair(new double[] { 1 }, 0);
            noRefit = true;
        } else {
            final GoalType goalType = (maximumLikelihood) ? GoalType.MAXIMIZE : GoalType.MINIMIZE;
            // Iteratively fit
            final CMAESOptimizer opt = new CMAESOptimizer(maxIterations, stopFitness, isActiveCma, diagonalOnly, checkFeasableCount, random, generateStatistics, checker);
            for (int iteration = 0; iteration <= fitRestarts; iteration++) {
                try {
                    // Start from the initial solution
                    final PointValuePair result = opt.optimize(new InitialGuess(initialSolution), new ObjectiveFunction(function), goalType, bounds, sigma, popSize, new MaxIter(maxIterations), new MaxEval(maxIterations * 2));
                    // opt.getEvaluations());
                    if (solution == null || result.getValue() < solution.getValue()) {
                        solution = result;
                    }
                } catch (final TooManyEvaluationsException | TooManyIterationsException ex) {
                // No solution
                }
                if (solution == null) {
                    continue;
                }
                try {
                    // Also restart from the current optimum
                    final PointValuePair result = opt.optimize(new InitialGuess(solution.getPointRef()), new ObjectiveFunction(function), goalType, bounds, sigma, popSize, new MaxIter(maxIterations), new MaxEval(maxIterations * 2));
                    // opt.getEvaluations());
                    if (result.getValue() < solution.getValue()) {
                        solution = result;
                    }
                } catch (final TooManyEvaluationsException | TooManyIterationsException ex) {
                // No solution
                }
            }
            if (solution == null) {
                return null;
            }
        }
        if (noRefit) {
            // Although we fit the log-likelihood, return the sum-of-squares to allow
            // comparison across different n
            final double p = solution.getPointRef()[0];
            double ss = 0;
            final double[] obs = function.pvalues;
            final double[] exp = function.getP(p);
            for (int i = 0; i < obs.length; i++) {
                ss += (obs[i] - exp[i]) * (obs[i] - exp[i]);
            }
            return new PointValuePair(solution.getPointRef(), ss);
        // We can do a LVM refit if the number of fitted points is more than 1.
        } else if (nFittedPoints > 1) {
            // Improve SS fit with a gradient based LVM optimizer
            final LevenbergMarquardtOptimizer optimizer = new LevenbergMarquardtOptimizer();
            try {
                final BinomialModelFunctionGradient gradientFunction = new BinomialModelFunctionGradient(histogram, n, zeroTruncated);
                // @formatter:off
                final LeastSquaresProblem problem = new LeastSquaresBuilder().maxEvaluations(Integer.MAX_VALUE).maxIterations(3000).start(solution.getPointRef()).target(gradientFunction.pvalues).weight(new DiagonalMatrix(gradientFunction.getWeights())).model(gradientFunction, gradientFunction::jacobian).build();
                // @formatter:on
                final Optimum lvmSolution = optimizer.optimize(problem);
                // Check the pValue is valid since the LVM is not bounded.
                final double p = lvmSolution.getPoint().getEntry(0);
                if (p <= 1 && p >= 0) {
                    // True if the weights are 1
                    final double ss = lvmSolution.getResiduals().dotProduct(lvmSolution.getResiduals());
                    // ss += (obs[i] - exp[i]) * (obs[i] - exp[i]);
                    if (ss < solution.getValue()) {
                        // MathUtils.rounded(100 * (solution.getValue() - ss) / solution.getValue(), 4));
                        return new PointValuePair(lvmSolution.getPoint().toArray(), ss);
                    }
                }
            } catch (final TooManyIterationsException ex) {
                log("Failed to re-fit: Too many iterations: %s", ex.getMessage());
            } catch (final ConvergenceException ex) {
                log("Failed to re-fit: %s", ex.getMessage());
            } catch (final Exception ex) {
            // Ignore this ...
            }
        }
        return solution;
    } catch (final RuntimeException ex) {
        log("Failed to fit Binomial distribution with N=%d : %s", n, ex.getMessage());
    }
    return null;
}
Also used : InitialGuess(org.apache.commons.math3.optim.InitialGuess) MaxEval(org.apache.commons.math3.optim.MaxEval) SimpleBounds(org.apache.commons.math3.optim.SimpleBounds) ObjectiveFunction(org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction) SimpleValueChecker(org.apache.commons.math3.optim.SimpleValueChecker) RandomGenerator(org.apache.commons.math3.random.RandomGenerator) PointValuePair(org.apache.commons.math3.optim.PointValuePair) LeastSquaresBuilder(org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) DiagonalMatrix(org.apache.commons.math3.linear.DiagonalMatrix) ConvergenceException(org.apache.commons.math3.exception.ConvergenceException) TooManyIterationsException(org.apache.commons.math3.exception.TooManyIterationsException) LeastSquaresProblem(org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem) RandomGeneratorAdapter(uk.ac.sussex.gdsc.core.utils.rng.RandomGeneratorAdapter) CMAESOptimizer(org.apache.commons.math3.optim.nonlinear.scalar.noderiv.CMAESOptimizer) GoalType(org.apache.commons.math3.optim.nonlinear.scalar.GoalType) ConvergenceException(org.apache.commons.math3.exception.ConvergenceException) TooManyIterationsException(org.apache.commons.math3.exception.TooManyIterationsException) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) Optimum(org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum) LevenbergMarquardtOptimizer(org.apache.commons.math3.fitting.leastsquares.LevenbergMarquardtOptimizer) OptimizationData(org.apache.commons.math3.optim.OptimizationData) MaxIter(org.apache.commons.math3.optim.MaxIter)

Example 4 with RandomGeneratorAdapter

use of uk.ac.sussex.gdsc.core.utils.rng.RandomGeneratorAdapter in project GDSC-SMLM by aherbert.

the class MaximumLikelihoodFitter method computeFit.

@Override
public FitStatus computeFit(double[] y, double[] fx, double[] a, double[] parametersVariance) {
    final int n = y.length;
    final LikelihoodWrapper maximumLikelihoodFunction = createLikelihoodWrapper((NonLinearFunction) function, n, y, a);
    @SuppressWarnings("rawtypes") BaseOptimizer baseOptimiser = null;
    try {
        final double[] startPoint = getInitialSolution(a);
        PointValuePair optimum = null;
        if (searchMethod == SearchMethod.POWELL || searchMethod == SearchMethod.POWELL_BOUNDED || searchMethod == SearchMethod.POWELL_ADAPTER) {
            // Non-differentiable version using Powell Optimiser
            // Background: see Numerical Recipes 10.5 (Direction Set (Powell's) method).
            // The optimiser could be extended to implement bounds on the directions moved.
            // However the mapping adapter seems to work OK.
            final boolean basisConvergence = false;
            // Perhaps these thresholds should be tighter?
            // The default is to use the sqrt() of the overall tolerance
            // final double lineRel = Math.sqrt(relativeThreshold);
            // final double lineAbs = Math.sqrt(absoluteThreshold);
            // final double lineRel = relativeThreshold * 1e2;
            // final double lineAbs = absoluteThreshold * 1e2;
            // Since we are fitting only a small number of parameters then just use the same tolerance
            // for each search direction
            final double lineRel = relativeThreshold;
            final double lineAbs = absoluteThreshold;
            final CustomPowellOptimizer o = new CustomPowellOptimizer(relativeThreshold, absoluteThreshold, lineRel, lineAbs, null, basisConvergence);
            baseOptimiser = o;
            OptimizationData maxIterationData = null;
            if (getMaxIterations() > 0) {
                maxIterationData = new MaxIter(getMaxIterations());
            }
            if (searchMethod == SearchMethod.POWELL_ADAPTER) {
                // Try using the mapping adapter for a bounded Powell search
                final MultivariateFunctionMappingAdapter adapter = new MultivariateFunctionMappingAdapter(new MultivariateLikelihood(maximumLikelihoodFunction), lower, upper);
                optimum = o.optimize(maxIterationData, new MaxEval(getMaxEvaluations()), new ObjectiveFunction(adapter), GoalType.MINIMIZE, new InitialGuess(adapter.boundedToUnbounded(startPoint)));
                final double[] solution = adapter.unboundedToBounded(optimum.getPointRef());
                optimum = new PointValuePair(solution, optimum.getValue());
            } else {
                if (powellFunction == null) {
                    powellFunction = new MultivariateLikelihood(maximumLikelihoodFunction);
                }
                // Update the maximum likelihood function in the Powell function wrapper
                powellFunction.fun = maximumLikelihoodFunction;
                final OptimizationData positionChecker = null;
                // new org.apache.commons.math3.optim.PositionChecker(relativeThreshold,
                // absoluteThreshold);
                SimpleBounds simpleBounds = null;
                if (powellFunction.isMapped()) {
                    final MappedMultivariateLikelihood adapter = (MappedMultivariateLikelihood) powellFunction;
                    if (searchMethod == SearchMethod.POWELL_BOUNDED) {
                        simpleBounds = new SimpleBounds(adapter.map(lower), adapter.map(upper));
                    }
                    optimum = o.optimize(maxIterationData, new MaxEval(getMaxEvaluations()), new ObjectiveFunction(powellFunction), GoalType.MINIMIZE, new InitialGuess(adapter.map(startPoint)), positionChecker, simpleBounds);
                    final double[] solution = adapter.unmap(optimum.getPointRef());
                    optimum = new PointValuePair(solution, optimum.getValue());
                } else {
                    if (searchMethod == SearchMethod.POWELL_BOUNDED) {
                        simpleBounds = new SimpleBounds(lower, upper);
                    }
                    optimum = o.optimize(maxIterationData, new MaxEval(getMaxEvaluations()), new ObjectiveFunction(powellFunction), GoalType.MINIMIZE, new InitialGuess(startPoint), positionChecker, simpleBounds);
                }
            }
        } else if (searchMethod == SearchMethod.BOBYQA) {
            // Differentiable approximation using Powell's BOBYQA algorithm.
            // This is slower than the Powell optimiser and requires a high number of evaluations.
            final int numberOfInterpolationpoints = this.getNumberOfFittedParameters() + 2;
            final BOBYQAOptimizer o = new BOBYQAOptimizer(numberOfInterpolationpoints);
            baseOptimiser = o;
            optimum = o.optimize(new MaxEval(getMaxEvaluations()), new ObjectiveFunction(new MultivariateLikelihood(maximumLikelihoodFunction)), GoalType.MINIMIZE, new InitialGuess(startPoint), new SimpleBounds(lower, upper));
        } else if (searchMethod == SearchMethod.CMAES) {
            // TODO - Understand why the CMAES optimiser does not fit very well on test data. It appears
            // to converge too early and the likelihood scores are not as low as the other optimisers.
            // The CMAESOptimiser is based on Matlab code:
            // https://www.lri.fr/~hansen/cmaes.m
            // Take the defaults from the Matlab documentation
            final double stopFitness = 0;
            final boolean isActiveCma = true;
            final int diagonalOnly = 0;
            final int checkFeasableCount = 1;
            final RandomGenerator random = new RandomGeneratorAdapter(UniformRandomProviders.create());
            final boolean generateStatistics = false;
            // The sigma determines the search range for the variables. It should be 1/3 of the initial
            // search region.
            final double[] sigma = new double[lower.length];
            for (int i = 0; i < sigma.length; i++) {
                sigma[i] = (upper[i] - lower[i]) / 3;
            }
            int popSize = (int) (4 + Math.floor(3 * Math.log(sigma.length)));
            // The CMAES optimiser is random and restarting can overcome problems with quick
            // convergence.
            // The Apache commons documentations states that convergence should occur between 30N and
            // 300N^2
            // function evaluations
            final int n30 = Math.min(sigma.length * sigma.length * 30, getMaxEvaluations() / 2);
            evaluations = 0;
            final OptimizationData[] data = new OptimizationData[] { new InitialGuess(startPoint), new CMAESOptimizer.PopulationSize(popSize), new MaxEval(getMaxEvaluations()), new CMAESOptimizer.Sigma(sigma), new ObjectiveFunction(new MultivariateLikelihood(maximumLikelihoodFunction)), GoalType.MINIMIZE, new SimpleBounds(lower, upper) };
            // Iterate to prevent early convergence
            int repeat = 0;
            while (evaluations < n30) {
                if (repeat++ > 1) {
                    // Update the start point and population size
                    if (optimum != null) {
                        data[0] = new InitialGuess(optimum.getPointRef());
                    }
                    popSize *= 2;
                    data[1] = new CMAESOptimizer.PopulationSize(popSize);
                }
                final CMAESOptimizer o = new CMAESOptimizer(getMaxIterations(), stopFitness, isActiveCma, diagonalOnly, checkFeasableCount, random, generateStatistics, new SimpleValueChecker(relativeThreshold, absoluteThreshold));
                baseOptimiser = o;
                final PointValuePair result = o.optimize(data);
                iterations += o.getIterations();
                evaluations += o.getEvaluations();
                if (optimum == null || result.getValue() < optimum.getValue()) {
                    optimum = result;
                }
            }
            // Prevent incrementing the iterations again
            baseOptimiser = null;
        } else {
            // The line search algorithm often fails. This is due to searching into a region where the
            // function evaluates to a negative so has been clipped. This means the upper bound of the
            // line cannot be found.
            // Note that running it on an easy problem (200 photons with fixed fitting (no background))
            // the algorithm does sometimes produces results better than the Powell algorithm but it is
            // slower.
            final BoundedNonLinearConjugateGradientOptimizer o = new BoundedNonLinearConjugateGradientOptimizer((searchMethod == SearchMethod.CONJUGATE_GRADIENT_FR) ? Formula.FLETCHER_REEVES : Formula.POLAK_RIBIERE, new SimpleValueChecker(relativeThreshold, absoluteThreshold));
            baseOptimiser = o;
            // Note: The gradients may become unstable at the edge of the bounds. Or they will not
            // change direction if the true solution is on the bounds since the gradient will always
            // continue towards the bounds. This is key to the conjugate gradient method. It searches
            // along a vector until the direction of the gradient is in the opposite direction (using
            // dot products, i.e. cosine of angle between them)
            // NR 10.7 states there is no advantage of the variable metric DFP or BFGS methods over
            // conjugate gradient methods. So I will try these first.
            // Try this:
            // Adapt the conjugate gradient optimiser to use the gradient to pick the search direction
            // and then for the line minimisation. However if the function is out of bounds then clip
            // the variables at the bounds and continue.
            // If the current point is at the bounds and the gradient is to continue out of bounds then
            // clip the gradient too.
            // Or: just use the gradient for the search direction then use the line minimisation/rest
            // as per the Powell optimiser. The bounds should limit the search.
            // I tried a Bounded conjugate gradient optimiser with clipped variables:
            // This sometimes works. However when the variables go a long way out of the expected range
            // the gradients can have vastly different magnitudes. This results in the algorithm
            // stalling since the gradients can be close to zero and the some of the parameters are no
            // longer adjusted. Perhaps this can be looked for and the algorithm then gives up and
            // resorts to a Powell optimiser from the current point.
            // Changed the bracketing step to very small (default is 1, changed to 0.001). This improves
            // the performance. The gradient direction is very sensitive to small changes in the
            // coordinates so a tighter bracketing of the line search helps.
            // Tried using a non-gradient method for the line search copied from the Powell optimiser:
            // This also works when the bracketing step is small but the number of iterations is higher.
            // 24.10.2014: I have tried to get conjugate gradient to work but the gradient function
            // must not behave suitably for the optimiser. In the current state both methods of using a
            // Bounded Conjugate Gradient Optimiser perform poorly relative to other optimisers:
            // Simulated : n=1000, signal=200, x=0.53, y=0.47
            // LVM : n=1000, signal=171, x=0.537, y=0.471 (1.003s)
            // Powell : n=1000, signal=187, x=0.537, y=0.48 (1.238s)
            // Gradient based PR (constrained): n=858, signal=161, x=0.533, y=0.474 (2.54s)
            // Gradient based PR (bounded): n=948, signal=161, x=0.533, y=0.473 (2.67s)
            // Non-gradient based : n=1000, signal=151.47, x=0.535, y=0.474 (1.626s)
            // The conjugate optimisers are slower, under predict the signal by the most and in the case
            // of the gradient based optimiser, fail to converge on some problems. This is worse when
            // constrained fitting is used and not tightly bounded fitting.
            // I will leave the code in as an option but would not recommend using it. I may remove it
            // in the future.
            // Note: It is strange that the non-gradient based line minimisation is more successful.
            // It may be that the gradient function is not accurate (due to round off error) or that it
            // is simply wrong when far from the optimum. My JUnit tests only evaluate the function
            // within the expected range of the answer.
            // Note the default step size on the Powell optimiser is 1 but the initial directions are
            // unit vectors.
            // So our bracketing step should be a minimum of 1 / average length of the first gradient
            // vector to prevent the first step being too large when bracketing.
            final double[] gradient = new double[startPoint.length];
            maximumLikelihoodFunction.likelihood(startPoint, gradient);
            double length = 0;
            for (final double d : gradient) {
                length += d * d;
            }
            final double bracketingStep = Math.min(0.001, ((length > 1) ? 1.0 / length : 1));
            o.setUseGradientLineSearch(gradientLineMinimisation);
            optimum = o.optimize(new MaxEval(getMaxEvaluations()), new ObjectiveFunctionGradient(new MultivariateVectorLikelihood(maximumLikelihoodFunction)), new ObjectiveFunction(new MultivariateLikelihood(maximumLikelihoodFunction)), GoalType.MINIMIZE, new InitialGuess(startPoint), new SimpleBounds(lowerConstraint, upperConstraint), new BoundedNonLinearConjugateGradientOptimizer.BracketingStep(bracketingStep));
        }
        if (optimum == null) {
            return FitStatus.FAILED_TO_CONVERGE;
        }
        final double[] solution = optimum.getPointRef();
        setSolution(a, solution);
        if (parametersVariance != null) {
            // Compute assuming a Poisson process.
            // Note:
            // If using a Poisson-Gamma-Gaussian model then these will be incorrect.
            // However the variance for the position estimates of a 2D PSF can be
            // scaled by a factor of 2 as in Mortensen, et al (2010) Nature Methods 7, 377-383, SI 4.3.
            // Since the type of function is unknown this cannot be done here.
            final FisherInformationMatrix m = new FisherInformationMatrix(maximumLikelihoodFunction.fisherInformation(solution));
            setDeviations(parametersVariance, m);
        }
        // Reverse negative log likelihood for maximum likelihood score
        value = -optimum.getValue();
    } catch (final TooManyIterationsException ex) {
        return FitStatus.TOO_MANY_ITERATIONS;
    } catch (final TooManyEvaluationsException ex) {
        return FitStatus.TOO_MANY_EVALUATIONS;
    } catch (final ConvergenceException ex) {
        // Occurs when QR decomposition fails - mark as a singular non-linear model (no solution)
        return FitStatus.SINGULAR_NON_LINEAR_MODEL;
    } catch (final Exception ex) {
        Logger.getLogger(getClass().getName()).log(Level.SEVERE, "Failed to fit", ex);
        return FitStatus.UNKNOWN;
    } finally {
        if (baseOptimiser != null) {
            iterations += baseOptimiser.getIterations();
            evaluations += baseOptimiser.getEvaluations();
        }
    }
    // Check this as likelihood functions can go wrong
    if (Double.isInfinite(value) || Double.isNaN(value)) {
        return FitStatus.INVALID_LIKELIHOOD;
    }
    return FitStatus.OK;
}
Also used : MaxEval(org.apache.commons.math3.optim.MaxEval) InitialGuess(org.apache.commons.math3.optim.InitialGuess) BOBYQAOptimizer(org.apache.commons.math3.optim.nonlinear.scalar.noderiv.BOBYQAOptimizer) SimpleBounds(org.apache.commons.math3.optim.SimpleBounds) ObjectiveFunction(org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction) SimpleValueChecker(org.apache.commons.math3.optim.SimpleValueChecker) RandomGenerator(org.apache.commons.math3.random.RandomGenerator) PointValuePair(org.apache.commons.math3.optim.PointValuePair) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) ConvergenceException(org.apache.commons.math3.exception.ConvergenceException) BoundedNonLinearConjugateGradientOptimizer(uk.ac.sussex.gdsc.smlm.math3.optim.nonlinear.scalar.gradient.BoundedNonLinearConjugateGradientOptimizer) TooManyIterationsException(org.apache.commons.math3.exception.TooManyIterationsException) RandomGeneratorAdapter(uk.ac.sussex.gdsc.core.utils.rng.RandomGeneratorAdapter) BaseOptimizer(org.apache.commons.math3.optim.BaseOptimizer) CMAESOptimizer(org.apache.commons.math3.optim.nonlinear.scalar.noderiv.CMAESOptimizer) FisherInformationMatrix(uk.ac.sussex.gdsc.smlm.fitting.FisherInformationMatrix) LikelihoodWrapper(uk.ac.sussex.gdsc.smlm.function.LikelihoodWrapper) PoissonLikelihoodWrapper(uk.ac.sussex.gdsc.smlm.function.PoissonLikelihoodWrapper) PoissonGammaGaussianLikelihoodWrapper(uk.ac.sussex.gdsc.smlm.function.PoissonGammaGaussianLikelihoodWrapper) PoissonGaussianLikelihoodWrapper(uk.ac.sussex.gdsc.smlm.function.PoissonGaussianLikelihoodWrapper) ConvergenceException(org.apache.commons.math3.exception.ConvergenceException) TooManyIterationsException(org.apache.commons.math3.exception.TooManyIterationsException) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) ObjectiveFunctionGradient(org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunctionGradient) MultivariateFunctionMappingAdapter(org.apache.commons.math3.optim.nonlinear.scalar.MultivariateFunctionMappingAdapter) OptimizationData(org.apache.commons.math3.optim.OptimizationData) CustomPowellOptimizer(uk.ac.sussex.gdsc.smlm.math3.optim.nonlinear.scalar.noderiv.CustomPowellOptimizer) MaxIter(org.apache.commons.math3.optim.MaxIter)

Example 5 with RandomGeneratorAdapter

use of uk.ac.sussex.gdsc.core.utils.rng.RandomGeneratorAdapter in project GDSC-SMLM by aherbert.

the class JumpDistanceAnalysis method createCmaesOptimizer.

private static CMAESOptimizer createCmaesOptimizer() {
    final double rel = 1e-8;
    final double abs = 1e-10;
    final int maxIterations = 2000;
    final double stopFitness = 0;
    final boolean isActiveCma = true;
    final int diagonalOnly = 20;
    final int checkFeasableCount = 1;
    final RandomGenerator random = new RandomGeneratorAdapter(UniformRandomProviders.create());
    final boolean generateStatistics = false;
    final ConvergenceChecker<PointValuePair> checker = new SimpleValueChecker(rel, abs);
    // Iterate this for stability in the initial guess
    return new CMAESOptimizer(maxIterations, stopFitness, isActiveCma, diagonalOnly, checkFeasableCount, random, generateStatistics, checker);
}
Also used : RandomGeneratorAdapter(uk.ac.sussex.gdsc.core.utils.rng.RandomGeneratorAdapter) CMAESOptimizer(org.apache.commons.math3.optim.nonlinear.scalar.noderiv.CMAESOptimizer) SimpleValueChecker(org.apache.commons.math3.optim.SimpleValueChecker) RandomGenerator(org.apache.commons.math3.random.RandomGenerator) PointValuePair(org.apache.commons.math3.optim.PointValuePair)

Aggregations

RandomGeneratorAdapter (uk.ac.sussex.gdsc.core.utils.rng.RandomGeneratorAdapter)7 RandomGenerator (org.apache.commons.math3.random.RandomGenerator)5 PointValuePair (org.apache.commons.math3.optim.PointValuePair)4 SimpleValueChecker (org.apache.commons.math3.optim.SimpleValueChecker)4 CMAESOptimizer (org.apache.commons.math3.optim.nonlinear.scalar.noderiv.CMAESOptimizer)4 TooManyEvaluationsException (org.apache.commons.math3.exception.TooManyEvaluationsException)3 TooManyIterationsException (org.apache.commons.math3.exception.TooManyIterationsException)3 InitialGuess (org.apache.commons.math3.optim.InitialGuess)3 MaxEval (org.apache.commons.math3.optim.MaxEval)3 OptimizationData (org.apache.commons.math3.optim.OptimizationData)3 SimpleBounds (org.apache.commons.math3.optim.SimpleBounds)3 ObjectiveFunction (org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction)3 ConvergenceException (org.apache.commons.math3.exception.ConvergenceException)2 MaxIter (org.apache.commons.math3.optim.MaxIter)2 ObjectiveFunctionGradient (org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunctionGradient)2 BoundedNonLinearConjugateGradientOptimizer (uk.ac.sussex.gdsc.smlm.math3.optim.nonlinear.scalar.gradient.BoundedNonLinearConjugateGradientOptimizer)2 BufferedReader (java.io.BufferedReader)1 File (java.io.File)1 FileInputStream (java.io.FileInputStream)1 IOException (java.io.IOException)1