Search in sources :

Example 1 with LeastSquaresProblem

use of org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem in project GDSC-SMLM by aherbert.

the class JumpDistanceAnalysis method doFitJumpDistanceHistogram.

/**
	 * Fit the jump distance histogram using a cumulative sum with the given number of species.
	 * <p>
	 * Results are sorted by the diffusion coefficient ascending.
	 * 
	 * @param jdHistogram
	 *            The cumulative jump distance histogram. X-axis is um^2, Y-axis is cumulative probability. Must be
	 *            monototic ascending.
	 * @param estimatedD
	 *            The estimated diffusion coefficient
	 * @param n
	 *            The number of species in the mixed population
	 * @return Array containing: { D (um^2), Fractions }. Can be null if no fit was made.
	 */
private double[][] doFitJumpDistanceHistogram(double[][] jdHistogram, double estimatedD, int n) {
    calibrated = isCalibrated();
    if (n == 1) {
        // Fit using a single population model
        LevenbergMarquardtOptimizer lvmOptimizer = new LevenbergMarquardtOptimizer();
        try {
            final JumpDistanceCumulFunction function = new JumpDistanceCumulFunction(jdHistogram[0], jdHistogram[1], estimatedD);
            //@formatter:off
            LeastSquaresProblem problem = new LeastSquaresBuilder().maxEvaluations(Integer.MAX_VALUE).maxIterations(3000).start(function.guess()).target(function.getY()).weight(new DiagonalMatrix(function.getWeights())).model(function, new MultivariateMatrixFunction() {

                public double[][] value(double[] point) throws IllegalArgumentException {
                    return function.jacobian(point);
                }
            }).build();
            //@formatter:on
            Optimum lvmSolution = lvmOptimizer.optimize(problem);
            double[] fitParams = lvmSolution.getPoint().toArray();
            // True for an unweighted fit
            ss = lvmSolution.getResiduals().dotProduct(lvmSolution.getResiduals());
            //ss = calculateSumOfSquares(function.getY(), function.value(fitParams));
            lastIC = ic = Maths.getAkaikeInformationCriterionFromResiduals(ss, function.x.length, 1);
            double[] coefficients = fitParams;
            double[] fractions = new double[] { 1 };
            logger.info("Fit Jump distance (N=1) : %s, SS = %s, IC = %s (%d evaluations)", formatD(fitParams[0]), Maths.rounded(ss, 4), Maths.rounded(ic, 4), lvmSolution.getEvaluations());
            return new double[][] { coefficients, fractions };
        } catch (TooManyIterationsException e) {
            logger.info("LVM optimiser failed to fit (N=1) : Too many iterations : %s", e.getMessage());
        } catch (ConvergenceException e) {
            logger.info("LVM optimiser failed to fit (N=1) : %s", e.getMessage());
        }
    }
    // Uses a weighted sum of n exponential functions, each function models a fraction of the particles.
    // An LVM fit cannot restrict the parameters so the fractions do not go below zero.
    // Use the CustomPowell/CMEASOptimizer which supports bounded fitting.
    MixedJumpDistanceCumulFunctionMultivariate function = new MixedJumpDistanceCumulFunctionMultivariate(jdHistogram[0], jdHistogram[1], estimatedD, n);
    double[] lB = function.getLowerBounds();
    int evaluations = 0;
    PointValuePair constrainedSolution = null;
    MaxEval maxEval = new MaxEval(20000);
    CustomPowellOptimizer powellOptimizer = createCustomPowellOptimizer();
    try {
        // The Powell algorithm can use more general bounds: 0 - Infinity
        constrainedSolution = powellOptimizer.optimize(maxEval, new ObjectiveFunction(function), new InitialGuess(function.guess()), new SimpleBounds(lB, function.getUpperBounds(Double.POSITIVE_INFINITY, Double.POSITIVE_INFINITY)), new CustomPowellOptimizer.BasisStep(function.step()), GoalType.MINIMIZE);
        evaluations = powellOptimizer.getEvaluations();
        logger.debug("Powell optimiser fit (N=%d) : SS = %f (%d evaluations)", n, constrainedSolution.getValue(), evaluations);
    } catch (TooManyEvaluationsException e) {
        logger.info("Powell optimiser failed to fit (N=%d) : Too many evaluations (%d)", n, powellOptimizer.getEvaluations());
    } catch (TooManyIterationsException e) {
        logger.info("Powell optimiser failed to fit (N=%d) : Too many iterations (%d)", n, powellOptimizer.getIterations());
    } catch (ConvergenceException e) {
        logger.info("Powell optimiser failed to fit (N=%d) : %s", n, e.getMessage());
    }
    if (constrainedSolution == null) {
        logger.info("Trying CMAES optimiser with restarts ...");
        double[] uB = function.getUpperBounds();
        SimpleBounds bounds = new SimpleBounds(lB, uB);
        // The sigma determines the search range for the variables. It should be 1/3 of the initial search region.
        double[] s = new double[lB.length];
        for (int i = 0; i < s.length; i++) s[i] = (uB[i] - lB[i]) / 3;
        OptimizationData sigma = new CMAESOptimizer.Sigma(s);
        OptimizationData popSize = new CMAESOptimizer.PopulationSize((int) (4 + Math.floor(3 * Math.log(function.x.length))));
        // Iterate this for stability in the initial guess
        CMAESOptimizer cmaesOptimizer = createCMAESOptimizer();
        for (int i = 0; i <= fitRestarts; i++) {
            // Try from the initial guess
            try {
                PointValuePair solution = cmaesOptimizer.optimize(new InitialGuess(function.guess()), new ObjectiveFunction(function), GoalType.MINIMIZE, bounds, sigma, popSize, maxEval);
                if (constrainedSolution == null || solution.getValue() < constrainedSolution.getValue()) {
                    evaluations = cmaesOptimizer.getEvaluations();
                    constrainedSolution = solution;
                    logger.debug("CMAES optimiser [%da] fit (N=%d) : SS = %f (%d evaluations)", i, n, solution.getValue(), evaluations);
                }
            } catch (TooManyEvaluationsException e) {
            }
            if (constrainedSolution == null)
                continue;
            // Try from the current optimum
            try {
                PointValuePair solution = cmaesOptimizer.optimize(new InitialGuess(constrainedSolution.getPointRef()), new ObjectiveFunction(function), GoalType.MINIMIZE, bounds, sigma, popSize, maxEval);
                if (solution.getValue() < constrainedSolution.getValue()) {
                    evaluations = cmaesOptimizer.getEvaluations();
                    constrainedSolution = solution;
                    logger.debug("CMAES optimiser [%db] fit (N=%d) : SS = %f (%d evaluations)", i, n, solution.getValue(), evaluations);
                }
            } catch (TooManyEvaluationsException e) {
            }
        }
        if (constrainedSolution != null) {
            // Re-optimise with Powell?
            try {
                PointValuePair solution = powellOptimizer.optimize(maxEval, new ObjectiveFunction(function), new InitialGuess(constrainedSolution.getPointRef()), new SimpleBounds(lB, function.getUpperBounds(Double.POSITIVE_INFINITY, Double.POSITIVE_INFINITY)), new CustomPowellOptimizer.BasisStep(function.step()), GoalType.MINIMIZE);
                if (solution.getValue() < constrainedSolution.getValue()) {
                    evaluations = cmaesOptimizer.getEvaluations();
                    constrainedSolution = solution;
                    logger.info("Powell optimiser re-fit (N=%d) : SS = %f (%d evaluations)", n, constrainedSolution.getValue(), evaluations);
                }
            } catch (TooManyEvaluationsException e) {
            } catch (TooManyIterationsException e) {
            } catch (ConvergenceException e) {
            }
        }
    }
    if (constrainedSolution == null) {
        logger.info("Failed to fit N=%d", n);
        return null;
    }
    double[] fitParams = constrainedSolution.getPointRef();
    ss = constrainedSolution.getValue();
    // TODO - Try a bounded BFGS optimiser
    // Try and improve using a LVM fit
    final MixedJumpDistanceCumulFunctionGradient functionGradient = new MixedJumpDistanceCumulFunctionGradient(jdHistogram[0], jdHistogram[1], estimatedD, n);
    Optimum lvmSolution;
    LevenbergMarquardtOptimizer lvmOptimizer = new LevenbergMarquardtOptimizer();
    try {
        //@formatter:off
        LeastSquaresProblem problem = new LeastSquaresBuilder().maxEvaluations(Integer.MAX_VALUE).maxIterations(3000).start(fitParams).target(functionGradient.getY()).weight(new DiagonalMatrix(functionGradient.getWeights())).model(functionGradient, new MultivariateMatrixFunction() {

            public double[][] value(double[] point) throws IllegalArgumentException {
                return functionGradient.jacobian(point);
            }
        }).build();
        //@formatter:on
        lvmSolution = lvmOptimizer.optimize(problem);
        // True for an unweighted fit
        double ss = lvmSolution.getResiduals().dotProduct(lvmSolution.getResiduals());
        // All fitted parameters must be above zero
        if (ss < this.ss && Maths.min(lvmSolution.getPoint().toArray()) > 0) {
            logger.info("  Re-fitting improved the SS from %s to %s (-%s%%)", Maths.rounded(this.ss, 4), Maths.rounded(ss, 4), Maths.rounded(100 * (this.ss - ss) / this.ss, 4));
            fitParams = lvmSolution.getPoint().toArray();
            this.ss = ss;
            evaluations += lvmSolution.getEvaluations();
        }
    } catch (TooManyIterationsException e) {
        logger.error("Failed to re-fit : Too many iterations : %s", e.getMessage());
    } catch (ConvergenceException e) {
        logger.error("Failed to re-fit : %s", e.getMessage());
    }
    // Since the fractions must sum to one we subtract 1 degree of freedom from the number of parameters
    ic = Maths.getAkaikeInformationCriterionFromResiduals(ss, function.x.length, fitParams.length - 1);
    double[] d = new double[n];
    double[] f = new double[n];
    double sum = 0;
    for (int i = 0; i < d.length; i++) {
        f[i] = fitParams[i * 2];
        sum += f[i];
        d[i] = fitParams[i * 2 + 1];
    }
    for (int i = 0; i < f.length; i++) f[i] /= sum;
    // Sort by coefficient size
    sort(d, f);
    double[] coefficients = d;
    double[] fractions = f;
    logger.info("Fit Jump distance (N=%d) : %s (%s), SS = %s, IC = %s (%d evaluations)", n, formatD(d), format(f), Maths.rounded(ss, 4), Maths.rounded(ic, 4), evaluations);
    if (isValid(d, f)) {
        lastIC = ic;
        return new double[][] { coefficients, fractions };
    }
    return null;
}
Also used : MaxEval(org.apache.commons.math3.optim.MaxEval) InitialGuess(org.apache.commons.math3.optim.InitialGuess) SimpleBounds(org.apache.commons.math3.optim.SimpleBounds) ObjectiveFunction(org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction) LeastSquaresBuilder(org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder) PointValuePair(org.apache.commons.math3.optim.PointValuePair) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) DiagonalMatrix(org.apache.commons.math3.linear.DiagonalMatrix) ConvergenceException(org.apache.commons.math3.exception.ConvergenceException) TooManyIterationsException(org.apache.commons.math3.exception.TooManyIterationsException) LeastSquaresProblem(org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem) MultivariateMatrixFunction(org.apache.commons.math3.analysis.MultivariateMatrixFunction) CMAESOptimizer(org.apache.commons.math3.optim.nonlinear.scalar.noderiv.CMAESOptimizer) Optimum(org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum) LevenbergMarquardtOptimizer(org.apache.commons.math3.fitting.leastsquares.LevenbergMarquardtOptimizer) OptimizationData(org.apache.commons.math3.optim.OptimizationData) CustomPowellOptimizer(org.apache.commons.math3.optim.nonlinear.scalar.noderiv.CustomPowellOptimizer)

