Search in sources :

Example 6 with TimingService

use of uk.ac.sussex.gdsc.test.utils.TimingService in project GDSC-SMLM by aherbert.

the class CubicSplineFunctionTest method speedTest.

@SuppressWarnings("null")
private void speedTest(int n, int order) {
    // No assertions, this is just a report
    Assumptions.assumeTrue(logger.isLoggable(Level.INFO));
    // Assumptions.assumeTrue(TestSettings.allow(TestComplexity.MEDIUM));
    final CubicSplineFunction cf = (n == 2) ? f2 : f1;
    Assumptions.assumeTrue(null != cf);
    final CubicSplineFunction cff = (n == 2) ? f2f : f1f;
    final ErfGaussian2DFunction gf = (ErfGaussian2DFunction) GaussianFunctionFactory.create2D(n, maxx, maxy, GaussianFunctionFactory.FIT_ASTIGMATISM, zModel);
    final Gaussian2DFunction gf2 = (order < 2) ? GaussianFunctionFactory.create2D(n, maxx, maxy, GaussianFunctionFactory.FIT_SIMPLE_FREE_CIRCLE, zModel) : null;
    final LocalList<double[]> l1 = new LocalList<>();
    final LocalList<double[]> l2 = new LocalList<>();
    final LocalList<double[]> l3 = new LocalList<>();
    final double[] a = new double[1 + n * CubicSplineFunction.PARAMETERS_PER_PEAK];
    final double[] b = new double[1 + n * Gaussian2DFunction.PARAMETERS_PER_PEAK];
    double[] bb = null;
    a[CubicSplineFunction.BACKGROUND] = 0.1;
    b[Gaussian2DFunction.BACKGROUND] = 0.1;
    for (int i = 0; i < n; i++) {
        a[i * CubicSplineFunction.PARAMETERS_PER_PEAK + CubicSplineFunction.SIGNAL] = 10;
        b[i * Gaussian2DFunction.PARAMETERS_PER_PEAK + Gaussian2DFunction.SIGNAL] = 10;
    }
    if (n == 2) {
        // Fix second peak parameters
        a[CubicSplineFunction.PARAMETERS_PER_PEAK + CubicSplineFunction.X_POSITION] = testcx1[0];
        a[CubicSplineFunction.PARAMETERS_PER_PEAK + CubicSplineFunction.Y_POSITION] = testcy1[0];
        a[CubicSplineFunction.PARAMETERS_PER_PEAK + CubicSplineFunction.Z_POSITION] = testcz1[0];
        b[Gaussian2DFunction.PARAMETERS_PER_PEAK + Gaussian2DFunction.X_POSITION] = testcx1[0];
        b[Gaussian2DFunction.PARAMETERS_PER_PEAK + Gaussian2DFunction.Y_POSITION] = testcy1[0];
        b[Gaussian2DFunction.PARAMETERS_PER_PEAK + Gaussian2DFunction.Z_POSITION] = testcz1[0];
    }
    if (gf2 != null) {
        bb = b.clone();
        if (n == 2) {
            // Fix second peak parameters
            bb[Gaussian2DFunction.PARAMETERS_PER_PEAK + Gaussian2DFunction.X_SD] = zModel.getSx(testcz1[0]);
            bb[Gaussian2DFunction.PARAMETERS_PER_PEAK + Gaussian2DFunction.Y_SD] = zModel.getSy(testcz1[0]);
        }
    }
    for (int x = 0; x <= maxx; x++) {
        a[CubicSplineFunction.X_POSITION] = x;
        b[Gaussian2DFunction.X_POSITION] = x;
        for (int y = 0; y <= maxy; y++) {
            a[CubicSplineFunction.Y_POSITION] = y;
            b[Gaussian2DFunction.Y_POSITION] = y;
            for (int z = -zDepth; z <= zDepth; z++) {
                a[CubicSplineFunction.Z_POSITION] = z;
                b[Gaussian2DFunction.Z_POSITION] = z;
                l1.add(a.clone());
                l2.add(b.clone());
                if (gf2 != null) {
                    bb[Gaussian2DFunction.X_SD] = zModel.getSx(z);
                    bb[Gaussian2DFunction.Y_SD] = zModel.getSy(z);
                    l3.add(bb.clone());
                }
            }
        }
    }
    final double[][] x1 = l1.toArray(new double[0][]);
    final double[][] x2 = l2.toArray(new double[0][]);
    final double[][] x3 = l3.toArray(new double[0][]);
    final TimingService ts = new TimingService(5);
    ts.execute(new FunctionTimingTask(gf, x2, order));
    if (gf2 != null) {
        ts.execute(new FunctionTimingTask(gf2, x3, order));
    }
    ts.execute(new FunctionTimingTask(cf, x1, order));
    ts.execute(new FunctionTimingTask(cff, x1, order, " single-precision"));
    final int size = ts.getSize();
    ts.repeat(size);
    logger.info(ts.getReport(size));
}
Also used : ErfGaussian2DFunction(uk.ac.sussex.gdsc.smlm.function.gaussian.erf.ErfGaussian2DFunction) LocalList(uk.ac.sussex.gdsc.core.utils.LocalList) ErfGaussian2DFunction(uk.ac.sussex.gdsc.smlm.function.gaussian.erf.ErfGaussian2DFunction) Gaussian2DFunction(uk.ac.sussex.gdsc.smlm.function.gaussian.Gaussian2DFunction) TimingService(uk.ac.sussex.gdsc.test.utils.TimingService)

