Search in sources :

Example 21 with PointValuePair

use of org.apache.commons.math3.optim.PointValuePair in project GDSC-SMLM by aherbert.

the class BFGSOptimizer method bfgs.

protected PointValuePair bfgs(ConvergenceChecker<PointValuePair> checker, double[] p, LineStepSearch lineSearch) {
    final int n = p.length;
    final double EPS = epsilon;
    double[] hdg = new double[n];
    double[] xi = new double[n];
    double[][] hessian = new double[n][n];
    // Get the gradient for the the bounded point
    applyBounds(p);
    double[] g = computeObjectiveGradient(p);
    checkGradients(g, p);
    // Initialise the hessian and search direction
    for (int i = 0; i < n; i++) {
        hessian[i][i] = 1.0;
        xi[i] = -g[i];
    }
    PointValuePair current = null;
    while (true) {
        incrementIterationCount();
        // Get the value of the point
        double fp = computeObjectiveValue(p);
        if (checker != null) {
            PointValuePair previous = current;
            current = new PointValuePair(p, fp);
            if (previous != null && checker.converged(getIterations(), previous, current)) {
                // We have found an optimum.
                converged = CHECKER;
                return current;
            }
        }
        // Move along the search direction.
        final double[] pnew;
        try {
            pnew = lineSearch.lineSearch(p, fp, g, xi);
        } catch (LineSearchRoundoffException e) {
            // This can happen if the Hessian is nearly singular or non-positive-definite.
            // In this case the algorithm should be restarted.
            converged = ROUNDOFF_ERROR;
            //System.out.printf("Roundoff error, iter=%d\n", getIterations());
            return new PointValuePair(p, fp);
        }
        // We assume the new point is on/within the bounds since the line search is constrained
        double fret = lineSearch.f;
        // Test for convergence on change in position
        if (positionChecker.converged(p, pnew)) {
            converged = POSITION;
            return new PointValuePair(pnew, fret);
        }
        // Update the line direction
        for (int i = 0; i < n; i++) {
            xi[i] = pnew[i] - p[i];
        }
        p = pnew;
        // Save the old gradient
        double[] dg = g;
        // Get the gradient for the new point
        g = computeObjectiveGradient(p);
        checkGradients(g, p);
        // If necessary recompute the function value. 
        // Doing this after the gradient evaluation allows the value to be cached when 
        // computing the objective gradient
        fp = fret;
        // Test for convergence on zero gradient.
        double test = 0;
        for (int i = 0; i < n; i++) {
            final double temp = Math.abs(g[i]) * FastMath.max(Math.abs(p[i]), 1);
            //final double temp = Math.abs(g[i]);
            if (test < temp)
                test = temp;
        }
        // Compute the biggest gradient relative to the objective function
        test /= FastMath.max(Math.abs(fp), 1);
        if (test < gradientTolerance) {
            converged = GRADIENT;
            return new PointValuePair(p, fp);
        }
        for (int i = 0; i < n; i++) dg[i] = g[i] - dg[i];
        for (int i = 0; i < n; i++) {
            hdg[i] = 0.0;
            for (int j = 0; j < n; j++) hdg[i] += hessian[i][j] * dg[j];
        }
        double fac = 0, fae = 0, sumdg = 0, sumxi = 0;
        for (int i = 0; i < n; i++) {
            fac += dg[i] * xi[i];
            fae += dg[i] * hdg[i];
            sumdg += dg[i] * dg[i];
            sumxi += xi[i] * xi[i];
        }
        if (fac > Math.sqrt(EPS * sumdg * sumxi)) {
            fac = 1.0 / fac;
            final double fad = 1.0 / fae;
            for (int i = 0; i < n; i++) dg[i] = fac * xi[i] - fad * hdg[i];
            for (int i = 0; i < n; i++) {
                for (int j = i; j < n; j++) {
                    hessian[i][j] += fac * xi[i] * xi[j] - fad * hdg[i] * hdg[j] + fae * dg[i] * dg[j];
                    hessian[j][i] = hessian[i][j];
                }
            }
        }
        for (int i = 0; i < n; i++) {
            xi[i] = 0.0;
            for (int j = 0; j < n; j++) xi[i] -= hessian[i][j] * g[j];
        }
    }
}
Also used : PointValuePair(org.apache.commons.math3.optim.PointValuePair)

Example 22 with PointValuePair

use of org.apache.commons.math3.optim.PointValuePair in project vcell by virtualcell.

the class FitBleachSpotOp method fitToGaussian.