Example 2 with LeastSquaresProblem

use of org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem in project GDSC-SMLM by aherbert.

the class BinomialFitter method fitBinomial.

/**
	 * Fit the binomial distribution (n,p) to the cumulative histogram. Performs fitting assuming a fixed n value and
	 * attempts to optimise p.
	 * 
	 * @param histogram
	 *            The input histogram
	 * @param mean
	 *            The histogram mean (used to estimate p). Calculated if NaN.
	 * @param n
	 *            The n to evaluate
	 * @param zeroTruncated
	 *            True if the model should ignore n=0 (zero-truncated binomial)
	 * @return The best fit (n, p)
	 * @throws IllegalArgumentException
	 *             If any of the input data values are negative
	 * @throws IllegalArgumentException
	 *             If any fitting a zero truncated binomial and there are no values above zero
	 */
public PointValuePair fitBinomial(double[] histogram, double mean, int n, boolean zeroTruncated) {
    if (Double.isNaN(mean))
        mean = getMean(histogram);
    if (zeroTruncated && histogram[0] > 0) {
        log("Fitting zero-truncated histogram but there are zero values - Renormalising to ignore zero");
        double cumul = 0;
        for (int i = 1; i < histogram.length; i++) cumul += histogram[i];
        if (cumul == 0)
            throw new IllegalArgumentException("Fitting zero-truncated histogram but there are no non-zero values");
        histogram[0] = 0;
        for (int i = 1; i < histogram.length; i++) histogram[i] /= cumul;
    }
    int nFittedPoints = Math.min(histogram.length, n + 1) - ((zeroTruncated) ? 1 : 0);
    if (nFittedPoints < 1) {
        log("No points to fit (%d): Histogram.length = %d, n = %d, zero-truncated = %b", nFittedPoints, histogram.length, n, zeroTruncated);
        return null;
    }
    // The model is only fitting the probability p
    // For a binomial n*p = mean => p = mean/n
    double[] initialSolution = new double[] { FastMath.min(mean / n, 1) };
    // Create the function
    BinomialModelFunction function = new BinomialModelFunction(histogram, n, zeroTruncated);
    double[] lB = new double[1];
    double[] uB = new double[] { 1 };
    SimpleBounds bounds = new SimpleBounds(lB, uB);
    // Fit
    // CMAESOptimizer or BOBYQAOptimizer support bounds
    // CMAESOptimiser based on Matlab code:
    // https://www.lri.fr/~hansen/cmaes.m
    // Take the defaults from the Matlab documentation
    int maxIterations = 2000;
    //Double.NEGATIVE_INFINITY;
    double stopFitness = 0;
    boolean isActiveCMA = true;
    int diagonalOnly = 0;
    int checkFeasableCount = 1;
    RandomGenerator random = new Well19937c();
    boolean generateStatistics = false;
    ConvergenceChecker<PointValuePair> checker = new SimpleValueChecker(1e-6, 1e-10);
    // The sigma determines the search range for the variables. It should be 1/3 of the initial search region.
    OptimizationData sigma = new CMAESOptimizer.Sigma(new double[] { (uB[0] - lB[0]) / 3 });
    OptimizationData popSize = new CMAESOptimizer.PopulationSize((int) (4 + Math.floor(3 * Math.log(2))));
    try {
        PointValuePair solution = null;
        boolean noRefit = maximumLikelihood;
        if (n == 1 && zeroTruncated) {
            // No need to fit
            solution = new PointValuePair(new double[] { 1 }, 0);
            noRefit = true;
        } else {
            GoalType goalType = (maximumLikelihood) ? GoalType.MAXIMIZE : GoalType.MINIMIZE;
            // Iteratively fit
            CMAESOptimizer opt = new CMAESOptimizer(maxIterations, stopFitness, isActiveCMA, diagonalOnly, checkFeasableCount, random, generateStatistics, checker);
            for (int iteration = 0; iteration <= fitRestarts; iteration++) {
                try {
                    // Start from the initial solution
                    PointValuePair result = opt.optimize(new InitialGuess(initialSolution), new ObjectiveFunction(function), goalType, bounds, sigma, popSize, new MaxIter(maxIterations), new MaxEval(maxIterations * 2));
                    //		opt.getEvaluations());
                    if (solution == null || result.getValue() < solution.getValue()) {
                        solution = result;
                    }
                } catch (TooManyEvaluationsException e) {
                } catch (TooManyIterationsException e) {
                }
                if (solution == null)
                    continue;
                try {
                    // Also restart from the current optimum
                    PointValuePair result = opt.optimize(new InitialGuess(solution.getPointRef()), new ObjectiveFunction(function), goalType, bounds, sigma, popSize, new MaxIter(maxIterations), new MaxEval(maxIterations * 2));
                    //		opt.getEvaluations());
                    if (result.getValue() < solution.getValue()) {
                        solution = result;
                    }
                } catch (TooManyEvaluationsException e) {
                } catch (TooManyIterationsException e) {
                }
            }
            if (solution == null)
                return null;
        }
        if (noRefit) {
            // Although we fit the log-likelihood, return the sum-of-squares to allow 
            // comparison across different n
            double p = solution.getPointRef()[0];
            double ss = 0;
            double[] obs = function.p;
            double[] exp = function.getP(p);
            for (int i = 0; i < obs.length; i++) ss += (obs[i] - exp[i]) * (obs[i] - exp[i]);
            return new PointValuePair(solution.getPointRef(), ss);
        } else // We can do a LVM refit if the number of fitted points is more than 1
        if (nFittedPoints > 1) {
            // Improve SS fit with a gradient based LVM optimizer
            LevenbergMarquardtOptimizer optimizer = new LevenbergMarquardtOptimizer();
            try {
                final BinomialModelFunctionGradient gradientFunction = new BinomialModelFunctionGradient(histogram, n, zeroTruncated);
                //@formatter:off
                LeastSquaresProblem problem = new LeastSquaresBuilder().maxEvaluations(Integer.MAX_VALUE).maxIterations(3000).start(solution.getPointRef()).target(gradientFunction.p).weight(new DiagonalMatrix(gradientFunction.getWeights())).model(gradientFunction, new MultivariateMatrixFunction() {

                    public double[][] value(double[] point) throws IllegalArgumentException {
                        return gradientFunction.jacobian(point);
                    }
                }).build();
                //@formatter:on
                Optimum lvmSolution = optimizer.optimize(problem);
                // Check the pValue is valid since the LVM is not bounded.
                double p = lvmSolution.getPoint().getEntry(0);
                if (p <= 1 && p >= 0) {
                    // True if the weights are 1
                    double ss = lvmSolution.getResiduals().dotProduct(lvmSolution.getResiduals());
                    //	ss += (obs[i] - exp[i]) * (obs[i] - exp[i]);
                    if (ss < solution.getValue()) {
                        //		Utils.rounded(100 * (solution.getValue() - ss) / solution.getValue(), 4));
                        return new PointValuePair(lvmSolution.getPoint().toArray(), ss);
                    }
                }
            } catch (TooManyIterationsException e) {
                log("Failed to re-fit: Too many iterations: %s", e.getMessage());
            } catch (ConvergenceException e) {
                log("Failed to re-fit: %s", e.getMessage());
            } catch (Exception e) {
            // Ignore this ...
            }
        }
        return solution;
    } catch (Exception e) {
        log("Failed to fit Binomial distribution with N=%d : %s", n, e.getMessage());
    }
    return null;
}
Also used : InitialGuess(org.apache.commons.math3.optim.InitialGuess) MaxEval(org.apache.commons.math3.optim.MaxEval) SimpleBounds(org.apache.commons.math3.optim.SimpleBounds) ObjectiveFunction(org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction) Well19937c(org.apache.commons.math3.random.Well19937c) SimpleValueChecker(org.apache.commons.math3.optim.SimpleValueChecker) RandomGenerator(org.apache.commons.math3.random.RandomGenerator) PointValuePair(org.apache.commons.math3.optim.PointValuePair) LeastSquaresBuilder(org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) DiagonalMatrix(org.apache.commons.math3.linear.DiagonalMatrix) ConvergenceException(org.apache.commons.math3.exception.ConvergenceException) TooManyIterationsException(org.apache.commons.math3.exception.TooManyIterationsException) LeastSquaresProblem(org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem) MultivariateMatrixFunction(org.apache.commons.math3.analysis.MultivariateMatrixFunction) CMAESOptimizer(org.apache.commons.math3.optim.nonlinear.scalar.noderiv.CMAESOptimizer) GoalType(org.apache.commons.math3.optim.nonlinear.scalar.GoalType) ConvergenceException(org.apache.commons.math3.exception.ConvergenceException) TooManyIterationsException(org.apache.commons.math3.exception.TooManyIterationsException) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) Optimum(org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum) LevenbergMarquardtOptimizer(org.apache.commons.math3.fitting.leastsquares.LevenbergMarquardtOptimizer) OptimizationData(org.apache.commons.math3.optim.OptimizationData) MaxIter(org.apache.commons.math3.optim.MaxIter)

