use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution in project GDSC-SMLM by aherbert.
the class TrackPopulationAnalysis method createModelTable.
/**
* Creates a table to show the final model. This uses the assignments to create a mixture model
* from the original data.
*
* @param data the data
* @param weights the weights for each component
* @param component the component
*/
private static void createModelTable(double[][] data, double[] weights, int[] component) {
final MixtureMultivariateGaussianDistribution model = MultivariateGaussianMixtureExpectationMaximization.createMixed(data, component);
final MultivariateGaussianDistribution[] distributions = model.getDistributions();
// Get the fraction of each component
final int[] count = new int[MathUtils.max(component) + 1];
Arrays.stream(component).forEach(c -> count[c]++);
try (BufferedTextWindow tw = new BufferedTextWindow(ImageJUtils.refresh(modelTableRef, () -> new TextWindow("Track Population Model", createHeader(), "", 800, 300)))) {
final StringBuilder sb = new StringBuilder();
for (int i = 0; i < weights.length; i++) {
sb.setLength(0);
sb.append(i).append('\t');
sb.append(MathUtils.rounded((double) count[i] / component.length)).append('\t');
sb.append(MathUtils.rounded(weights[i]));
final double[] means = distributions[i].getMeans();
final double[] sd = distributions[i].getStandardDeviations();
for (int j = 0; j < means.length; j++) {
sb.append('\t').append(MathUtils.rounded(means[j])).append('\t').append(MathUtils.rounded(sd[j]));
}
tw.append(sb.toString());
}
}
}
use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution in project GDSC-SMLM by aherbert.
the class MultivariateGaussianMixtureExpectationMaximizationTest method testFitThrows.
@Test
void testFitThrows() {
// This does not matter for the initial checks
final MixtureMultivariateGaussianDistribution initialMixture = null;
final MultivariateGaussianMixtureExpectationMaximization fitter = new MultivariateGaussianMixtureExpectationMaximization(new double[2][3]);
// Test initial model is null and the likelihood is zero
Assertions.assertEquals(0, fitter.getLogLikelihood());
Assertions.assertEquals(0, fitter.getIterations());
Assertions.assertNull(fitter.getFittedModel());
// Valid parameters
final int maxIterations = 10;
// Not positive iterations
Assertions.assertThrows(IllegalArgumentException.class, () -> {
fitter.fit(initialMixture, 0, DEFAULT_CONVERGENCE_CHECKER);
});
// Not convergence checker
Assertions.assertThrows(NullPointerException.class, () -> {
fitter.fit(initialMixture, maxIterations, null);
});
// Null mixture
Assertions.assertThrows(NullPointerException.class, () -> {
fitter.fit(initialMixture, maxIterations, DEFAULT_CONVERGENCE_CHECKER);
});
// Incorrect data dimensions. Create a 50-50 mixture of 2D Gaussians
final MixtureMultivariateGaussianDistribution initialMixture2 = new MixtureMultivariateGaussianDistribution(new double[] { 0.5, 0.5 }, new MultivariateGaussianDistribution[] { new MultivariateGaussianDistribution(new double[2], new double[][] { { 1, 0 }, { 0, 2 } }), new MultivariateGaussianDistribution(new double[2], new double[][] { { 1, 0 }, { 0, 2 } }) });
Assertions.assertThrows(IllegalArgumentException.class, () -> {
fitter.fit(initialMixture2, maxIterations, DEFAULT_CONVERGENCE_CHECKER);
});
}
use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution in project GDSC-SMLM by aherbert.
the class MultivariateGaussianMixtureExpectationMaximizationTest method canCreateMixedMultivariateGaussianDistribution.
@SeededTest
void canCreateMixedMultivariateGaussianDistribution(RandomSeed seed) {
// Will be normalised
final double[] weights = { 1, 1 };
final double[][] means = new double[2][];
final double[][][] covariances = new double[2][][];
final double[][] data1 = { { 1, 2 }, { 2.5, 1.5 }, { 3.5, 1.0 } };
final double[][] data2 = { { 4, 2 }, { 3.5, -1.5 }, { -3.5, 1.0 } };
means[0] = getColumnMeans(data1);
covariances[0] = getCovariance(data1);
means[1] = getColumnMeans(data2);
covariances[1] = getCovariance(data2);
// Create components. This does not have to be zero based.
final LocalList<double[]> list = new LocalList<>();
list.addAll(Arrays.asList(data1));
list.addAll(Arrays.asList(data2));
final double[][] data = list.toArray(new double[0][]);
final int[] components = { -1, -1, -1, 3, 3, 3 };
// Randomise the data
for (int n = 0; n < 3; n++) {
final long start = n + seed.getSeedAsLong();
// This relies on the shuffle being the same
RandomUtils.shuffle(data, RngUtils.create(start));
RandomUtils.shuffle(components, RngUtils.create(start));
final MixtureMultivariateGaussianDistribution dist = MultivariateGaussianMixtureExpectationMaximization.createMixed(data, components);
Assertions.assertArrayEquals(new double[] { 0.5, 0.5 }, dist.getWeights());
final MultivariateGaussianDistribution[] distributions = dist.getDistributions();
Assertions.assertEquals(weights.length, distributions.length);
final DoubleDoubleBiPredicate test = TestHelper.doublesAreClose(1e-8);
for (int i = 0; i < means.length; i++) {
TestAssertions.assertArrayTest(means[i], distributions[i].getMeans(), test);
TestAssertions.assertArrayTest(covariances[i], distributions[i].getCovariances(), test);
}
}
}
use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution in project GDSC-SMLM by aherbert.
the class MultivariateGaussianMixtureExpectationMaximizationTest method canCreateMultivariateGaussianDistribution.
@Test
void canCreateMultivariateGaussianDistribution() {
final double[][] data = { { 1, 2 }, { 2.5, 1.5 }, { 3.5, 1.0 } };
final double[] means = getColumnMeans(data);
final double[][] covariances = getCovariance(data);
final MultivariateGaussianDistribution dist = MultivariateGaussianDistribution.create(means, covariances);
Assertions.assertSame(means, dist.getMeans());
Assertions.assertSame(covariances, dist.getCovariances());
final double[] sd = dist.getStandardDeviations();
Assertions.assertEquals(covariances.length, sd.length);
for (int i = 0; i < sd.length; i++) {
Assertions.assertEquals(Math.sqrt(covariances[i][i]), sd[i]);
}
// Test against Apache commons
final MultivariateNormalDistribution expDist = new MultivariateNormalDistribution(means, covariances);
for (final double[] x : data) {
Assertions.assertEquals(expDist.density(x), dist.density(x));
}
}
use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution in project GDSC-SMLM by aherbert.
the class TrackPopulationAnalysis method fitGaussianMixture.
/**
* Fit the Gaussian mixture to the data. The fitter with the highest likelihood from a number of
* repeats is returned.
*
* @param data the data
* @param sortDimension the sort dimension
* @return the multivariate gaussian mixture
*/
private MultivariateGaussianMixtureExpectationMaximization fitGaussianMixture(final double[][] data, int sortDimension) {
// Get the unmixed multivariate Guassian.
MultivariateGaussianDistribution unmixed = MultivariateGaussianMixtureExpectationMaximization.createUnmixed(data);
// Normalise the columns of the data
// Get means and SD of each column
final double[] means = unmixed.getMeans();
final double[] sd = unmixed.getStandardDeviations();
final int dimensions = means.length;
for (final double[] value : data) {
for (int i = 0; i < dimensions; i++) {
value[i] = (value[i] - means[i]) / sd[i];
}
}
// Repeat. The mean should be approximately 0 and std.dev. 1.
unmixed = MultivariateGaussianMixtureExpectationMaximization.createUnmixed(data);
// Record the likelihood of the unmixed model
double logLikelihood = Arrays.stream(data).mapToDouble(unmixed::density).map(Math::log).sum();
// x means, x*x covariances
final int parametersPerGaussian = dimensions + dimensions * dimensions;
double aic = MathUtils.getAkaikeInformationCriterion(logLikelihood, parametersPerGaussian);
double bic = MathUtils.getBayesianInformationCriterion(logLikelihood, data.length, parametersPerGaussian);
ImageJUtils.log("1 component log-likelihood=%s. AIC=%s. BIC=%s", logLikelihood, aic, bic);
// Fit a mixed component model.
// Increment the number of components up to a maximim or when the model does not improve.
MultivariateGaussianMixtureExpectationMaximization mixed = null;
for (int numComponents = 2; numComponents <= settings.maxComponents; numComponents++) {
final MultivariateGaussianMixtureExpectationMaximization mixed2 = createMixed(data, dimensions, numComponents);
if (mixed2 == null) {
ImageJUtils.log("Failed to fit a %d component mixture model", numComponents);
break;
}
final double logLikelihood2 = mixed2.getLogLikelihood();
// n * (means, covariances, 1 weight) - 1
// (Note: subtract 1 as the weights are constrained by summing to 1)
final int param2 = numComponents * (parametersPerGaussian + 1) - 1;
final double aic2 = MathUtils.getAkaikeInformationCriterion(logLikelihood2, param2);
final double bic2 = MathUtils.getBayesianInformationCriterion(logLikelihood2, data.length, param2);
// Log-likelihood ratio test statistic
final double lambdaLr = -2 * (logLikelihood - logLikelihood2);
// DF = difference in dimensionality from previous number of components
// means, covariances, 1 weight
final int degreesOfFreedom = parametersPerGaussian + 1;
final double q = ChiSquaredDistributionTable.computeQValue(lambdaLr, degreesOfFreedom);
ImageJUtils.log("%d component log-likelihood=%s. AIC=%s. BIC=%s. LLR significance=%s.", numComponents, logLikelihood2, aic2, bic2, MathUtils.rounded(q));
final double[] weights = mixed2.getFittedModel().getWeights();
// For consistency sort the mixture by the mean of the diffusion coefficient
final double[] values = Arrays.stream(mixed2.getFittedModel().getDistributions()).mapToDouble(d -> d.getMeans()[sortDimension]).toArray();
SortUtils.sortData(weights, values, false, false);
ImageJUtils.log("Population weights: " + Arrays.toString(weights));
if (MathUtils.min(weights) < settings.minWeight) {
ImageJUtils.log("%d component model has population weight %s under minimum level %s", numComponents, MathUtils.min(weights), settings.minWeight);
break;
}
if (aic <= aic2 || bic <= bic2 || q > 0.001) {
ImageJUtils.log("%d component model is not significant", numComponents);
break;
}
aic = aic2;
bic = bic2;
logLikelihood = logLikelihood2;
mixed = mixed2;
}
return mixed;
}
Aggregations