use of uk.ac.sussex.gdsc.smlm.function.ExtendedNonLinearFunction in project GDSC-SMLM by aherbert.
the class ApacheLvmFitter method computeFit.
@Override
public FitStatus computeFit(double[] y, final double[] fx, double[] a, double[] parametersVariance) {
final int n = y.length;
try {
// Different convergence thresholds seem to have no effect on the resulting fit, only the
// number of
// iterations for convergence
final double initialStepBoundFactor = 100;
final double costRelativeTolerance = 1e-10;
final double parRelativeTolerance = 1e-10;
final double orthoTolerance = 1e-10;
final double threshold = Precision.SAFE_MIN;
// Extract the parameters to be fitted
final double[] initialSolution = getInitialSolution(a);
// TODO - Pass in more advanced stopping criteria.
// Create the target and weight arrays
final double[] yd = new double[n];
// final double[] w = new double[n];
for (int i = 0; i < n; i++) {
yd[i] = y[i];
// w[i] = 1;
}
final LevenbergMarquardtOptimizer optimizer = new LevenbergMarquardtOptimizer(initialStepBoundFactor, costRelativeTolerance, parRelativeTolerance, orthoTolerance, threshold);
// @formatter:off
final LeastSquaresBuilder builder = new LeastSquaresBuilder().maxEvaluations(Integer.MAX_VALUE).maxIterations(getMaxEvaluations()).start(initialSolution).target(yd);
if (function instanceof ExtendedNonLinearFunction && ((ExtendedNonLinearFunction) function).canComputeValuesAndJacobian()) {
// Compute together, or each individually
builder.model(new ValueAndJacobianFunction() {
final ExtendedNonLinearFunction fun = (ExtendedNonLinearFunction) function;
@Override
public Pair<RealVector, RealMatrix> value(RealVector point) {
final double[] p = point.toArray();
final org.apache.commons.lang3.tuple.Pair<double[], double[][]> result = fun.computeValuesAndJacobian(p);
return new Pair<>(new ArrayRealVector(result.getKey(), false), new Array2DRowRealMatrix(result.getValue(), false));
}
@Override
public RealVector computeValue(double[] params) {
return new ArrayRealVector(fun.computeValues(params), false);
}
@Override
public RealMatrix computeJacobian(double[] params) {
return new Array2DRowRealMatrix(fun.computeJacobian(params), false);
}
});
} else {
// Compute separately
builder.model(new MultivariateVectorFunctionWrapper((NonLinearFunction) function, a, n), new MultivariateMatrixFunctionWrapper((NonLinearFunction) function, a, n));
}
final LeastSquaresProblem problem = builder.build();
final Optimum optimum = optimizer.optimize(problem);
final double[] parameters = optimum.getPoint().toArray();
setSolution(a, parameters);
iterations = optimum.getIterations();
evaluations = optimum.getEvaluations();
if (parametersVariance != null) {
// Set up the Jacobian.
final RealMatrix j = optimum.getJacobian();
// Compute transpose(J)J.
final RealMatrix jTj = j.transpose().multiply(j);
final double[][] data = (jTj instanceof Array2DRowRealMatrix) ? ((Array2DRowRealMatrix) jTj).getDataRef() : jTj.getData();
final FisherInformationMatrix m = new FisherInformationMatrix(data);
setDeviations(parametersVariance, m);
}
// Compute function value
if (fx != null) {
final ValueFunction function = (ValueFunction) this.function;
function.initialise0(a);
function.forEach(new ValueProcedure() {
int index;
@Override
public void execute(double value) {
fx[index++] = value;
}
});
}
// As this is unweighted then we can do this to get the sum of squared residuals
// This is the same as optimum.getCost() * optimum.getCost(); The getCost() function
// just computes the dot product anyway.
value = optimum.getResiduals().dotProduct(optimum.getResiduals());
} catch (final TooManyEvaluationsException ex) {
return FitStatus.TOO_MANY_EVALUATIONS;
} catch (final TooManyIterationsException ex) {
return FitStatus.TOO_MANY_ITERATIONS;
} catch (final ConvergenceException ex) {
// Occurs when QR decomposition fails - mark as a singular non-linear model (no solution)
return FitStatus.SINGULAR_NON_LINEAR_MODEL;
} catch (final Exception ex) {
// TODO - Find out the other exceptions from the fitter and add return values to match.
return FitStatus.UNKNOWN;
}
return FitStatus.OK;
}
Aggregations