Example 3 with LeastSquaresProblem

use of org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem in project GDSC-SMLM by aherbert.

the class ApacheLVMFitter method computeFit.

public FitStatus computeFit(double[] y, final double[] y_fit, double[] a, double[] a_dev) {
    int n = y.length;
    try {
        // Different convergence thresholds seem to have no effect on the resulting fit, only the number of
        // iterations for convergence
        final double initialStepBoundFactor = 100;
        final double costRelativeTolerance = 1e-10;
        final double parRelativeTolerance = 1e-10;
        final double orthoTolerance = 1e-10;
        final double threshold = Precision.SAFE_MIN;
        // Extract the parameters to be fitted
        final double[] initialSolution = getInitialSolution(a);
        // TODO - Pass in more advanced stopping criteria.
        // Create the target and weight arrays
        final double[] yd = new double[n];
        final double[] w = new double[n];
        for (int i = 0; i < n; i++) {
            yd[i] = y[i];
            w[i] = 1;
        }
        LevenbergMarquardtOptimizer optimizer = new LevenbergMarquardtOptimizer(initialStepBoundFactor, costRelativeTolerance, parRelativeTolerance, orthoTolerance, threshold);
        //@formatter:off
        LeastSquaresBuilder builder = new LeastSquaresBuilder().maxEvaluations(Integer.MAX_VALUE).maxIterations(getMaxEvaluations()).start(initialSolution).target(yd).weight(new DiagonalMatrix(w));
        if (f instanceof ExtendedNonLinearFunction && ((ExtendedNonLinearFunction) f).canComputeValuesAndJacobian()) {
            // Compute together, or each individually
            builder.model(new ValueAndJacobianFunction() {

                final ExtendedNonLinearFunction fun = (ExtendedNonLinearFunction) f;

                public Pair<RealVector, RealMatrix> value(RealVector point) {
                    final double[] p = point.toArray();
                    final Pair<double[], double[][]> result = fun.computeValuesAndJacobian(p);
                    return new Pair<RealVector, RealMatrix>(new ArrayRealVector(result.getFirst(), false), new Array2DRowRealMatrix(result.getSecond(), false));
                }

                public RealVector computeValue(double[] params) {
                    return new ArrayRealVector(fun.computeValues(params), false);
                }

                public RealMatrix computeJacobian(double[] params) {
                    return new Array2DRowRealMatrix(fun.computeJacobian(params), false);
                }
            });
        } else {
            // Compute separately
            builder.model(new MultivariateVectorFunctionWrapper((NonLinearFunction) f, a, n), new MultivariateMatrixFunctionWrapper((NonLinearFunction) f, a, n));
        }
        LeastSquaresProblem problem = builder.build();
        Optimum optimum = optimizer.optimize(problem);
        final double[] parameters = optimum.getPoint().toArray();
        setSolution(a, parameters);
        iterations = optimum.getIterations();
        evaluations = optimum.getEvaluations();
        if (a_dev != null) {
            try {
                double[][] covar = optimum.getCovariances(threshold).getData();
                setDeviationsFromMatrix(a_dev, covar);
            } catch (SingularMatrixException e) {
                // Matrix inversion failed. In order to return a solution 
                // return the reciprocal of the diagonal of the Fisher information 
                // for a loose bound on the limit 
                final int[] gradientIndices = f.gradientIndices();
                final int nparams = gradientIndices.length;
                GradientCalculator calculator = GradientCalculatorFactory.newCalculator(nparams);
                double[][] alpha = new double[nparams][nparams];
                double[] beta = new double[nparams];
                calculator.findLinearised(nparams, y, a, alpha, beta, (NonLinearFunction) f);
                FisherInformationMatrix m = new FisherInformationMatrix(alpha);
                setDeviations(a_dev, m.crlb(true));
            }
        }
        // Compute function value
        if (y_fit != null) {
            Gaussian2DFunction f = (Gaussian2DFunction) this.f;
            f.initialise0(a);
            f.forEach(new ValueProcedure() {

                int i = 0;

                public void execute(double value) {
                    y_fit[i] = value;
                }
            });
        }
        // As this is unweighted then we can do this to get the sum of squared residuals
        // This is the same as optimum.getCost() * optimum.getCost(); The getCost() function
        // just computes the dot product anyway.
        value = optimum.getResiduals().dotProduct(optimum.getResiduals());
    } catch (TooManyEvaluationsException e) {
        return FitStatus.TOO_MANY_EVALUATIONS;
    } catch (TooManyIterationsException e) {
        return FitStatus.TOO_MANY_ITERATIONS;
    } catch (ConvergenceException e) {
        // Occurs when QR decomposition fails - mark as a singular non-linear model (no solution)
        return FitStatus.SINGULAR_NON_LINEAR_MODEL;
    } catch (Exception e) {
        // TODO - Find out the other exceptions from the fitter and add return values to match. 
        return FitStatus.UNKNOWN;
    }
    return FitStatus.OK;
}
Also used : ValueProcedure(gdsc.smlm.function.ValueProcedure) ExtendedNonLinearFunction(gdsc.smlm.function.ExtendedNonLinearFunction) NonLinearFunction(gdsc.smlm.function.NonLinearFunction) LeastSquaresBuilder(org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix) Gaussian2DFunction(gdsc.smlm.function.gaussian.Gaussian2DFunction) ValueAndJacobianFunction(org.apache.commons.math3.fitting.leastsquares.ValueAndJacobianFunction) DiagonalMatrix(org.apache.commons.math3.linear.DiagonalMatrix) RealVector(org.apache.commons.math3.linear.RealVector) ArrayRealVector(org.apache.commons.math3.linear.ArrayRealVector) ConvergenceException(org.apache.commons.math3.exception.ConvergenceException) SingularMatrixException(org.apache.commons.math3.linear.SingularMatrixException) TooManyIterationsException(org.apache.commons.math3.exception.TooManyIterationsException) LeastSquaresProblem(org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem) GradientCalculator(gdsc.smlm.fitting.nonlinear.gradient.GradientCalculator) Pair(org.apache.commons.math3.util.Pair) ArrayRealVector(org.apache.commons.math3.linear.ArrayRealVector) FisherInformationMatrix(gdsc.smlm.fitting.FisherInformationMatrix) MultivariateMatrixFunctionWrapper(gdsc.smlm.function.MultivariateMatrixFunctionWrapper) SingularMatrixException(org.apache.commons.math3.linear.SingularMatrixException) ConvergenceException(org.apache.commons.math3.exception.ConvergenceException) TooManyIterationsException(org.apache.commons.math3.exception.TooManyIterationsException) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) Optimum(org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum) LevenbergMarquardtOptimizer(org.apache.commons.math3.fitting.leastsquares.LevenbergMarquardtOptimizer) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix) RealMatrix(org.apache.commons.math3.linear.RealMatrix) MultivariateVectorFunctionWrapper(gdsc.smlm.function.MultivariateVectorFunctionWrapper) ExtendedNonLinearFunction(gdsc.smlm.function.ExtendedNonLinearFunction)

