use of org.apache.commons.math3.distribution.MultivariateNormalDistribution in project GDSC-SMLM by aherbert.
the class MultivariateGaussianMixtureExpectationMaximizationTest method canFit.
@SeededTest
void canFit(RandomSeed seed) {
// Test verses the Commons Math estimation
final UniformRandomProvider rng = RngUtils.create(seed.getSeed());
final DoubleDoubleBiPredicate test = TestHelper.doublesAreClose(1e-5, 1e-16);
final int sampleSize = 1000;
// Number of components
for (int n = 2; n <= 3; n++) {
final double[] sampleWeights = createWeights(n, rng);
final double[][] sampleMeans = create(n, 2, rng, -5, 5);
final double[][] sampleStdDevs = create(n, 2, rng, 1, 10);
final double[] sampleCorrelations = create(n, rng, -0.9, 0.9);
final double[][] data = createData2d(sampleSize, rng, sampleWeights, sampleMeans, sampleStdDevs, sampleCorrelations);
final MixtureMultivariateGaussianDistribution initialModel1 = MultivariateGaussianMixtureExpectationMaximization.estimate(data, n);
final MultivariateGaussianMixtureExpectationMaximization fitter1 = new MultivariateGaussianMixtureExpectationMaximization(data);
Assertions.assertTrue(fitter1.fit(initialModel1));
final MultivariateNormalMixtureExpectationMaximization fitter2 = new MultivariateNormalMixtureExpectationMaximization(data);
fitter2.fit(MultivariateNormalMixtureExpectationMaximization.estimate(data, n));
final double ll1 = fitter1.getLogLikelihood() / sampleSize;
Assertions.assertNotEquals(0, ll1);
final double ll2 = fitter2.getLogLikelihood();
TestAssertions.assertTest(ll2, ll1, test);
final MixtureMultivariateGaussianDistribution model1 = fitter1.getFittedModel();
Assertions.assertNotNull(model1);
final MixtureMultivariateNormalDistribution model2 = fitter2.getFittedModel();
// Check fitted models are the same
final List<Pair<Double, MultivariateNormalDistribution>> comp = model2.getComponents();
final double[] weights = model1.getWeights();
final MultivariateGaussianDistribution[] distributions = model1.getDistributions();
Assertions.assertEquals(n, comp.size());
Assertions.assertEquals(n, weights.length);
Assertions.assertEquals(n, distributions.length);
for (int i = 0; i < n; i++) {
TestAssertions.assertTest(comp.get(i).getFirst(), weights[i], test, "weight");
final MultivariateNormalDistribution d = comp.get(i).getSecond();
TestAssertions.assertArrayTest(d.getMeans(), distributions[i].getMeans(), test, "means");
TestAssertions.assertArrayTest(d.getCovariances().getData(), distributions[i].getCovariances(), test, "covariances");
}
final int iterations = fitter1.getIterations();
Assertions.assertNotEquals(0, iterations);
// Test without convergence
if (iterations > 2) {
Assertions.assertFalse(fitter1.fit(initialModel1, 2, DEFAULT_CONVERGENCE_CHECKER));
}
}
}
Aggregations