use of uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate in project GDSC-SMLM by aherbert.
the class MultivariateGaussianMixtureExpectationMaximizationTest method canFit.
@SeededTest
void canFit(RandomSeed seed) {
// Test verses the Commons Math estimation
final UniformRandomProvider rng = RngUtils.create(seed.getSeed());
final DoubleDoubleBiPredicate test = TestHelper.doublesAreClose(1e-5, 1e-16);
final int sampleSize = 1000;
// Number of components
for (int n = 2; n <= 3; n++) {
final double[] sampleWeights = createWeights(n, rng);
final double[][] sampleMeans = create(n, 2, rng, -5, 5);
final double[][] sampleStdDevs = create(n, 2, rng, 1, 10);
final double[] sampleCorrelations = create(n, rng, -0.9, 0.9);
final double[][] data = createData2d(sampleSize, rng, sampleWeights, sampleMeans, sampleStdDevs, sampleCorrelations);
final MixtureMultivariateGaussianDistribution initialModel1 = MultivariateGaussianMixtureExpectationMaximization.estimate(data, n);
final MultivariateGaussianMixtureExpectationMaximization fitter1 = new MultivariateGaussianMixtureExpectationMaximization(data);
Assertions.assertTrue(fitter1.fit(initialModel1));
final MultivariateNormalMixtureExpectationMaximization fitter2 = new MultivariateNormalMixtureExpectationMaximization(data);
fitter2.fit(MultivariateNormalMixtureExpectationMaximization.estimate(data, n));
final double ll1 = fitter1.getLogLikelihood() / sampleSize;
Assertions.assertNotEquals(0, ll1);
final double ll2 = fitter2.getLogLikelihood();
TestAssertions.assertTest(ll2, ll1, test);
final MixtureMultivariateGaussianDistribution model1 = fitter1.getFittedModel();
Assertions.assertNotNull(model1);
final MixtureMultivariateNormalDistribution model2 = fitter2.getFittedModel();
// Check fitted models are the same
final List<Pair<Double, MultivariateNormalDistribution>> comp = model2.getComponents();
final double[] weights = model1.getWeights();
final MultivariateGaussianDistribution[] distributions = model1.getDistributions();
Assertions.assertEquals(n, comp.size());
Assertions.assertEquals(n, weights.length);
Assertions.assertEquals(n, distributions.length);
for (int i = 0; i < n; i++) {
TestAssertions.assertTest(comp.get(i).getFirst(), weights[i], test, "weight");
final MultivariateNormalDistribution d = comp.get(i).getSecond();
TestAssertions.assertArrayTest(d.getMeans(), distributions[i].getMeans(), test, "means");
TestAssertions.assertArrayTest(d.getCovariances().getData(), distributions[i].getCovariances(), test, "covariances");
}
final int iterations = fitter1.getIterations();
Assertions.assertNotEquals(0, iterations);
// Test without convergence
if (iterations > 2) {
Assertions.assertFalse(fitter1.fit(initialModel1, 2, DEFAULT_CONVERGENCE_CHECKER));
}
}
}
use of uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate in project GDSC-SMLM by aherbert.
the class PsfModelGradient1FunctionTest method canComputeValueAndGradient.
@Test
void canComputeValueAndGradient() {
// Use a reasonable z-depth function from the Smith, et al (2010) paper (page 377)
final double sx = 1.08;
final double sy = 1.01;
final double gamma = 0.389;
final double d = 0.531;
final double Ax = -0.0708;
final double Bx = -0.073;
final double Ay = 0.164;
final double By = 0.0417;
final AstigmatismZModel zModel = HoltzerAstigmatismZModel.create(sx, sy, gamma, d, Ax, Bx, Ay, By);
// Small size ensure the PSF model covers the entire image
final int maxx = 11;
final int maxy = 11;
final double[] ve = new double[maxx * maxy];
final double[] vo = new double[maxx * maxy];
final double[][] ge = new double[maxx * maxy][];
final double[][] go = new double[maxx * maxy][];
final PsfModelGradient1Function psf = new PsfModelGradient1Function(new GaussianPsfModel(zModel), maxx, maxy);
final ErfGaussian2DFunction f = new SingleAstigmatismErfGaussian2DFunction(maxx, maxy, zModel);
f.setErfFunction(ErfFunction.COMMONS_MATH);
final double[] a2 = new double[Gaussian2DFunction.PARAMETERS_PER_PEAK + 1];
final DoubleDoubleBiPredicate equality = TestHelper.doublesAreClose(1e-8, 0);
final double c = maxx * 0.5;
for (int i = -1; i <= 1; i++) {
final double x0 = c + i * 0.33;
for (int j = -1; j <= 1; j++) {
final double x1 = c + j * 0.33;
for (int k = -1; k <= 1; k++) {
final double x2 = k * 0.33;
for (final double in : new double[] { 23.2, 405.67 }) {
// Background is constant for gradients so just use 1 value
final double[] a = new double[] { 2.2, in, x0, x1, x2 };
psf.initialise1(a);
psf.forEach(new Gradient1Procedure() {
int index = 0;
@Override
public void execute(double value, double[] dyDa) {
vo[index] = value;
go[index] = dyDa.clone();
index++;
}
});
a2[Gaussian2DFunction.BACKGROUND] = a[0];
a2[Gaussian2DFunction.SIGNAL] = a[1];
a2[Gaussian2DFunction.X_POSITION] = a[2] - 0.5;
a2[Gaussian2DFunction.Y_POSITION] = a[3] - 0.5;
a2[Gaussian2DFunction.Z_POSITION] = a[4];
f.initialise1(a2);
f.forEach(new Gradient1Procedure() {
int index = 0;
@Override
public void execute(double value, double[] dyDa) {
ve[index] = value;
ge[index] = dyDa.clone();
index++;
}
});
for (int ii = 0; ii < ve.length; ii++) {
TestAssertions.assertTest(ve[ii], vo[ii], equality);
TestAssertions.assertArrayTest(ge[ii], go[ii], equality);
}
}
}
}
}
}
use of uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate in project GDSC-SMLM by aherbert.
the class Gaussian2DPeakResultHelperTest method canComputeCumulative2DAndInverse.
@Test
void canComputeCumulative2DAndInverse() {
Assertions.assertEquals(0, Gaussian2DPeakResultHelper.cumulative2D(0));
Assertions.assertTrue(1 == Gaussian2DPeakResultHelper.cumulative2D(Double.POSITIVE_INFINITY));
Assertions.assertEquals(0, Gaussian2DPeakResultHelper.inverseCumulative2D(0));
Assertions.assertTrue(Double.POSITIVE_INFINITY == Gaussian2DPeakResultHelper.inverseCumulative2D(1));
final DoubleDoubleBiPredicate predicate = TestHelper.doublesAreClose(1e-8, 0);
for (int i = 1; i <= 10; i++) {
final double r = i / 10.0;
final double p = Gaussian2DPeakResultHelper.cumulative2D(r);
final double r2 = Gaussian2DPeakResultHelper.inverseCumulative2D(p);
TestAssertions.assertTest(r, r2, predicate);
}
}
use of uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate in project GDSC-SMLM by aherbert.
the class ConvolutionTest method canComputeConvolution.
@SeededTest
void canComputeConvolution(RandomSeed seed) {
final UniformRandomProvider random = RngUtils.create(seed.getSeed());
int size = 10;
final DoubleDoubleBiPredicate predicate = TestHelper.doublesAreClose(1e-6, 0);
for (int i = 0; i < sizeLoops; i++) {
double sd = 0.5;
for (int j = 0; j < sdLoops; j++) {
final double[] data = randomData(random, size);
final double[] kernel = createKernel(sd);
final double[] r1 = Convolution.convolve(data, kernel);
final double[] r1b = Convolution.convolve(kernel, data);
final double[] r2 = Convolution.convolveFft(data, kernel);
final double[] r2b = Convolution.convolveFft(kernel, data);
Assertions.assertEquals(r1.length, r1b.length);
Assertions.assertEquals(r1.length, r2.length);
Assertions.assertEquals(r1.length, r2b.length);
TestAssertions.assertArrayTest(r1, r1b, predicate, "Spatial convolution doesn't match");
TestAssertions.assertArrayTest(r2, r2b, predicate, "FFT convolution doesn't match");
sd *= 2;
}
size *= 2;
}
}
use of uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate in project GDSC-SMLM by aherbert.
the class PeakResultsReaderTest method checkEqual.
private static void checkEqual(ResultsFileFormat fileFormat, boolean showDeviations, boolean showEndFrame, boolean showId, boolean showPrecision, boolean showCategory, boolean sort, MemoryPeakResults expectedResults, MemoryPeakResults actualResults) {
Assertions.assertNotNull(actualResults, "Input results are null");
Assertions.assertEquals(expectedResults.size(), actualResults.size(), "Size differ");
final PeakResult[] expected = expectedResults.toArray();
final PeakResult[] actual = actualResults.toArray();
if (sort) {
// Results should be sorted by time
Arrays.sort(expected, (o1, o2) -> o1.getFrame() - o2.getFrame());
}
// TSF requires the bias be subtracted
// double bias = expectedResults.getCalibration().getBias();
final DoubleDoubleBiPredicate deltaD = TestHelper.doublesIsCloseTo(1e-5, 0);
final FloatFloatBiPredicate deltaF = TestHelper.floatsIsCloseTo(1e-5, 0);
for (int i = 0; i < actualResults.size(); i++) {
final PeakResult p1 = expected[i];
final PeakResult p2 = actual[i];
final ObjectArrayFormatSupplier msg = new ObjectArrayFormatSupplier("%s @ [" + i + "]", 1);
Assertions.assertEquals(p1.getFrame(), p2.getFrame(), msg.set(0, "Peak"));
if (fileFormat == ResultsFileFormat.MALK) {
TestAssertions.assertTest(p1.getXPosition(), p2.getXPosition(), deltaF, msg.set(0, "X"));
TestAssertions.assertTest(p1.getYPosition(), p2.getYPosition(), deltaF, msg.set(0, "Y"));
TestAssertions.assertTest(p1.getIntensity(), p2.getIntensity(), deltaF, msg.set(0, "Intensity"));
continue;
}
Assertions.assertEquals(p1.getOrigX(), p2.getOrigX(), msg.set(0, "Orig X"));
Assertions.assertEquals(p1.getOrigY(), p2.getOrigY(), msg.set(0, "Orig Y"));
Assertions.assertNotNull(p2.getParameters(), msg.set(0, "Params is null"));
if (showEndFrame) {
Assertions.assertEquals(p1.getEndFrame(), p2.getEndFrame(), msg.set(0, "End frame"));
}
if (showId) {
Assertions.assertEquals(p1.getId(), p2.getId(), msg.set(0, "ID"));
}
if (showDeviations) {
Assertions.assertNotNull(p2.getParameterDeviations(), msg.set(0, "Deviations"));
}
if (showCategory) {
Assertions.assertEquals(p1.getCategory(), p2.getCategory(), msg.set(0, "Category"));
}
// Binary should be exact for float numbers
if (fileFormat == ResultsFileFormat.BINARY) {
Assertions.assertEquals(p1.getOrigValue(), p2.getOrigValue(), msg.set(0, "Orig value"));
Assertions.assertEquals(p1.getError(), p2.getError(), msg.set(0, "Error"));
Assertions.assertEquals(p1.getNoise(), p2.getNoise(), msg.set(0, "Noise"));
Assertions.assertEquals(p1.getMeanIntensity(), p2.getMeanIntensity(), msg.set(0, "Mean intensity"));
Assertions.assertArrayEquals(p1.getParameters(), p2.getParameters(), msg.set(0, "Params"));
if (showDeviations) {
Assertions.assertArrayEquals(p1.getParameterDeviations(), p2.getParameterDeviations(), msg.set(0, "Params StdDev"));
}
if (showPrecision) {
Assertions.assertEquals(p1.getPrecision(), p2.getPrecision(), msg.set(0, "Precision"));
}
continue;
}
// Otherwise have an error
TestAssertions.assertTest(p1.getOrigValue(), p2.getOrigValue(), deltaF, msg.set(0, "Orig value"));
TestAssertions.assertTest(p1.getError(), p2.getError(), deltaD, msg.set(0, "Error"));
TestAssertions.assertTest(p1.getNoise(), p2.getNoise(), deltaF, msg.set(0, "Noise"));
TestAssertions.assertTest(p1.getMeanIntensity(), p2.getMeanIntensity(), deltaF, msg.set(0, "Mean intensity"));
TestAssertions.assertArrayTest(p1.getParameters(), p2.getParameters(), deltaF, msg.set(0, "Params"));
if (showDeviations) {
TestAssertions.assertArrayTest(p1.getParameterDeviations(), p2.getParameterDeviations(), deltaF, msg.set(0, "Params StdDev"));
}
if (showPrecision) {
// Handle NaN precisions
final double pa = p1.getPrecision();
final double pb = p2.getPrecision();
if (!Double.isNaN(pa) || !Double.isNaN(pb)) {
TestAssertions.assertTest(p1.getPrecision(), p2.getPrecision(), deltaD, msg.set(0, "Precision"));
}
}
}
// Check the header information
Assertions.assertEquals(expectedResults.getName(), actualResults.getName(), "Name");
Assertions.assertEquals(expectedResults.getConfiguration(), actualResults.getConfiguration(), "Configuration");
final Rectangle r1 = expectedResults.getBounds();
final Rectangle r2 = actualResults.getBounds();
if (r1 != null) {
Assertions.assertNotNull(r2, "Bounds");
Assertions.assertEquals(r1.x, r2.x, "Bounds x");
Assertions.assertEquals(r1.y, r2.y, "Bounds y");
Assertions.assertEquals(r1.width, r2.width, "Bounds width");
Assertions.assertEquals(r1.height, r2.height, "Bounds height");
} else {
Assertions.assertNull(r2, "Bounds");
}
final Calibration c1 = expectedResults.getCalibration();
final Calibration c2 = actualResults.getCalibration();
if (c1 != null) {
Assertions.assertNotNull(c2, "Calibration");
// Be lenient and allow no TimeUnit to match TimeUnit.FRAME
boolean ok = c1.equals(c2);
if (!ok && new CalibrationReader(c1).getTimeUnitValue() == TimeUnit.TIME_UNIT_NA_VALUE) {
switch(fileFormat) {
case BINARY:
case MALK:
case TEXT:
case TSF:
final CalibrationWriter writer = new CalibrationWriter(c1);
writer.setTimeUnit(TimeUnit.FRAME);
ok = writer.getCalibration().equals(c2);
break;
default:
// Do not assume frames for other file formats
break;
}
}
Assertions.assertTrue(ok, "Calibration");
} else {
Assertions.assertNull(c2, "Calibration");
}
final PSF p1 = expectedResults.getPsf();
final PSF p2 = actualResults.getPsf();
if (p1 != null) {
Assertions.assertNotNull(p2, "PSF");
Assertions.assertTrue(p1.equals(p2), "PSF");
} else {
Assertions.assertNull(p2, "PSF");
}
}
Aggregations