Search in sources :

Example 11 with LeastSquaresProblem

use of org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem in project GDSC-SMLM by aherbert.

the class ApacheLvmFitter method computeFit.

@Override
public FitStatus computeFit(double[] y, final double[] fx, double[] a, double[] parametersVariance) {
    final int n = y.length;
    try {
        // Different convergence thresholds seem to have no effect on the resulting fit, only the
        // number of
        // iterations for convergence
        final double initialStepBoundFactor = 100;
        final double costRelativeTolerance = 1e-10;
        final double parRelativeTolerance = 1e-10;
        final double orthoTolerance = 1e-10;
        final double threshold = Precision.SAFE_MIN;
        // Extract the parameters to be fitted
        final double[] initialSolution = getInitialSolution(a);
        // TODO - Pass in more advanced stopping criteria.
        // Create the target and weight arrays
        final double[] yd = new double[n];
        // final double[] w = new double[n];
        for (int i = 0; i < n; i++) {
            yd[i] = y[i];
        // w[i] = 1;
        }
        final LevenbergMarquardtOptimizer optimizer = new LevenbergMarquardtOptimizer(initialStepBoundFactor, costRelativeTolerance, parRelativeTolerance, orthoTolerance, threshold);
        // @formatter:off
        final LeastSquaresBuilder builder = new LeastSquaresBuilder().maxEvaluations(Integer.MAX_VALUE).maxIterations(getMaxEvaluations()).start(initialSolution).target(yd);
        if (function instanceof ExtendedNonLinearFunction && ((ExtendedNonLinearFunction) function).canComputeValuesAndJacobian()) {
            // Compute together, or each individually
            builder.model(new ValueAndJacobianFunction() {

                final ExtendedNonLinearFunction fun = (ExtendedNonLinearFunction) function;

                @Override
                public Pair<RealVector, RealMatrix> value(RealVector point) {
                    final double[] p = point.toArray();
                    final org.apache.commons.lang3.tuple.Pair<double[], double[][]> result = fun.computeValuesAndJacobian(p);
                    return new Pair<>(new ArrayRealVector(result.getKey(), false), new Array2DRowRealMatrix(result.getValue(), false));
                }

                @Override
                public RealVector computeValue(double[] params) {
                    return new ArrayRealVector(fun.computeValues(params), false);
                }

                @Override
                public RealMatrix computeJacobian(double[] params) {
                    return new Array2DRowRealMatrix(fun.computeJacobian(params), false);
                }
            });
        } else {
            // Compute separately
            builder.model(new MultivariateVectorFunctionWrapper((NonLinearFunction) function, a, n), new MultivariateMatrixFunctionWrapper((NonLinearFunction) function, a, n));
        }
        final LeastSquaresProblem problem = builder.build();
        final Optimum optimum = optimizer.optimize(problem);
        final double[] parameters = optimum.getPoint().toArray();
        setSolution(a, parameters);
        iterations = optimum.getIterations();
        evaluations = optimum.getEvaluations();
        if (parametersVariance != null) {
            // Set up the Jacobian.
            final RealMatrix j = optimum.getJacobian();
            // Compute transpose(J)J.
            final RealMatrix jTj = j.transpose().multiply(j);
            final double[][] data = (jTj instanceof Array2DRowRealMatrix) ? ((Array2DRowRealMatrix) jTj).getDataRef() : jTj.getData();
            final FisherInformationMatrix m = new FisherInformationMatrix(data);
            setDeviations(parametersVariance, m);
        }
        // Compute function value
        if (fx != null) {
            final ValueFunction function = (ValueFunction) this.function;
            function.initialise0(a);
            function.forEach(new ValueProcedure() {

                int index;

                @Override
                public void execute(double value) {
                    fx[index++] = value;
                }
            });
        }
        // As this is unweighted then we can do this to get the sum of squared residuals
        // This is the same as optimum.getCost() * optimum.getCost(); The getCost() function
        // just computes the dot product anyway.
        value = optimum.getResiduals().dotProduct(optimum.getResiduals());
    } catch (final TooManyEvaluationsException ex) {
        return FitStatus.TOO_MANY_EVALUATIONS;
    } catch (final TooManyIterationsException ex) {
        return FitStatus.TOO_MANY_ITERATIONS;
    } catch (final ConvergenceException ex) {
        // Occurs when QR decomposition fails - mark as a singular non-linear model (no solution)
        return FitStatus.SINGULAR_NON_LINEAR_MODEL;
    } catch (final Exception ex) {
        // TODO - Find out the other exceptions from the fitter and add return values to match.
        return FitStatus.UNKNOWN;
    }
    return FitStatus.OK;
}
Also used : ValueFunction(uk.ac.sussex.gdsc.smlm.function.ValueFunction) ValueProcedure(uk.ac.sussex.gdsc.smlm.function.ValueProcedure) NonLinearFunction(uk.ac.sussex.gdsc.smlm.function.NonLinearFunction) ExtendedNonLinearFunction(uk.ac.sussex.gdsc.smlm.function.ExtendedNonLinearFunction) LeastSquaresBuilder(org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix) ValueAndJacobianFunction(org.apache.commons.math3.fitting.leastsquares.ValueAndJacobianFunction) RealVector(org.apache.commons.math3.linear.RealVector) ArrayRealVector(org.apache.commons.math3.linear.ArrayRealVector) ConvergenceException(org.apache.commons.math3.exception.ConvergenceException) TooManyIterationsException(org.apache.commons.math3.exception.TooManyIterationsException) LeastSquaresProblem(org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem) Pair(org.apache.commons.math3.util.Pair) ArrayRealVector(org.apache.commons.math3.linear.ArrayRealVector) FisherInformationMatrix(uk.ac.sussex.gdsc.smlm.fitting.FisherInformationMatrix) MultivariateMatrixFunctionWrapper(uk.ac.sussex.gdsc.smlm.function.MultivariateMatrixFunctionWrapper) ConvergenceException(org.apache.commons.math3.exception.ConvergenceException) TooManyIterationsException(org.apache.commons.math3.exception.TooManyIterationsException) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) Optimum(org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum) LevenbergMarquardtOptimizer(org.apache.commons.math3.fitting.leastsquares.LevenbergMarquardtOptimizer) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix) RealMatrix(org.apache.commons.math3.linear.RealMatrix) MultivariateVectorFunctionWrapper(uk.ac.sussex.gdsc.smlm.function.MultivariateVectorFunctionWrapper) ExtendedNonLinearFunction(uk.ac.sussex.gdsc.smlm.function.ExtendedNonLinearFunction)

