Search in sources :

Example 1 with NonLinearFunction

use of uk.ac.sussex.gdsc.smlm.function.NonLinearFunction in project GDSC-SMLM by aherbert.

the class ApacheLvmFitter method computeFit.

@Override
public FitStatus computeFit(double[] y, final double[] fx, double[] a, double[] parametersVariance) {
    final int n = y.length;
    try {
        // Different convergence thresholds seem to have no effect on the resulting fit, only the
        // number of
        // iterations for convergence
        final double initialStepBoundFactor = 100;
        final double costRelativeTolerance = 1e-10;
        final double parRelativeTolerance = 1e-10;
        final double orthoTolerance = 1e-10;
        final double threshold = Precision.SAFE_MIN;
        // Extract the parameters to be fitted
        final double[] initialSolution = getInitialSolution(a);
        // TODO - Pass in more advanced stopping criteria.
        // Create the target and weight arrays
        final double[] yd = new double[n];
        // final double[] w = new double[n];
        for (int i = 0; i < n; i++) {
            yd[i] = y[i];
        // w[i] = 1;
        }
        final LevenbergMarquardtOptimizer optimizer = new LevenbergMarquardtOptimizer(initialStepBoundFactor, costRelativeTolerance, parRelativeTolerance, orthoTolerance, threshold);
        // @formatter:off
        final LeastSquaresBuilder builder = new LeastSquaresBuilder().maxEvaluations(Integer.MAX_VALUE).maxIterations(getMaxEvaluations()).start(initialSolution).target(yd);
        if (function instanceof ExtendedNonLinearFunction && ((ExtendedNonLinearFunction) function).canComputeValuesAndJacobian()) {
            // Compute together, or each individually
            builder.model(new ValueAndJacobianFunction() {

                final ExtendedNonLinearFunction fun = (ExtendedNonLinearFunction) function;

                @Override
                public Pair<RealVector, RealMatrix> value(RealVector point) {
                    final double[] p = point.toArray();
                    final org.apache.commons.lang3.tuple.Pair<double[], double[][]> result = fun.computeValuesAndJacobian(p);
                    return new Pair<>(new ArrayRealVector(result.getKey(), false), new Array2DRowRealMatrix(result.getValue(), false));
                }

                @Override
                public RealVector computeValue(double[] params) {
                    return new ArrayRealVector(fun.computeValues(params), false);
                }

                @Override
                public RealMatrix computeJacobian(double[] params) {
                    return new Array2DRowRealMatrix(fun.computeJacobian(params), false);
                }
            });
        } else {
            // Compute separately
            builder.model(new MultivariateVectorFunctionWrapper((NonLinearFunction) function, a, n), new MultivariateMatrixFunctionWrapper((NonLinearFunction) function, a, n));
        }
        final LeastSquaresProblem problem = builder.build();
        final Optimum optimum = optimizer.optimize(problem);
        final double[] parameters = optimum.getPoint().toArray();
        setSolution(a, parameters);
        iterations = optimum.getIterations();
        evaluations = optimum.getEvaluations();
        if (parametersVariance != null) {
            // Set up the Jacobian.
            final RealMatrix j = optimum.getJacobian();
            // Compute transpose(J)J.
            final RealMatrix jTj = j.transpose().multiply(j);
            final double[][] data = (jTj instanceof Array2DRowRealMatrix) ? ((Array2DRowRealMatrix) jTj).getDataRef() : jTj.getData();
            final FisherInformationMatrix m = new FisherInformationMatrix(data);
            setDeviations(parametersVariance, m);
        }
        // Compute function value
        if (fx != null) {
            final ValueFunction function = (ValueFunction) this.function;
            function.initialise0(a);
            function.forEach(new ValueProcedure() {

                int index;

                @Override
                public void execute(double value) {
                    fx[index++] = value;
                }
            });
        }
        // As this is unweighted then we can do this to get the sum of squared residuals
        // This is the same as optimum.getCost() * optimum.getCost(); The getCost() function
        // just computes the dot product anyway.
        value = optimum.getResiduals().dotProduct(optimum.getResiduals());
    } catch (final TooManyEvaluationsException ex) {
        return FitStatus.TOO_MANY_EVALUATIONS;
    } catch (final TooManyIterationsException ex) {
        return FitStatus.TOO_MANY_ITERATIONS;
    } catch (final ConvergenceException ex) {
        // Occurs when QR decomposition fails - mark as a singular non-linear model (no solution)
        return FitStatus.SINGULAR_NON_LINEAR_MODEL;
    } catch (final Exception ex) {
        // TODO - Find out the other exceptions from the fitter and add return values to match.
        return FitStatus.UNKNOWN;
    }
    return FitStatus.OK;
}
Also used : ValueFunction(uk.ac.sussex.gdsc.smlm.function.ValueFunction) ValueProcedure(uk.ac.sussex.gdsc.smlm.function.ValueProcedure) NonLinearFunction(uk.ac.sussex.gdsc.smlm.function.NonLinearFunction) ExtendedNonLinearFunction(uk.ac.sussex.gdsc.smlm.function.ExtendedNonLinearFunction) LeastSquaresBuilder(org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix) ValueAndJacobianFunction(org.apache.commons.math3.fitting.leastsquares.ValueAndJacobianFunction) RealVector(org.apache.commons.math3.linear.RealVector) ArrayRealVector(org.apache.commons.math3.linear.ArrayRealVector) ConvergenceException(org.apache.commons.math3.exception.ConvergenceException) TooManyIterationsException(org.apache.commons.math3.exception.TooManyIterationsException) LeastSquaresProblem(org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem) Pair(org.apache.commons.math3.util.Pair) ArrayRealVector(org.apache.commons.math3.linear.ArrayRealVector) FisherInformationMatrix(uk.ac.sussex.gdsc.smlm.fitting.FisherInformationMatrix) MultivariateMatrixFunctionWrapper(uk.ac.sussex.gdsc.smlm.function.MultivariateMatrixFunctionWrapper) ConvergenceException(org.apache.commons.math3.exception.ConvergenceException) TooManyIterationsException(org.apache.commons.math3.exception.TooManyIterationsException) TooManyEvaluationsException(org.apache.commons.math3.exception.TooManyEvaluationsException) Optimum(org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum) LevenbergMarquardtOptimizer(org.apache.commons.math3.fitting.leastsquares.LevenbergMarquardtOptimizer) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix) RealMatrix(org.apache.commons.math3.linear.RealMatrix) MultivariateVectorFunctionWrapper(uk.ac.sussex.gdsc.smlm.function.MultivariateVectorFunctionWrapper) ExtendedNonLinearFunction(uk.ac.sussex.gdsc.smlm.function.ExtendedNonLinearFunction)

