use of uk.ac.sussex.gdsc.test.junit5.SpeedTag in project GDSC-SMLM by aherbert.
the class MultivariateGaussianMixtureExpectationMaximizationTest method testExpectationMaximizationSpeedWithDifferentNumberOfComponents.
/**
* Test the speed of implementations of the expectation maximization algorithm with a mixture of n
* 2D Gaussian distributions.
*
* @param seed the seed
*/
@SpeedTag
@SeededTest
void testExpectationMaximizationSpeedWithDifferentNumberOfComponents(RandomSeed seed) {
Assumptions.assumeTrue(TestSettings.allow(TestComplexity.HIGH));
// Create data
final UniformRandomProvider rng = RngUtils.create(seed.getSeed());
for (int n = 2; n <= 4; n++) {
final double[][][] data = new double[10][][];
for (int i = 0; i < data.length; i++) {
final double[] sampleWeights = createWeights(n, rng);
final double[][] sampleMeans = create(n, 2, rng, -5, 5);
final double[][] sampleStdDevs = create(n, 2, rng, 1, 10);
final double[] sampleCorrelations = create(n, rng, -0.9, 0.9);
data[i] = createData2d(1000, rng, sampleWeights, sampleMeans, sampleStdDevs, sampleCorrelations);
}
final int numComponents = n;
// Time initial estimation and fitting
final TimingService ts = new TimingService();
ts.execute(new FittingSpeedTask("Commons n=" + n + " 2D", data) {
@Override
Object run(double[][] data) {
final MultivariateNormalMixtureExpectationMaximization fitter = new MultivariateNormalMixtureExpectationMaximization(data);
fitter.fit(MultivariateNormalMixtureExpectationMaximization.estimate(data, numComponents));
return fitter.getLogLikelihood();
}
});
ts.execute(new FittingSpeedTask("GDSC n=" + n + " 2D", data) {
@Override
Object run(double[][] data) {
final MultivariateGaussianMixtureExpectationMaximization fitter = new MultivariateGaussianMixtureExpectationMaximization(data);
fitter.fit(MultivariateGaussianMixtureExpectationMaximization.estimate(data, numComponents));
return fitter.getLogLikelihood();
}
});
if (logger.isLoggable(Level.INFO)) {
logger.info(ts.getReport());
}
// More than twice as fast
Assertions.assertTrue(ts.get(-1).getMean() < ts.get(-2).getMean() / 2);
}
}
use of uk.ac.sussex.gdsc.test.junit5.SpeedTag in project GDSC-SMLM by aherbert.
the class ConvolutionTest method doSpeedTest.
@SpeedTag
@SeededTest
void doSpeedTest(RandomSeed seed) {
Assumptions.assumeTrue(logger.isLoggable(Level.INFO));
Assumptions.assumeTrue(TestSettings.allow(TestComplexity.HIGH));
final UniformRandomProvider rg = RngUtils.create(seed.getSeed());
int size = 10;
for (int i = 0; i < sizeLoops; i++) {
double sd = 0.5;
for (int j = 0; j < sdLoops; j++) {
speedTest(rg, size, sd);
sd *= 2;
}
size *= 2;
}
}
use of uk.ac.sussex.gdsc.test.junit5.SpeedTag in project GDSC-SMLM by aherbert.
the class ErfGaussian2DFunctionTest method functionIsFasterUsingForEach.
// Speed test forEach verses equivalent eval() function calls
@SpeedTag
@Test
void functionIsFasterUsingForEach() {
Assumptions.assumeTrue(TestSettings.allow(TestComplexity.MEDIUM));
final ErfGaussian2DFunction f1 = (ErfGaussian2DFunction) this.f1;
final LocalList<double[]> params = new LocalList<>();
for (final double background : testbackground) {
// Peak 1
for (final double signal1 : testsignal1) {
for (final double cx1 : testcx1) {
for (final double cy1 : testcy1) {
for (final double cz1 : testcz1) {
for (final double[] w1 : testw1) {
for (final double angle1 : testangle1) {
final double[] a = createParameters(background, signal1, cx1, cy1, cz1, w1[0], w1[1], angle1);
params.add(a);
}
}
}
}
}
}
}
final double[][] x = params.toArray(new double[0][]);
final int runs = 10000 / x.length;
final TimingService ts = new TimingService(runs);
ts.execute(new FunctionTimingTask(f1, x, 2));
ts.execute(new FunctionTimingTask(f1, x, 1));
ts.execute(new FunctionTimingTask(f1, x, 0));
ts.execute(new ForEachTimingTask(f1, x, 2));
ts.execute(new ForEachTimingTask(f1, x, 1));
ts.execute(new ForEachTimingTask(f1, x, 0));
final int size = ts.getSize();
ts.repeat(size);
if (logger.isLoggable(Level.INFO)) {
logger.info(ts.getReport());
}
for (int i = 1; i <= 3; i++) {
final TimingResult slow = ts.get(-i - 3);
final TimingResult fast = ts.get(-i);
logger.log(TestLogUtils.getTimingRecord(slow, fast));
}
}
use of uk.ac.sussex.gdsc.test.junit5.SpeedTag in project GDSC-SMLM by aherbert.
the class PoissonGaussianConvolutionFunctionTest method pdfFasterThanPmf.
@SpeedTag
@SeededTest
void pdfFasterThanPmf(RandomSeed seed) {
Assumptions.assumeTrue(TestSettings.allow(TestComplexity.MEDIUM));
// Realistic CCD parameters for speed test
final double s = 7.16;
final double g = 3.1;
final PoissonGaussianConvolutionFunction f1 = PoissonGaussianConvolutionFunction.createWithStandardDeviation(1 / g, s);
f1.setComputePmf(true);
final PoissonGaussianConvolutionFunction f2 = PoissonGaussianConvolutionFunction.createWithStandardDeviation(1 / g, s);
f2.setComputePmf(false);
final UniformRandomProvider rg = RngUtils.create(seed.getSeed());
// Generate realistic data from the probability mass function
final double[][] samples = new double[photons.length][];
for (int j = 0; j < photons.length; j++) {
final int start = (int) (4 * -s);
int mu = start;
final StoredDataStatistics stats = new StoredDataStatistics();
while (stats.getSum() < 0.995) {
final double p = f1.likelihood(mu, photons[j]);
stats.add(p);
if (mu > 10 && p / stats.getSum() < 1e-6) {
break;
}
mu++;
}
// Generate cumulative probability
final double[] data = stats.getValues();
for (int i = 1; i < data.length; i++) {
data[i] += data[i - 1];
}
// Normalise
for (int i = 0, end = data.length - 1; i < data.length; i++) {
data[i] /= data[end];
}
// Sample
final double[] sample = new double[1000];
for (int i = 0; i < sample.length; i++) {
final double p = rg.nextDouble();
int x = 0;
while (x < data.length && data[x] < p) {
x++;
}
sample[i] = start + x;
}
samples[j] = sample;
}
// Warm-up
run(f1, samples, photons);
run(f2, samples, photons);
long t1 = 0;
for (int i = 0; i < 5; i++) {
t1 += run(f1, samples, photons);
}
long t2 = 0;
for (int i = 0; i < 5; i++) {
t2 += run(f2, samples, photons);
}
logger.log(TestLogUtils.getTimingRecord("cdf", t1, "pdf", t2));
}
use of uk.ac.sussex.gdsc.test.junit5.SpeedTag in project GDSC-SMLM by aherbert.
the class ErfGaussian2DFunctionTest method functionIsFasterThanEquivalentGaussian2DFunction.
// Speed test verses equivalent Gaussian2DFunction
@SpeedTag
@Test
void functionIsFasterThanEquivalentGaussian2DFunction() {
Assumptions.assumeTrue(TestSettings.allow(TestComplexity.MEDIUM));
final int flags = this.flags & ~GaussianFunctionFactory.FIT_ERF;
final Gaussian2DFunction gf = GaussianFunctionFactory.create2D(1, maxx, maxy, flags, zModel);
final boolean zDepth = (flags & GaussianFunctionFactory.FIT_Z) != 0;
final LocalList<double[]> params1 = new LocalList<>();
final LocalList<double[]> params2 = new LocalList<>();
for (final double background : testbackground) {
// Peak 1
for (final double signal1 : testsignal1) {
for (final double cx1 : testcx1) {
for (final double cy1 : testcy1) {
for (final double cz1 : testcz1) {
for (final double[] w1 : testw1) {
for (final double angle1 : testangle1) {
double[] params = createParameters(background, signal1, cx1, cy1, cz1, w1[0], w1[1], angle1);
params1.add(params);
if (zDepth) {
// Change to a standard free circular function
params = params.clone();
params[Gaussian2DFunction.X_SD] *= zModel.getSx(params[Gaussian2DFunction.Z_POSITION]);
params[Gaussian2DFunction.Y_SD] *= zModel.getSy(params[Gaussian2DFunction.Z_POSITION]);
params[Gaussian2DFunction.Z_POSITION] = 0;
params2.add(params);
}
}
}
}
}
}
}
}
final double[][] x = params1.toArray(new double[0][]);
final double[][] x2 = (zDepth) ? params2.toArray(new double[0][]) : x;
final int runs = 10000 / x.length;
final TimingService ts = new TimingService(runs);
ts.execute(new FunctionTimingTask(gf, x2, 1));
ts.execute(new FunctionTimingTask(gf, x2, 0));
ts.execute(new FunctionTimingTask(f1, x, 2));
ts.execute(new FunctionTimingTask(f1, x, 1));
ts.execute(new FunctionTimingTask(f1, x, 0));
final int size = ts.getSize();
ts.repeat(size);
if (logger.isLoggable(Level.INFO)) {
logger.info(ts.getReport());
}
for (int i = 1; i <= 2; i++) {
final TimingResult slow = ts.get(-i - 3);
final TimingResult fast = ts.get(-i);
logger.log(TestLogUtils.getTimingRecord(slow, fast));
}
}
Aggregations