Example 7 with TimingService

use of uk.ac.sussex.gdsc.test.utils.TimingService in project GDSC-SMLM by aherbert.

the class FastLogTest method canTestDoubleSpeedLog1P.

@SpeedTag
@SeededTest
void canTestDoubleSpeedLog1P(RandomSeed seed) {
    // No assertions, this is just a report
    Assumptions.assumeTrue(logger.isLoggable(Level.INFO));
    Assumptions.assumeTrue(TestSettings.allow(TestComplexity.MEDIUM));
    final UniformRandomProvider rng = RngUtils.create(seed.getSeed());
    final double[] x = new double[1000000];
    for (int i = 0; i < x.length; i++) {
        x[i] = nextUniformDouble(rng);
    }
    final MathLog fl = new MathLog();
    final TimingService ts = new TimingService(5);
    // ts.execute(new DoubleTimingTask(new TestLog(fl), 0, x));
    ts.execute(new DoubleTimingTask(new Test1PLog(fl), 0, x));
    ts.execute(new DoubleTimingTask(new TestLog1P(fl), 0, x));
    ts.execute(new DoubleTimingTask(new TestLog1PApache(fl), 0, x));
    // ts.execute(new DoubleTimingTask(new TestLog(fl), 0, x));
    ts.execute(new DoubleTimingTask(new Test1PLog(fl), 0, x));
    ts.execute(new DoubleTimingTask(new TestLog1P(fl), 0, x));
    ts.execute(new DoubleTimingTask(new TestLog1PApache(fl), 0, x));
    final int size = ts.getSize();
    ts.repeat(size);
    logger.info(ts.getReport(size));
}
Also used : UniformRandomProvider(org.apache.commons.rng.UniformRandomProvider) TimingService(uk.ac.sussex.gdsc.test.utils.TimingService) SpeedTag(uk.ac.sussex.gdsc.test.junit5.SpeedTag) SeededTest(uk.ac.sussex.gdsc.test.junit5.SeededTest)

Example 8 with TimingService

use of uk.ac.sussex.gdsc.test.utils.TimingService in project GDSC-SMLM by aherbert.

the class FastLogTest method canTestFloatSpeed.