Example 2 with NonLinearFunction

use of uk.ac.sussex.gdsc.smlm.function.NonLinearFunction in project GDSC-SMLM by aherbert.

the class ApacheLvmFitter method computeValue.

@Override
public boolean computeValue(double[] y, double[] fx, double[] a) {
    final GradientCalculator calculator = GradientCalculatorUtils.newCalculator(function.getNumberOfGradients(), false);
    // Since we know the function is a Gaussian2DFunction from the constructor
    value = calculator.findLinearised(y.length, y, fx, a, (NonLinearFunction) function);
    return true;
}
Also used : GradientCalculator(uk.ac.sussex.gdsc.smlm.fitting.nonlinear.gradient.GradientCalculator) NonLinearFunction(uk.ac.sussex.gdsc.smlm.function.NonLinearFunction) ExtendedNonLinearFunction(uk.ac.sussex.gdsc.smlm.function.ExtendedNonLinearFunction)

Example 3 with NonLinearFunction

use of uk.ac.sussex.gdsc.smlm.function.NonLinearFunction in project GDSC-SMLM by aherbert.

the class GradientCalculatorSpeedTest method mleGradientCalculatorComputesLikelihood.

@SeededTest
void mleGradientCalculatorComputesLikelihood() {
    // @formatter:off
    final NonLinearFunction func = new NonLinearFunction() {

        double u;

        @Override
        public void initialise(double[] a) {
            u = a[0];
        }

        @Override
        public int[] gradientIndices() {
            return null;
        }

        @Override
        public double eval(int x, double[] dyda) {
            return 0;
        }

        @Override
        public double eval(int x) {
            return u;
        }

        @Override
        public double evalw(int x, double[] dyda, double[] w) {
            return 0;
        }

        @Override
        public double evalw(int x, double[] w) {
            return 0;
        }

        @Override
        public boolean canComputeWeights() {
            return false;
        }

        @Override
        public int getNumberOfGradients() {
            return 0;
        }
    };
    // @formatter:on
    final DoubleDoubleBiPredicate predicate = TestHelper.doublesAreClose(1e-10, 0);
    final int[] xx = SimpleArrayUtils.natural(100);
    final double[] xxx = SimpleArrayUtils.newArray(100, 0, 1.0);
    for (final double u : new double[] { 0.79, 2.5, 5.32 }) {
        double ll = 0;
        double oll = 0;
        final PoissonDistribution pd = new PoissonDistribution(u);
        // The logLikelihood function for the entire array of observations is then asserted.
        for (final int x : xx) {
            double obs = PoissonCalculator.likelihood(u, x);
            double exp = pd.probability(x);
            TestAssertions.assertTest(exp, obs, predicate, "likelihood");
            obs = PoissonCalculator.logLikelihood(u, x);
            exp = pd.logProbability(x);
            TestAssertions.assertTest(exp, obs, predicate, "log likelihood");
            oll += obs;
            ll += exp;
        }
        final MleGradientCalculator gc = new MleGradientCalculator(1);
        final double o = gc.logLikelihood(xxx, new double[] { u }, func);
        Assertions.assertEquals(oll, o, "sum log likelihood should exactly match the PoissonCalculator");
        TestAssertions.assertTest(ll, o, predicate, "sum log likelihood");
    }
}
Also used : PoissonDistribution(org.apache.commons.math3.distribution.PoissonDistribution) DoubleDoubleBiPredicate(uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate) NonLinearFunction(uk.ac.sussex.gdsc.smlm.function.NonLinearFunction) SeededTest(uk.ac.sussex.gdsc.test.junit5.SeededTest)

