use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization in project GDSC-SMLM by aherbert.
the class TrackPopulationAnalysis method createMixed.
/**
* Creates the multivariate gaussian mixture as the best of many repeats of the expectation
* maximisation algorithm.
*
* @param data the data
* @param dimensions the dimensions
* @param numComponents the number of components
* @return the multivariate gaussian mixture expectation maximization
*/
private MultivariateGaussianMixtureExpectationMaximization createMixed(final double[][] data, int dimensions, int numComponents) {
// Fit a mixed multivariate Gaussian with different repeats.
final UnitSphereSampler sampler = UnitSphereSampler.of(UniformRandomProviders.create(Mixers.stafford13(settings.seed++)), dimensions);
final LocalList<CompletableFuture<MultivariateGaussianMixtureExpectationMaximization>> results = new LocalList<>(settings.repeats);
final DoubleDoubleBiPredicate test = createConvergenceTest(settings.relativeError);
if (settings.debug) {
ImageJUtils.log(" Fitting %d components", numComponents);
}
final Ticker ticker = ImageJUtils.createTicker(settings.repeats, 2, "Fitting...");
final AtomicInteger failures = new AtomicInteger();
for (int i = 0; i < settings.repeats; i++) {
final double[] vector = sampler.sample();
results.add(CompletableFuture.supplyAsync(() -> {
final MultivariateGaussianMixtureExpectationMaximization fitter = new MultivariateGaussianMixtureExpectationMaximization(data);
try {
// This may also throw the same exceptions due to inversion of the covariance matrix
final MixtureMultivariateGaussianDistribution initialMixture = MultivariateGaussianMixtureExpectationMaximization.estimate(data, numComponents, point -> {
double dot = 0;
for (int j = 0; j < dimensions; j++) {
dot += vector[j] * point[j];
}
return dot;
});
final boolean result = fitter.fit(initialMixture, settings.maxIterations, test);
// Log the result. Note: The ImageJ log is synchronized.
if (settings.debug) {
ImageJUtils.log(" Fit: log-likelihood=%s, iter=%d, converged=%b", fitter.getLogLikelihood(), fitter.getIterations(), result);
}
return result ? fitter : null;
} catch (NonPositiveDefiniteMatrixException | SingularMatrixException ex) {
failures.getAndIncrement();
if (settings.debug) {
ImageJUtils.log(" Fit failed during iteration %d. No variance in a sub-population " + "component (check alpha is not always 1.0).", fitter.getIterations());
}
} finally {
ticker.tick();
}
return null;
}));
}
ImageJUtils.finished();
if (failures.get() != 0 && settings.debug) {
ImageJUtils.log(" %d component fit failed %d/%d", numComponents, failures.get(), settings.repeats);
}
// Collect results and return the best model.
return results.stream().map(f -> f.join()).filter(f -> f != null).sorted((f1, f2) -> Double.compare(f2.getLogLikelihood(), f1.getLogLikelihood())).findFirst().orElse(null);
}
use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization in project GDSC-SMLM by aherbert.
the class MultivariateGaussianMixtureExpectationMaximizationTest method testExpectationMaximizationSpeed.
/**
* Test the speed of implementations of the expectation maximization algorithm with a mixture of n
* ND Gaussian distributions.
*
* @param seed the seed
*/
@SpeedTag
@SeededTest
void testExpectationMaximizationSpeed(RandomSeed seed) {
Assumptions.assumeTrue(TestSettings.allow(TestComplexity.HIGH));
final MultivariateGaussianMixtureExpectationMaximization.DoubleDoubleBiPredicate relChecker = TestHelper.doublesAreClose(1e-6)::test;
// Create data
final UniformRandomProvider rng = RngUtils.create(seed.getSeed());
for (int n = 2; n <= 3; n++) {
for (int dim = 2; dim <= 4; dim++) {
final double[][][] data = new double[10][][];
final int nCorrelations = dim - 1;
for (int i = 0; i < data.length; i++) {
final double[] sampleWeights = createWeights(n, rng);
final double[][] sampleMeans = create(n, dim, rng, -5, 5);
final double[][] sampleStdDevs = create(n, dim, rng, 1, 10);
final double[][] sampleCorrelations = IntStream.range(0, n).mapToObj(component -> create(nCorrelations, rng, -0.9, 0.9)).toArray(double[][]::new);
data[i] = createDataNd(1000, rng, sampleWeights, sampleMeans, sampleStdDevs, sampleCorrelations);
}
final int numComponents = n;
// Time initial estimation and fitting
final TimingService ts = new TimingService();
ts.execute(new FittingSpeedTask("Commons n=" + n + " " + dim + "D", data) {
@Override
Object run(double[][] data) {
final MultivariateNormalMixtureExpectationMaximization fitter = new MultivariateNormalMixtureExpectationMaximization(data);
fitter.fit(MultivariateNormalMixtureExpectationMaximization.estimate(data, numComponents));
return fitter.getLogLikelihood();
}
});
ts.execute(new FittingSpeedTask("GDSC n=" + n + " " + dim + "D", data) {
@Override
Object run(double[][] data) {
final MultivariateGaussianMixtureExpectationMaximization fitter = new MultivariateGaussianMixtureExpectationMaximization(data);
fitter.fit(MultivariateGaussianMixtureExpectationMaximization.estimate(data, numComponents));
return fitter.getLogLikelihood();
}
});
ts.execute(new FittingSpeedTask("GDSC rel 1e-6 n=" + n + " " + dim + "D", data) {
@Override
Object run(double[][] data) {
final MultivariateGaussianMixtureExpectationMaximization fitter = new MultivariateGaussianMixtureExpectationMaximization(data);
fitter.fit(MultivariateGaussianMixtureExpectationMaximization.estimate(data, numComponents), 1000, relChecker);
return fitter.getLogLikelihood();
}
});
if (logger.isLoggable(Level.INFO)) {
logger.info(ts.getReport());
}
// More than twice as fast
Assertions.assertTrue(ts.get(-2).getMean() < ts.get(-3).getMean() / 2);
}
}
}
use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization in project GDSC-SMLM by aherbert.
the class MultivariateGaussianMixtureExpectationMaximizationTest method testFitThrows.
@Test
void testFitThrows() {
// This does not matter for the initial checks
final MixtureMultivariateGaussianDistribution initialMixture = null;
final MultivariateGaussianMixtureExpectationMaximization fitter = new MultivariateGaussianMixtureExpectationMaximization(new double[2][3]);
// Test initial model is null and the likelihood is zero
Assertions.assertEquals(0, fitter.getLogLikelihood());
Assertions.assertEquals(0, fitter.getIterations());
Assertions.assertNull(fitter.getFittedModel());
// Valid parameters
final int maxIterations = 10;
// Not positive iterations
Assertions.assertThrows(IllegalArgumentException.class, () -> {
fitter.fit(initialMixture, 0, DEFAULT_CONVERGENCE_CHECKER);
});
// Not convergence checker
Assertions.assertThrows(NullPointerException.class, () -> {
fitter.fit(initialMixture, maxIterations, null);
});
// Null mixture
Assertions.assertThrows(NullPointerException.class, () -> {
fitter.fit(initialMixture, maxIterations, DEFAULT_CONVERGENCE_CHECKER);
});
// Incorrect data dimensions. Create a 50-50 mixture of 2D Gaussians
final MixtureMultivariateGaussianDistribution initialMixture2 = new MixtureMultivariateGaussianDistribution(new double[] { 0.5, 0.5 }, new MultivariateGaussianDistribution[] { new MultivariateGaussianDistribution(new double[2], new double[][] { { 1, 0 }, { 0, 2 } }), new MultivariateGaussianDistribution(new double[2], new double[][] { { 1, 0 }, { 0, 2 } }) });
Assertions.assertThrows(IllegalArgumentException.class, () -> {
fitter.fit(initialMixture2, maxIterations, DEFAULT_CONVERGENCE_CHECKER);
});
}
use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization in project GDSC-SMLM by aherbert.
the class TrackPopulationAnalysis method fitGaussianMixture.
/**
* Fit the Gaussian mixture to the data. The fitter with the highest likelihood from a number of
* repeats is returned.
*
* @param data the data
* @param sortDimension the sort dimension
* @return the multivariate gaussian mixture
*/
private MultivariateGaussianMixtureExpectationMaximization fitGaussianMixture(final double[][] data, int sortDimension) {
// Get the unmixed multivariate Guassian.
MultivariateGaussianDistribution unmixed = MultivariateGaussianMixtureExpectationMaximization.createUnmixed(data);
// Normalise the columns of the data
// Get means and SD of each column
final double[] means = unmixed.getMeans();
final double[] sd = unmixed.getStandardDeviations();
final int dimensions = means.length;
for (final double[] value : data) {
for (int i = 0; i < dimensions; i++) {
value[i] = (value[i] - means[i]) / sd[i];
}
}
// Repeat. The mean should be approximately 0 and std.dev. 1.
unmixed = MultivariateGaussianMixtureExpectationMaximization.createUnmixed(data);
// Record the likelihood of the unmixed model
double logLikelihood = Arrays.stream(data).mapToDouble(unmixed::density).map(Math::log).sum();
// x means, x*x covariances
final int parametersPerGaussian = dimensions + dimensions * dimensions;
double aic = MathUtils.getAkaikeInformationCriterion(logLikelihood, parametersPerGaussian);
double bic = MathUtils.getBayesianInformationCriterion(logLikelihood, data.length, parametersPerGaussian);
ImageJUtils.log("1 component log-likelihood=%s. AIC=%s. BIC=%s", logLikelihood, aic, bic);
// Fit a mixed component model.
// Increment the number of components up to a maximim or when the model does not improve.
MultivariateGaussianMixtureExpectationMaximization mixed = null;
for (int numComponents = 2; numComponents <= settings.maxComponents; numComponents++) {
final MultivariateGaussianMixtureExpectationMaximization mixed2 = createMixed(data, dimensions, numComponents);
if (mixed2 == null) {
ImageJUtils.log("Failed to fit a %d component mixture model", numComponents);
break;
}
final double logLikelihood2 = mixed2.getLogLikelihood();
// n * (means, covariances, 1 weight) - 1
// (Note: subtract 1 as the weights are constrained by summing to 1)
final int param2 = numComponents * (parametersPerGaussian + 1) - 1;
final double aic2 = MathUtils.getAkaikeInformationCriterion(logLikelihood2, param2);
final double bic2 = MathUtils.getBayesianInformationCriterion(logLikelihood2, data.length, param2);
// Log-likelihood ratio test statistic
final double lambdaLr = -2 * (logLikelihood - logLikelihood2);
// DF = difference in dimensionality from previous number of components
// means, covariances, 1 weight
final int degreesOfFreedom = parametersPerGaussian + 1;
final double q = ChiSquaredDistributionTable.computeQValue(lambdaLr, degreesOfFreedom);
ImageJUtils.log("%d component log-likelihood=%s. AIC=%s. BIC=%s. LLR significance=%s.", numComponents, logLikelihood2, aic2, bic2, MathUtils.rounded(q));
final double[] weights = mixed2.getFittedModel().getWeights();
// For consistency sort the mixture by the mean of the diffusion coefficient
final double[] values = Arrays.stream(mixed2.getFittedModel().getDistributions()).mapToDouble(d -> d.getMeans()[sortDimension]).toArray();
SortUtils.sortData(weights, values, false, false);
ImageJUtils.log("Population weights: " + Arrays.toString(weights));
if (MathUtils.min(weights) < settings.minWeight) {
ImageJUtils.log("%d component model has population weight %s under minimum level %s", numComponents, MathUtils.min(weights), settings.minWeight);
break;
}
if (aic <= aic2 || bic <= bic2 || q > 0.001) {
ImageJUtils.log("%d component model is not significant", numComponents);
break;
}
aic = aic2;
bic = bic2;
logLikelihood = logLikelihood2;
mixed = mixed2;
}
return mixed;
}
use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization in project GDSC-SMLM by aherbert.
the class TrackPopulationAnalysis method run.
@Override
public void run(String arg) {
SmlmUsageTracker.recordPlugin(this.getClass(), arg);
if (MemoryPeakResults.isMemoryEmpty()) {
IJ.error(TITLE, "No localisations in memory");
return;
}
settings = Settings.load();
// Saved by reference so just save now
settings.save();
// Read in multiple traced datasets
// All datasets must have the same pixel pitch and exposure time
// Get parameters
// Convert datasets to tracks
// For each track compute the 4 local track features using the configured window
//
// Optional:
// Fit a multi-variate Gaussian mixture model to the data
// (using the configured number of components/populations)
// Assign each point in the track using the model.
// Smooth the assignments.
//
// The alternative is to use the localisation category to assign populations.
//
// Plot histograms of each track parameter, coloured by component
final List<MemoryPeakResults> combinedResults = new LocalList<>();
if (!showInputDialog(combinedResults)) {
return;
}
final boolean hasCategory = showHasCategoryDialog(combinedResults);
if (!showDialog(hasCategory)) {
return;
}
ImageJUtils.log(TITLE + "...");
final List<Trace> tracks = getTracks(combinedResults, settings.window, settings.minTrackLength);
if (tracks.isEmpty()) {
IJ.error(TITLE, "No tracks. Please check the input data and min track length setting.");
return;
}
final Calibration cal = combinedResults.get(0).getCalibration();
final CalibrationReader cr = new CalibrationReader(cal);
// Use micrometer / second
final TypeConverter<DistanceUnit> distanceConverter = cr.getDistanceConverter(DistanceUnit.UM);
final double exposureTime = cr.getExposureTime() / 1000.0;
final Pair<int[], double[][]> trackData = extractTrackData(tracks, distanceConverter, exposureTime, hasCategory);
final double[][] data = trackData.getValue();
// Histogram the raw data.
final Array2DRowRealMatrix raw = new Array2DRowRealMatrix(data, false);
final WindowOrganiser wo = new WindowOrganiser();
// Store the histogram data for plotting the components
final double[][] columns = new double[FEATURE_NAMES.length][];
final double[][] limits = new double[FEATURE_NAMES.length][];
// Get column data
for (int i = 0; i < FEATURE_NAMES.length; i++) {
columns[i] = raw.getColumn(i);
if (i == FEATURE_D) {
// Plot using a logarithmic scale
SimpleArrayUtils.apply(columns[i], Math::log10);
}
limits[i] = MathUtils.limits(columns[i]);
}
// Compute histogram bins
final int[] bins = new int[FEATURE_NAMES.length];
if (settings.histogramBins > 0) {
Arrays.fill(bins, settings.histogramBins);
} else {
for (int i = 0; i < FEATURE_NAMES.length; i++) {
bins[i] = HistogramPlot.getBins(StoredData.create(columns[i]), BinMethod.FD);
}
// Use the maximum so all histograms look the same
Arrays.fill(bins, MathUtils.max(bins));
}
// Compute plots
final Plot[] plots = new Plot[FEATURE_NAMES.length];
for (int i = 0; i < FEATURE_NAMES.length; i++) {
final double[][] hist = HistogramPlot.calcHistogram(columns[i], limits[i][0], limits[i][1], bins[i]);
plots[i] = new Plot(TITLE + " " + FEATURE_NAMES[i], getFeatureLabel(i, i == FEATURE_D), "Frequency");
plots[i].addPoints(hist[0], hist[1], Plot.BAR);
ImageJUtils.display(plots[i].getTitle(), plots[i], ImageJUtils.NO_TO_FRONT, wo);
}
wo.tile();
// The component for each data point
int[] component;
// The number of components
int numComponents;
// Data used to fit the Gaussian mixture model
double[][] fitData;
// The fitted model
MixtureMultivariateGaussianDistribution model;
if (hasCategory) {
// Use the category as the component.
// No fit data and no output model
fitData = null;
model = null;
// The component is stored at the end of the raw track data.
final int end = data[0].length - 1;
component = Arrays.stream(data).mapToInt(d -> (int) d[end]).toArray();
numComponents = MathUtils.max(component) + 1;
// In the EM algorithm the probability of each data point is computed and normalised to
// sum to 1. The normalised probabilities are averaged to create the weights.
// Note the probability of each data point uses the previous weight and the algorithm
// iterates.
// This is not a fitted model but the input model so use
// zero weights to indicate no fitting was performed.
final double[] weights = new double[numComponents];
// Remove the trailing component to show the 'model' in a table.
createModelTable(Arrays.stream(data).map(d -> Arrays.copyOf(d, end)).toArray(double[][]::new), weights, component);
} else {
// Multivariate Gaussian mixture EM
// Provide option to not use the anomalous exponent in the population mix.
int sortDimension = SORT_DIMENSION;
if (settings.ignoreAlpha) {
// Remove index 0. This shifts the sort dimension.
sortDimension--;
fitData = Arrays.stream(data).map(d -> Arrays.copyOfRange(d, 1, d.length)).toArray(double[][]::new);
} else {
fitData = SimpleArrayUtils.deepCopy(data);
}
final MultivariateGaussianMixtureExpectationMaximization mixed = fitGaussianMixture(fitData, sortDimension);
if (mixed == null) {
IJ.error(TITLE, "Failed to fit a mixture model");
return;
}
model = sortComponents(mixed.getFittedModel(), sortDimension);
// For the best model, assign to the most likely population.
component = assignData(fitData, model);
// Table of the final model using the original data (i.e. not normalised)
final double[] weights = model.getWeights();
numComponents = weights.length;
createModelTable(data, weights, component);
}
// Output coloured histograms of the populations.
final LUT lut = LutHelper.createLut(settings.lutIndex);
IntFunction<Color> colourMap;
if (LutHelper.getColour(lut, 0).equals(Color.BLACK)) {
colourMap = i -> LutHelper.getNonZeroColour(lut, i, 0, numComponents - 1);
} else {
colourMap = i -> LutHelper.getColour(lut, i, 0, numComponents - 1);
}
for (int i = 0; i < FEATURE_NAMES.length; i++) {
// Extract the data for each component
final double[] col = columns[i];
final Plot plot = plots[i];
for (int n = 0; n < numComponents; n++) {
final StoredData feature = new StoredData();
for (int j = 0; j < component.length; j++) {
if (component[j] == n) {
feature.add(col[j]);
}
}
if (feature.size() == 0) {
continue;
}
final double[][] hist = HistogramPlot.calcHistogram(feature.values(), limits[i][0], limits[i][1], bins[i]);
// Colour the points
plot.setColor(colourMap.apply(n));
plot.addPoints(hist[0], hist[1], Plot.BAR);
}
plot.updateImage();
}
createTrackDataTable(tracks, trackData, fitData, model, component, cal, colourMap);
// Analysis.
// Assign the original localisations to their track component.
// Q. What about the start/end not covered by the window?
// Save tracks as a dataset labelled with the sub-track ID?
// Output for the bound component and free components track parameters.
// Compute dwell times.
// Other ...
// Track analysis plugin:
// Extract all continuous segments of the same component.
// Produce MSD plot with error bars.
// Fit using FBM model.
}
Aggregations