use of org.apache.commons.math3.stat.correlation.Covariance in project GDSC-SMLM by aherbert.
the class TrackPopulationAnalysis method createMixed.
/**
* Creates the multivariate gaussian mixture as the best of many repeats of the expectation
* maximisation algorithm.
*
* @param data the data
* @param dimensions the dimensions
* @param numComponents the number of components
* @return the multivariate gaussian mixture expectation maximization
*/
private MultivariateGaussianMixtureExpectationMaximization createMixed(final double[][] data, int dimensions, int numComponents) {
// Fit a mixed multivariate Gaussian with different repeats.
final UnitSphereSampler sampler = UnitSphereSampler.of(UniformRandomProviders.create(Mixers.stafford13(settings.seed++)), dimensions);
final LocalList<CompletableFuture<MultivariateGaussianMixtureExpectationMaximization>> results = new LocalList<>(settings.repeats);
final DoubleDoubleBiPredicate test = createConvergenceTest(settings.relativeError);
if (settings.debug) {
ImageJUtils.log(" Fitting %d components", numComponents);
}
final Ticker ticker = ImageJUtils.createTicker(settings.repeats, 2, "Fitting...");
final AtomicInteger failures = new AtomicInteger();
for (int i = 0; i < settings.repeats; i++) {
final double[] vector = sampler.sample();
results.add(CompletableFuture.supplyAsync(() -> {
final MultivariateGaussianMixtureExpectationMaximization fitter = new MultivariateGaussianMixtureExpectationMaximization(data);
try {
// This may also throw the same exceptions due to inversion of the covariance matrix
final MixtureMultivariateGaussianDistribution initialMixture = MultivariateGaussianMixtureExpectationMaximization.estimate(data, numComponents, point -> {
double dot = 0;
for (int j = 0; j < dimensions; j++) {
dot += vector[j] * point[j];
}
return dot;
});
final boolean result = fitter.fit(initialMixture, settings.maxIterations, test);
// Log the result. Note: The ImageJ log is synchronized.
if (settings.debug) {
ImageJUtils.log(" Fit: log-likelihood=%s, iter=%d, converged=%b", fitter.getLogLikelihood(), fitter.getIterations(), result);
}
return result ? fitter : null;
} catch (NonPositiveDefiniteMatrixException | SingularMatrixException ex) {
failures.getAndIncrement();
if (settings.debug) {
ImageJUtils.log(" Fit failed during iteration %d. No variance in a sub-population " + "component (check alpha is not always 1.0).", fitter.getIterations());
}
} finally {
ticker.tick();
}
return null;
}));
}
ImageJUtils.finished();
if (failures.get() != 0 && settings.debug) {
ImageJUtils.log(" %d component fit failed %d/%d", numComponents, failures.get(), settings.repeats);
}
// Collect results and return the best model.
return results.stream().map(f -> f.join()).filter(f -> f != null).sorted((f1, f2) -> Double.compare(f2.getLogLikelihood(), f1.getLogLikelihood())).findFirst().orElse(null);
}
use of org.apache.commons.math3.stat.correlation.Covariance in project GDSC-SMLM by aherbert.
the class MultivariateGaussianMixtureExpectationMaximizationTest method getColumnMeans.
/**
* Gets the column means. This is done using the same method as the means in the Apache Commons
* Math Covariance class.
*
* @param data the data
* @return the column means
*/
private static double[] getColumnMeans(double[][] data) {
final Array2DRowRealMatrix m = new Array2DRowRealMatrix(data);
final Mean mean = new Mean();
return IntStream.range(0, data[0].length).mapToDouble(i -> mean.evaluate(m.getColumn(i))).toArray();
}
use of org.apache.commons.math3.stat.correlation.Covariance in project GDSC-SMLM by aherbert.
the class MultivariateGaussianMixtureExpectationMaximizationTest method createData2d.
/**
* Creates the data from a mixture of n 2D Gaussian distributions. The length of the weights array
* (and all other arrays) is the number of mixture components.
*
* @param sampleSize the sample size
* @param rng the random generator
* @param weights the weights for each component
* @param means the means for the x and y dimensions
* @param stdDevs the std devs for the x and y dimensions
* @param correlations the correlations between the x and y dimensions
* @return the double[][]
*/
private static double[][] createData2d(int sampleSize, UniformRandomProvider rng, double[] weights, double[][] means, double[][] stdDevs, double[] correlations) {
// Use Commons Math for sampling
final ArrayList<Pair<Double, MultivariateNormalDistribution>> components = new ArrayList<>();
for (int i = 0; i < weights.length; i++) {
// Create covariance matrix
final double sx = stdDevs[i][0];
final double sy = stdDevs[i][1];
final double sxsy = correlations[i] * sx * sy;
final double[][] covar = new double[][] { { sx * sx, sxsy }, { sxsy, sy * sy } };
components.add(new Pair<>(weights[i], new MultivariateNormalDistribution(means[i], covar)));
}
final MixtureMultivariateNormalDistribution dist = new MixtureMultivariateNormalDistribution(new RandomGeneratorAdapter(rng), components);
return dist.sample(sampleSize);
}
use of org.apache.commons.math3.stat.correlation.Covariance in project IR_Base by Linda-sunshine.
the class TUIR method update_SigmaP.
// variational inference for p(P|\nu,\Sigma) for each user
private void update_SigmaP(_User user) {
_User4ETBIR u = (_User4ETBIR) user;
int idx = m_usersIndex.get(u.getUserID());
ArrayList<Integer> Iu = new ArrayList<>();
if (m_mapByUser.containsKey(idx))
// all the items reviewed by this user
Iu = m_mapByUser.get(idx);
RealMatrix eta_stat_sigma = MatrixUtils.createRealIdentityMatrix(number_of_topics).scalarMultiply(m_sigma);
for (Integer itemIdx : Iu) {
_Product4ETBIR item = (_Product4ETBIR) m_items.get(itemIdx);
RealMatrix eta_vec = MatrixUtils.createColumnRealMatrix(item.m_eta);
double eta0 = Utils.sumOfArray(item.m_eta);
RealMatrix eta_stat_i = MatrixUtils.createRealDiagonalMatrix(item.m_eta).add(eta_vec.multiply(eta_vec.transpose()));
eta_stat_sigma = eta_stat_sigma.add(eta_stat_i.scalarMultiply(m_rho / (eta0 * (eta0 + 1.0))));
}
eta_stat_sigma = new LUDecomposition(eta_stat_sigma).getSolver().getInverse();
for (int k = 0; k < number_of_topics; k++) // all topics share the same covariance
u.m_SigmaP[k] = eta_stat_sigma.getData();
}
use of org.apache.commons.math3.stat.correlation.Covariance in project lucene-solr by apache.
the class CovarianceEvaluator method evaluate.
public Number evaluate(Tuple tuple) throws IOException {
StreamEvaluator colEval1 = subEvaluators.get(0);
StreamEvaluator colEval2 = subEvaluators.get(1);
List<Number> numbers1 = (List<Number>) colEval1.evaluate(tuple);
List<Number> numbers2 = (List<Number>) colEval2.evaluate(tuple);
double[] column1 = new double[numbers1.size()];
double[] column2 = new double[numbers2.size()];
for (int i = 0; i < numbers1.size(); i++) {
column1[i] = numbers1.get(i).doubleValue();
}
for (int i = 0; i < numbers2.size(); i++) {
column2[i] = numbers2.get(i).doubleValue();
}
Covariance covariance = new Covariance();
return covariance.covariance(column1, column2);
}
Aggregations