Example 12 with LeastSquaresProblem

use of org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem in project GDSC-SMLM by aherbert.

the class TraceDiffusion method fitMsd.

/**
 * Fit the MSD using a linear fit that must pass through 0,0.
 *
 * <p>Update the plot by adding the fit line.
 *
 * @param x the x
 * @param y the y
 * @param title the title
 * @param plot the plot
 * @return [D, precision]
 */
private double[] fitMsd(double[] x, double[] y, String title, Plot plot) {
    // The Weimann paper (Plos One e64287) fits:
    // MSD(n dt) = 4D n dt + 4s^2
    // n = number of jumps
    // dt = time difference between frames
    // s = localisation precision
    // Thus we should fit an intercept as well.
    // From the fit D = gradient / (4*exposureTime)
    // D
    double diffCoeff = 0;
    double intercept = 0;
    double precision = 0;
    final LevenbergMarquardtOptimizer optimizer = new LevenbergMarquardtOptimizer();
    Optimum lvmSolution;
    double ic = Double.NaN;
    // Fit with no intercept
    try {
        final LinearFunction function = new LinearFunction(x, y, clusteringSettings.getFitLength());
        final double[] parameters = new double[] { function.guess() };
        // @formatter:off
        final LeastSquaresProblem problem = new LeastSquaresBuilder().maxEvaluations(Integer.MAX_VALUE).maxIterations(3000).start(parameters).target(function.getY()).weight(new DiagonalMatrix(function.getWeights())).model(function, function::jacobian).build();
        // @formatter:on
        lvmSolution = optimizer.optimize(problem);
        final double ss = lvmSolution.getResiduals().dotProduct(lvmSolution.getResiduals());
        // double ss = 0;
        // double[] obs = function.getY();
        // double[] exp = lvmSolution.getValue();
        // for (int i = 0; i < obs.length; i++)
        // ss += (obs[i] - exp[i]) * (obs[i] - exp[i]);
        ic = getAkaikeInformationCriterionFromResiduals(ss, function.getY().length, 1);
        final double gradient = lvmSolution.getPoint().getEntry(0);
        diffCoeff = gradient / 4;
        ImageJUtils.log("Linear fit (%d points) : Gradient = %s, D = %s um^2/s, SS = %s, " + "IC = %s (%d evaluations)", function.getY().length, MathUtils.rounded(gradient, 4), MathUtils.rounded(diffCoeff, 4), MathUtils.rounded(ss), MathUtils.rounded(ic), lvmSolution.getEvaluations());
    } catch (final TooManyIterationsException ex) {
        ImageJUtils.log("Failed to fit : Too many iterations (%s)", ex.getMessage());
    } catch (final ConvergenceException ex) {
        ImageJUtils.log("Failed to fit : %s", ex.getMessage());
    }
    // Fit with intercept.
    // Optionally include the intercept (which is the estimated precision).
    final boolean fitIntercept = true;
    try {
        final LinearFunctionWithIntercept function = new LinearFunctionWithIntercept(x, y, clusteringSettings.getFitLength(), fitIntercept);
        // @formatter:off
        final LeastSquaresProblem problem = new LeastSquaresBuilder().maxEvaluations(Integer.MAX_VALUE).maxIterations(3000).start(function.guess()).target(function.getY()).weight(new DiagonalMatrix(function.getWeights())).model(function, function::jacobian).build();
        // @formatter:on
        lvmSolution = optimizer.optimize(problem);
        final RealVector residuals = lvmSolution.getResiduals();
        final double ss = residuals.dotProduct(residuals);
        // double ss = 0;
        // double[] obs = function.getY();
        // double[] exp = lvmSolution.getValue();
        // for (int i = 0; i < obs.length; i++)
        // ss += (obs[i] - exp[i]) * (obs[i] - exp[i]);
        final double ic2 = getAkaikeInformationCriterionFromResiduals(ss, function.getY().length, 2);
        final double gradient = lvmSolution.getPoint().getEntry(0);
        final double s = lvmSolution.getPoint().getEntry(1);
        final double intercept2 = 4 * s * s;
        if (ic2 < ic || Double.isNaN(ic)) {
            if (settings.debugFitting) {
                // Convert fitted precision in um to nm
                ImageJUtils.log("Linear fit with intercept (%d points) : Gradient = %s, Intercept = %s, " + "D = %s um^2/s, precision = %s nm, SS = %s, IC = %s (%d evaluations)", function.getY().length, MathUtils.rounded(gradient, 4), MathUtils.rounded(intercept2, 4), MathUtils.rounded(gradient / 4, 4), MathUtils.rounded(s * 1000, 4), MathUtils.rounded(ss), MathUtils.rounded(ic2), lvmSolution.getEvaluations());
            }
            intercept = intercept2;
            diffCoeff = gradient / 4;
            precision = s;
        }
    } catch (final TooManyIterationsException ex) {
        ImageJUtils.log("Failed to fit with intercept : Too many iterations (%s)", ex.getMessage());
    } catch (final ConvergenceException ex) {
        ImageJUtils.log("Failed to fit with intercept : %s", ex.getMessage());
    }
    if (clusteringSettings.getMsdCorrection()) {
        // i.e. the intercept is allowed to be a small negative.
        try {
            // This function fits the jump distance (n) not the time (nt) so update x
            final double[] x2 = new double[x.length];
            for (int i = 0; i < x2.length; i++) {
                x2[i] = x[i] / exposureTime;
            }
            final LinearFunctionWithMsdCorrectedIntercept function = new LinearFunctionWithMsdCorrectedIntercept(x2, y, clusteringSettings.getFitLength(), fitIntercept);
            // @formatter:off
            final LeastSquaresProblem problem = new LeastSquaresBuilder().maxEvaluations(Integer.MAX_VALUE).maxIterations(3000).start(function.guess()).target(function.getY()).weight(new DiagonalMatrix(function.getWeights())).model(function, function::jacobian).build();
            // @formatter:on
            lvmSolution = optimizer.optimize(problem);
            final RealVector residuals = lvmSolution.getResiduals();
            final double ss = residuals.dotProduct(residuals);
            // double ss = 0;
            // double[] obs = function.getY();
            // double[] exp = lvmSolution.getValue();
            // for (int i = 0; i < obs.length; i++)
            // ss += (obs[i] - exp[i]) * (obs[i] - exp[i]);
            final double ic2 = getAkaikeInformationCriterionFromResiduals(ss, function.getY().length, 2);
            double gradient = lvmSolution.getPoint().getEntry(0);
            final double s = lvmSolution.getPoint().getEntry(1);
            final double intercept2 = 4 * s * s - gradient / 3;
            // Q. Is this working?
            // Try fixed precision fitting. Is the gradient correct?
            // Revisit all the equations to see if they are wrong.
            // Try adding the x[0] datapoint using the precision.
            // Change the formula to not be linear at x[0] and to just fit the precision, i.e. the
            // intercept2 = 4 * s * s - gradient / 3 is wrong as the
            // equation is not linear below n=1.
            // Incorporate the exposure time into the gradient to allow comparison to other fits
            gradient /= exposureTime;
            if (ic2 < ic || Double.isNaN(ic)) {
                if (settings.debugFitting) {
                    // Convert fitted precision in um to nm
                    ImageJUtils.log("Linear fit with MSD corrected intercept (%d points) : Gradient = %s, " + "Intercept = %s, D = %s um^2/s, precision = %s nm, SS = %s, " + "IC = %s (%d evaluations)", function.getY().length, MathUtils.rounded(gradient, 4), MathUtils.rounded(intercept2, 4), MathUtils.rounded(gradient / 4, 4), MathUtils.rounded(s * 1000, 4), MathUtils.rounded(ss), MathUtils.rounded(ic2), lvmSolution.getEvaluations());
                }
                intercept = intercept2;
                diffCoeff = gradient / 4;
                precision = s;
            }
        } catch (final TooManyIterationsException ex) {
            ImageJUtils.log("Failed to fit with intercept : Too many iterations (%s)", ex.getMessage());
        } catch (final ConvergenceException ex) {
            ImageJUtils.log("Failed to fit with intercept : %s", ex.getMessage());
        }
    }
    // Add the fit to the plot
    if (diffCoeff > 0) {
        plot.setColor(Color.magenta);
        plot.drawLine(0, intercept, x[x.length - 1], 4 * diffCoeff * x[x.length - 1] + intercept);
        display(title, plot);
        checkTraceSettings(diffCoeff);
    }
    return new double[] { diffCoeff, precision };
}
Also used : LeastSquaresBuilder(org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder) Optimum(org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum) LevenbergMarquardtOptimizer(org.apache.commons.math3.fitting.leastsquares.LevenbergMarquardtOptimizer) DiagonalMatrix(org.apache.commons.math3.linear.DiagonalMatrix) ConvergenceException(org.apache.commons.math3.exception.ConvergenceException) RealVector(org.apache.commons.math3.linear.RealVector) TooManyIterationsException(org.apache.commons.math3.exception.TooManyIterationsException) LeastSquaresProblem(org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem)