Aggregations

NonLinearFunction (uk.ac.sussex.gdsc.smlm.function.NonLinearFunction)3 ExtendedNonLinearFunction (uk.ac.sussex.gdsc.smlm.function.ExtendedNonLinearFunction)2 PoissonDistribution (org.apache.commons.math3.distribution.PoissonDistribution)1 ConvergenceException (org.apache.commons.math3.exception.ConvergenceException)1 TooManyEvaluationsException (org.apache.commons.math3.exception.TooManyEvaluationsException)1 TooManyIterationsException (org.apache.commons.math3.exception.TooManyIterationsException)1 LeastSquaresBuilder (org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder)1 Optimum (org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum)1 LeastSquaresProblem (org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem)1 LevenbergMarquardtOptimizer (org.apache.commons.math3.fitting.leastsquares.LevenbergMarquardtOptimizer)1 ValueAndJacobianFunction (org.apache.commons.math3.fitting.leastsquares.ValueAndJacobianFunction)1 Array2DRowRealMatrix (org.apache.commons.math3.linear.Array2DRowRealMatrix)1 ArrayRealVector (org.apache.commons.math3.linear.ArrayRealVector)1 RealMatrix (org.apache.commons.math3.linear.RealMatrix)1 RealVector (org.apache.commons.math3.linear.RealVector)1 Pair (org.apache.commons.math3.util.Pair)1 FisherInformationMatrix (uk.ac.sussex.gdsc.smlm.fitting.FisherInformationMatrix)1 GradientCalculator (uk.ac.sussex.gdsc.smlm.fitting.nonlinear.gradient.GradientCalculator)1 MultivariateMatrixFunctionWrapper (uk.ac.sussex.gdsc.smlm.function.MultivariateMatrixFunctionWrapper)1 MultivariateVectorFunctionWrapper (uk.ac.sussex.gdsc.smlm.function.MultivariateVectorFunctionWrapper)1