use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution in project GDSC-SMLM by aherbert.
the class TrackPopulationAnalysis method run.
@Override
public void run(String arg) {
SmlmUsageTracker.recordPlugin(this.getClass(), arg);
if (MemoryPeakResults.isMemoryEmpty()) {
IJ.error(TITLE, "No localisations in memory");
return;
}
settings = Settings.load();
// Saved by reference so just save now
settings.save();
// Read in multiple traced datasets
// All datasets must have the same pixel pitch and exposure time
// Get parameters
// Convert datasets to tracks
// For each track compute the 4 local track features using the configured window
//
// Optional:
// Fit a multi-variate Gaussian mixture model to the data
// (using the configured number of components/populations)
// Assign each point in the track using the model.
// Smooth the assignments.
//
// The alternative is to use the localisation category to assign populations.
//
// Plot histograms of each track parameter, coloured by component
final List<MemoryPeakResults> combinedResults = new LocalList<>();
if (!showInputDialog(combinedResults)) {
return;
}
final boolean hasCategory = showHasCategoryDialog(combinedResults);
if (!showDialog(hasCategory)) {
return;
}
ImageJUtils.log(TITLE + "...");
final List<Trace> tracks = getTracks(combinedResults, settings.window, settings.minTrackLength);
if (tracks.isEmpty()) {
IJ.error(TITLE, "No tracks. Please check the input data and min track length setting.");
return;
}
final Calibration cal = combinedResults.get(0).getCalibration();
final CalibrationReader cr = new CalibrationReader(cal);
// Use micrometer / second
final TypeConverter<DistanceUnit> distanceConverter = cr.getDistanceConverter(DistanceUnit.UM);
final double exposureTime = cr.getExposureTime() / 1000.0;
final Pair<int[], double[][]> trackData = extractTrackData(tracks, distanceConverter, exposureTime, hasCategory);
final double[][] data = trackData.getValue();
// Histogram the raw data.
final Array2DRowRealMatrix raw = new Array2DRowRealMatrix(data, false);
final WindowOrganiser wo = new WindowOrganiser();
// Store the histogram data for plotting the components
final double[][] columns = new double[FEATURE_NAMES.length][];
final double[][] limits = new double[FEATURE_NAMES.length][];
// Get column data
for (int i = 0; i < FEATURE_NAMES.length; i++) {
columns[i] = raw.getColumn(i);
if (i == FEATURE_D) {
// Plot using a logarithmic scale
SimpleArrayUtils.apply(columns[i], Math::log10);
}
limits[i] = MathUtils.limits(columns[i]);
}
// Compute histogram bins
final int[] bins = new int[FEATURE_NAMES.length];
if (settings.histogramBins > 0) {
Arrays.fill(bins, settings.histogramBins);
} else {
for (int i = 0; i < FEATURE_NAMES.length; i++) {
bins[i] = HistogramPlot.getBins(StoredData.create(columns[i]), BinMethod.FD);
}
// Use the maximum so all histograms look the same
Arrays.fill(bins, MathUtils.max(bins));
}
// Compute plots
final Plot[] plots = new Plot[FEATURE_NAMES.length];
for (int i = 0; i < FEATURE_NAMES.length; i++) {
final double[][] hist = HistogramPlot.calcHistogram(columns[i], limits[i][0], limits[i][1], bins[i]);
plots[i] = new Plot(TITLE + " " + FEATURE_NAMES[i], getFeatureLabel(i, i == FEATURE_D), "Frequency");
plots[i].addPoints(hist[0], hist[1], Plot.BAR);
ImageJUtils.display(plots[i].getTitle(), plots[i], ImageJUtils.NO_TO_FRONT, wo);
}
wo.tile();
// The component for each data point
int[] component;
// The number of components
int numComponents;
// Data used to fit the Gaussian mixture model
double[][] fitData;
// The fitted model
MixtureMultivariateGaussianDistribution model;
if (hasCategory) {
// Use the category as the component.
// No fit data and no output model
fitData = null;
model = null;
// The component is stored at the end of the raw track data.
final int end = data[0].length - 1;
component = Arrays.stream(data).mapToInt(d -> (int) d[end]).toArray();
numComponents = MathUtils.max(component) + 1;
// In the EM algorithm the probability of each data point is computed and normalised to
// sum to 1. The normalised probabilities are averaged to create the weights.
// Note the probability of each data point uses the previous weight and the algorithm
// iterates.
// This is not a fitted model but the input model so use
// zero weights to indicate no fitting was performed.
final double[] weights = new double[numComponents];
// Remove the trailing component to show the 'model' in a table.
createModelTable(Arrays.stream(data).map(d -> Arrays.copyOf(d, end)).toArray(double[][]::new), weights, component);
} else {
// Multivariate Gaussian mixture EM
// Provide option to not use the anomalous exponent in the population mix.
int sortDimension = SORT_DIMENSION;
if (settings.ignoreAlpha) {
// Remove index 0. This shifts the sort dimension.
sortDimension--;
fitData = Arrays.stream(data).map(d -> Arrays.copyOfRange(d, 1, d.length)).toArray(double[][]::new);
} else {
fitData = SimpleArrayUtils.deepCopy(data);
}
final MultivariateGaussianMixtureExpectationMaximization mixed = fitGaussianMixture(fitData, sortDimension);
if (mixed == null) {
IJ.error(TITLE, "Failed to fit a mixture model");
return;
}
model = sortComponents(mixed.getFittedModel(), sortDimension);
// For the best model, assign to the most likely population.
component = assignData(fitData, model);
// Table of the final model using the original data (i.e. not normalised)
final double[] weights = model.getWeights();
numComponents = weights.length;
createModelTable(data, weights, component);
}
// Output coloured histograms of the populations.
final LUT lut = LutHelper.createLut(settings.lutIndex);
IntFunction<Color> colourMap;
if (LutHelper.getColour(lut, 0).equals(Color.BLACK)) {
colourMap = i -> LutHelper.getNonZeroColour(lut, i, 0, numComponents - 1);
} else {
colourMap = i -> LutHelper.getColour(lut, i, 0, numComponents - 1);
}
for (int i = 0; i < FEATURE_NAMES.length; i++) {
// Extract the data for each component
final double[] col = columns[i];
final Plot plot = plots[i];
for (int n = 0; n < numComponents; n++) {
final StoredData feature = new StoredData();
for (int j = 0; j < component.length; j++) {
if (component[j] == n) {
feature.add(col[j]);
}
}
if (feature.size() == 0) {
continue;
}
final double[][] hist = HistogramPlot.calcHistogram(feature.values(), limits[i][0], limits[i][1], bins[i]);
// Colour the points
plot.setColor(colourMap.apply(n));
plot.addPoints(hist[0], hist[1], Plot.BAR);
}
plot.updateImage();
}
createTrackDataTable(tracks, trackData, fitData, model, component, cal, colourMap);
// Analysis.
// Assign the original localisations to their track component.
// Q. What about the start/end not covered by the window?
// Save tracks as a dataset labelled with the sub-track ID?
// Output for the bound component and free components track parameters.
// Compute dwell times.
// Other ...
// Track analysis plugin:
// Extract all continuous segments of the same component.
// Produce MSD plot with error bars.
// Fit using FBM model.
}
use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution in project GDSC-SMLM by aherbert.
the class MultivariateGaussianMixtureExpectationMaximization method fit.
/**
* Fit a mixture model to the data supplied to the constructor.
*
* <p>The quality of the fit depends on the concavity of the data provided to the constructor and
* the initial mixture provided to this function. If the data has many local optima, multiple runs
* of the fitting function with different initial mixtures may be required to find the optimal
* solution. If a SingularMatrixException is encountered, it is possible that another
* initialisation would work.
*
* @param initialMixture model containing initial values of weights and multivariate normals
* @param maxIterations maximum iterations allowed for fit
* @param convergencePredicate convergence predicated used to test the logLikelihoods between
* successive iterations
* @return true if converged within the threshold
* @throws IllegalArgumentException if maxIterations is less than one or if initialMixture mean
* vector and data number of columns are not equal
* @throws SingularMatrixException if any component's covariance matrix is singular during
* fitting.
* @throws NonPositiveDefiniteMatrixException if any component's covariance matrix is not positive
* definite during fitting.
*/
public boolean fit(MixtureMultivariateGaussianDistribution initialMixture, int maxIterations, DoubleDoubleBiPredicate convergencePredicate) {
ValidationUtils.checkStrictlyPositive(maxIterations, "maxIterations");
ValidationUtils.checkNotNull(convergencePredicate, "convergencePredicate");
ValidationUtils.checkNotNull(initialMixture, "initialMixture");
final int n = data.length;
final int k = initialMixture.weights.length;
// Number of data columns.
final int numCols = data[0].length;
final int numMeanColumns = initialMixture.distributions[0].means.length;
ValidationUtils.checkArgument(numCols == numMeanColumns, "Mixture model dimension mismatch with data columns: %d != %d", numCols, numMeanColumns);
logLikelihood = -Double.MAX_VALUE;
iterations = 0;
// Initialize model to fit to initial mixture.
fittedModel = initialMixture;
while (iterations++ <= maxIterations) {
final double previousLogLikelihood = logLikelihood;
double sumLogLikelihood = 0;
// Weight and distribution of each component
final double[] weights = fittedModel.weights;
final MultivariateGaussianDistribution[] mvns = fittedModel.distributions;
// E-step: compute the data dependent parameters of the expectation function.
// The percentage of row's total density between a row and a component
final double[][] gamma = new double[n][k];
// Sum of gamma for each component
final double[] gammaSums = new double[k];
// Sum of gamma times its row for each each component
final double[][] gammaDataProdSums = new double[k][numCols];
// Cache for the weight multiplied by the distribution density
final double[] mvnDensity = new double[k];
for (int i = 0; i < n; i++) {
final double[] point = data[i];
// Compute densities for each component and the row density
double rowDensity = 0;
for (int j = 0; j < k; j++) {
final double d = weights[j] * mvns[j].density(point);
mvnDensity[j] = d;
rowDensity += d;
}
sumLogLikelihood += Math.log(rowDensity);
for (int j = 0; j < k; j++) {
gamma[i][j] = mvnDensity[j] / rowDensity;
gammaSums[j] += gamma[i][j];
for (int col = 0; col < numCols; col++) {
gammaDataProdSums[j][col] += gamma[i][j] * point[col];
}
}
}
logLikelihood = sumLogLikelihood;
// M-step: compute the new parameters based on the expectation function.
final double[] newWeights = new double[k];
final double[][] newMeans = new double[k][numCols];
for (int j = 0; j < k; j++) {
newWeights[j] = gammaSums[j] / n;
for (int col = 0; col < numCols; col++) {
newMeans[j][col] = gammaDataProdSums[j][col] / gammaSums[j];
}
}
// Compute new covariance matrices.
// These are symmetric so we compute the triangular half.
final double[][][] newCovMats = new double[k][numCols][numCols];
final double[] vec = new double[numCols];
for (int i = 0; i < n; i++) {
final double[] point = data[i];
for (int j = 0; j < k; j++) {
subtract(point, newMeans[j], vec);
final double g = gamma[i][j];
// covariance = vec * vecT
// covariance[ii][jj] = vec[ii] * vec[jj] * gamma[i][j]
final double[][] covar = newCovMats[j];
for (int ii = 0; ii < numCols; ii++) {
// pre-compute
final double vig = vec[ii] * g;
final double[] covari = covar[ii];
for (int jj = 0; jj <= ii; jj++) {
covari[jj] += vig * vec[jj];
}
}
}
}
// Converting to arrays for use by fitted model
final MultivariateGaussianDistribution[] distributions = new MultivariateGaussianDistribution[k];
for (int j = 0; j < k; j++) {
// Make symmetric and normalise by gamma sum
final double norm = 1.0 / gammaSums[j];
final double[][] covar = newCovMats[j];
for (int ii = 0; ii < numCols; ii++) {
// diagonal
covar[ii][ii] *= norm;
// elements
for (int jj = 0; jj < ii; jj++) {
final double tmp = covar[ii][jj] * norm;
covar[ii][jj] = tmp;
covar[jj][ii] = tmp;
}
}
distributions[j] = new MultivariateGaussianDistribution(newMeans[j], covar);
}
// Update current model
fittedModel = MixtureMultivariateGaussianDistribution.create(newWeights, distributions);
// Check convergence
if (convergencePredicate.test(previousLogLikelihood, logLikelihood)) {
return true;
}
}
// No convergence
return false;
}
use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution in project GDSC-SMLM by aherbert.
the class MultivariateGaussianMixtureExpectationMaximization method estimate.
/**
* Helper method to create a multivariate Gaussian mixture model which can be used to initialize
* {@link #fit(MixtureMultivariateGaussianDistribution)}.
*
* <p>The function is used to extract a value for ranking all points. These are then split
* uniformly into the specified number of components. This is equivalent to projection of the
* points onto a line and partitioning the points along the line uniformly; this cuts the points
* into sets using hyper-planes defined by the vector of the line. A good ranking metric is a dot
* product with a random unit vector uniformly sampled from the surface of an n-dimension
* hypersphere.
*
* <pre>
* double[] vec = ...;
* ToDoubleFunction<double[]> rankingMetric =
* values -> {
* double d = 0;
* for (int i = 0; i < vec.length; i++) {
* d += vec[i] * values[i];
* }
* return d;
* };
* </pre>
*
* <p>This method can be used with the data supplied to the instance constructor to try to
* determine a good mixture model at which to start the fit, but it is not guaranteed to supply a
* model which will find the optimal solution or even converge.
*
* @param data data to estimate distribution
* @param numComponents number of components for estimated mixture
* @param rankingMetric the function to generate the ranking metric
* @return Multivariate Gaussian mixture model estimated from the data
* @throws IllegalArgumentException if data has less than 2 rows, if {@code numComponents < 2}, or
* if {@code numComponents} is greater than the number of data rows.
*/
public static MixtureMultivariateGaussianDistribution estimate(double[][] data, int numComponents, ToDoubleFunction<double[]> rankingMetric) {
ValidationUtils.checkArgument(data.length >= 2, "Estimation requires at least 2 data points: %d", data.length);
ValidationUtils.checkArgument(numComponents >= 2, "Multivariate Gaussian mixture requires at least 2 components: %d", numComponents);
ValidationUtils.checkArgument(numComponents <= data.length, "Number of components %d greater than data length %d", numComponents, data.length);
ValidationUtils.checkNotNull(rankingMetric, "rankingMetric");
final int numRows = data.length;
final int numCols = data[0].length;
ValidationUtils.checkArgument(numCols >= 2, "Multivariate Gaussian requires at least 2 data columns: %d", numCols);
// Sort the data
final ProjectedData[] sortedData = new ProjectedData[numRows];
for (int i = 0; i < numRows; i++) {
sortedData[i] = new ProjectedData(data[i], rankingMetric.applyAsDouble(data[i]));
}
Arrays.sort(sortedData, (o1, o2) -> Double.compare(o1.value, o2.value));
// components of mixture model to be created
final MultivariateGaussianDistribution[] distributions = new MultivariateGaussianDistribution[numComponents];
// create a component based on data in each bin
for (int binIndex = 0; binIndex < numComponents; binIndex++) {
// minimum index (inclusive) from sorted data for this bin
final int minIndex = (binIndex * numRows) / numComponents;
// maximum index (exclusive) from sorted data for this bin
final int maxIndex = ((binIndex + 1) * numRows) / numComponents;
// number of data records that will be in this bin
final int numBinRows = maxIndex - minIndex;
// data for this bin
final double[][] binData = new double[numBinRows][];
// mean of each column for the data in the this bin
final double[] columnMeans = new double[numCols];
// populate bin and create component
for (int i = minIndex, iBin = 0; i < maxIndex; i++, iBin++) {
final double[] values = sortedData[i].data;
binData[iBin] = values;
for (int j = 0; j < numCols; j++) {
columnMeans[j] += values[j];
}
}
SimpleArrayUtils.multiply(columnMeans, 1.0 / numBinRows);
distributions[binIndex] = new MultivariateGaussianDistribution(columnMeans, covariance(columnMeans, binData));
}
// uniform weight for each bin
final double[] weights = SimpleArrayUtils.newDoubleArray(numComponents, 1.0 / numComponents);
return new MixtureMultivariateGaussianDistribution(weights, distributions);
}
use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution in project GDSC-SMLM by aherbert.
the class MultivariateGaussianMixtureExpectationMaximization method createMixed.
/**
* Helper method to create a multivariate Gaussian mixture model which can be used to initialize
* {@link #fit(MixtureMultivariateGaussianDistribution)}.
*
* <p>This method can be used with the data supplied to the instance constructor to try to
* determine a good mixture model at which to start the fit, but it is not guaranteed to supply a
* model which will find the optimal solution or even converge.
*
* <p>The weights for each component will be uniform.
*
* @param data data for the distribution
* @param component the component for each data point
* @return Multivariate Gaussian mixture model estimated from the data
* @throws IllegalArgumentException if data has less than 2 rows, if there is a size mismatch
* between the data and components length, or if the number of components is less than 2.
*/
public static MixtureMultivariateGaussianDistribution createMixed(double[][] data, int[] component) {
ValidationUtils.checkArgument(data.length >= 2, "Estimation requires at least 2 data points: %d", data.length);
ValidationUtils.checkArgument(data.length == component.length, "Data and component size mismatch: %d != %d", data.length, component.length);
final int numRows = data.length;
final int numCols = data[0].length;
ValidationUtils.checkArgument(numCols >= 2, "Multivariate Gaussian requires at least 2 data columns: %d", numCols);
// Sort the data
final ClassifiedData[] sortedData = new ClassifiedData[data.length];
for (int i = 0; i < numRows; i++) {
sortedData[i] = new ClassifiedData(data[i], component[i]);
}
Arrays.sort(sortedData, (o1, o2) -> Integer.compare(o1.value, o2.value));
ValidationUtils.checkArgument(sortedData[0].value != sortedData[sortedData.length - 1].value, "Mixture model requires at least 2 data components");
// components of mixture model to be created
final LocalList<MultivariateGaussianDistribution> distributions = new LocalList<>();
int from = 0;
while (from < sortedData.length) {
// Find the end
int to = from + 1;
final int comp = sortedData[from].value;
while (to < sortedData.length && sortedData[to].value == comp) {
to++;
}
// number of data records that will be in this component
final int numCompRows = to - from;
// data for this component
final double[][] compData = new double[numCompRows][];
// mean of each column for the data
final double[] columnMeans = new double[numCols];
// populate and create component
int count = 0;
for (int i = from; i < to; i++) {
final double[] values = sortedData[i].data;
compData[count++] = values;
for (int j = 0; j < numCols; j++) {
columnMeans[j] += values[j];
}
}
SimpleArrayUtils.multiply(columnMeans, 1.0 / numCompRows);
distributions.add(new MultivariateGaussianDistribution(columnMeans, covariance(columnMeans, compData)));
from = to;
}
// uniform weight for each bin
final int numComponents = distributions.size();
final double[] weights = SimpleArrayUtils.newDoubleArray(numComponents, 1.0 / numComponents);
return new MixtureMultivariateGaussianDistribution(weights, distributions.toArray(new MultivariateGaussianDistribution[0]));
}
use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution in project GDSC-SMLM by aherbert.
the class MultivariateGaussianMixtureExpectationMaximizationTest method canEstimateInitialMixture.
@SeededTest
void canEstimateInitialMixture(RandomSeed seed) {
// Test verses the Commons Math estimation
final UniformRandomProvider rng = RngUtils.create(seed.getSeed());
final DoubleDoubleBiPredicate test = TestHelper.doublesAreClose(1e-5, 1e-16);
// Number of components
for (int n = 2; n <= 3; n++) {
final double[] sampleWeights = createWeights(n, rng);
final double[][] sampleMeans = create(n, 2, rng, -5, 5);
final double[][] sampleStdDevs = create(n, 2, rng, 1, 10);
final double[] sampleCorrelations = create(n, rng, -0.9, 0.9);
final double[][] data = createData2d(1000, rng, sampleWeights, sampleMeans, sampleStdDevs, sampleCorrelations);
final MixtureMultivariateGaussianDistribution model1 = MultivariateGaussianMixtureExpectationMaximization.estimate(data, n);
final MixtureMultivariateNormalDistribution model2 = MultivariateNormalMixtureExpectationMaximization.estimate(data, n);
final List<Pair<Double, MultivariateNormalDistribution>> comp = model2.getComponents();
final double[] weights = model1.getWeights();
final MultivariateGaussianDistribution[] distributions = model1.getDistributions();
Assertions.assertEquals(n, comp.size());
Assertions.assertEquals(n, weights.length);
Assertions.assertEquals(n, distributions.length);
for (int i = 0; i < n; i++) {
// Must be binary equal for estimated model
Assertions.assertEquals(comp.get(i).getFirst(), weights[i], "weight");
final MultivariateNormalDistribution d = comp.get(i).getSecond();
TestAssertions.assertArrayTest(d.getMeans(), distributions[i].getMeans(), test, "means");
TestAssertions.assertArrayTest(d.getCovariances().getData(), distributions[i].getCovariances(), test, "covariances");
}
}
}
Aggregations