Example 13 with LeastSquaresProblem

use of org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem in project GDSC-SMLM by aherbert.

the class AstigmatismModelManager method doCurveFit.

private boolean doCurveFit() {
    // Estimate:
    // Focal plane = where width is at a minimum
    // s0x/s0y = the min width of x/y
    // gamma = Half the distance between the focal planes
    // z0 = half way between the two focal planes
    // d = depth of focus
    double[] smoothSx = fitSx;
    double[] smoothSy = fitSy;
    if (pluginSettings.getSmoothing() > 0) {
        final LoessInterpolator loess = new LoessInterpolator(pluginSettings.getSmoothing(), 0);
        smoothSx = loess.smooth(fitZ, fitSx);
        smoothSy = loess.smooth(fitZ, fitSy);
        final Plot plot = widthPlot.getPlot();
        plot.setColor(Color.RED);
        plot.addPoints(fitZ, smoothSx, Plot.LINE);
        plot.setColor(Color.BLUE);
        plot.addPoints(fitZ, smoothSy, Plot.LINE);
        plot.setColor(Color.BLACK);
        plot.updateImage();
    }
    final int focalPlaneXindex = SimpleArrayUtils.findMinIndex(smoothSx);
    final int focalPlaneYindex = SimpleArrayUtils.findMinIndex(smoothSy);
    final double s0x = smoothSx[focalPlaneXindex];
    final double s0y = smoothSy[focalPlaneYindex];
    final double focalPlaneX = fitZ[focalPlaneXindex];
    final double focalPlaneY = fitZ[focalPlaneYindex];
    double gamma = Math.abs(focalPlaneY - focalPlaneX) / 2;
    final double z0 = (focalPlaneX + focalPlaneY) / 2;
    final double d = (estimateD(focalPlaneXindex, fitZ, smoothSx) + estimateD(focalPlaneYindex, fitZ, smoothSy)) / 2;
    // Start with Ax, Bx, Ay, By as zero.
    final double Ax = 0;
    final double Bx = 0;
    final double Ay = 0;
    final double By = 0;
    // If this is not the case we can invert the gamma parameter.
    if (focalPlaneXindex < focalPlaneYindex) {
        gamma = -gamma;
    }
    // Use an LVM fitter with numerical gradients.
    final double initialStepBoundFactor = 100;
    final double costRelativeTolerance = 1e-10;
    final double parRelativeTolerance = 1e-10;
    final double orthoTolerance = 1e-10;
    final double threshold = Precision.SAFE_MIN;
    // We optimise against both sx and sy as a combined y-value.
    final double[] y = new double[fitZ.length * 2];
    System.arraycopy(fitSx, 0, y, 0, fitSx.length);
    System.arraycopy(fitSy, 0, y, fitSx.length, fitSy.length);
    final LevenbergMarquardtOptimizer optimizer = new LevenbergMarquardtOptimizer(initialStepBoundFactor, costRelativeTolerance, parRelativeTolerance, orthoTolerance, threshold);
    parameters = new double[9];
    parameters[P_GAMMA] = gamma;
    parameters[P_Z0] = z0;
    parameters[P_D] = d;
    parameters[P_S0X] = s0x;
    parameters[P_AX] = Ax;
    parameters[P_BX] = Bx;
    parameters[P_S0Y] = s0y;
    parameters[P_AY] = Ay;
    parameters[P_BY] = By;
    record("Initial", parameters);
    if (pluginSettings.getShowEstimatedCurve()) {
        plotFit(parameters);
        IJ.showMessage(TITLE, "Showing the estimated curve parameters.\nClick OK to continue.");
    }
    // @formatter:off
    final LeastSquaresBuilder builder = new LeastSquaresBuilder().maxEvaluations(Integer.MAX_VALUE).maxIterations(3000).start(parameters).target(y);
    if (pluginSettings.getWeightedFit()) {
        builder.weight(new DiagonalMatrix(getWeights(smoothSx, smoothSy)));
    }
    final AstigmatismVectorFunction vf = new AstigmatismVectorFunction();
    builder.model(vf, new AstigmatismMatrixFunction());
    final LeastSquaresProblem problem = builder.build();
    try {
        final Optimum optimum = optimizer.optimize(problem);
        parameters = optimum.getPoint().toArray();
        record("Final", parameters);
        plotFit(parameters);
        saveResult(optimum);
    } catch (final Exception ex) {
        IJ.error(TITLE, "Failed to fit curve: " + ex.getMessage());
        return false;
    }
    return true;
}
Also used : Plot(ij.gui.Plot) ConversionException(uk.ac.sussex.gdsc.core.data.utils.ConversionException) LeastSquaresBuilder(org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder) LoessInterpolator(org.apache.commons.math3.analysis.interpolation.LoessInterpolator) Optimum(org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum) LevenbergMarquardtOptimizer(org.apache.commons.math3.fitting.leastsquares.LevenbergMarquardtOptimizer) DiagonalMatrix(org.apache.commons.math3.linear.DiagonalMatrix) LeastSquaresProblem(org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem)

