use of org.apache.commons.rng.sampling.UnitSphereSampler in project GDSC-SMLM by aherbert.
the class TrackPopulationAnalysis method createMixed.
/**
* Creates the multivariate gaussian mixture as the best of many repeats of the expectation
* maximisation algorithm.
*
* @param data the data
* @param dimensions the dimensions
* @param numComponents the number of components
* @return the multivariate gaussian mixture expectation maximization
*/
private MultivariateGaussianMixtureExpectationMaximization createMixed(final double[][] data, int dimensions, int numComponents) {
// Fit a mixed multivariate Gaussian with different repeats.
final UnitSphereSampler sampler = UnitSphereSampler.of(UniformRandomProviders.create(Mixers.stafford13(settings.seed++)), dimensions);
final LocalList<CompletableFuture<MultivariateGaussianMixtureExpectationMaximization>> results = new LocalList<>(settings.repeats);
final DoubleDoubleBiPredicate test = createConvergenceTest(settings.relativeError);
if (settings.debug) {
ImageJUtils.log(" Fitting %d components", numComponents);
}
final Ticker ticker = ImageJUtils.createTicker(settings.repeats, 2, "Fitting...");
final AtomicInteger failures = new AtomicInteger();
for (int i = 0; i < settings.repeats; i++) {
final double[] vector = sampler.sample();
results.add(CompletableFuture.supplyAsync(() -> {
final MultivariateGaussianMixtureExpectationMaximization fitter = new MultivariateGaussianMixtureExpectationMaximization(data);
try {
// This may also throw the same exceptions due to inversion of the covariance matrix
final MixtureMultivariateGaussianDistribution initialMixture = MultivariateGaussianMixtureExpectationMaximization.estimate(data, numComponents, point -> {
double dot = 0;
for (int j = 0; j < dimensions; j++) {
dot += vector[j] * point[j];
}
return dot;
});
final boolean result = fitter.fit(initialMixture, settings.maxIterations, test);
// Log the result. Note: The ImageJ log is synchronized.
if (settings.debug) {
ImageJUtils.log(" Fit: log-likelihood=%s, iter=%d, converged=%b", fitter.getLogLikelihood(), fitter.getIterations(), result);
}
return result ? fitter : null;
} catch (NonPositiveDefiniteMatrixException | SingularMatrixException ex) {
failures.getAndIncrement();
if (settings.debug) {
ImageJUtils.log(" Fit failed during iteration %d. No variance in a sub-population " + "component (check alpha is not always 1.0).", fitter.getIterations());
}
} finally {
ticker.tick();
}
return null;
}));
}
ImageJUtils.finished();
if (failures.get() != 0 && settings.debug) {
ImageJUtils.log(" %d component fit failed %d/%d", numComponents, failures.get(), settings.repeats);
}
// Collect results and return the best model.
return results.stream().map(f -> f.join()).filter(f -> f != null).sorted((f1, f2) -> Double.compare(f2.getLogLikelihood(), f1.getLogLikelihood())).findFirst().orElse(null);
}
use of org.apache.commons.rng.sampling.UnitSphereSampler in project GDSC-SMLM by aherbert.
the class AiryPsfModel method sample.
/**
* Sample from an Airy distribution.
*
* @param n The number of samples
* @param x0 The centre in dimension 0
* @param x1 The centre in dimension 1
* @param w0 The Airy width for dimension 0
* @param w1 The Airy width for dimension 1
* @param rng The random generator to use for sampling
* @return The sample x and y values
*/
public double[][] sample(final int n, final double x0, final double x1, final double w0, final double w1, UniformRandomProvider rng) {
this.w0 = w0;
this.w1 = w1;
if (spline == null) {
createAiryDistribution();
}
double[] x = new double[n];
double[] y = new double[n];
final UnitSphereSampler vg = UnitSphereSampler.of(rng, 2);
int count = 0;
for (int i = 0; i < n; i++) {
final double p = rng.nextDouble();
if (p > POWER[SAMPLE_RINGS]) {
// TODO - We could add a simple interpolation here using a spline from AiryPattern.power()
continue;
}
final double radius = spline.value(p);
// Convert to xy using a random vector generator
final double[] v = vg.sample();
x[count] = v[0] * radius * w0 + x0;
y[count] = v[1] * radius * w1 + x1;
count++;
}
if (count < n) {
x = Arrays.copyOf(x, count);
y = Arrays.copyOf(y, count);
}
return new double[][] { x, y };
}
Aggregations