use of uk.ac.sussex.gdsc.test.junit5.SeededTest in project GDSC-SMLM by aherbert.
the class MultivariateGaussianMixtureExpectationMaximizationTest method testExpectationMaximizationSpeed.
/**
* Test the speed of implementations of the expectation maximization algorithm with a mixture of n
* ND Gaussian distributions.
*
* @param seed the seed
*/
@SpeedTag
@SeededTest
void testExpectationMaximizationSpeed(RandomSeed seed) {
Assumptions.assumeTrue(TestSettings.allow(TestComplexity.HIGH));
final MultivariateGaussianMixtureExpectationMaximization.DoubleDoubleBiPredicate relChecker = TestHelper.doublesAreClose(1e-6)::test;
// Create data
final UniformRandomProvider rng = RngUtils.create(seed.getSeed());
for (int n = 2; n <= 3; n++) {
for (int dim = 2; dim <= 4; dim++) {
final double[][][] data = new double[10][][];
final int nCorrelations = dim - 1;
for (int i = 0; i < data.length; i++) {
final double[] sampleWeights = createWeights(n, rng);
final double[][] sampleMeans = create(n, dim, rng, -5, 5);
final double[][] sampleStdDevs = create(n, dim, rng, 1, 10);
final double[][] sampleCorrelations = IntStream.range(0, n).mapToObj(component -> create(nCorrelations, rng, -0.9, 0.9)).toArray(double[][]::new);
data[i] = createDataNd(1000, rng, sampleWeights, sampleMeans, sampleStdDevs, sampleCorrelations);
}
final int numComponents = n;
// Time initial estimation and fitting
final TimingService ts = new TimingService();
ts.execute(new FittingSpeedTask("Commons n=" + n + " " + dim + "D", data) {
@Override
Object run(double[][] data) {
final MultivariateNormalMixtureExpectationMaximization fitter = new MultivariateNormalMixtureExpectationMaximization(data);
fitter.fit(MultivariateNormalMixtureExpectationMaximization.estimate(data, numComponents));
return fitter.getLogLikelihood();
}
});
ts.execute(new FittingSpeedTask("GDSC n=" + n + " " + dim + "D", data) {
@Override
Object run(double[][] data) {
final MultivariateGaussianMixtureExpectationMaximization fitter = new MultivariateGaussianMixtureExpectationMaximization(data);
fitter.fit(MultivariateGaussianMixtureExpectationMaximization.estimate(data, numComponents));
return fitter.getLogLikelihood();
}
});
ts.execute(new FittingSpeedTask("GDSC rel 1e-6 n=" + n + " " + dim + "D", data) {
@Override
Object run(double[][] data) {
final MultivariateGaussianMixtureExpectationMaximization fitter = new MultivariateGaussianMixtureExpectationMaximization(data);
fitter.fit(MultivariateGaussianMixtureExpectationMaximization.estimate(data, numComponents), 1000, relChecker);
return fitter.getLogLikelihood();
}
});
if (logger.isLoggable(Level.INFO)) {
logger.info(ts.getReport());
}
// More than twice as fast
Assertions.assertTrue(ts.get(-2).getMean() < ts.get(-3).getMean() / 2);
}
}
}
use of uk.ac.sussex.gdsc.test.junit5.SeededTest in project GDSC-SMLM by aherbert.
the class MultivariateGaussianMixtureExpectationMaximizationTest method canCreateMixedMultivariateGaussianDistribution.
@SeededTest
void canCreateMixedMultivariateGaussianDistribution(RandomSeed seed) {
// Will be normalised
final double[] weights = { 1, 1 };
final double[][] means = new double[2][];
final double[][][] covariances = new double[2][][];
final double[][] data1 = { { 1, 2 }, { 2.5, 1.5 }, { 3.5, 1.0 } };
final double[][] data2 = { { 4, 2 }, { 3.5, -1.5 }, { -3.5, 1.0 } };
means[0] = getColumnMeans(data1);
covariances[0] = getCovariance(data1);
means[1] = getColumnMeans(data2);
covariances[1] = getCovariance(data2);
// Create components. This does not have to be zero based.
final LocalList<double[]> list = new LocalList<>();
list.addAll(Arrays.asList(data1));
list.addAll(Arrays.asList(data2));
final double[][] data = list.toArray(new double[0][]);
final int[] components = { -1, -1, -1, 3, 3, 3 };
// Randomise the data
for (int n = 0; n < 3; n++) {
final long start = n + seed.getSeedAsLong();
// This relies on the shuffle being the same
RandomUtils.shuffle(data, RngUtils.create(start));
RandomUtils.shuffle(components, RngUtils.create(start));
final MixtureMultivariateGaussianDistribution dist = MultivariateGaussianMixtureExpectationMaximization.createMixed(data, components);
Assertions.assertArrayEquals(new double[] { 0.5, 0.5 }, dist.getWeights());
final MultivariateGaussianDistribution[] distributions = dist.getDistributions();
Assertions.assertEquals(weights.length, distributions.length);
final DoubleDoubleBiPredicate test = TestHelper.doublesAreClose(1e-8);
for (int i = 0; i < means.length; i++) {
TestAssertions.assertArrayTest(means[i], distributions[i].getMeans(), test);
TestAssertions.assertArrayTest(covariances[i], distributions[i].getCovariances(), test);
}
}
}
use of uk.ac.sussex.gdsc.test.junit5.SeededTest in project GDSC-SMLM by aherbert.
the class PeakResultDigestTest method sameSize1ResultsAreEqual.
@SeededTest
void sameSize1ResultsAreEqual(RandomSeed seed) {
final UniformRandomProvider r = RngUtils.create(seed.getSeed());
final PeakResult[] r1 = createResults(r, 1, 5, false, false, false, false);
final PeakResultsDigest digest = new PeakResultsDigest(r1);
Assertions.assertTrue(digest.matches(r1));
Assertions.assertTrue(digest.matches(digest));
}
use of uk.ac.sussex.gdsc.test.junit5.SeededTest in project GDSC-SMLM by aherbert.
the class PeakResultDigestTest method sameResultsAreEqualWithDeviation.
@SeededTest
void sameResultsAreEqualWithDeviation(RandomSeed seed) {
final UniformRandomProvider r = RngUtils.create(seed.getSeed());
final PeakResult[] r1 = createResults(r, 10, 5, true, false, false, false);
final PeakResultsDigest digest = new PeakResultsDigest(r1);
Assertions.assertTrue(digest.matches(r1));
Assertions.assertTrue(digest.matches(digest));
}
use of uk.ac.sussex.gdsc.test.junit5.SeededTest in project GDSC-SMLM by aherbert.
the class PeakResultDigestTest method differentResultsAreNotEqual.
@SeededTest
void differentResultsAreNotEqual(RandomSeed seed) {
final UniformRandomProvider r = RngUtils.create(seed.getSeed());
final PeakResult[] r1 = createResults(r, 10, 5, false, false, false, false);
final PeakResultsDigest digest = new PeakResultsDigest(r1);
for (final int size : new int[] { 10, 1, 0 }) {
final PeakResult[] r2 = createResults(r, size, 5, false, false, false, false);
Assertions.assertFalse(digest.matches(r2));
Assertions.assertFalse(digest.matches(new PeakResultsDigest(r2)));
}
}
Aggregations