Example 14 with LeastSquaresProblem

use of org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem in project GDSC-SMLM by aherbert.

the class BlinkEstimator method fit.

/**
	 * Fit the dark time to counts of molecules curve. Only use the first n fitted points.
	 * <p>
	 * Calculates:<br/>
	 * N = The number of photoblinking molecules in the sample<br/>
	 * nBlink = The average number of blinks per flourophore<br/>
	 * tOff = The off-time
	 * 
	 * @param td
	 *            The dark time
	 * @param ntd
	 *            The counts of molecules
	 * @param nFittedPoints
	 * @param log
	 *            Write the fitting results to the ImageJ log window
	 * @return The fitted parameters [N, nBlink, tOff], or null if no fit was possible
	 */
public double[] fit(double[] td, double[] ntd, int nFittedPoints, boolean log) {
    blinkingModel = new BlinkingFunction();
    blinkingModel.setLogging(true);
    for (int i = 0; i < nFittedPoints; i++) blinkingModel.addPoint(td[i], ntd[i]);
    // Different convergence thresholds seem to have no effect on the resulting fit, only the number of
    // iterations for convergence
    double initialStepBoundFactor = 100;
    double costRelativeTolerance = 1e-6;
    double parRelativeTolerance = 1e-6;
    double orthoTolerance = 1e-6;
    double threshold = Precision.SAFE_MIN;
    LevenbergMarquardtOptimizer optimiser = new LevenbergMarquardtOptimizer(initialStepBoundFactor, costRelativeTolerance, parRelativeTolerance, orthoTolerance, threshold);
    try {
        double[] obs = blinkingModel.getY();
        //@formatter:off
        LeastSquaresProblem problem = new LeastSquaresBuilder().maxEvaluations(Integer.MAX_VALUE).maxIterations(1000).start(new double[] { ntd[0], 0.1, td[1] }).target(obs).weight(new DiagonalMatrix(blinkingModel.getWeights())).model(blinkingModel, new MultivariateMatrixFunction() {

            public double[][] value(double[] point) throws IllegalArgumentException {
                return blinkingModel.jacobian(point);
            }
        }).build();
        //@formatter:on
        blinkingModel.setLogging(false);
        Optimum optimum = optimiser.optimize(problem);
        double[] parameters = optimum.getPoint().toArray();
        //double[] exp = blinkingModel.value(parameters);
        double mean = 0;
        for (double d : obs) mean += d;
        mean /= obs.length;
        double ssResiduals = 0, ssTotal = 0;
        for (int i = 0; i < obs.length; i++) {
            //ssResiduals += (obs[i] - exp[i]) * (obs[i] - exp[i]);
            ssTotal += (obs[i] - mean) * (obs[i] - mean);
        }
        // This is true if the weights are 1
        ssResiduals = optimum.getResiduals().dotProduct(optimum.getResiduals());
        r2 = 1 - ssResiduals / ssTotal;
        adjustedR2 = getAdjustedCoefficientOfDetermination(ssResiduals, ssTotal, obs.length, parameters.length);
        if (log) {
            Utils.log("  Fit %d points. R^2 = %s. Adjusted R^2 = %s", obs.length, Utils.rounded(r2, 4), Utils.rounded(adjustedR2, 4));
            Utils.log("  N=%s, nBlink=%s, tOff=%s (%s frames)", Utils.rounded(parameters[0], 4), Utils.rounded(parameters[1], 4), Utils.rounded(parameters[2], 4), Utils.rounded(parameters[2] / msPerFrame, 4));
        }
        return parameters;
    } catch (TooManyIterationsException e) {
        if (log)
            Utils.log("  Failed to fit %d points: Too many iterations: (%s)", blinkingModel.size(), e.getMessage());
        return null;
    } catch (ConvergenceException e) {
        if (log)
            Utils.log("  Failed to fit %d points", blinkingModel.size());
        return null;
    }
}
Also used : Optimum(org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum) LevenbergMarquardtOptimizer(org.apache.commons.math3.fitting.leastsquares.LevenbergMarquardtOptimizer) DiagonalMatrix(org.apache.commons.math3.linear.DiagonalMatrix) ConvergenceException(org.apache.commons.math3.exception.ConvergenceException) TooManyIterationsException(org.apache.commons.math3.exception.TooManyIterationsException) LeastSquaresProblem(org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem) MultivariateMatrixFunction(org.apache.commons.math3.analysis.MultivariateMatrixFunction) LeastSquaresBuilder(org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder)