@SpeedTag
@SeededTest
void canTestFloatSpeed(RandomSeed seed) {
    // No assertions, this is just a report
    Assumptions.assumeTrue(logger.isLoggable(Level.INFO));
    Assumptions.assumeTrue(TestSettings.allow(TestComplexity.MEDIUM));
    final UniformRandomProvider rng = RngUtils.create(seed.getSeed());
    final float[] x = new float[1000000];
    for (int i = 0; i < x.length; i++) {
        x[i] = nextUniformFloat(rng);
    }
    final TimingService ts = new TimingService(5);
    ts.execute(new FloatTimingTask(new TestLog(new MathLog()), 0, x));
    ts.execute(new FloatTimingTask(new TestLog(new FastMathLog()), 0, x));
    for (final int q : new int[] { 11 }) {
        final int n = 23 - q;
        final IcsiFastLog fl = IcsiFastLog.create(n, DataType.FLOAT);
        ts.execute(new FloatTimingTask(new TestLog(fl), q, x));
        ts.execute(new FloatTimingTask(new TestFastLog(fl), q, x));
        final FFastLog ff = new FFastLog(n);
        ts.execute(new FloatTimingTask(new TestLog(ff), q, x));
        ts.execute(new FloatTimingTask(new TestFastLog(ff), q, x));
        final DFastLog df = new DFastLog(n);
        ts.execute(new FloatTimingTask(new TestLog(df), q, x));
        ts.execute(new FloatTimingTask(new TestFastLog(df), q, x));
        final TurboLog tf = new TurboLog(n);
        ts.execute(new FloatTimingTask(new TestLog(tf), q, x));
        ts.execute(new FloatTimingTask(new TestFastLog(tf), q, x));
        // TurboLog2 tf2 = new TurboLog2(n);
        // ts.execute(new FloatTimingTask(new TestLog(tf2), q, x));
        // ts.execute(new FloatTimingTask(new TestFastLog(tf2), q, x));
        // For the same precision we can reduce n
        final TurboLog2 tf3 = new TurboLog2(n - 1);
        ts.execute(new FloatTimingTask(new TestLog(tf3), q + 1, x));
        ts.execute(new FloatTimingTask(new TestFastLog(tf3), q + 1, x));
    }
    final int size = ts.getSize();
    ts.repeat(size);
    logger.info(ts.getReport(size));
}
Also used : UniformRandomProvider(org.apache.commons.rng.UniformRandomProvider) TimingService(uk.ac.sussex.gdsc.test.utils.TimingService) SpeedTag(uk.ac.sussex.gdsc.test.junit5.SpeedTag) SeededTest(uk.ac.sussex.gdsc.test.junit5.SeededTest)

Example 9 with TimingService

use of uk.ac.sussex.gdsc.test.utils.TimingService in project GDSC-SMLM by aherbert.

the class MultivariateGaussianMixtureExpectationMaximizationTest method testExpectationMaximizationSpeed.

/**
 * Test the speed of implementations of the expectation maximization algorithm with a mixture of n
 * ND Gaussian distributions.
 *
 * @param seed the seed
 */