Example 4 with LeastSquaresProblem

use of org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem in project GDSC-SMLM by aherbert.

the class TraceDiffusion method fitMSD.

/**
	 * Fit the MSD using a linear fit that must pass through 0,0.
	 * <p>
	 * Update the plot by adding the fit line.
	 * 
	 * @param x
	 * @param y
	 * @param title
	 * @param plot
	 * @return [D, precision]
	 */
private double[] fitMSD(double[] x, double[] y, String title, Plot2 plot) {
    // The Weimann paper (Plos One e64287) fits:
    // MSD(n dt) = 4D n dt + 4s^2
    // n = number of jumps
    // dt = time difference between frames
    // s = localisation precision
    // Thus we should fit an intercept as well.
    // From the fit D = gradient / (4*exposureTime)
    double D = 0;
    double intercept = 0;
    double precision = 0;
    LevenbergMarquardtOptimizer optimizer = new LevenbergMarquardtOptimizer();
    Optimum lvmSolution;
    double ic = 0;
    // Fit with no intercept
    try {
        final LinearFunction function = new LinearFunction(x, y, settings.fitLength);
        double[] parameters = new double[] { function.guess() };
        //@formatter:off
        LeastSquaresProblem problem = new LeastSquaresBuilder().maxEvaluations(Integer.MAX_VALUE).maxIterations(3000).start(parameters).target(function.getY()).weight(new DiagonalMatrix(function.getWeights())).model(function, new MultivariateMatrixFunction() {

            public double[][] value(double[] point) throws IllegalArgumentException {
                return function.jacobian(point);
            }
        }).build();
        //@formatter:on
        lvmSolution = optimizer.optimize(problem);
        double ss = lvmSolution.getResiduals().dotProduct(lvmSolution.getResiduals());
        //double ss = 0;
        //double[] obs = function.getY();
        //double[] exp = lvmSolution.getValue();
        //for (int i = 0; i < obs.length; i++)
        //	ss += (obs[i] - exp[i]) * (obs[i] - exp[i]);
        ic = Maths.getAkaikeInformationCriterionFromResiduals(ss, function.getY().length, 1);
        double gradient = lvmSolution.getPoint().getEntry(0);
        D = gradient / 4;
        Utils.log("Linear fit (%d points) : Gradient = %s, D = %s um^2/s, SS = %s, IC = %s (%d evaluations)", function.getY().length, Utils.rounded(gradient, 4), Utils.rounded(D, 4), Utils.rounded(ss), Utils.rounded(ic), lvmSolution.getEvaluations());
    } catch (TooManyIterationsException e) {
        Utils.log("Failed to fit : Too many iterations (%s)", e.getMessage());
    } catch (ConvergenceException e) {
        Utils.log("Failed to fit : %s", e.getMessage());
    }
    // Fit with intercept.
    // Optionally include the intercept (which is the estimated precision).
    boolean fitIntercept = true;
    try {
        final LinearFunctionWithIntercept function = new LinearFunctionWithIntercept(x, y, settings.fitLength, fitIntercept);
        //@formatter:off
        LeastSquaresProblem problem = new LeastSquaresBuilder().maxEvaluations(Integer.MAX_VALUE).maxIterations(3000).start(function.guess()).target(function.getY()).weight(new DiagonalMatrix(function.getWeights())).model(function, new MultivariateMatrixFunction() {

            public double[][] value(double[] point) throws IllegalArgumentException {
                return function.jacobian(point);
            }
        }).build();
        //@formatter:on
        lvmSolution = optimizer.optimize(problem);
        double ss = lvmSolution.getResiduals().dotProduct(lvmSolution.getResiduals());
        //double ss = 0;
        //double[] obs = function.getY();
        //double[] exp = lvmSolution.getValue();
        //for (int i = 0; i < obs.length; i++)
        //	ss += (obs[i] - exp[i]) * (obs[i] - exp[i]);
        double ic2 = Maths.getAkaikeInformationCriterionFromResiduals(ss, function.getY().length, 2);
        double gradient = lvmSolution.getPoint().getEntry(0);
        final double s = lvmSolution.getPoint().getEntry(1);
        double intercept2 = 4 * s * s;
        if (ic2 < ic || debugFitting) {
            // Convert fitted precision in um to nm
            Utils.log("Linear fit with intercept (%d points) : Gradient = %s, Intercept = %s, D = %s um^2/s, precision = %s nm, SS = %s, IC = %s (%d evaluations)", function.getY().length, Utils.rounded(gradient, 4), Utils.rounded(intercept2, 4), Utils.rounded(gradient / 4, 4), Utils.rounded(s * 1000, 4), Utils.rounded(ss), Utils.rounded(ic2), lvmSolution.getEvaluations());
        }
        if (lvmSolution == null || ic2 < ic) {
            intercept = intercept2;
            D = gradient / 4;
            precision = s;
        }
    } catch (TooManyIterationsException e) {
        Utils.log("Failed to fit with intercept : Too many iterations (%s)", e.getMessage());
    } catch (ConvergenceException e) {
        Utils.log("Failed to fit with intercept : %s", e.getMessage());
    }
    if (settings.msdCorrection) {
        // i.e. the intercept is allowed to be a small negative.
        try {
            // This function fits the jump distance (n) not the time (nt) so update x
            double[] x2 = new double[x.length];
            for (int i = 0; i < x2.length; i++) x2[i] = x[i] / exposureTime;
            final LinearFunctionWithMSDCorrectedIntercept function = new LinearFunctionWithMSDCorrectedIntercept(x2, y, settings.fitLength, fitIntercept);
            //@formatter:off
            LeastSquaresProblem problem = new LeastSquaresBuilder().maxEvaluations(Integer.MAX_VALUE).maxIterations(3000).start(function.guess()).target(function.getY()).weight(new DiagonalMatrix(function.getWeights())).model(function, new MultivariateMatrixFunction() {

                public double[][] value(double[] point) throws IllegalArgumentException {
                    return function.jacobian(point);
                }
            }).build();
            //@formatter:on
            lvmSolution = optimizer.optimize(problem);
            double ss = lvmSolution.getResiduals().dotProduct(lvmSolution.getResiduals());
            //double ss = 0;
            //double[] obs = function.getY();
            //double[] exp = lvmSolution.getValue();
            //for (int i = 0; i < obs.length; i++)
            //	ss += (obs[i] - exp[i]) * (obs[i] - exp[i]);
            double ic2 = Maths.getAkaikeInformationCriterionFromResiduals(ss, function.getY().length, 2);
            double gradient = lvmSolution.getPoint().getEntry(0);
            final double s = lvmSolution.getPoint().getEntry(1);
            double intercept2 = 4 * s * s - gradient / 3;
            // Q. Is this working?
            // Try fixed precision fitting. Is the gradient correct?
            // Revisit all the equations to see if they are wrong.
            // Try adding the x[0] datapoint using the precision.
            // Change the formula to not be linear at x[0] and to just fit the precision, i.e. the intercept2 = 4 * s * s - gradient / 3 is wrong as the 
            // equation is not linear below n=1.
            // Incorporate the exposure time into the gradient to allow comparison to other fits 
            gradient /= exposureTime;
            if (ic2 < ic || debugFitting) {
                // Convert fitted precision in um to nm
                Utils.log("Linear fit with MSD corrected intercept (%d points) : Gradient = %s, Intercept = %s, D = %s um^2/s, precision = %s nm, SS = %s, IC = %s (%d evaluations)", function.getY().length, Utils.rounded(gradient, 4), Utils.rounded(intercept2, 4), Utils.rounded(gradient / 4, 4), Utils.rounded(s * 1000, 4), Utils.rounded(ss), Utils.rounded(ic2), lvmSolution.getEvaluations());
            }
            if (lvmSolution == null || ic2 < ic) {
                intercept = intercept2;
                D = gradient / 4;
                precision = s;
            }
        } catch (TooManyIterationsException e) {
            Utils.log("Failed to fit with intercept : Too many iterations (%s)", e.getMessage());
        } catch (ConvergenceException e) {
            Utils.log("Failed to fit with intercept : %s", e.getMessage());
        }
    }
    // Add the fit to the plot
    if (D > 0) {
        plot.setColor(Color.magenta);
        plot.drawLine(0, intercept, x[x.length - 1], 4 * D * x[x.length - 1] + intercept);
        display(title, plot);
        checkTraceDistance(D);
    }
    return new double[] { D, precision };
}
Also used : LeastSquaresBuilder(org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder) Optimum(org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum) LevenbergMarquardtOptimizer(org.apache.commons.math3.fitting.leastsquares.LevenbergMarquardtOptimizer) DiagonalMatrix(org.apache.commons.math3.linear.DiagonalMatrix) ConvergenceException(org.apache.commons.math3.exception.ConvergenceException) TooManyIterationsException(org.apache.commons.math3.exception.TooManyIterationsException) LeastSquaresProblem(org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem) MultivariateMatrixFunction(org.apache.commons.math3.analysis.MultivariateMatrixFunction)

