use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution in project GDSC-SMLM by aherbert.
the class MultivariateGaussianMixtureExpectationMaximizationTest method canFit.
@SeededTest
void canFit(RandomSeed seed) {
// Test verses the Commons Math estimation
final UniformRandomProvider rng = RngUtils.create(seed.getSeed());
final DoubleDoubleBiPredicate test = TestHelper.doublesAreClose(1e-5, 1e-16);
final int sampleSize = 1000;
// Number of components
for (int n = 2; n <= 3; n++) {
final double[] sampleWeights = createWeights(n, rng);
final double[][] sampleMeans = create(n, 2, rng, -5, 5);
final double[][] sampleStdDevs = create(n, 2, rng, 1, 10);
final double[] sampleCorrelations = create(n, rng, -0.9, 0.9);
final double[][] data = createData2d(sampleSize, rng, sampleWeights, sampleMeans, sampleStdDevs, sampleCorrelations);
final MixtureMultivariateGaussianDistribution initialModel1 = MultivariateGaussianMixtureExpectationMaximization.estimate(data, n);
final MultivariateGaussianMixtureExpectationMaximization fitter1 = new MultivariateGaussianMixtureExpectationMaximization(data);
Assertions.assertTrue(fitter1.fit(initialModel1));
final MultivariateNormalMixtureExpectationMaximization fitter2 = new MultivariateNormalMixtureExpectationMaximization(data);
fitter2.fit(MultivariateNormalMixtureExpectationMaximization.estimate(data, n));
final double ll1 = fitter1.getLogLikelihood() / sampleSize;
Assertions.assertNotEquals(0, ll1);
final double ll2 = fitter2.getLogLikelihood();
TestAssertions.assertTest(ll2, ll1, test);
final MixtureMultivariateGaussianDistribution model1 = fitter1.getFittedModel();
Assertions.assertNotNull(model1);
final MixtureMultivariateNormalDistribution model2 = fitter2.getFittedModel();
// Check fitted models are the same
final List<Pair<Double, MultivariateNormalDistribution>> comp = model2.getComponents();
final double[] weights = model1.getWeights();
final MultivariateGaussianDistribution[] distributions = model1.getDistributions();
Assertions.assertEquals(n, comp.size());
Assertions.assertEquals(n, weights.length);
Assertions.assertEquals(n, distributions.length);
for (int i = 0; i < n; i++) {
TestAssertions.assertTest(comp.get(i).getFirst(), weights[i], test, "weight");
final MultivariateNormalDistribution d = comp.get(i).getSecond();
TestAssertions.assertArrayTest(d.getMeans(), distributions[i].getMeans(), test, "means");
TestAssertions.assertArrayTest(d.getCovariances().getData(), distributions[i].getCovariances(), test, "covariances");
}
final int iterations = fitter1.getIterations();
Assertions.assertNotEquals(0, iterations);
// Test without convergence
if (iterations > 2) {
Assertions.assertFalse(fitter1.fit(initialModel1, 2, DEFAULT_CONVERGENCE_CHECKER));
}
}
}
use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution in project GDSC-SMLM by aherbert.
the class MultivariateGaussianMixtureExpectationMaximizationTest method canCreateMixtureMultivariateGaussianDistribution.
@Test
void canCreateMixtureMultivariateGaussianDistribution() {
// Will be normalised
final double[] weights = { 1, 3 };
final double[][] means = new double[2][];
final double[][][] covariances = new double[2][][];
final double[][] data = { { 1, 2 }, { 2.5, 1.5 }, { 3.5, 1.0 } };
final double[][] data2 = { { 4, 2 }, { 3.5, -1.5 }, { -3.5, 1.0 } };
means[0] = getColumnMeans(data);
covariances[0] = getCovariance(data);
means[1] = getColumnMeans(data2);
covariances[1] = getCovariance(data2);
final MixtureMultivariateGaussianDistribution dist = MixtureMultivariateGaussianDistribution.create(weights, means, covariances);
Assertions.assertArrayEquals(new double[] { 0.25, 0.75 }, dist.getWeights());
final MultivariateGaussianDistribution[] distributions = dist.getDistributions();
Assertions.assertEquals(weights.length, distributions.length);
for (int i = 0; i < means.length; i++) {
Assertions.assertArrayEquals(means[i], distributions[i].getMeans());
Assertions.assertArrayEquals(covariances[i], distributions[i].getCovariances());
}
// Test against Apache commons
final MixtureMultivariateNormalDistribution expDist = new MixtureMultivariateNormalDistribution(weights, means, covariances);
for (final double[] x : data) {
Assertions.assertEquals(expDist.density(x), dist.density(x), 1e-10);
}
// Test the package private create method normalises the weights
Assertions.assertArrayEquals(new double[] { 1, 3 }, weights);
final MixtureMultivariateGaussianDistribution dist2 = MixtureMultivariateGaussianDistribution.create(weights, distributions);
// Stored by reference
Assertions.assertArrayEquals(weights, dist2.getWeights());
// Normalised in-place
Assertions.assertArrayEquals(new double[] { 0.25, 0.75 }, weights);
}
Aggregations