Example 15 with LeastSquaresProblem

use of org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem in project GDSC-SMLM by aherbert.

the class PCPALMFitting method fitRandomModel.

/**
	 * Fits the correlation curve with r>0 to the random model using the estimated density and precision. Parameters
	 * must be fit within a tolerance of the starting values.
	 * 
	 * @param gr
	 * @param sigmaS
	 *            The estimated precision
	 * @param proteinDensity
	 *            The estimate protein density
	 * @return The fitted parameters [precision, density]
	 */
private double[] fitRandomModel(double[][] gr, double sigmaS, double proteinDensity, String resultColour) {
    final RandomModelFunction function = new RandomModelFunction();
    randomModel = function;
    log("Fitting %s: Estimated precision = %f nm, estimated protein density = %g um^-2", randomModel.getName(), sigmaS, proteinDensity * 1e6);
    randomModel.setLogging(true);
    for (int i = offset; i < gr[0].length; i++) {
        // Only fit the curve above the estimated resolution (points below it will be subject to error)
        if (gr[0][i] > sigmaS * fitAboveEstimatedPrecision)
            randomModel.addPoint(gr[0][i], gr[1][i]);
    }
    LevenbergMarquardtOptimizer optimizer = new LevenbergMarquardtOptimizer();
    Optimum optimum;
    try {
        //@formatter:off
        LeastSquaresProblem problem = new LeastSquaresBuilder().maxEvaluations(Integer.MAX_VALUE).maxIterations(3000).start(new double[] { sigmaS, proteinDensity }).target(function.getY()).weight(new DiagonalMatrix(function.getWeights())).model(function, new MultivariateMatrixFunction() {

            public double[][] value(double[] point) throws IllegalArgumentException {
                return function.jacobian(point);
            }
        }).build();
        //@formatter:on
        optimum = optimizer.optimize(problem);
    } catch (TooManyIterationsException e) {
        log("Failed to fit %s: Too many iterations (%s)", randomModel.getName(), e.getMessage());
        return null;
    } catch (ConvergenceException e) {
        log("Failed to fit %s: %s", randomModel.getName(), e.getMessage());
        return null;
    }
    randomModel.setLogging(false);
    double[] parameters = optimum.getPoint().toArray();
    // Ensure the width is positive
    parameters[0] = Math.abs(parameters[0]);
    double ss = optimum.getResiduals().dotProduct(optimum.getResiduals());
    ic1 = Maths.getAkaikeInformationCriterionFromResiduals(ss, randomModel.size(), parameters.length);
    final double fitSigmaS = parameters[0];
    final double fitProteinDensity = parameters[1];
    // Check the fitted parameters are within tolerance of the initial estimates
    double e1 = parameterDrift(sigmaS, fitSigmaS);
    double e2 = parameterDrift(proteinDensity, fitProteinDensity);
    log("  %s fit: SS = %f. cAIC = %f. %d evaluations", randomModel.getName(), ss, ic1, optimum.getEvaluations());
    log("  %s parameters:", randomModel.getName());
    log("    Average precision = %s nm (%s%%)", Utils.rounded(fitSigmaS, 4), Utils.rounded(e1, 4));
    log("    Average protein density = %s um^-2 (%s%%)", Utils.rounded(fitProteinDensity * 1e6, 4), Utils.rounded(e2, 4));
    valid1 = true;
    if (fittingTolerance > 0 && (Math.abs(e1) > fittingTolerance || Math.abs(e2) > fittingTolerance)) {
        log("  Failed to fit %s within tolerance (%s%%): Average precision = %f nm (%s%%), average protein density = %g um^-2 (%s%%)", randomModel.getName(), Utils.rounded(fittingTolerance, 4), fitSigmaS, Utils.rounded(e1, 4), fitProteinDensity * 1e6, Utils.rounded(e2, 4));
        valid1 = false;
    }
    if (valid1) {
        // ---------
        // TODO - My data does not comply with this criteria. 
        // This could be due to the PC-PALM Molecule code limiting the nmPerPixel to fit the images in memory 
        // thus removing correlations at small r.
        // It could also be due to the nature of the random simulations being 3D not 2D membranes 
        // as per the PC-PALM paper. 
        // ---------
        // Evaluate g(r)protein where:
        // g(r)peaks = g(r)protein + g(r)stoch
        // g(r)peaks ~ 1           + g(r)stoch
        // Verify g(r)protein should be <1.5 for all r>0
        double[] gr_stoch = randomModel.value(parameters);
        double[] gr_peaks = randomModel.getY();
        double[] gr_ = randomModel.getX();
        //SummaryStatistics stats = new SummaryStatistics();
        for (int i = 0; i < gr_peaks.length; i++) {
            // Only evaluate above the fitted average precision 
            if (gr_[i] < fitSigmaS)
                continue;
            // Note the RandomModelFunction evaluates g(r)stoch + 1;
            double gr_protein_i = gr_peaks[i] - (gr_stoch[i] - 1);
            if (gr_protein_i > gr_protein_threshold) {
                // Failed fit
                log("  Failed to fit %s: g(r)protein %s > %s @ r=%s", randomModel.getName(), Utils.rounded(gr_protein_i, 4), Utils.rounded(gr_protein_threshold, 4), Utils.rounded(gr_[i], 4));
                valid1 = false;
            }
        //stats.addValue(gr_i);
        //System.out.printf("g(r)protein @ %f = %f\n", gr[0][i], gr_protein_i);
        }
    }
    addResult(randomModel.getName(), resultColour, valid1, fitSigmaS, fitProteinDensity, 0, 0, 0, 0, ic1);
    return parameters;
}
Also used : Optimum(org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum) LevenbergMarquardtOptimizer(org.apache.commons.math3.fitting.leastsquares.LevenbergMarquardtOptimizer) DiagonalMatrix(org.apache.commons.math3.linear.DiagonalMatrix) ConvergenceException(org.apache.commons.math3.exception.ConvergenceException) TooManyIterationsException(org.apache.commons.math3.exception.TooManyIterationsException) LeastSquaresProblem(org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem) MultivariateMatrixFunction(org.apache.commons.math3.analysis.MultivariateMatrixFunction) LeastSquaresBuilder(org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder)