static GaussianFitResults fitToGaussian(double init_center_i, double init_center_j, double init_radius2, FloatImage image) {
    // 
    // do some optimization on the image (fitting to a Gaussian)
    // set initial guesses from ROI operation.
    // 
    ISize imageSize = image.getISize();
    final int num_i = imageSize.getX();
    final int num_j = imageSize.getY();
    final float[] floatPixels = image.getFloatPixels();
    // 
    // initial guess based on previous fit of ROI
    // do gaussian fit in index space for center and standard deviation (later to translate it back to world coordinates)
    // 
    final int window_size = (int) Math.sqrt(init_radius2) * 4;
    // final int window_min_i = 0;       // (int) Math.max(0, Math.floor(init_center_i - window_size/2));
    // final int window_max_i = num_i-1; // (int) Math.min(num_i-1, Math.ceil(init_center_i + window_size/2));
    // final int window_min_j = 0;       // (int) Math.max(0, Math.floor(init_center_j - window_size/2));
    // final int window_max_j = num_j-1; // (int) Math.min(num_j-1, Math.ceil(init_center_j + window_size/2));
    final int window_min_i = (int) Math.max(0, Math.floor(init_center_i - window_size / 2));
    final int window_max_i = (int) Math.min(num_i - 1, Math.ceil(init_center_i + window_size / 2));
    final int window_min_j = (int) Math.max(0, Math.floor(init_center_j - window_size / 2));
    final int window_max_j = (int) Math.min(num_j - 1, Math.ceil(init_center_j + window_size / 2));
    final int PARAM_INDEX_CENTER_I = 0;
    final int PARAM_INDEX_CENTER_J = 1;
    final int PARAM_INDEX_K = 2;
    final int PARAM_INDEX_HIGH = 3;
    final int PARAM_INDEX_RADIUS_SQUARED = 4;
    final int NUM_PARAMETERS = 5;
    double[] initParameters = new double[NUM_PARAMETERS];
    initParameters[PARAM_INDEX_CENTER_I] = init_center_i;
    initParameters[PARAM_INDEX_CENTER_J] = init_center_j;
    initParameters[PARAM_INDEX_HIGH] = 1.0;
    initParameters[PARAM_INDEX_K] = 10;
    initParameters[PARAM_INDEX_RADIUS_SQUARED] = init_radius2;
    PowellOptimizer optimizer = new PowellOptimizer(1e-4, 1e-1);
    MultivariateFunction func = new MultivariateFunction() {

        @Override
        public double value(double[] point) {
            double center_i = point[PARAM_INDEX_CENTER_I];
            double center_j = point[PARAM_INDEX_CENTER_J];
            double high = point[PARAM_INDEX_HIGH];
            double K = point[PARAM_INDEX_K];
            double radius2 = point[PARAM_INDEX_RADIUS_SQUARED];
            double error2 = 0;
            for (int j = window_min_j; j <= window_max_j; j++) {
                // double y = j - center_j;
                double y = j;
                for (int i = window_min_i; i <= window_max_i; i++) {
                    // double x = i - center_i;
                    double x = i;
                    double modelValue = high - FastMath.exp(-K * FastMath.exp(-2 * (x * x + y * y) / radius2));
                    double imageValue = floatPixels[j * num_i + i];
                    double error = modelValue - imageValue;
                    error2 += error * error;
                }
            }
            System.out.println(new GaussianFitResults(center_i, center_j, radius2, K, high, error2));
            return error2;
        }
    };
    PointValuePair pvp = optimizer.optimize(new ObjectiveFunction(func), new InitialGuess(initParameters), new MaxEval(100000), GoalType.MINIMIZE);
    double[] fittedParamValues = pvp.getPoint();
    double fitted_center_i = fittedParamValues[PARAM_INDEX_CENTER_I];
    double fitted_center_j = fittedParamValues[PARAM_INDEX_CENTER_J];
    double fitted_radius2 = fittedParamValues[PARAM_INDEX_RADIUS_SQUARED];
    double fitted_K = fittedParamValues[PARAM_INDEX_K];
    double fitted_high = fittedParamValues[PARAM_INDEX_HIGH];
    double objectiveFunctionValue = pvp.getValue();
    return new GaussianFitResults(fitted_center_i, fitted_center_j, fitted_radius2, fitted_K, fitted_high, objectiveFunctionValue);
}
Also used : MultivariateFunction(org.apache.commons.math3.analysis.MultivariateFunction) InitialGuess(org.apache.commons.math3.optim.InitialGuess) MaxEval(org.apache.commons.math3.optim.MaxEval) ISize(org.vcell.util.ISize) ObjectiveFunction(org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction) PowellOptimizer(org.apache.commons.math3.optim.nonlinear.scalar.noderiv.PowellOptimizer) PointValuePair(org.apache.commons.math3.optim.PointValuePair)

Aggregations

PointValuePair (org.apache.commons.math3.optim.PointValuePair)22 InitialGuess (org.apache.commons.math3.optim.InitialGuess)11 MaxEval (org.apache.commons.math3.optim.MaxEval)11 ObjectiveFunction (org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction)11 TooManyEvaluationsException (org.apache.commons.math3.exception.TooManyEvaluationsException)9 TooManyIterationsException (org.apache.commons.math3.exception.TooManyIterationsException)7 ConvergenceException (org.apache.commons.math3.exception.ConvergenceException)6 CMAESOptimizer (org.apache.commons.math3.optim.nonlinear.scalar.noderiv.CMAESOptimizer)6 OptimizationData (org.apache.commons.math3.optim.OptimizationData)5 SimpleBounds (org.apache.commons.math3.optim.SimpleBounds)5 SimpleValueChecker (org.apache.commons.math3.optim.SimpleValueChecker)5 UnivariatePointValuePair (org.apache.commons.math3.optim.univariate.UnivariatePointValuePair)5 MultivariateMatrixFunction (org.apache.commons.math3.analysis.MultivariateMatrixFunction)4 LeastSquaresBuilder (org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder)4 Optimum (org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum)4 LeastSquaresProblem (org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem)4 LevenbergMarquardtOptimizer (org.apache.commons.math3.fitting.leastsquares.LevenbergMarquardtOptimizer)4 DiagonalMatrix (org.apache.commons.math3.linear.DiagonalMatrix)4 CustomPowellOptimizer (org.apache.commons.math3.optim.nonlinear.scalar.noderiv.CustomPowellOptimizer)4 RandomGenerator (org.apache.commons.math3.random.RandomGenerator)4