Search in sources :

Example 6 with GradientCalculator

use of gdsc.smlm.fitting.nonlinear.gradient.GradientCalculator in project GDSC-SMLM by aherbert.

the class SolverSpeedTest method createData.

private boolean createData(float[][] alpha, float[] beta, boolean positiveDifinite) {
    // Generate a 2D Gaussian
    SingleFreeCircularGaussian2DFunction func = new SingleFreeCircularGaussian2DFunction(10, 10);
    double[] a = new double[] { // Background, Amplitude, Angle, Xpos, Ypos, Xwidth, yWidth
    20 + rand.nextDouble() * 5, 10 + rand.nextDouble() * 5, 0, 5 + rand.nextDouble() * 2, 5 + rand.nextDouble() * 2, 5 + rand.nextDouble() * 2, 5 + rand.nextDouble() * 2 };
    int[] x = new int[100];
    double[] y = new double[100];
    func.initialise(a);
    for (int i = 0; i < x.length; i++) {
        // Add random noise
        y[i] = func.eval(i) + ((rand.nextDouble() < 0.5) ? -rand.nextDouble() * 5 : rand.nextDouble() * 5);
    }
    // Randomise parameters
    for (int i = 0; i < a.length; i++) a[i] += (rand.nextDouble() < 0.5) ? -rand.nextDouble() : rand.nextDouble();
    // Compute the Hessian and parameter gradient vector
    GradientCalculator calc = new GradientCalculator(6);
    double[][] alpha2 = new double[6][6];
    double[] beta2 = new double[6];
    calc.findLinearised(y.length, y, a, alpha2, beta2, func);
    // Update the Hessian using a lambda shift
    double lambda = 1.001;
    for (int i = 0; i < alpha2.length; i++) alpha2[i][i] *= lambda;
    // Copy back
    for (int i = 0; i < beta.length; i++) {
        beta[i] = (float) beta2[i];
        for (int j = 0; j < beta.length; j++) {
            alpha[i][j] = (float) alpha2[i][j];
        }
    }
    // Check for a positive definite matrix
    if (positiveDifinite) {
        EJMLLinearSolver solver = new EJMLLinearSolver();
        return solver.solveCholeskyLDLT(copydouble(alpha), copydouble(beta));
    }
    return true;
}
Also used : SingleFreeCircularGaussian2DFunction(gdsc.smlm.function.gaussian.SingleFreeCircularGaussian2DFunction) GradientCalculator(gdsc.smlm.fitting.nonlinear.gradient.GradientCalculator)

Aggregations

GradientCalculator (gdsc.smlm.fitting.nonlinear.gradient.GradientCalculator)6 Gaussian2DFunction (gdsc.smlm.function.gaussian.Gaussian2DFunction)4 ValueProcedure (gdsc.smlm.function.ValueProcedure)3 RandomDataGenerator (org.apache.commons.math3.random.RandomDataGenerator)3 Well19937c (org.apache.commons.math3.random.Well19937c)3 TimingService (gdsc.core.test.TimingService)2 TurboList (gdsc.core.utils.TurboList)2 ExtendedNonLinearFunction (gdsc.smlm.function.ExtendedNonLinearFunction)2 NonLinearFunction (gdsc.smlm.function.NonLinearFunction)2 DenseMatrix64F (org.ejml.data.DenseMatrix64F)2 FisherInformationMatrix (gdsc.smlm.fitting.FisherInformationMatrix)1 MultivariateMatrixFunctionWrapper (gdsc.smlm.function.MultivariateMatrixFunctionWrapper)1 MultivariateVectorFunctionWrapper (gdsc.smlm.function.MultivariateVectorFunctionWrapper)1 SingleFreeCircularGaussian2DFunction (gdsc.smlm.function.gaussian.SingleFreeCircularGaussian2DFunction)1 ConvergenceException (org.apache.commons.math3.exception.ConvergenceException)1 TooManyEvaluationsException (org.apache.commons.math3.exception.TooManyEvaluationsException)1 TooManyIterationsException (org.apache.commons.math3.exception.TooManyIterationsException)1 LeastSquaresBuilder (org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder)1 Optimum (org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum)1 LeastSquaresProblem (org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem)1