Example 5 with LeastSquaresProblem

use of org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem in project GDSC-SMLM by aherbert.

the class PCPALMFitting method fitEmulsionModel.

/**
	 * Fits the correlation curve with r>0 to the clustered model using the estimated density and precision. Parameters
	 * must be fit within a tolerance of the starting values.
	 * 
	 * @param gr
	 * @param sigmaS
	 *            The estimated precision
	 * @param proteinDensity
	 *            The estimated protein density
	 * @return The fitted parameters [precision, density, clusterRadius, clusterDensity]
	 */
private double[] fitEmulsionModel(double[][] gr, double sigmaS, double proteinDensity, String resultColour) {
    final EmulsionModelFunctionGradient function = new EmulsionModelFunctionGradient();
    emulsionModel = function;
    log("Fitting %s: Estimated precision = %f nm, estimated protein density = %g um^-2", emulsionModel.getName(), sigmaS, proteinDensity * 1e6);
    emulsionModel.setLogging(true);
    for (int i = offset; i < gr[0].length; i++) {
        // Only fit the curve above the estimated resolution (points below it will be subject to error)
        if (gr[0][i] > sigmaS * fitAboveEstimatedPrecision)
            emulsionModel.addPoint(gr[0][i], gr[1][i]);
    }
    double[] parameters;
    // The model is: sigma, density, range, amplitude, alpha
    double[] initialSolution = new double[] { sigmaS, proteinDensity, sigmaS * 5, 1, sigmaS * 5 };
    int evaluations = 0;
    // Constrain the fitting to be close to the estimated precision (sigmaS) and protein density.
    // LVM fitting does not support constrained fitting so use a bounded optimiser.
    SumOfSquaresModelFunction emulsionModelMulti = new SumOfSquaresModelFunction(emulsionModel);
    double[] x = emulsionModelMulti.x;
    double[] y = emulsionModelMulti.y;
    // Range should be equal to the first time the g(r) curve crosses 1
    for (int i = 0; i < x.length; i++) if (y[i] < 1) {
        initialSolution[4] = initialSolution[2] = (i > 0) ? (x[i - 1] + x[i]) * 0.5 : x[i];
        break;
    }
    // Put some bounds around the initial guess. Use the fitting tolerance (in %) if provided.
    double limit = (fittingTolerance > 0) ? 1 + fittingTolerance / 100 : 2;
    double[] lB = new double[] { initialSolution[0] / limit, initialSolution[1] / limit, 0, 0, 0 };
    // The amplitude and range should not extend beyond the limits of the g(r) curve.
    // TODO - Find out the expected range for the alpha parameter.  
    double[] uB = new double[] { initialSolution[0] * limit, initialSolution[1] * limit, Maths.max(x), Maths.max(gr[1]), Maths.max(x) * 2 };
    log("Fitting %s using a bounded search: %s < precision < %s & %s < density < %s", emulsionModel.getName(), Utils.rounded(lB[0], 4), Utils.rounded(uB[0], 4), Utils.rounded(lB[1] * 1e6, 4), Utils.rounded(uB[1] * 1e6, 4));
    PointValuePair constrainedSolution = runBoundedOptimiser(gr, initialSolution, lB, uB, emulsionModelMulti);
    if (constrainedSolution == null)
        return null;
    parameters = constrainedSolution.getPointRef();
    evaluations = boundedEvaluations;
    // Refit using a LVM  
    if (useLSE) {
        log("Re-fitting %s using a gradient optimisation", emulsionModel.getName());
        LevenbergMarquardtOptimizer optimizer = new LevenbergMarquardtOptimizer();
        Optimum lvmSolution;
        try {
            //@formatter:off
            LeastSquaresProblem problem = new LeastSquaresBuilder().maxEvaluations(Integer.MAX_VALUE).maxIterations(3000).start(parameters).target(function.getY()).weight(new DiagonalMatrix(function.getWeights())).model(function, new MultivariateMatrixFunction() {

                public double[][] value(double[] point) throws IllegalArgumentException {
                    return function.jacobian(point);
                }
            }).build();
            //@formatter:on
            lvmSolution = optimizer.optimize(problem);
            evaluations += lvmSolution.getEvaluations();
            double ss = lvmSolution.getResiduals().dotProduct(lvmSolution.getResiduals());
            if (ss < constrainedSolution.getValue()) {
                log("Re-fitting %s improved the SS from %s to %s (-%s%%)", emulsionModel.getName(), Utils.rounded(constrainedSolution.getValue(), 4), Utils.rounded(ss, 4), Utils.rounded(100 * (constrainedSolution.getValue() - ss) / constrainedSolution.getValue(), 4));
                parameters = lvmSolution.getPoint().toArray();
            }
        } catch (TooManyIterationsException e) {
            log("Failed to re-fit %s: Too many iterations (%s)", emulsionModel.getName(), e.getMessage());
        } catch (ConvergenceException e) {
            log("Failed to re-fit %s: %s", emulsionModel.getName(), e.getMessage());
        }
    }
    emulsionModel.setLogging(false);
    // Ensure the width is positive
    parameters[0] = Math.abs(parameters[0]);
    //parameters[2] = Math.abs(parameters[2]);
    double ss = 0;
    double[] obs = emulsionModel.getY();
    double[] exp = emulsionModel.value(parameters);
    for (int i = 0; i < obs.length; i++) ss += (obs[i] - exp[i]) * (obs[i] - exp[i]);
    ic3 = Maths.getAkaikeInformationCriterionFromResiduals(ss, emulsionModel.size(), parameters.length);
    final double fitSigmaS = parameters[0];
    final double fitProteinDensity = parameters[1];
    //The radius of the cluster domain
    final double domainRadius = parameters[2];
    //The density of the cluster domain
    final double domainDensity = parameters[3];
    //The coherence length between circles
    final double coherence = parameters[4];
    // This is from the PC-PALM paper. It may not be correct for the emulsion model.
    final double nCluster = 2 * domainDensity * Math.PI * domainRadius * domainRadius * fitProteinDensity;
    double e1 = parameterDrift(sigmaS, fitSigmaS);
    double e2 = parameterDrift(proteinDensity, fitProteinDensity);
    log("  %s fit: SS = %f. cAIC = %f. %d evaluations", emulsionModel.getName(), ss, ic3, evaluations);
    log("  %s parameters:", emulsionModel.getName());
    log("    Average precision = %s nm (%s%%)", Utils.rounded(fitSigmaS, 4), Utils.rounded(e1, 4));
    log("    Average protein density = %s um^-2 (%s%%)", Utils.rounded(fitProteinDensity * 1e6, 4), Utils.rounded(e2, 4));
    log("    Domain radius = %s nm", Utils.rounded(domainRadius, 4));
    log("    Domain density = %s", Utils.rounded(domainDensity, 4));
    log("    Domain coherence = %s", Utils.rounded(coherence, 4));
    log("    nCluster = %s", Utils.rounded(nCluster, 4));
    // Check the fitted parameters are within tolerance of the initial estimates
    valid2 = true;
    if (fittingTolerance > 0 && (Math.abs(e1) > fittingTolerance || Math.abs(e2) > fittingTolerance)) {
        log("  Failed to fit %s within tolerance (%s%%): Average precision = %f nm (%s%%), average protein density = %g um^-2 (%s%%)", emulsionModel.getName(), Utils.rounded(fittingTolerance, 4), fitSigmaS, Utils.rounded(e1, 4), fitProteinDensity * 1e6, Utils.rounded(e2, 4));
        valid2 = false;
    }
    // Check extra parameters. Domain radius should be higher than the precision. Density should be positive
    if (domainRadius < fitSigmaS) {
        log("  Failed to fit %s: Domain radius is smaller than the average precision (%s < %s)", emulsionModel.getName(), Utils.rounded(domainRadius, 4), Utils.rounded(fitSigmaS, 4));
        valid2 = false;
    }
    if (domainDensity < 0) {
        log("  Failed to fit %s: Domain density is negative (%s)", emulsionModel.getName(), Utils.rounded(domainDensity, 4));
        valid2 = false;
    }
    if (ic3 > ic1) {
        log("  Failed to fit %s - Information Criterion has increased %s%%", emulsionModel.getName(), Utils.rounded((100 * (ic3 - ic1) / ic1), 4));
        valid2 = false;
    }
    addResult(emulsionModel.getName(), resultColour, valid2, fitSigmaS, fitProteinDensity, domainRadius, domainDensity, nCluster, coherence, ic3);
    return parameters;
}
Also used : PointValuePair(org.apache.commons.math3.optim.PointValuePair) LeastSquaresBuilder(org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder) Optimum(org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum) LevenbergMarquardtOptimizer(org.apache.commons.math3.fitting.leastsquares.LevenbergMarquardtOptimizer) DiagonalMatrix(org.apache.commons.math3.linear.DiagonalMatrix) ConvergenceException(org.apache.commons.math3.exception.ConvergenceException) TooManyIterationsException(org.apache.commons.math3.exception.TooManyIterationsException) LeastSquaresProblem(org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem) MultivariateMatrixFunction(org.apache.commons.math3.analysis.MultivariateMatrixFunction)

