use of uk.ac.sussex.gdsc.smlm.function.NonLinearFunction in project GDSC-SMLM by aherbert.
the class ApacheLvmFitter method computeFit.
@Override
public FitStatus computeFit(double[] y, final double[] fx, double[] a, double[] parametersVariance) {
final int n = y.length;
try {
// Different convergence thresholds seem to have no effect on the resulting fit, only the
// number of
// iterations for convergence
final double initialStepBoundFactor = 100;
final double costRelativeTolerance = 1e-10;
final double parRelativeTolerance = 1e-10;
final double orthoTolerance = 1e-10;
final double threshold = Precision.SAFE_MIN;
// Extract the parameters to be fitted
final double[] initialSolution = getInitialSolution(a);
// TODO - Pass in more advanced stopping criteria.
// Create the target and weight arrays
final double[] yd = new double[n];
// final double[] w = new double[n];
for (int i = 0; i < n; i++) {
yd[i] = y[i];
// w[i] = 1;
}
final LevenbergMarquardtOptimizer optimizer = new LevenbergMarquardtOptimizer(initialStepBoundFactor, costRelativeTolerance, parRelativeTolerance, orthoTolerance, threshold);
// @formatter:off
final LeastSquaresBuilder builder = new LeastSquaresBuilder().maxEvaluations(Integer.MAX_VALUE).maxIterations(getMaxEvaluations()).start(initialSolution).target(yd);
if (function instanceof ExtendedNonLinearFunction && ((ExtendedNonLinearFunction) function).canComputeValuesAndJacobian()) {
// Compute together, or each individually
builder.model(new ValueAndJacobianFunction() {
final ExtendedNonLinearFunction fun = (ExtendedNonLinearFunction) function;
@Override
public Pair<RealVector, RealMatrix> value(RealVector point) {
final double[] p = point.toArray();
final org.apache.commons.lang3.tuple.Pair<double[], double[][]> result = fun.computeValuesAndJacobian(p);
return new Pair<>(new ArrayRealVector(result.getKey(), false), new Array2DRowRealMatrix(result.getValue(), false));
}
@Override
public RealVector computeValue(double[] params) {
return new ArrayRealVector(fun.computeValues(params), false);
}
@Override
public RealMatrix computeJacobian(double[] params) {
return new Array2DRowRealMatrix(fun.computeJacobian(params), false);
}
});
} else {
// Compute separately
builder.model(new MultivariateVectorFunctionWrapper((NonLinearFunction) function, a, n), new MultivariateMatrixFunctionWrapper((NonLinearFunction) function, a, n));
}
final LeastSquaresProblem problem = builder.build();
final Optimum optimum = optimizer.optimize(problem);
final double[] parameters = optimum.getPoint().toArray();
setSolution(a, parameters);
iterations = optimum.getIterations();
evaluations = optimum.getEvaluations();
if (parametersVariance != null) {
// Set up the Jacobian.
final RealMatrix j = optimum.getJacobian();
// Compute transpose(J)J.
final RealMatrix jTj = j.transpose().multiply(j);
final double[][] data = (jTj instanceof Array2DRowRealMatrix) ? ((Array2DRowRealMatrix) jTj).getDataRef() : jTj.getData();
final FisherInformationMatrix m = new FisherInformationMatrix(data);
setDeviations(parametersVariance, m);
}
// Compute function value
if (fx != null) {
final ValueFunction function = (ValueFunction) this.function;
function.initialise0(a);
function.forEach(new ValueProcedure() {
int index;
@Override
public void execute(double value) {
fx[index++] = value;
}
});
}
// As this is unweighted then we can do this to get the sum of squared residuals
// This is the same as optimum.getCost() * optimum.getCost(); The getCost() function
// just computes the dot product anyway.
value = optimum.getResiduals().dotProduct(optimum.getResiduals());
} catch (final TooManyEvaluationsException ex) {
return FitStatus.TOO_MANY_EVALUATIONS;
} catch (final TooManyIterationsException ex) {
return FitStatus.TOO_MANY_ITERATIONS;
} catch (final ConvergenceException ex) {
// Occurs when QR decomposition fails - mark as a singular non-linear model (no solution)
return FitStatus.SINGULAR_NON_LINEAR_MODEL;
} catch (final Exception ex) {
// TODO - Find out the other exceptions from the fitter and add return values to match.
return FitStatus.UNKNOWN;
}
return FitStatus.OK;
}
use of uk.ac.sussex.gdsc.smlm.function.NonLinearFunction in project GDSC-SMLM by aherbert.
the class ApacheLvmFitter method computeValue.
@Override
public boolean computeValue(double[] y, double[] fx, double[] a) {
final GradientCalculator calculator = GradientCalculatorUtils.newCalculator(function.getNumberOfGradients(), false);
// Since we know the function is a Gaussian2DFunction from the constructor
value = calculator.findLinearised(y.length, y, fx, a, (NonLinearFunction) function);
return true;
}
use of uk.ac.sussex.gdsc.smlm.function.NonLinearFunction in project GDSC-SMLM by aherbert.
the class GradientCalculatorSpeedTest method mleGradientCalculatorComputesLikelihood.
@SeededTest
void mleGradientCalculatorComputesLikelihood() {
// @formatter:off
final NonLinearFunction func = new NonLinearFunction() {
double u;
@Override
public void initialise(double[] a) {
u = a[0];
}
@Override
public int[] gradientIndices() {
return null;
}
@Override
public double eval(int x, double[] dyda) {
return 0;
}
@Override
public double eval(int x) {
return u;
}
@Override
public double evalw(int x, double[] dyda, double[] w) {
return 0;
}
@Override
public double evalw(int x, double[] w) {
return 0;
}
@Override
public boolean canComputeWeights() {
return false;
}
@Override
public int getNumberOfGradients() {
return 0;
}
};
// @formatter:on
final DoubleDoubleBiPredicate predicate = TestHelper.doublesAreClose(1e-10, 0);
final int[] xx = SimpleArrayUtils.natural(100);
final double[] xxx = SimpleArrayUtils.newArray(100, 0, 1.0);
for (final double u : new double[] { 0.79, 2.5, 5.32 }) {
double ll = 0;
double oll = 0;
final PoissonDistribution pd = new PoissonDistribution(u);
// The logLikelihood function for the entire array of observations is then asserted.
for (final int x : xx) {
double obs = PoissonCalculator.likelihood(u, x);
double exp = pd.probability(x);
TestAssertions.assertTest(exp, obs, predicate, "likelihood");
obs = PoissonCalculator.logLikelihood(u, x);
exp = pd.logProbability(x);
TestAssertions.assertTest(exp, obs, predicate, "log likelihood");
oll += obs;
ll += exp;
}
final MleGradientCalculator gc = new MleGradientCalculator(1);
final double o = gc.logLikelihood(xxx, new double[] { u }, func);
Assertions.assertEquals(oll, o, "sum log likelihood should exactly match the PoissonCalculator");
TestAssertions.assertTest(ll, o, predicate, "sum log likelihood");
}
}
Aggregations