Search in sources :

Example 11 with ConvergenceChecker

use of org.apache.commons.math3.optim.ConvergenceChecker in project GDSC-SMLM by aherbert.

the class PCPALMFitting method runBoundedOptimiser.

private PointValuePair runBoundedOptimiser(double[][] gr, double[] initialSolution, double[] lB, double[] uB, SumOfSquaresModelFunction function) {
    // Create the functions to optimise
    ObjectiveFunction objective = new ObjectiveFunction(new SumOfSquaresMultivariateFunction(function));
    ObjectiveFunctionGradient gradient = new ObjectiveFunctionGradient(new SumOfSquaresMultivariateVectorFunction(function));
    final boolean debug = false;
    // Try a BFGS optimiser since this will produce a deterministic solution and can respect bounds.
    PointValuePair optimum = null;
    boundedEvaluations = 0;
    final MaxEval maxEvaluations = new MaxEval(2000);
    MultivariateOptimizer opt = null;
    for (int iteration = 0; iteration <= fitRestarts; iteration++) {
        try {
            opt = new BFGSOptimizer();
            final double relativeThreshold = 1e-6;
            // Configure maximum step length for each dimension using the bounds
            double[] stepLength = new double[lB.length];
            for (int i = 0; i < stepLength.length; i++) stepLength[i] = (uB[i] - lB[i]) * 0.3333333;
            // The GoalType is always minimise so no need to pass this in
            optimum = opt.optimize(maxEvaluations, gradient, objective, new InitialGuess((optimum == null) ? initialSolution : optimum.getPointRef()), new SimpleBounds(lB, uB), new BFGSOptimizer.GradientTolerance(relativeThreshold), new BFGSOptimizer.StepLength(stepLength));
            if (debug)
                System.out.printf("BFGS Iter %d = %g (%d)\n", iteration, optimum.getValue(), opt.getEvaluations());
        } catch (TooManyEvaluationsException e) {
            // No need to restart
            break;
        } catch (RuntimeException e) {
            // No need to restart
            break;
        } finally {
            boundedEvaluations += opt.getEvaluations();
        }
    }
    // Try a CMAES optimiser which is non-deterministic. To overcome this we perform restarts.
    // CMAESOptimiser based on Matlab code:
    // https://www.lri.fr/~hansen/cmaes.m
    // Take the defaults from the Matlab documentation
    //Double.NEGATIVE_INFINITY;
    double stopFitness = 0;
    boolean isActiveCMA = true;
    int diagonalOnly = 0;
    int checkFeasableCount = 1;
    //Well19937c();
    RandomGenerator random = new Well44497b();
    boolean generateStatistics = false;
    ConvergenceChecker<PointValuePair> checker = new SimpleValueChecker(1e-6, 1e-10);
    // The sigma determines the search range for the variables. It should be 1/3 of the initial search region.
    double[] range = new double[lB.length];
    for (int i = 0; i < lB.length; i++) range[i] = (uB[i] - lB[i]) / 3;
    OptimizationData sigma = new CMAESOptimizer.Sigma(range);
    OptimizationData popSize = new CMAESOptimizer.PopulationSize((int) (4 + Math.floor(3 * Math.log(initialSolution.length))));
    SimpleBounds bounds = new SimpleBounds(lB, uB);
    opt = new CMAESOptimizer(maxEvaluations.getMaxEval(), stopFitness, isActiveCMA, diagonalOnly, checkFeasableCount, random, generateStatistics, checker);
    // Restart the optimiser several times and take the best answer.
    for (int iteration = 0; iteration <= fitRestarts; iteration++) {
        try {
            // Start from the initial solution
            PointValuePair constrainedSolution = opt.optimize(new InitialGuess(initialSolution), objective, GoalType.MINIMIZE, bounds, sigma, popSize, maxEvaluations);
            if (debug)
                System.out.printf("CMAES Iter %d initial = %g (%d)\n", iteration, constrainedSolution.getValue(), opt.getEvaluations());
            boundedEvaluations += opt.getEvaluations();
            if (optimum == null || constrainedSolution.getValue() < optimum.getValue()) {
                optimum = constrainedSolution;
            }
        } catch (TooManyEvaluationsException e) {
        } catch (TooManyIterationsException e) {
        } finally {
            boundedEvaluations += maxEvaluations.getMaxEval();
        }
        if (optimum == null)
            continue;
        try {
            // Also restart from the current optimum
            PointValuePair constrainedSolution = opt.optimize(new InitialGuess(optimum.getPointRef()), objective, GoalType.MINIMIZE, bounds, sigma, popSize, maxEvaluations);
            if (debug)
                System.out.printf("CMAES Iter %d restart = %g (%d)\n", iteration, constrainedSolution.getValue(), opt.getEvaluations());
            if (constrainedSolution.getValue() < optimum.getValue()) {
                optimum = constrainedSolution;
            }
        } catch (TooManyEvaluationsException e) {
        } catch (TooManyIterationsException e) {
        } finally {
            boundedEvaluations += maxEvaluations.getMaxEval();
        }
    }
    return optimum;
}
Also used : MultivariateOptimizer(org.apache.commons.math3.optim.nonlinear.scalar.MultivariateOptimizer) MaxEval(org.apache.commons.math3.optim.MaxEval) InitialGuess(org.apache.commons.math3.optim.InitialGuess) SimpleBounds(org.apache.commons.math3.optim.SimpleBounds) ObjectiveFunction(org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction) SimpleValueChecker(org.apache.commons.math3.optim.SimpleValueChecker) BFGSOptimizer(org.apache.commons.math3.optim.nonlinear.scalar.gradient.BFGSOptimizer) RandomGenerator(org.apache.commons.math3.random.RandomGenerator) PointValuePair(org.apache.commons.math3.optim.PointValuePair) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) TooManyIterationsException(org.apache.commons.math3.exception.TooManyIterationsException) CMAESOptimizer(org.apache.commons.math3.optim.nonlinear.scalar.noderiv.CMAESOptimizer) ObjectiveFunctionGradient(org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunctionGradient) Well44497b(org.apache.commons.math3.random.Well44497b) OptimizationData(org.apache.commons.math3.optim.OptimizationData)