@SpeedTag
@SeededTest
void testExpectationMaximizationSpeed(RandomSeed seed) {
    Assumptions.assumeTrue(TestSettings.allow(TestComplexity.HIGH));
    final MultivariateGaussianMixtureExpectationMaximization.DoubleDoubleBiPredicate relChecker = TestHelper.doublesAreClose(1e-6)::test;
    // Create data
    final UniformRandomProvider rng = RngUtils.create(seed.getSeed());
    for (int n = 2; n <= 3; n++) {
        for (int dim = 2; dim <= 4; dim++) {
            final double[][][] data = new double[10][][];
            final int nCorrelations = dim - 1;
            for (int i = 0; i < data.length; i++) {
                final double[] sampleWeights = createWeights(n, rng);
                final double[][] sampleMeans = create(n, dim, rng, -5, 5);
                final double[][] sampleStdDevs = create(n, dim, rng, 1, 10);
                final double[][] sampleCorrelations = IntStream.range(0, n).mapToObj(component -> create(nCorrelations, rng, -0.9, 0.9)).toArray(double[][]::new);
                data[i] = createDataNd(1000, rng, sampleWeights, sampleMeans, sampleStdDevs, sampleCorrelations);
            }
            final int numComponents = n;
            // Time initial estimation and fitting
            final TimingService ts = new TimingService();
            ts.execute(new FittingSpeedTask("Commons n=" + n + " " + dim + "D", data) {

                @Override
                Object run(double[][] data) {
                    final MultivariateNormalMixtureExpectationMaximization fitter = new MultivariateNormalMixtureExpectationMaximization(data);
                    fitter.fit(MultivariateNormalMixtureExpectationMaximization.estimate(data, numComponents));
                    return fitter.getLogLikelihood();
                }
            });
            ts.execute(new FittingSpeedTask("GDSC n=" + n + " " + dim + "D", data) {

                @Override
                Object run(double[][] data) {
                    final MultivariateGaussianMixtureExpectationMaximization fitter = new MultivariateGaussianMixtureExpectationMaximization(data);
                    fitter.fit(MultivariateGaussianMixtureExpectationMaximization.estimate(data, numComponents));
                    return fitter.getLogLikelihood();
                }
            });
            ts.execute(new FittingSpeedTask("GDSC rel 1e-6 n=" + n + " " + dim + "D", data) {

                @Override
                Object run(double[][] data) {
                    final MultivariateGaussianMixtureExpectationMaximization fitter = new MultivariateGaussianMixtureExpectationMaximization(data);
                    fitter.fit(MultivariateGaussianMixtureExpectationMaximization.estimate(data, numComponents), 1000, relChecker);
                    return fitter.getLogLikelihood();
                }
            });
            if (logger.isLoggable(Level.INFO)) {
                logger.info(ts.getReport());
            }
            // More than twice as fast
            Assertions.assertTrue(ts.get(-2).getMean() < ts.get(-3).getMean() / 2);
        }
    }
}
Also used : IntStream(java.util.stream.IntStream) RandomUtils(uk.ac.sussex.gdsc.core.utils.rng.RandomUtils) Arrays(java.util.Arrays) MultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution) BaseTimingTask(uk.ac.sussex.gdsc.test.utils.BaseTimingTask) RngUtils(uk.ac.sussex.gdsc.test.rng.RngUtils) Covariance(org.apache.commons.math3.stat.correlation.Covariance) ArrayList(java.util.ArrayList) Level(java.util.logging.Level) MultivariateNormalMixtureExpectationMaximization(org.apache.commons.math3.distribution.fitting.MultivariateNormalMixtureExpectationMaximization) AfterAll(org.junit.jupiter.api.AfterAll) Mean(org.apache.commons.math3.stat.descriptive.moment.Mean) TimingService(uk.ac.sussex.gdsc.test.utils.TimingService) BeforeAll(org.junit.jupiter.api.BeforeAll) ContinuousUniformSampler(org.apache.commons.rng.sampling.distribution.ContinuousUniformSampler) MultivariateNormalDistribution(org.apache.commons.math3.distribution.MultivariateNormalDistribution) TestComplexity(uk.ac.sussex.gdsc.test.utils.TestComplexity) MixtureMultivariateNormalDistribution(org.apache.commons.math3.distribution.MixtureMultivariateNormalDistribution) MathUtils(uk.ac.sussex.gdsc.core.utils.MathUtils) TestAssertions(uk.ac.sussex.gdsc.test.api.TestAssertions) RandomSeed(uk.ac.sussex.gdsc.test.junit5.RandomSeed) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix) UniformRandomProvider(org.apache.commons.rng.UniformRandomProvider) SpeedTag(uk.ac.sussex.gdsc.test.junit5.SpeedTag) Pair(org.apache.commons.math3.util.Pair) DoubleDoubleBiPredicate(uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate) RandomGeneratorAdapter(uk.ac.sussex.gdsc.core.utils.rng.RandomGeneratorAdapter) Logger(java.util.logging.Logger) SamplerUtils(uk.ac.sussex.gdsc.core.utils.rng.SamplerUtils) SeededTest(uk.ac.sussex.gdsc.test.junit5.SeededTest) Test(org.junit.jupiter.api.Test) List(java.util.List) Assumptions(org.junit.jupiter.api.Assumptions) TestSettings(uk.ac.sussex.gdsc.test.utils.TestSettings) SimpleArrayUtils(uk.ac.sussex.gdsc.core.utils.SimpleArrayUtils) SharedStateContinuousSampler(org.apache.commons.rng.sampling.distribution.SharedStateContinuousSampler) Assertions(org.junit.jupiter.api.Assertions) TestHelper(uk.ac.sussex.gdsc.test.api.TestHelper) MixtureMultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution) NormalizedGaussianSampler(org.apache.commons.rng.sampling.distribution.NormalizedGaussianSampler) LocalList(uk.ac.sussex.gdsc.core.utils.LocalList) UniformRandomProvider(org.apache.commons.rng.UniformRandomProvider) TimingService(uk.ac.sussex.gdsc.test.utils.TimingService) MultivariateNormalMixtureExpectationMaximization(org.apache.commons.math3.distribution.fitting.MultivariateNormalMixtureExpectationMaximization) SpeedTag(uk.ac.sussex.gdsc.test.junit5.SpeedTag) SeededTest(uk.ac.sussex.gdsc.test.junit5.SeededTest)

Example 10 with TimingService

use of uk.ac.sussex.gdsc.test.utils.TimingService in project GDSC-SMLM by aherbert.

the class MultivariateGaussianMixtureExpectationMaximizationTest method testExpectationMaximizationSpeedWithDifferentNumberOfComponents.

/**
 * Test the speed of implementations of the expectation maximization algorithm with a mixture of n
 * 2D Gaussian distributions.
 *
 * @param seed the seed
 */
