use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution in project GDSC-SMLM by aherbert.
the class TrackPopulationAnalysis method createMixed.
/**
* Creates the multivariate gaussian mixture as the best of many repeats of the expectation
* maximisation algorithm.
*
* @param data the data
* @param dimensions the dimensions
* @param numComponents the number of components
* @return the multivariate gaussian mixture expectation maximization
*/
private MultivariateGaussianMixtureExpectationMaximization createMixed(final double[][] data, int dimensions, int numComponents) {
// Fit a mixed multivariate Gaussian with different repeats.
final UnitSphereSampler sampler = UnitSphereSampler.of(UniformRandomProviders.create(Mixers.stafford13(settings.seed++)), dimensions);
final LocalList<CompletableFuture<MultivariateGaussianMixtureExpectationMaximization>> results = new LocalList<>(settings.repeats);
final DoubleDoubleBiPredicate test = createConvergenceTest(settings.relativeError);
if (settings.debug) {
ImageJUtils.log(" Fitting %d components", numComponents);
}
final Ticker ticker = ImageJUtils.createTicker(settings.repeats, 2, "Fitting...");
final AtomicInteger failures = new AtomicInteger();
for (int i = 0; i < settings.repeats; i++) {
final double[] vector = sampler.sample();
results.add(CompletableFuture.supplyAsync(() -> {
final MultivariateGaussianMixtureExpectationMaximization fitter = new MultivariateGaussianMixtureExpectationMaximization(data);
try {
// This may also throw the same exceptions due to inversion of the covariance matrix
final MixtureMultivariateGaussianDistribution initialMixture = MultivariateGaussianMixtureExpectationMaximization.estimate(data, numComponents, point -> {
double dot = 0;
for (int j = 0; j < dimensions; j++) {
dot += vector[j] * point[j];
}
return dot;
});
final boolean result = fitter.fit(initialMixture, settings.maxIterations, test);
// Log the result. Note: The ImageJ log is synchronized.
if (settings.debug) {
ImageJUtils.log(" Fit: log-likelihood=%s, iter=%d, converged=%b", fitter.getLogLikelihood(), fitter.getIterations(), result);
}
return result ? fitter : null;
} catch (NonPositiveDefiniteMatrixException | SingularMatrixException ex) {
failures.getAndIncrement();
if (settings.debug) {
ImageJUtils.log(" Fit failed during iteration %d. No variance in a sub-population " + "component (check alpha is not always 1.0).", fitter.getIterations());
}
} finally {
ticker.tick();
}
return null;
}));
}
ImageJUtils.finished();
if (failures.get() != 0 && settings.debug) {
ImageJUtils.log(" %d component fit failed %d/%d", numComponents, failures.get(), settings.repeats);
}
// Collect results and return the best model.
return results.stream().map(f -> f.join()).filter(f -> f != null).sorted((f1, f2) -> Double.compare(f2.getLogLikelihood(), f1.getLogLikelihood())).findFirst().orElse(null);
}
use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution in project GDSC-SMLM by aherbert.
the class TrackPopulationAnalysis method createModelTable.
/**
* Creates a table to show the final model. This uses the assignments to create a mixture model
* from the original data.
*
* @param data the data
* @param weights the weights for each component
* @param component the component
*/
private static void createModelTable(double[][] data, double[] weights, int[] component) {
final MixtureMultivariateGaussianDistribution model = MultivariateGaussianMixtureExpectationMaximization.createMixed(data, component);
final MultivariateGaussianDistribution[] distributions = model.getDistributions();
// Get the fraction of each component
final int[] count = new int[MathUtils.max(component) + 1];
Arrays.stream(component).forEach(c -> count[c]++);
try (BufferedTextWindow tw = new BufferedTextWindow(ImageJUtils.refresh(modelTableRef, () -> new TextWindow("Track Population Model", createHeader(), "", 800, 300)))) {
final StringBuilder sb = new StringBuilder();
for (int i = 0; i < weights.length; i++) {
sb.setLength(0);
sb.append(i).append('\t');
sb.append(MathUtils.rounded((double) count[i] / component.length)).append('\t');
sb.append(MathUtils.rounded(weights[i]));
final double[] means = distributions[i].getMeans();
final double[] sd = distributions[i].getStandardDeviations();
for (int j = 0; j < means.length; j++) {
sb.append('\t').append(MathUtils.rounded(means[j])).append('\t').append(MathUtils.rounded(sd[j]));
}
tw.append(sb.toString());
}
}
}
use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution in project GDSC-SMLM by aherbert.
the class MultivariateGaussianMixtureExpectationMaximizationTest method testFitThrows.
@Test
void testFitThrows() {
// This does not matter for the initial checks
final MixtureMultivariateGaussianDistribution initialMixture = null;
final MultivariateGaussianMixtureExpectationMaximization fitter = new MultivariateGaussianMixtureExpectationMaximization(new double[2][3]);
// Test initial model is null and the likelihood is zero
Assertions.assertEquals(0, fitter.getLogLikelihood());
Assertions.assertEquals(0, fitter.getIterations());
Assertions.assertNull(fitter.getFittedModel());
// Valid parameters
final int maxIterations = 10;
// Not positive iterations
Assertions.assertThrows(IllegalArgumentException.class, () -> {
fitter.fit(initialMixture, 0, DEFAULT_CONVERGENCE_CHECKER);
});
// Not convergence checker
Assertions.assertThrows(NullPointerException.class, () -> {
fitter.fit(initialMixture, maxIterations, null);
});
// Null mixture
Assertions.assertThrows(NullPointerException.class, () -> {
fitter.fit(initialMixture, maxIterations, DEFAULT_CONVERGENCE_CHECKER);
});
// Incorrect data dimensions. Create a 50-50 mixture of 2D Gaussians
final MixtureMultivariateGaussianDistribution initialMixture2 = new MixtureMultivariateGaussianDistribution(new double[] { 0.5, 0.5 }, new MultivariateGaussianDistribution[] { new MultivariateGaussianDistribution(new double[2], new double[][] { { 1, 0 }, { 0, 2 } }), new MultivariateGaussianDistribution(new double[2], new double[][] { { 1, 0 }, { 0, 2 } }) });
Assertions.assertThrows(IllegalArgumentException.class, () -> {
fitter.fit(initialMixture2, maxIterations, DEFAULT_CONVERGENCE_CHECKER);
});
}
use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution in project GDSC-SMLM by aherbert.
the class MultivariateGaussianMixtureExpectationMaximizationTest method canCreateMixedMultivariateGaussianDistribution.
@SeededTest
void canCreateMixedMultivariateGaussianDistribution(RandomSeed seed) {
// Will be normalised
final double[] weights = { 1, 1 };
final double[][] means = new double[2][];
final double[][][] covariances = new double[2][][];
final double[][] data1 = { { 1, 2 }, { 2.5, 1.5 }, { 3.5, 1.0 } };
final double[][] data2 = { { 4, 2 }, { 3.5, -1.5 }, { -3.5, 1.0 } };
means[0] = getColumnMeans(data1);
covariances[0] = getCovariance(data1);
means[1] = getColumnMeans(data2);
covariances[1] = getCovariance(data2);
// Create components. This does not have to be zero based.
final LocalList<double[]> list = new LocalList<>();
list.addAll(Arrays.asList(data1));
list.addAll(Arrays.asList(data2));
final double[][] data = list.toArray(new double[0][]);
final int[] components = { -1, -1, -1, 3, 3, 3 };
// Randomise the data
for (int n = 0; n < 3; n++) {
final long start = n + seed.getSeedAsLong();
// This relies on the shuffle being the same
RandomUtils.shuffle(data, RngUtils.create(start));
RandomUtils.shuffle(components, RngUtils.create(start));
final MixtureMultivariateGaussianDistribution dist = MultivariateGaussianMixtureExpectationMaximization.createMixed(data, components);
Assertions.assertArrayEquals(new double[] { 0.5, 0.5 }, dist.getWeights());
final MultivariateGaussianDistribution[] distributions = dist.getDistributions();
Assertions.assertEquals(weights.length, distributions.length);
final DoubleDoubleBiPredicate test = TestHelper.doublesAreClose(1e-8);
for (int i = 0; i < means.length; i++) {
TestAssertions.assertArrayTest(means[i], distributions[i].getMeans(), test);
TestAssertions.assertArrayTest(covariances[i], distributions[i].getCovariances(), test);
}
}
}
use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution in project GDSC-SMLM by aherbert.
the class MultivariateGaussianMixtureExpectationMaximizationTest method testCreateMixtureMultivariateGaussianDistributionThrows.
@Test
void testCreateMixtureMultivariateGaussianDistributionThrows() {
// Will be normalised
final double[] weights = { 1, 3 };
final double[][] means = new double[2][];
final double[][][] covariances = new double[2][][];
final double[][] data = { { 1, 2 }, { 2.5, 1.5 }, { 3.5, 1.0 } };
final double[][] data2 = { { 4, 2 }, { 3.5, -1.5 }, { -3.5, 1.0 } };
means[0] = getColumnMeans(data);
covariances[0] = getCovariance(data);
means[1] = getColumnMeans(data2);
covariances[1] = getCovariance(data2);
// Test with means and covariances
// Non positive weights
Assertions.assertThrows(IllegalArgumentException.class, () -> {
MixtureMultivariateGaussianDistribution.create(new double[] { -1, 1 }, means, covariances);
});
// Weights sum is not finite
Assertions.assertThrows(IllegalArgumentException.class, () -> {
MixtureMultivariateGaussianDistribution.create(new double[] { Double.MAX_VALUE, Double.MAX_VALUE }, means, covariances);
});
// Incorrect size mean
Assertions.assertThrows(IllegalArgumentException.class, () -> {
final double[][] means2 = means.clone();
means2[0] = Arrays.copyOf(means2[0], means2[0].length + 1);
MixtureMultivariateGaussianDistribution.create(weights, means2, covariances);
});
// Bad covariance matrix
Assertions.assertThrows(IllegalArgumentException.class, () -> {
final double[][][] covariances2 = covariances.clone();
covariances2[0] = new double[3][2];
MixtureMultivariateGaussianDistribution.create(weights, means, covariances2);
});
// Test with the weights and distributions.
// Create a valid mixture to use the distributions.
final MixtureMultivariateGaussianDistribution dist = MixtureMultivariateGaussianDistribution.create(weights, means, covariances);
// Weights sum is not finite
Assertions.assertThrows(IllegalArgumentException.class, () -> {
MixtureMultivariateGaussianDistribution.create(new double[] { Double.MAX_VALUE, Double.MAX_VALUE }, dist.getDistributions());
});
// Weights and distributions length mismatch
Assertions.assertThrows(IllegalArgumentException.class, () -> {
MixtureMultivariateGaussianDistribution.create(new double[] { 1 }, dist.getDistributions());
});
}
Aggregations