Example 12 with ConvergenceChecker

use of org.apache.commons.math3.optim.ConvergenceChecker in project GDSC-SMLM by aherbert.

the class JumpDistanceAnalysis method createCMAESOptimizer.

private CMAESOptimizer createCMAESOptimizer() {
    double rel = 1e-8;
    double abs = 1e-10;
    int maxIterations = 2000;
    //Double.NEGATIVE_INFINITY;
    double stopFitness = 0;
    boolean isActiveCMA = true;
    int diagonalOnly = 20;
    int checkFeasableCount = 1;
    RandomGenerator random = new Well19937c();
    boolean generateStatistics = false;
    ConvergenceChecker<PointValuePair> checker = new SimpleValueChecker(rel, abs);
    // Iterate this for stability in the initial guess
    return new CMAESOptimizer(maxIterations, stopFitness, isActiveCMA, diagonalOnly, checkFeasableCount, random, generateStatistics, checker);
}
Also used : CMAESOptimizer(org.apache.commons.math3.optim.nonlinear.scalar.noderiv.CMAESOptimizer) Well19937c(org.apache.commons.math3.random.Well19937c) SimpleValueChecker(org.apache.commons.math3.optim.SimpleValueChecker) RandomGenerator(org.apache.commons.math3.random.RandomGenerator) PointValuePair(org.apache.commons.math3.optim.PointValuePair)

Example 13 with ConvergenceChecker

use of org.apache.commons.math3.optim.ConvergenceChecker in project GDSC-SMLM by aherbert.

the class BFGSOptimizer method doOptimize.

/** {@inheritDoc} */
@Override
protected PointValuePair doOptimize() {
    final ConvergenceChecker<PointValuePair> checker = getConvergenceChecker();
    double[] p = getStartPoint();
    // Assume minimisation
    sign = -1;
    LineStepSearch lineSearch = new LineStepSearch();
    // In case there are no restarts
    if (restarts <= 0)
        return bfgsWithRoundoffCheck(checker, p, lineSearch);
    PointValuePair lastResult = null;
    PointValuePair result = null;
    //int lastConverge = 0;
    int iteration = 0;
    //int[] count = new int[3];
    while (iteration <= restarts) {
        iteration++;
        result = bfgsWithRoundoffCheck(checker, p, lineSearch);
        if (converged == GRADIENT) {
            // If no gradient remains then we cannot move anywhere so return
            break;
        }
        if (lastResult != null) {
            // Check if the optimum was improved using the convergence criteria
            if (checker != null && checker.converged(getIterations(), lastResult, result)) {
                break;
            }
            if (positionChecker.converged(lastResult.getPointRef(), result.getPointRef())) {
                break;
            }
        }
        // Store the new optimum and repeat
        lastResult = result;
        //lastConverge = converged;
        p = lastResult.getPointRef();
    }
    return result;
}
Also used : PointValuePair(org.apache.commons.math3.optim.PointValuePair)

Example 14 with ConvergenceChecker

use of org.apache.commons.math3.optim.ConvergenceChecker in project GDSC-SMLM by aherbert.

the class BFGSOptimizer method bfgs.

