Search in sources :

Example 1 with MultivariateMatrixFunctionWrapper

use of uk.ac.sussex.gdsc.smlm.function.MultivariateMatrixFunctionWrapper in project GDSC-SMLM by aherbert.

the class ApacheLvmFitter method computeFit.

@Override
public FitStatus computeFit(double[] y, final double[] fx, double[] a, double[] parametersVariance) {
    final int n = y.length;
    try {
        // Different convergence thresholds seem to have no effect on the resulting fit, only the
        // number of
        // iterations for convergence
        final double initialStepBoundFactor = 100;
        final double costRelativeTolerance = 1e-10;
        final double parRelativeTolerance = 1e-10;
        final double orthoTolerance = 1e-10;
        final double threshold = Precision.SAFE_MIN;
        // Extract the parameters to be fitted
        final double[] initialSolution = getInitialSolution(a);
        // TODO - Pass in more advanced stopping criteria.
        // Create the target and weight arrays
        final double[] yd = new double[n];
        // final double[] w = new double[n];
        for (int i = 0; i < n; i++) {
            yd[i] = y[i];
        // w[i] = 1;
        }
        final LevenbergMarquardtOptimizer optimizer = new LevenbergMarquardtOptimizer(initialStepBoundFactor, costRelativeTolerance, parRelativeTolerance, orthoTolerance, threshold);
        // @formatter:off
        final LeastSquaresBuilder builder = new LeastSquaresBuilder().maxEvaluations(Integer.MAX_VALUE).maxIterations(getMaxEvaluations()).start(initialSolution).target(yd);
        if (function instanceof ExtendedNonLinearFunction && ((ExtendedNonLinearFunction) function).canComputeValuesAndJacobian()) {
            // Compute together, or each individually
            builder.model(new ValueAndJacobianFunction() {

                final ExtendedNonLinearFunction fun = (ExtendedNonLinearFunction) function;

                @Override
                public Pair<RealVector, RealMatrix> value(RealVector point) {
                    final double[] p = point.toArray();
                    final org.apache.commons.lang3.tuple.Pair<double[], double[][]> result = fun.computeValuesAndJacobian(p);
                    return new Pair<>(new ArrayRealVector(result.getKey(), false), new Array2DRowRealMatrix(result.getValue(), false));
                }

                @Override
                public RealVector computeValue(double[] params) {
                    return new ArrayRealVector(fun.computeValues(params), false);
                }

                @Override
                public RealMatrix computeJacobian(double[] params) {
                    return new Array2DRowRealMatrix(fun.computeJacobian(params), false);
                }
            });
        } else {
            // Compute separately
            builder.model(new MultivariateVectorFunctionWrapper((NonLinearFunction) function, a, n), new MultivariateMatrixFunctionWrapper((NonLinearFunction) function, a, n));
        }
        final LeastSquaresProblem problem = builder.build();
        final Optimum optimum = optimizer.optimize(problem);
        final double[] parameters = optimum.getPoint().toArray();
        setSolution(a, parameters);
        iterations = optimum.getIterations();
        evaluations = optimum.getEvaluations();
        if (parametersVariance != null) {
            // Set up the Jacobian.
            final RealMatrix j = optimum.getJacobian();
            // Compute transpose(J)J.
            final RealMatrix jTj = j.transpose().multiply(j);
            final double[][] data = (jTj instanceof Array2DRowRealMatrix) ? ((Array2DRowRealMatrix) jTj).getDataRef() : jTj.getData();
            final FisherInformationMatrix m = new FisherInformationMatrix(data);
            setDeviations(parametersVariance, m);
        }
        // Compute function value
        if (fx != null) {
            final ValueFunction function = (ValueFunction) this.function;
            function.initialise0(a);
            function.forEach(new ValueProcedure() {

                int index;

                @Override
                public void execute(double value) {
                    fx[index++] = value;
                }
            });
        }
        // As this is unweighted then we can do this to get the sum of squared residuals
        // This is the same as optimum.getCost() * optimum.getCost(); The getCost() function
        // just computes the dot product anyway.
        value = optimum.getResiduals().dotProduct(optimum.getResiduals());
    } catch (final TooManyEvaluationsException ex) {
        return FitStatus.TOO_MANY_EVALUATIONS;
    } catch (final TooManyIterationsException ex) {
        return FitStatus.TOO_MANY_ITERATIONS;
    } catch (final ConvergenceException ex) {
        // Occurs when QR decomposition fails - mark as a singular non-linear model (no solution)
        return FitStatus.SINGULAR_NON_LINEAR_MODEL;
    } catch (final Exception ex) {
        // TODO - Find out the other exceptions from the fitter and add return values to match.
        return FitStatus.UNKNOWN;
    }
    return FitStatus.OK;
}
Also used : ValueFunction(uk.ac.sussex.gdsc.smlm.function.ValueFunction) ValueProcedure(uk.ac.sussex.gdsc.smlm.function.ValueProcedure) NonLinearFunction(uk.ac.sussex.gdsc.smlm.function.NonLinearFunction) ExtendedNonLinearFunction(uk.ac.sussex.gdsc.smlm.function.ExtendedNonLinearFunction) LeastSquaresBuilder(org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix) ValueAndJacobianFunction(org.apache.commons.math3.fitting.leastsquares.ValueAndJacobianFunction) RealVector(org.apache.commons.math3.linear.RealVector) ArrayRealVector(org.apache.commons.math3.linear.ArrayRealVector) ConvergenceException(org.apache.commons.math3.exception.ConvergenceException) TooManyIterationsException(org.apache.commons.math3.exception.TooManyIterationsException) LeastSquaresProblem(org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem) Pair(org.apache.commons.math3.util.Pair) ArrayRealVector(org.apache.commons.math3.linear.ArrayRealVector) FisherInformationMatrix(uk.ac.sussex.gdsc.smlm.fitting.FisherInformationMatrix) MultivariateMatrixFunctionWrapper(uk.ac.sussex.gdsc.smlm.function.MultivariateMatrixFunctionWrapper) ConvergenceException(org.apache.commons.math3.exception.ConvergenceException) TooManyIterationsException(org.apache.commons.math3.exception.TooManyIterationsException) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) Optimum(org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum) LevenbergMarquardtOptimizer(org.apache.commons.math3.fitting.leastsquares.LevenbergMarquardtOptimizer) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix) RealMatrix(org.apache.commons.math3.linear.RealMatrix) MultivariateVectorFunctionWrapper(uk.ac.sussex.gdsc.smlm.function.MultivariateVectorFunctionWrapper) ExtendedNonLinearFunction(uk.ac.sussex.gdsc.smlm.function.ExtendedNonLinearFunction)