Aggregations

LeastSquaresBuilder (org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder)19 Optimum (org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum)19 LeastSquaresProblem (org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem)19 LevenbergMarquardtOptimizer (org.apache.commons.math3.fitting.leastsquares.LevenbergMarquardtOptimizer)19 DiagonalMatrix (org.apache.commons.math3.linear.DiagonalMatrix)18 ConvergenceException (org.apache.commons.math3.exception.ConvergenceException)16 TooManyIterationsException (org.apache.commons.math3.exception.TooManyIterationsException)16 MultivariateMatrixFunction (org.apache.commons.math3.analysis.MultivariateMatrixFunction)9 PointValuePair (org.apache.commons.math3.optim.PointValuePair)8 TooManyEvaluationsException (org.apache.commons.math3.exception.TooManyEvaluationsException)6 InitialGuess (org.apache.commons.math3.optim.InitialGuess)4 MaxEval (org.apache.commons.math3.optim.MaxEval)4 OptimizationData (org.apache.commons.math3.optim.OptimizationData)4 SimpleBounds (org.apache.commons.math3.optim.SimpleBounds)4 ObjectiveFunction (org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction)4 CMAESOptimizer (org.apache.commons.math3.optim.nonlinear.scalar.noderiv.CMAESOptimizer)4 RealVector (org.apache.commons.math3.linear.RealVector)3 Nullable (uk.ac.sussex.gdsc.core.annotation.Nullable)3 ValueAndJacobianFunction (org.apache.commons.math3.fitting.leastsquares.ValueAndJacobianFunction)2 Array2DRowRealMatrix (org.apache.commons.math3.linear.Array2DRowRealMatrix)2