protected PointValuePair bfgs(ConvergenceChecker<PointValuePair> checker, double[] p, LineStepSearch lineSearch) {
    final int n = p.length;
    final double EPS = epsilon;
    double[] hdg = new double[n];
    double[] xi = new double[n];
    double[][] hessian = new double[n][n];
    // Get the gradient for the the bounded point
    applyBounds(p);
    double[] g = computeObjectiveGradient(p);
    checkGradients(g, p);
    // Initialise the hessian and search direction
    for (int i = 0; i < n; i++) {
        hessian[i][i] = 1.0;
        xi[i] = -g[i];
    }
    PointValuePair current = null;
    while (true) {
        incrementIterationCount();
        // Get the value of the point
        double fp = computeObjectiveValue(p);
        if (checker != null) {
            PointValuePair previous = current;
            current = new PointValuePair(p, fp);
            if (previous != null && checker.converged(getIterations(), previous, current)) {
                // We have found an optimum.
                converged = CHECKER;
                return current;
            }
        }
        // Move along the search direction.
        final double[] pnew;
        try {
            pnew = lineSearch.lineSearch(p, fp, g, xi);
        } catch (LineSearchRoundoffException e) {
            // This can happen if the Hessian is nearly singular or non-positive-definite.
            // In this case the algorithm should be restarted.
            converged = ROUNDOFF_ERROR;
            //System.out.printf("Roundoff error, iter=%d\n", getIterations());
            return new PointValuePair(p, fp);
        }
        // We assume the new point is on/within the bounds since the line search is constrained
        double fret = lineSearch.f;
        // Test for convergence on change in position
        if (positionChecker.converged(p, pnew)) {
            converged = POSITION;
            return new PointValuePair(pnew, fret);
        }
        // Update the line direction
        for (int i = 0; i < n; i++) {
            xi[i] = pnew[i] - p[i];
        }
        p = pnew;
        // Save the old gradient
        double[] dg = g;
        // Get the gradient for the new point
        g = computeObjectiveGradient(p);
        checkGradients(g, p);
        // If necessary recompute the function value. 
        // Doing this after the gradient evaluation allows the value to be cached when 
        // computing the objective gradient
        fp = fret;
        // Test for convergence on zero gradient.
        double test = 0;
        for (int i = 0; i < n; i++) {
            final double temp = Math.abs(g[i]) * FastMath.max(Math.abs(p[i]), 1);
            //final double temp = Math.abs(g[i]);
            if (test < temp)
                test = temp;
        }
        // Compute the biggest gradient relative to the objective function
        test /= FastMath.max(Math.abs(fp), 1);
        if (test < gradientTolerance) {
            converged = GRADIENT;
            return new PointValuePair(p, fp);
        }
        for (int i = 0; i < n; i++) dg[i] = g[i] - dg[i];
        for (int i = 0; i < n; i++) {
            hdg[i] = 0.0;
            for (int j = 0; j < n; j++) hdg[i] += hessian[i][j] * dg[j];
        }
        double fac = 0, fae = 0, sumdg = 0, sumxi = 0;
        for (int i = 0; i < n; i++) {
            fac += dg[i] * xi[i];
            fae += dg[i] * hdg[i];
            sumdg += dg[i] * dg[i];
            sumxi += xi[i] * xi[i];
        }
        if (fac > Math.sqrt(EPS * sumdg * sumxi)) {
            fac = 1.0 / fac;
            final double fad = 1.0 / fae;
            for (int i = 0; i < n; i++) dg[i] = fac * xi[i] - fad * hdg[i];
            for (int i = 0; i < n; i++) {
                for (int j = i; j < n; j++) {
                    hessian[i][j] += fac * xi[i] * xi[j] - fad * hdg[i] * hdg[j] + fae * dg[i] * dg[j];
                    hessian[j][i] = hessian[i][j];
                }
            }
        }
        for (int i = 0; i < n; i++) {
            xi[i] = 0.0;
            for (int j = 0; j < n; j++) xi[i] -= hessian[i][j] * g[j];
        }
    }
}
Also used : PointValuePair(org.apache.commons.math3.optim.PointValuePair)

Aggregations

PointValuePair (org.apache.commons.math3.optim.PointValuePair)13 SimpleValueChecker (org.apache.commons.math3.optim.SimpleValueChecker)8 GoalType (org.apache.commons.math3.optim.nonlinear.scalar.GoalType)6 CMAESOptimizer (org.apache.commons.math3.optim.nonlinear.scalar.noderiv.CMAESOptimizer)6 RandomGenerator (org.apache.commons.math3.random.RandomGenerator)6 TooManyEvaluationsException (org.apache.commons.math3.exception.TooManyEvaluationsException)4 TooManyIterationsException (org.apache.commons.math3.exception.TooManyIterationsException)4 InitialGuess (org.apache.commons.math3.optim.InitialGuess)4 MaxEval (org.apache.commons.math3.optim.MaxEval)4 OptimizationData (org.apache.commons.math3.optim.OptimizationData)4 SimpleBounds (org.apache.commons.math3.optim.SimpleBounds)4 ObjectiveFunction (org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction)4 UnivariatePointValuePair (org.apache.commons.math3.optim.univariate.UnivariatePointValuePair)4 RandomGeneratorAdapter (uk.ac.sussex.gdsc.core.utils.rng.RandomGeneratorAdapter)3 UnivariateFunction (org.apache.commons.math3.analysis.UnivariateFunction)2 ConvergenceException (org.apache.commons.math3.exception.ConvergenceException)2 MathIllegalStateException (org.apache.commons.math3.exception.MathIllegalStateException)2 MathInternalError (org.apache.commons.math3.exception.MathInternalError)2 LeastSquaresBuilder (org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder)2 Optimum (org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum)2