use of uk.ac.sussex.gdsc.smlm.function.Gradient2Function in project GDSC-SMLM by aherbert.
the class FitConfiguration method createFunctionSolver.
private BaseFunctionSolver createFunctionSolver() {
if (gaussianFunction == null) {
// Other code may want to call getFunctionSolver() to see if exceptions are thrown
// so create a dummy function so we can return a function solver.
gaussianFunction = createGaussianFunction(1, 1, 1);
}
if (getFitSolverValue() == FitSolver.MLE_VALUE) {
// Only support CCD/EM-CCD at the moment
if (!calibration.isCcdCamera()) {
throw new IllegalStateException("CCD/EM-CCD camera is required for fit solver: " + getFitSolver());
}
// This requires the gain
if (gain <= 0) {
throw new IllegalStateException("The gain is required for fit solver: " + getFitSolver());
}
final MaximumLikelihoodFitter.SearchMethod searchMethod = convertSearchMethod();
// Only the Poisson likelihood function supports gradients
if (searchMethod.usesGradients() && isModelCamera()) {
throw new IllegalStateException(String.format("The derivative based search method '%s' can only be used with the " + "'%s' likelihood function, i.e. no model camera noise", searchMethod, MaximumLikelihoodFitter.LikelihoodFunction.POISSON));
}
final MaximumLikelihoodFitter fitter = new MaximumLikelihoodFitter(gaussianFunction);
fitter.setRelativeThreshold(getRelativeThreshold());
fitter.setAbsoluteThreshold(getAbsoluteThreshold());
fitter.setMaxEvaluations(getMaxFunctionEvaluations());
fitter.setMaxIterations(getMaxIterations());
fitter.setSearchMethod(searchMethod);
fitter.setGradientLineMinimisation(isGradientLineMinimisation());
// Specify the likelihood function to use
if (isModelCamera()) {
// Set the camera read noise.
// Do not check if this is set as 0 is a valid option.
fitter.setSigma(calibration.getReadNoise());
if (emCcd) {
// EMCCD = Poisson+Gamma+Gaussian
fitter.setLikelihoodFunction(MaximumLikelihoodFitter.LikelihoodFunction.POISSON_GAMMA_GAUSSIAN);
} else {
// CCD = Poisson+Gaussian
fitter.setLikelihoodFunction(MaximumLikelihoodFitter.LikelihoodFunction.POISSON_GAUSSIAN);
}
} else {
fitter.setLikelihoodFunction(MaximumLikelihoodFitter.LikelihoodFunction.POISSON);
}
// All models use the amplification gain (i.e. how many ADUs/electron)
if (!calibration.hasCountPerElectron()) {
throw new IllegalStateException("The amplification is required for the fit solver: " + getFitSolver());
}
fitter.setAlpha(1.0 / calibration.getCountPerElectron());
return fitter;
}
// All the remaining solvers are based on the stepping function solver
final ToleranceChecker tc = getToleranceChecker();
final ParameterBounds bounds = new ParameterBounds(gaussianFunction);
if (isUseClamping()) {
setClampValues(bounds);
}
SteppingFunctionSolver solver;
switch(getFitSolverValue()) {
case FitSolver.LVM_LSE_VALUE:
solver = new LseLvmSteppingFunctionSolver(gaussianFunction, tc, bounds);
break;
case FitSolver.LVM_MLE_VALUE:
checkCameraCalibration();
solver = new MleLvmSteppingFunctionSolver(gaussianFunction, tc, bounds);
break;
case FitSolver.LVM_WLSE_VALUE:
checkCameraCalibration();
solver = new WLseLvmSteppingFunctionSolver(gaussianFunction, tc, bounds);
break;
case FitSolver.FAST_MLE_VALUE:
checkCameraCalibration();
// This may throw a class cast exception if the function does not support
// the Gradient2Function interface
solver = new FastMleSteppingFunctionSolver((Gradient2Function) gaussianFunction, tc, bounds);
break;
default:
throw new IllegalStateException("Unknown fit solver: " + getFitSolver());
}
if (solver instanceof LvmSteppingFunctionSolver) {
((LvmSteppingFunctionSolver) solver).setInitialLambda(getLambda());
} else if (solver instanceof FastMleSteppingFunctionSolver) {
((FastMleSteppingFunctionSolver) solver).setLineSearchMethod(convertLineSearchMethod());
}
return solver;
}
use of uk.ac.sussex.gdsc.smlm.function.Gradient2Function in project GDSC-SMLM by aherbert.
the class BaseFunctionSolverTest method fitAndComputeDeviationsMatch.
/**
* Check the fit and compute deviations match. The first solver will be used to do the fit. This
* is initialised from the solution so the convergence criteria can be set to accept the first
* step. The second solver is used to compute deviations (thus is not initialised for fitting).
*
* @param seed the seed
* @param solver1 the solver 1
* @param solver2 the solver 2
* @param noiseModel the noise model
* @param useWeights the use weights
*/
void fitAndComputeDeviationsMatch(RandomSeed seed, BaseFunctionSolver solver1, BaseFunctionSolver solver2, NoiseModel noiseModel, boolean useWeights) {
final double[] noise = getNoise(seed, noiseModel);
if (solver1.isWeighted() && useWeights) {
solver1.setWeights(getWeights(seed, noiseModel));
solver2.setWeights(getWeights(seed, noiseModel));
}
// Draw target data
final UniformRandomProvider rg = RngUtils.create(seed.getSeed());
final double[] data = drawGaussian(p12, noise, noiseModel, rg);
// fit with 2 peaks using the known params.
// compare to 2 peak deviation computation.
final Gaussian2DFunction f2 = GaussianFunctionFactory.create2D(2, size, size, flags, null);
solver1.setGradientFunction(f2);
solver2.setGradientFunction(f2);
double[] params = p12.clone();
double[] expected = new double[params.length];
double[] observed = new double[params.length];
solver1.fit(data, null, params, expected);
// System.out.TestLog.fine(logger,"a="+Arrays.toString(a));
solver2.computeDeviations(data, params, observed);
// System.out.TestLog.fine(logger,"e2="+Arrays.toString(e));
// System.out.TestLog.fine(logger,"o2="+Arrays.toString(o));
Assertions.assertArrayEquals(observed, expected, "Fit 2 peaks and deviations 2 peaks do not match");
// Try again with y-fit values
params = p12.clone();
final double[] o1 = new double[f2.size()];
final double[] o2 = new double[o1.length];
solver1.fit(data, o1, params, expected);
// System.out.TestLog.fine(logger,"a="+Arrays.toString(a));
solver2.computeValue(data, o2, params);
Assertions.assertArrayEquals(observed, expected, "Fit 2 peaks with yFit and deviations 2 peaks do not match");
final StandardValueProcedure p = new StandardValueProcedure();
double[] ev = p.getValues(f2, params);
Assertions.assertArrayEquals(ev, o1, 1e-8, "Fit 2 peaks yFit");
Assertions.assertArrayEquals(ev, o2, 1e-8, "computeValue 2 peaks yFit");
if (solver1 instanceof SteppingFunctionSolver) {
// fit with 1 peak + 1 precomputed using the known params.
// compare to 2 peak deviation computation.
final ErfGaussian2DFunction f1 = (ErfGaussian2DFunction) GaussianFunctionFactory.create2D(1, size, size, flags, null);
final Gradient2Function pf1 = OffsetGradient2Function.wrapGradient2Function(f1, p2v);
solver1.setGradientFunction(pf1);
params = p1.clone();
expected = new double[params.length];
solver1.fit(data, null, params, expected);
// To copy the second peak
final double[] a2 = p12.clone();
// Add the same fitted first peak
System.arraycopy(params, 0, a2, 0, params.length);
solver2.computeDeviations(data, a2, observed);
// System.out.TestLog.fine(logger,"e1p1=" + Arrays.toString(e));
// System.out.TestLog.fine(logger,"o2=" + Arrays.toString(o));
// Deviation should be lower with only 1 peak.
// Due to matrix inversion this may not be the case for all parameters so count.
int ok = 0;
int fail = 0;
final StringBuilder sb = new StringBuilder();
for (int i = 0; i < expected.length; i++) {
if (expected[i] <= observed[i]) {
ok++;
continue;
}
fail++;
TextUtils.formatTo(sb, "Fit 1 peak + 1 precomputed is higher than deviations 2 peaks %s: %s > %s", Gaussian2DFunction.getName(i), expected[i], observed[i]);
}
if (fail > ok) {
Assertions.fail(sb.toString());
}
// Try again with y-fit values
params = p1.clone();
Arrays.fill(o1, 0);
Arrays.fill(o2, 0);
observed = new double[params.length];
solver1.fit(data, o1, params, observed);
solver2.computeValue(data, o2, a2);
Assertions.assertArrayEquals(observed, expected, 1e-8, "Fit 1 peak + 1 precomputed with yFit and deviations 1 peak + " + "1 precomputed do not match");
ev = p.getValues(pf1, params);
Assertions.assertArrayEquals(ev, o1, 1e-8, "Fit 1 peak + 1 precomputed yFit");
Assertions.assertArrayEquals(ev, o2, 1e-8, "computeValue 1 peak + 1 precomputed yFit");
}
}
use of uk.ac.sussex.gdsc.smlm.function.Gradient2Function in project GDSC-SMLM by aherbert.
the class BaseFunctionSolverTest method fitAndComputeValueMatch.
/**
* Check the fit and compute values match. The first solver will be used to do the fit. This is
* initialised from the solution so the convergence criteria can be set to accept the first step.
* The second solver is used to compute values (thus is not initialised for fitting).
*
* @param seed the seed
* @param solver1 the solver
* @param solver2 the solver 2
* @param noiseModel the noise model
* @param useWeights the use weights
*/
void fitAndComputeValueMatch(RandomSeed seed, BaseFunctionSolver solver1, BaseFunctionSolver solver2, NoiseModel noiseModel, boolean useWeights) {
final double[] noise = getNoise(seed, noiseModel);
if (solver1.isWeighted() && useWeights) {
solver1.setWeights(getWeights(seed, noiseModel));
solver2.setWeights(getWeights(seed, noiseModel));
}
// Draw target data
final UniformRandomProvider rg = RngUtils.create(seed.getSeed());
final double[] data = drawGaussian(p12, noise, noiseModel, rg);
// fit with 2 peaks using the known params.
final Gaussian2DFunction f2 = GaussianFunctionFactory.create2D(2, size, size, flags, null);
solver1.setGradientFunction(f2);
solver2.setGradientFunction(f2);
double[] params = p12.clone();
solver1.fit(data, null, params, null);
solver2.computeValue(data, null, params);
final DoubleDoubleBiPredicate predicate = TestHelper.doublesAreClose(1e-10, 0);
double v1 = solver1.getValue();
double v2 = solver2.getValue();
TestAssertions.assertTest(v1, v2, predicate, "Fit 2 peaks and computeValue");
final double[] o1 = new double[f2.size()];
final double[] o2 = new double[o1.length];
solver1.fit(data, o1, params, null);
solver2.computeValue(data, o2, params);
v1 = solver1.getValue();
v2 = solver2.getValue();
TestAssertions.assertTest(v1, v2, predicate, "Fit 2 peaks and computeValue with yFit");
final StandardValueProcedure p = new StandardValueProcedure();
double[] expected = p.getValues(f2, params);
Assertions.assertArrayEquals(expected, o1, 1e-8, "Fit 2 peaks yFit");
Assertions.assertArrayEquals(expected, o2, 1e-8, "computeValue 2 peaks yFit");
if (solver1 instanceof SteppingFunctionSolver) {
// fit with 1 peak + 1 precomputed using the known params.
// compare to 2 peak computation.
final ErfGaussian2DFunction f1 = (ErfGaussian2DFunction) GaussianFunctionFactory.create2D(1, size, size, flags, null);
final Gradient2Function pf1 = OffsetGradient2Function.wrapGradient2Function(f1, p2v);
solver1.setGradientFunction(pf1);
solver2.setGradientFunction(pf1);
params = p1.clone();
solver1.fit(data, null, params, null);
solver2.computeValue(data, null, params);
v1 = solver1.getValue();
v2 = solver2.getValue();
TestAssertions.assertTest(v1, v2, predicate, "Fit 1 peak + 1 precomputed and computeValue");
Arrays.fill(o1, 0);
Arrays.fill(o2, 0);
solver1.fit(data, o1, params, null);
solver2.computeValue(data, o2, params);
v1 = solver1.getValue();
v2 = solver2.getValue();
TestAssertions.assertTest(v1, v2, predicate, "Fit 1 peak + 1 precomputed and computeValue with yFit");
expected = p.getValues(pf1, params);
Assertions.assertArrayEquals(expected, o1, 1e-8, "Fit 1 peak + 1 precomputed yFit");
Assertions.assertArrayEquals(expected, o2, 1e-8, "computeValue 1 peak + 1 precomputed yFit");
}
}
use of uk.ac.sussex.gdsc.smlm.function.Gradient2Function in project GDSC-SMLM by aherbert.
the class FastMleSteppingFunctionSolver method computeLastFisherInformationMatrix.
@Override
protected FisherInformationMatrix computeLastFisherInformationMatrix(double[] fx) {
Gradient2Function f2 = (Gradient2Function) function;
// Capture the y-values if necessary
if (fx != null && fx.length == f2.size()) {
f2 = new Gradient2FunctionValueStore(f2, fx);
}
// Add the weights if necessary
if (obsVariances != null) {
f2 = OffsetGradient2Function.wrapGradient2Function(f2, obsVariances);
}
// The fisher information is that for a Poisson process
final PoissonGradientProcedure p = PoissonGradientProcedureUtils.create(f2);
initialiseAndRun(p);
if (p.isNaNGradients()) {
throw new FunctionSolverException(FitStatus.INVALID_GRADIENTS);
}
return new FisherInformationMatrix(p.getLinear(), p.numberOfGradients);
}
use of uk.ac.sussex.gdsc.smlm.function.Gradient2Function in project GDSC-SMLM by aherbert.
the class FastMleGradient2ProcedureTest method gradientProcedureLinearIsFasterThanGradientProcedure.
private void gradientProcedureLinearIsFasterThanGradientProcedure(RandomSeed seed, final int nparams) {
Assumptions.assumeTrue(TestSettings.allow(TestComplexity.MEDIUM));
final int iter = 100;
final ArrayList<double[]> paramsList = new ArrayList<>(iter);
final ArrayList<double[]> yList = new ArrayList<>(iter);
createData(RngUtils.create(seed.getSeed()), 1, iter, paramsList, yList);
// Remove the timing of the function call by creating a dummy function
final Gradient2Function func = new FakeGradientFunction(blockWidth, nparams);
for (int i = 0; i < paramsList.size(); i++) {
final FastMleGradient2Procedure p1 = new FastMleGradient2Procedure(yList.get(i), func);
p1.computeSecondDerivative(paramsList.get(i));
p1.computeSecondDerivative(paramsList.get(i));
final FastMleGradient2Procedure p2 = FastMleGradient2ProcedureUtils.createUnrolled(yList.get(i), func);
p2.computeSecondDerivative(paramsList.get(i));
p2.computeSecondDerivative(paramsList.get(i));
// Check they are the same
final int ii = i;
Assertions.assertArrayEquals(p1.d1, p2.d1, () -> "D1 " + ii);
Assertions.assertArrayEquals(p1.d2, p2.d2, () -> "D2 " + ii);
}
// Realistic loops for an optimisation
final int loops = 15;
// Run till stable timing
final Timer t1 = new Timer() {
@Override
void run() {
for (int i = 0, k = 0; i < paramsList.size(); i++) {
final FastMleGradient2Procedure p1 = new FastMleGradient2Procedure(yList.get(i), func);
for (int j = loops; j-- > 0; ) {
p1.computeSecondDerivative(paramsList.get(k++ % iter));
}
}
}
};
final long time1 = t1.getTime();
final Timer t2 = new Timer(t1.loops) {
@Override
void run() {
for (int i = 0, k = 0; i < paramsList.size(); i++) {
final FastMleGradient2Procedure p2 = FastMleGradient2ProcedureUtils.createUnrolled(yList.get(i), func);
for (int j = loops; j-- > 0; ) {
p2.computeSecondDerivative(paramsList.get(k++ % iter));
}
}
}
};
final long time2 = t2.getTime();
logger.log(TestLogUtils.getRecord(Level.INFO, "Standard = %d : Unrolled %d = %d : %fx", time1, nparams, time2, (1.0 * time1) / time2));
Assertions.assertTrue(time2 < time1 * 1.5);
}
Aggregations