Search in sources :

Example 6 with SpeedTag

use of uk.ac.sussex.gdsc.test.junit5.SpeedTag in project GDSC-SMLM by aherbert.

the class MultivariateGaussianMixtureExpectationMaximizationTest method testExpectationMaximizationSpeedWithDifferentNumberOfComponents.

/**
 * Test the speed of implementations of the expectation maximization algorithm with a mixture of n
 * 2D Gaussian distributions.
 *
 * @param seed the seed
 */
@SpeedTag
@SeededTest
void testExpectationMaximizationSpeedWithDifferentNumberOfComponents(RandomSeed seed) {
    Assumptions.assumeTrue(TestSettings.allow(TestComplexity.HIGH));
    // Create data
    final UniformRandomProvider rng = RngUtils.create(seed.getSeed());
    for (int n = 2; n <= 4; n++) {
        final double[][][] data = new double[10][][];
        for (int i = 0; i < data.length; i++) {
            final double[] sampleWeights = createWeights(n, rng);
            final double[][] sampleMeans = create(n, 2, rng, -5, 5);
            final double[][] sampleStdDevs = create(n, 2, rng, 1, 10);
            final double[] sampleCorrelations = create(n, rng, -0.9, 0.9);
            data[i] = createData2d(1000, rng, sampleWeights, sampleMeans, sampleStdDevs, sampleCorrelations);
        }
        final int numComponents = n;
        // Time initial estimation and fitting
        final TimingService ts = new TimingService();
        ts.execute(new FittingSpeedTask("Commons n=" + n + " 2D", data) {

            @Override
            Object run(double[][] data) {
                final MultivariateNormalMixtureExpectationMaximization fitter = new MultivariateNormalMixtureExpectationMaximization(data);
                fitter.fit(MultivariateNormalMixtureExpectationMaximization.estimate(data, numComponents));
                return fitter.getLogLikelihood();
            }
        });
        ts.execute(new FittingSpeedTask("GDSC n=" + n + " 2D", data) {

            @Override
            Object run(double[][] data) {
                final MultivariateGaussianMixtureExpectationMaximization fitter = new MultivariateGaussianMixtureExpectationMaximization(data);
                fitter.fit(MultivariateGaussianMixtureExpectationMaximization.estimate(data, numComponents));
                return fitter.getLogLikelihood();
            }
        });
        if (logger.isLoggable(Level.INFO)) {
            logger.info(ts.getReport());
        }
        // More than twice as fast
        Assertions.assertTrue(ts.get(-1).getMean() < ts.get(-2).getMean() / 2);
    }
}
Also used : UniformRandomProvider(org.apache.commons.rng.UniformRandomProvider) TimingService(uk.ac.sussex.gdsc.test.utils.TimingService) MultivariateNormalMixtureExpectationMaximization(org.apache.commons.math3.distribution.fitting.MultivariateNormalMixtureExpectationMaximization) SpeedTag(uk.ac.sussex.gdsc.test.junit5.SpeedTag) SeededTest(uk.ac.sussex.gdsc.test.junit5.SeededTest)

Example 7 with SpeedTag

use of uk.ac.sussex.gdsc.test.junit5.SpeedTag in project GDSC-SMLM by aherbert.

the class ConvolutionTest method doSpeedTest.

@SpeedTag
@SeededTest
void doSpeedTest(RandomSeed seed) {
    Assumptions.assumeTrue(logger.isLoggable(Level.INFO));
    Assumptions.assumeTrue(TestSettings.allow(TestComplexity.HIGH));
    final UniformRandomProvider rg = RngUtils.create(seed.getSeed());
    int size = 10;
    for (int i = 0; i < sizeLoops; i++) {
        double sd = 0.5;
        for (int j = 0; j < sdLoops; j++) {
            speedTest(rg, size, sd);
            sd *= 2;
        }
        size *= 2;
    }
}
Also used : UniformRandomProvider(org.apache.commons.rng.UniformRandomProvider) SpeedTag(uk.ac.sussex.gdsc.test.junit5.SpeedTag) SeededTest(uk.ac.sussex.gdsc.test.junit5.SeededTest)

Example 8 with SpeedTag

use of uk.ac.sussex.gdsc.test.junit5.SpeedTag in project GDSC-SMLM by aherbert.

the class ErfGaussian2DFunctionTest method functionIsFasterUsingForEach.