Aggregations

LeastSquaresBuilder (org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder)19 Optimum (org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum)19 LeastSquaresProblem (org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem)19 LevenbergMarquardtOptimizer (org.apache.commons.math3.fitting.leastsquares.LevenbergMarquardtOptimizer)19 DiagonalMatrix (org.apache.commons.math3.linear.DiagonalMatrix)18 ConvergenceException (org.apache.commons.math3.exception.ConvergenceException)16 TooManyIterationsException (org.apache.commons.math3.exception.TooManyIterationsException)16 MultivariateMatrixFunction (org.apache.commons.math3.analysis.MultivariateMatrixFunction)9 PointValuePair (org.apache.commons.math3.optim.PointValuePair)8 TooManyEvaluationsException (org.apache.commons.math3.exception.TooManyEvaluationsException)6 InitialGuess (org.apache.commons.math3.optim.InitialGuess)4 MaxEval (org.apache.commons.math3.optim.MaxEval)4 OptimizationData (org.apache.commons.math3.optim.OptimizationData)4 SimpleBounds (org.apache.commons.math3.optim.SimpleBounds)4 ObjectiveFunction (org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction)4 CMAESOptimizer (org.apache.commons.math3.optim.nonlinear.scalar.noderiv.CMAESOptimizer)4 RealVector (org.apache.commons.math3.linear.RealVector)3 Nullable (uk.ac.sussex.gdsc.core.annotation.Nullable)3 ValueAndJacobianFunction (org.apache.commons.math3.fitting.leastsquares.ValueAndJacobianFunction)2 Array2DRowRealMatrix (org.apache.commons.math3.linear.Array2DRowRealMatrix)2