use of gdsc.smlm.function.NonLinearFunction in project GDSC-SMLM by aherbert.
the class ApacheLVMFitter method computeFit.
public FitStatus computeFit(double[] y, final double[] y_fit, double[] a, double[] a_dev) {
int n = y.length;
try {
// Different convergence thresholds seem to have no effect on the resulting fit, only the number of
// iterations for convergence
final double initialStepBoundFactor = 100;
final double costRelativeTolerance = 1e-10;
final double parRelativeTolerance = 1e-10;
final double orthoTolerance = 1e-10;
final double threshold = Precision.SAFE_MIN;
// Extract the parameters to be fitted
final double[] initialSolution = getInitialSolution(a);
// TODO - Pass in more advanced stopping criteria.
// Create the target and weight arrays
final double[] yd = new double[n];
final double[] w = new double[n];
for (int i = 0; i < n; i++) {
yd[i] = y[i];
w[i] = 1;
}
LevenbergMarquardtOptimizer optimizer = new LevenbergMarquardtOptimizer(initialStepBoundFactor, costRelativeTolerance, parRelativeTolerance, orthoTolerance, threshold);
//@formatter:off
LeastSquaresBuilder builder = new LeastSquaresBuilder().maxEvaluations(Integer.MAX_VALUE).maxIterations(getMaxEvaluations()).start(initialSolution).target(yd).weight(new DiagonalMatrix(w));
if (f instanceof ExtendedNonLinearFunction && ((ExtendedNonLinearFunction) f).canComputeValuesAndJacobian()) {
// Compute together, or each individually
builder.model(new ValueAndJacobianFunction() {
final ExtendedNonLinearFunction fun = (ExtendedNonLinearFunction) f;
public Pair<RealVector, RealMatrix> value(RealVector point) {
final double[] p = point.toArray();
final Pair<double[], double[][]> result = fun.computeValuesAndJacobian(p);
return new Pair<RealVector, RealMatrix>(new ArrayRealVector(result.getFirst(), false), new Array2DRowRealMatrix(result.getSecond(), false));
}
public RealVector computeValue(double[] params) {
return new ArrayRealVector(fun.computeValues(params), false);
}
public RealMatrix computeJacobian(double[] params) {
return new Array2DRowRealMatrix(fun.computeJacobian(params), false);
}
});
} else {
// Compute separately
builder.model(new MultivariateVectorFunctionWrapper((NonLinearFunction) f, a, n), new MultivariateMatrixFunctionWrapper((NonLinearFunction) f, a, n));
}
LeastSquaresProblem problem = builder.build();
Optimum optimum = optimizer.optimize(problem);
final double[] parameters = optimum.getPoint().toArray();
setSolution(a, parameters);
iterations = optimum.getIterations();
evaluations = optimum.getEvaluations();
if (a_dev != null) {
try {
double[][] covar = optimum.getCovariances(threshold).getData();
setDeviationsFromMatrix(a_dev, covar);
} catch (SingularMatrixException e) {
// Matrix inversion failed. In order to return a solution
// return the reciprocal of the diagonal of the Fisher information
// for a loose bound on the limit
final int[] gradientIndices = f.gradientIndices();
final int nparams = gradientIndices.length;
GradientCalculator calculator = GradientCalculatorFactory.newCalculator(nparams);
double[][] alpha = new double[nparams][nparams];
double[] beta = new double[nparams];
calculator.findLinearised(nparams, y, a, alpha, beta, (NonLinearFunction) f);
FisherInformationMatrix m = new FisherInformationMatrix(alpha);
setDeviations(a_dev, m.crlb(true));
}
}
// Compute function value
if (y_fit != null) {
Gaussian2DFunction f = (Gaussian2DFunction) this.f;
f.initialise0(a);
f.forEach(new ValueProcedure() {
int i = 0;
public void execute(double value) {
y_fit[i] = value;
}
});
}
// As this is unweighted then we can do this to get the sum of squared residuals
// This is the same as optimum.getCost() * optimum.getCost(); The getCost() function
// just computes the dot product anyway.
value = optimum.getResiduals().dotProduct(optimum.getResiduals());
} catch (TooManyEvaluationsException e) {
return FitStatus.TOO_MANY_EVALUATIONS;
} catch (TooManyIterationsException e) {
return FitStatus.TOO_MANY_ITERATIONS;
} catch (ConvergenceException e) {
// Occurs when QR decomposition fails - mark as a singular non-linear model (no solution)
return FitStatus.SINGULAR_NON_LINEAR_MODEL;
} catch (Exception e) {
// TODO - Find out the other exceptions from the fitter and add return values to match.
return FitStatus.UNKNOWN;
}
return FitStatus.OK;
}
use of gdsc.smlm.function.NonLinearFunction in project GDSC-SMLM by aherbert.
the class ApacheLVMFitter method computeValue.
@Override
public boolean computeValue(double[] y, double[] y_fit, double[] a) {
final int nparams = f.gradientIndices().length;
GradientCalculator calculator = GradientCalculatorFactory.newCalculator(nparams, false);
// Since we know the function is a Gaussian2DFunction
value = calculator.findLinearised(y.length, y, y_fit, a, (NonLinearFunction) f);
return true;
}
use of gdsc.smlm.function.NonLinearFunction in project GDSC-SMLM by aherbert.
the class GradientCalculatorSpeedTest method mleGradientCalculatorComputesLikelihood.
@Test
public void mleGradientCalculatorComputesLikelihood() {
//@formatter:off
NonLinearFunction func = new NonLinearFunction() {
double u;
public void initialise(double[] a) {
u = a[0];
}
public int[] gradientIndices() {
return null;
}
public double eval(int x, double[] dyda) {
return 0;
}
public double eval(int x) {
return u;
}
public double eval(int x, double[] dyda, double[] w) {
return 0;
}
public double evalw(int x, double[] w) {
return 0;
}
public boolean canComputeWeights() {
return false;
}
public int getNumberOfGradients() {
return 0;
}
};
//@formatter:on
int[] xx = Utils.newArray(100, 0, 1);
double[] xxx = Utils.newArray(100, 0, 1.0);
for (double u : new double[] { 0.79, 2.5, 5.32 }) {
double ll = 0;
PoissonDistribution pd = new PoissonDistribution(u);
for (int x : xx) {
double o = MLEGradientCalculator.likelihood(u, x);
double e = pd.probability(x);
Assert.assertEquals("likelihood", e, o, e * 1e-10);
o = MLEGradientCalculator.logLikelihood(u, x);
e = pd.logProbability(x);
Assert.assertEquals("log likelihood", e, o, Math.abs(e) * 1e-10);
ll += e;
}
MLEGradientCalculator gc = new MLEGradientCalculator(1);
double o = gc.logLikelihood(xxx, new double[] { u }, func);
Assert.assertEquals("sum log likelihood", ll, o, Math.abs(ll) * 1e-10);
}
}
Aggregations