// Speed test forEach verses equivalent eval() function calls
@SpeedTag
@Test
void functionIsFasterUsingForEach() {
    Assumptions.assumeTrue(TestSettings.allow(TestComplexity.MEDIUM));
    final ErfGaussian2DFunction f1 = (ErfGaussian2DFunction) this.f1;
    final LocalList<double[]> params = new LocalList<>();
    for (final double background : testbackground) {
        // Peak 1
        for (final double signal1 : testsignal1) {
            for (final double cx1 : testcx1) {
                for (final double cy1 : testcy1) {
                    for (final double cz1 : testcz1) {
                        for (final double[] w1 : testw1) {
                            for (final double angle1 : testangle1) {
                                final double[] a = createParameters(background, signal1, cx1, cy1, cz1, w1[0], w1[1], angle1);
                                params.add(a);
                            }
                        }
                    }
                }
            }
        }
    }
    final double[][] x = params.toArray(new double[0][]);
    final int runs = 10000 / x.length;
    final TimingService ts = new TimingService(runs);
    ts.execute(new FunctionTimingTask(f1, x, 2));
    ts.execute(new FunctionTimingTask(f1, x, 1));
    ts.execute(new FunctionTimingTask(f1, x, 0));
    ts.execute(new ForEachTimingTask(f1, x, 2));
    ts.execute(new ForEachTimingTask(f1, x, 1));
    ts.execute(new ForEachTimingTask(f1, x, 0));
    final int size = ts.getSize();
    ts.repeat(size);
    if (logger.isLoggable(Level.INFO)) {
        logger.info(ts.getReport());
    }
    for (int i = 1; i <= 3; i++) {
        final TimingResult slow = ts.get(-i - 3);
        final TimingResult fast = ts.get(-i);
        logger.log(TestLogUtils.getTimingRecord(slow, fast));
    }
}
Also used : LocalList(uk.ac.sussex.gdsc.core.utils.LocalList) TimingResult(uk.ac.sussex.gdsc.test.utils.TimingResult) TimingService(uk.ac.sussex.gdsc.test.utils.TimingService) SpeedTag(uk.ac.sussex.gdsc.test.junit5.SpeedTag) Gaussian2DFunctionTest(uk.ac.sussex.gdsc.smlm.function.gaussian.Gaussian2DFunctionTest) Test(org.junit.jupiter.api.Test)

Example 9 with SpeedTag

use of uk.ac.sussex.gdsc.test.junit5.SpeedTag in project GDSC-SMLM by aherbert.

the class PoissonGaussianConvolutionFunctionTest method pdfFasterThanPmf.

@SpeedTag
@SeededTest
void pdfFasterThanPmf(RandomSeed seed) {
    Assumptions.assumeTrue(TestSettings.allow(TestComplexity.MEDIUM));
    // Realistic CCD parameters for speed test
    final double s = 7.16;
    final double g = 3.1;
    final PoissonGaussianConvolutionFunction f1 = PoissonGaussianConvolutionFunction.createWithStandardDeviation(1 / g, s);
    f1.setComputePmf(true);
    final PoissonGaussianConvolutionFunction f2 = PoissonGaussianConvolutionFunction.createWithStandardDeviation(1 / g, s);
    f2.setComputePmf(false);
    final UniformRandomProvider rg = RngUtils.create(seed.getSeed());
    // Generate realistic data from the probability mass function
    final double[][] samples = new double[photons.length][];
    for (int j = 0; j < photons.length; j++) {
        final int start = (int) (4 * -s);
        int mu = start;
        final StoredDataStatistics stats = new StoredDataStatistics();
        while (stats.getSum() < 0.995) {
            final double p = f1.likelihood(mu, photons[j]);
            stats.add(p);
            if (mu > 10 && p / stats.getSum() < 1e-6) {
                break;
            }
            mu++;
        }
        // Generate cumulative probability
        final double[] data = stats.getValues();
        for (int i = 1; i < data.length; i++) {
            data[i] += data[i - 1];
        }
        // Normalise
        for (int i = 0, end = data.length - 1; i < data.length; i++) {
            data[i] /= data[end];
        }
        // Sample
        final double[] sample = new double[1000];
        for (int i = 0; i < sample.length; i++) {
            final double p = rg.nextDouble();
            int x = 0;
            while (x < data.length && data[x] < p) {
                x++;
            }
            sample[i] = start + x;
        }
        samples[j] = sample;
    }
    // Warm-up
    run(f1, samples, photons);
    run(f2, samples, photons);
    long t1 = 0;
    for (int i = 0; i < 5; i++) {
        t1 += run(f1, samples, photons);
    }
    long t2 = 0;
    for (int i = 0; i < 5; i++) {
        t2 += run(f2, samples, photons);
    }
    logger.log(TestLogUtils.getTimingRecord("cdf", t1, "pdf", t2));
}
Also used : StoredDataStatistics(uk.ac.sussex.gdsc.core.utils.StoredDataStatistics) UniformRandomProvider(org.apache.commons.rng.UniformRandomProvider) SpeedTag(uk.ac.sussex.gdsc.test.junit5.SpeedTag) SeededTest(uk.ac.sussex.gdsc.test.junit5.SeededTest)

Example 10 with SpeedTag

use of uk.ac.sussex.gdsc.test.junit5.SpeedTag in project GDSC-SMLM by aherbert.

the class ErfGaussian2DFunctionTest method functionIsFasterThanEquivalentGaussian2DFunction.