@SpeedTag
@SeededTest
void testExpectationMaximizationSpeedWithDifferentNumberOfComponents(RandomSeed seed) {
    Assumptions.assumeTrue(TestSettings.allow(TestComplexity.HIGH));
    // Create data
    final UniformRandomProvider rng = RngUtils.create(seed.getSeed());
    for (int n = 2; n <= 4; n++) {
        final double[][][] data = new double[10][][];
        for (int i = 0; i < data.length; i++) {
            final double[] sampleWeights = createWeights(n, rng);
            final double[][] sampleMeans = create(n, 2, rng, -5, 5);
            final double[][] sampleStdDevs = create(n, 2, rng, 1, 10);
            final double[] sampleCorrelations = create(n, rng, -0.9, 0.9);
            data[i] = createData2d(1000, rng, sampleWeights, sampleMeans, sampleStdDevs, sampleCorrelations);
        }
        final int numComponents = n;
        // Time initial estimation and fitting
        final TimingService ts = new TimingService();
        ts.execute(new FittingSpeedTask("Commons n=" + n + " 2D", data) {

            @Override
            Object run(double[][] data) {
                final MultivariateNormalMixtureExpectationMaximization fitter = new MultivariateNormalMixtureExpectationMaximization(data);
                fitter.fit(MultivariateNormalMixtureExpectationMaximization.estimate(data, numComponents));
                return fitter.getLogLikelihood();
            }
        });
        ts.execute(new FittingSpeedTask("GDSC n=" + n + " 2D", data) {

            @Override
            Object run(double[][] data) {
                final MultivariateGaussianMixtureExpectationMaximization fitter = new MultivariateGaussianMixtureExpectationMaximization(data);
                fitter.fit(MultivariateGaussianMixtureExpectationMaximization.estimate(data, numComponents));
                return fitter.getLogLikelihood();
            }
        });
        if (logger.isLoggable(Level.INFO)) {
            logger.info(ts.getReport());
        }
        // More than twice as fast
        Assertions.assertTrue(ts.get(-1).getMean() < ts.get(-2).getMean() / 2);
    }
}
Also used : UniformRandomProvider(org.apache.commons.rng.UniformRandomProvider) TimingService(uk.ac.sussex.gdsc.test.utils.TimingService) MultivariateNormalMixtureExpectationMaximization(org.apache.commons.math3.distribution.fitting.MultivariateNormalMixtureExpectationMaximization) SpeedTag(uk.ac.sussex.gdsc.test.junit5.SpeedTag) SeededTest(uk.ac.sussex.gdsc.test.junit5.SeededTest)

Aggregations

TimingService (uk.ac.sussex.gdsc.test.utils.TimingService)23 SeededTest (uk.ac.sussex.gdsc.test.junit5.SeededTest)15 UniformRandomProvider (org.apache.commons.rng.UniformRandomProvider)14 SpeedTag (uk.ac.sussex.gdsc.test.junit5.SpeedTag)13 TimingResult (uk.ac.sussex.gdsc.test.utils.TimingResult)9 Test (org.junit.jupiter.api.Test)7 LocalList (uk.ac.sussex.gdsc.core.utils.LocalList)6 Gaussian2DFunction (uk.ac.sussex.gdsc.smlm.function.gaussian.Gaussian2DFunction)4 BaseTimingTask (uk.ac.sussex.gdsc.test.utils.BaseTimingTask)3 FloatProcessor (ij.process.FloatProcessor)2 MultivariateNormalMixtureExpectationMaximization (org.apache.commons.math3.distribution.fitting.MultivariateNormalMixtureExpectationMaximization)2 SharedStateContinuousSampler (org.apache.commons.rng.sampling.distribution.SharedStateContinuousSampler)2 DenseMatrix64F (org.ejml.data.DenseMatrix64F)2 GradientCalculator (uk.ac.sussex.gdsc.smlm.fitting.nonlinear.gradient.GradientCalculator)2 ValueProcedure (uk.ac.sussex.gdsc.smlm.function.ValueProcedure)2 Gaussian2DFunctionTest (uk.ac.sussex.gdsc.smlm.function.gaussian.Gaussian2DFunctionTest)2 ImageProcessor (ij.process.ImageProcessor)1 Rectangle (java.awt.Rectangle)1 ArrayList (java.util.ArrayList)1 Arrays (java.util.Arrays)1