Aggregations

ConvergenceException (org.apache.commons.math3.exception.ConvergenceException)1 TooManyEvaluationsException (org.apache.commons.math3.exception.TooManyEvaluationsException)1 TooManyIterationsException (org.apache.commons.math3.exception.TooManyIterationsException)1 LeastSquaresBuilder (org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder)1 Optimum (org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum)1 LeastSquaresProblem (org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem)1 LevenbergMarquardtOptimizer (org.apache.commons.math3.fitting.leastsquares.LevenbergMarquardtOptimizer)1 ValueAndJacobianFunction (org.apache.commons.math3.fitting.leastsquares.ValueAndJacobianFunction)1 Array2DRowRealMatrix (org.apache.commons.math3.linear.Array2DRowRealMatrix)1 ArrayRealVector (org.apache.commons.math3.linear.ArrayRealVector)1 RealMatrix (org.apache.commons.math3.linear.RealMatrix)1 RealVector (org.apache.commons.math3.linear.RealVector)1 Pair (org.apache.commons.math3.util.Pair)1 FisherInformationMatrix (uk.ac.sussex.gdsc.smlm.fitting.FisherInformationMatrix)1 ExtendedNonLinearFunction (uk.ac.sussex.gdsc.smlm.function.ExtendedNonLinearFunction)1 MultivariateMatrixFunctionWrapper (uk.ac.sussex.gdsc.smlm.function.MultivariateMatrixFunctionWrapper)1 MultivariateVectorFunctionWrapper (uk.ac.sussex.gdsc.smlm.function.MultivariateVectorFunctionWrapper)1 NonLinearFunction (uk.ac.sussex.gdsc.smlm.function.NonLinearFunction)1 ValueFunction (uk.ac.sussex.gdsc.smlm.function.ValueFunction)1 ValueProcedure (uk.ac.sussex.gdsc.smlm.function.ValueProcedure)1