// Speed test verses equivalent Gaussian2DFunction
@SpeedTag
@Test
void functionIsFasterThanEquivalentGaussian2DFunction() {
    Assumptions.assumeTrue(TestSettings.allow(TestComplexity.MEDIUM));
    final int flags = this.flags & ~GaussianFunctionFactory.FIT_ERF;
    final Gaussian2DFunction gf = GaussianFunctionFactory.create2D(1, maxx, maxy, flags, zModel);
    final boolean zDepth = (flags & GaussianFunctionFactory.FIT_Z) != 0;
    final LocalList<double[]> params1 = new LocalList<>();
    final LocalList<double[]> params2 = new LocalList<>();
    for (final double background : testbackground) {
        // Peak 1
        for (final double signal1 : testsignal1) {
            for (final double cx1 : testcx1) {
                for (final double cy1 : testcy1) {
                    for (final double cz1 : testcz1) {
                        for (final double[] w1 : testw1) {
                            for (final double angle1 : testangle1) {
                                double[] params = createParameters(background, signal1, cx1, cy1, cz1, w1[0], w1[1], angle1);
                                params1.add(params);
                                if (zDepth) {
                                    // Change to a standard free circular function
                                    params = params.clone();
                                    params[Gaussian2DFunction.X_SD] *= zModel.getSx(params[Gaussian2DFunction.Z_POSITION]);
                                    params[Gaussian2DFunction.Y_SD] *= zModel.getSy(params[Gaussian2DFunction.Z_POSITION]);
                                    params[Gaussian2DFunction.Z_POSITION] = 0;
                                    params2.add(params);
                                }
                            }
                        }
                    }
                }
            }
        }
    }
    final double[][] x = params1.toArray(new double[0][]);
    final double[][] x2 = (zDepth) ? params2.toArray(new double[0][]) : x;
    final int runs = 10000 / x.length;
    final TimingService ts = new TimingService(runs);
    ts.execute(new FunctionTimingTask(gf, x2, 1));
    ts.execute(new FunctionTimingTask(gf, x2, 0));
    ts.execute(new FunctionTimingTask(f1, x, 2));
    ts.execute(new FunctionTimingTask(f1, x, 1));
    ts.execute(new FunctionTimingTask(f1, x, 0));
    final int size = ts.getSize();
    ts.repeat(size);
    if (logger.isLoggable(Level.INFO)) {
        logger.info(ts.getReport());
    }
    for (int i = 1; i <= 2; i++) {
        final TimingResult slow = ts.get(-i - 3);
        final TimingResult fast = ts.get(-i);
        logger.log(TestLogUtils.getTimingRecord(slow, fast));
    }
}
Also used : LocalList(uk.ac.sussex.gdsc.core.utils.LocalList) TimingResult(uk.ac.sussex.gdsc.test.utils.TimingResult) Gaussian2DFunction(uk.ac.sussex.gdsc.smlm.function.gaussian.Gaussian2DFunction) TimingService(uk.ac.sussex.gdsc.test.utils.TimingService) SpeedTag(uk.ac.sussex.gdsc.test.junit5.SpeedTag) Gaussian2DFunctionTest(uk.ac.sussex.gdsc.smlm.function.gaussian.Gaussian2DFunctionTest) Test(org.junit.jupiter.api.Test)

Aggregations

SpeedTag (uk.ac.sussex.gdsc.test.junit5.SpeedTag)16 SeededTest (uk.ac.sussex.gdsc.test.junit5.SeededTest)13 TimingService (uk.ac.sussex.gdsc.test.utils.TimingService)13 UniformRandomProvider (org.apache.commons.rng.UniformRandomProvider)12 TimingResult (uk.ac.sussex.gdsc.test.utils.TimingResult)6 Test (org.junit.jupiter.api.Test)5 LocalList (uk.ac.sussex.gdsc.core.utils.LocalList)3 MultivariateNormalMixtureExpectationMaximization (org.apache.commons.math3.distribution.fitting.MultivariateNormalMixtureExpectationMaximization)2 Gaussian2DFunctionTest (uk.ac.sussex.gdsc.smlm.function.gaussian.Gaussian2DFunctionTest)2 FloatProcessor (ij.process.FloatProcessor)1 ArrayList (java.util.ArrayList)1 Arrays (java.util.Arrays)1 List (java.util.List)1 Level (java.util.logging.Level)1 Logger (java.util.logging.Logger)1 IntStream (java.util.stream.IntStream)1 MixtureMultivariateNormalDistribution (org.apache.commons.math3.distribution.MixtureMultivariateNormalDistribution)1 MultivariateNormalDistribution (org.apache.commons.math3.distribution.MultivariateNormalDistribution)1 Array2DRowRealMatrix (org.apache.commons.math3.linear.Array2DRowRealMatrix)1 Covariance (org.apache.commons.math3.stat.correlation.Covariance)1