Search in sources :

Example 21 with Covariance

use of org.apache.commons.math3.stat.correlation.Covariance in project GDSC-SMLM by aherbert.

the class TrackPopulationAnalysis method createMixed.

/**
 * Creates the multivariate gaussian mixture as the best of many repeats of the expectation
 * maximisation algorithm.
 *
 * @param data the data
 * @param dimensions the dimensions
 * @param numComponents the number of components
 * @return the multivariate gaussian mixture expectation maximization
 */
private MultivariateGaussianMixtureExpectationMaximization createMixed(final double[][] data, int dimensions, int numComponents) {
    // Fit a mixed multivariate Gaussian with different repeats.
    final UnitSphereSampler sampler = UnitSphereSampler.of(UniformRandomProviders.create(Mixers.stafford13(settings.seed++)), dimensions);
    final LocalList<CompletableFuture<MultivariateGaussianMixtureExpectationMaximization>> results = new LocalList<>(settings.repeats);
    final DoubleDoubleBiPredicate test = createConvergenceTest(settings.relativeError);
    if (settings.debug) {
        ImageJUtils.log("  Fitting %d components", numComponents);
    }
    final Ticker ticker = ImageJUtils.createTicker(settings.repeats, 2, "Fitting...");
    final AtomicInteger failures = new AtomicInteger();
    for (int i = 0; i < settings.repeats; i++) {
        final double[] vector = sampler.sample();
        results.add(CompletableFuture.supplyAsync(() -> {
            final MultivariateGaussianMixtureExpectationMaximization fitter = new MultivariateGaussianMixtureExpectationMaximization(data);
            try {
                // This may also throw the same exceptions due to inversion of the covariance matrix
                final MixtureMultivariateGaussianDistribution initialMixture = MultivariateGaussianMixtureExpectationMaximization.estimate(data, numComponents, point -> {
                    double dot = 0;
                    for (int j = 0; j < dimensions; j++) {
                        dot += vector[j] * point[j];
                    }
                    return dot;
                });
                final boolean result = fitter.fit(initialMixture, settings.maxIterations, test);
                // Log the result. Note: The ImageJ log is synchronized.
                if (settings.debug) {
                    ImageJUtils.log("  Fit: log-likelihood=%s, iter=%d, converged=%b", fitter.getLogLikelihood(), fitter.getIterations(), result);
                }
                return result ? fitter : null;
            } catch (NonPositiveDefiniteMatrixException | SingularMatrixException ex) {
                failures.getAndIncrement();
                if (settings.debug) {
                    ImageJUtils.log("  Fit failed during iteration %d. No variance in a sub-population " + "component (check alpha is not always 1.0).", fitter.getIterations());
                }
            } finally {
                ticker.tick();
            }
            return null;
        }));
    }
    ImageJUtils.finished();
    if (failures.get() != 0 && settings.debug) {
        ImageJUtils.log("  %d component fit failed %d/%d", numComponents, failures.get(), settings.repeats);
    }
    // Collect results and return the best model.
    return results.stream().map(f -> f.join()).filter(f -> f != null).sorted((f1, f2) -> Double.compare(f2.getLogLikelihood(), f1.getLogLikelihood())).findFirst().orElse(null);
}
Also used : UnitSphereSampler(org.apache.commons.rng.sampling.UnitSphereSampler) Color(java.awt.Color) Arrays(java.util.Arrays) ByteProcessor(ij.process.ByteProcessor) Calibration(uk.ac.sussex.gdsc.smlm.data.config.CalibrationProtos.Calibration) IntUnaryOperator(java.util.function.IntUnaryOperator) HistogramPlotBuilder(uk.ac.sussex.gdsc.core.ij.HistogramPlot.HistogramPlotBuilder) IdFramePeakResultComparator(uk.ac.sussex.gdsc.smlm.results.sort.IdFramePeakResultComparator) UnaryOperator(java.util.function.UnaryOperator) RealVector(org.apache.commons.math3.linear.RealVector) Evaluation(org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation) MultivariateJacobianFunction(org.apache.commons.math3.fitting.leastsquares.MultivariateJacobianFunction) VisibleForTesting(uk.ac.sussex.gdsc.core.data.VisibleForTesting) MemoryPeakResults(uk.ac.sussex.gdsc.smlm.results.MemoryPeakResults) NonPositiveDefiniteMatrixException(org.apache.commons.math3.linear.NonPositiveDefiniteMatrixException) LeastSquaresFactory(org.apache.commons.math3.fitting.leastsquares.LeastSquaresFactory) RowSorter(javax.swing.RowSorter) JFrame(javax.swing.JFrame) LutHelper(uk.ac.sussex.gdsc.core.ij.process.LutHelper) KeyStroke(javax.swing.KeyStroke) ConvergenceException(org.apache.commons.math3.exception.ConvergenceException) DistanceUnit(uk.ac.sussex.gdsc.smlm.data.config.UnitProtos.DistanceUnit) KeyEvent(java.awt.event.KeyEvent) WindowAdapter(java.awt.event.WindowAdapter) TextUtils(uk.ac.sussex.gdsc.core.utils.TextUtils) Plot(ij.gui.Plot) TIntHashSet(gnu.trove.set.hash.TIntHashSet) ImagePlus(ij.ImagePlus) DefaultTableCellRenderer(javax.swing.table.DefaultTableCellRenderer) TDoubleArrayList(gnu.trove.list.array.TDoubleArrayList) SumOfSquaredDeviations(uk.ac.sussex.gdsc.core.math.SumOfSquaredDeviations) BasicStroke(java.awt.BasicStroke) RealMatrix(org.apache.commons.math3.linear.RealMatrix) MixtureMultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution) FDistribution(org.apache.commons.math3.distribution.FDistribution) PlugIn(ij.plugin.PlugIn) ActionListener(java.awt.event.ActionListener) MultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution) PolygonRoi(ij.gui.PolygonRoi) StoredData(uk.ac.sussex.gdsc.core.utils.StoredData) WindowManager(ij.WindowManager) PeakResult(uk.ac.sussex.gdsc.smlm.results.PeakResult) Supplier(java.util.function.Supplier) PointRoi(ij.gui.PointRoi) Trace(uk.ac.sussex.gdsc.smlm.results.Trace) MultiDialog(uk.ac.sussex.gdsc.core.ij.gui.MultiDialog) UnitSphereSampler(org.apache.commons.rng.sampling.UnitSphereSampler) GenericDialog(ij.gui.GenericDialog) AbstractTableModel(javax.swing.table.AbstractTableModel) SortUtils(uk.ac.sussex.gdsc.core.utils.SortUtils) Overlay(ij.gui.Overlay) IntFunction(java.util.function.IntFunction) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix) Pair(org.apache.commons.math3.util.Pair) Mean(uk.ac.sussex.gdsc.core.math.Mean) Window(java.awt.Window) LeastSquaresProblem(org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem) AttributePeakResult(uk.ac.sussex.gdsc.smlm.results.AttributePeakResult) JScrollPane(javax.swing.JScrollPane) ConvergenceChecker(org.apache.commons.math3.optim.ConvergenceChecker) ListSelectionListener(javax.swing.event.ListSelectionListener) PeakResultStoreList(uk.ac.sussex.gdsc.smlm.results.PeakResultStoreList) Optimum(org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum) DoubleEquality(uk.ac.sussex.gdsc.core.utils.DoubleEquality) TIntObjectHashMap(gnu.trove.map.hash.TIntObjectHashMap) TIntArrayList(gnu.trove.list.array.TIntArrayList) Mixers(uk.ac.sussex.gdsc.core.utils.rng.Mixers) TextWindow(ij.text.TextWindow) IntConsumer(java.util.function.IntConsumer) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) DataException(uk.ac.sussex.gdsc.core.data.DataException) NonBlockingExtendedGenericDialog(uk.ac.sussex.gdsc.core.ij.gui.NonBlockingExtendedGenericDialog) ScreenDimensionHelper(uk.ac.sussex.gdsc.core.ij.gui.ScreenDimensionHelper) MathUtils(uk.ac.sussex.gdsc.core.utils.MathUtils) CalibrationWriter(uk.ac.sussex.gdsc.smlm.data.config.CalibrationWriter) ListSelectionEvent(javax.swing.event.ListSelectionEvent) ArrayRealVector(org.apache.commons.math3.linear.ArrayRealVector) DoubleDoubleBiPredicate(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.DoubleDoubleBiPredicate) JMenuBar(javax.swing.JMenuBar) BufferedTextWindow(uk.ac.sussex.gdsc.core.ij.BufferedTextWindow) ExtendedGenericDialog(uk.ac.sussex.gdsc.core.ij.gui.ExtendedGenericDialog) JMenu(javax.swing.JMenu) TIntIntHashMap(gnu.trove.map.hash.TIntIntHashMap) MultivariateGaussianMixtureExpectationMaximization(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization) WindowEvent(java.awt.event.WindowEvent) List(java.util.List) SimpleArrayUtils(uk.ac.sussex.gdsc.core.utils.SimpleArrayUtils) JTable(javax.swing.JTable) LUT(ij.process.LUT) TypeConverter(uk.ac.sussex.gdsc.core.data.utils.TypeConverter) Roi(ij.gui.Roi) IntStream(java.util.stream.IntStream) PrecisionResultProcedure(uk.ac.sussex.gdsc.smlm.results.procedures.PrecisionResultProcedure) ParameterValidator(org.apache.commons.math3.fitting.leastsquares.ParameterValidator) TDoubleList(gnu.trove.list.TDoubleList) ValidationUtils(uk.ac.sussex.gdsc.core.utils.ValidationUtils) CompletableFuture(java.util.concurrent.CompletableFuture) LevenbergMarquardtOptimizer(org.apache.commons.math3.fitting.leastsquares.LevenbergMarquardtOptimizer) WindowOrganiser(uk.ac.sussex.gdsc.core.ij.plugin.WindowOrganiser) SingularMatrixException(org.apache.commons.math3.linear.SingularMatrixException) AtomicReference(java.util.concurrent.atomic.AtomicReference) SwingConstants(javax.swing.SwingConstants) DoubleUnaryOperator(java.util.function.DoubleUnaryOperator) JMenuItem(javax.swing.JMenuItem) Statistics(uk.ac.sussex.gdsc.core.utils.Statistics) DoubleData(uk.ac.sussex.gdsc.core.utils.DoubleData) TFloatArrayList(gnu.trove.list.array.TFloatArrayList) UniformRandomProvider(org.apache.commons.rng.UniformRandomProvider) ChiSquaredDistributionTable(uk.ac.sussex.gdsc.smlm.function.ChiSquaredDistributionTable) LutColour(uk.ac.sussex.gdsc.core.ij.process.LutHelper.LutColour) Ticker(uk.ac.sussex.gdsc.core.logging.Ticker) ActionEvent(java.awt.event.ActionEvent) TooManyIterationsException(org.apache.commons.math3.exception.TooManyIterationsException) CalibrationReader(uk.ac.sussex.gdsc.smlm.data.config.CalibrationReader) Consumer(java.util.function.Consumer) ImageWindow(ij.gui.ImageWindow) SimpleRegression(org.apache.commons.math3.stat.regression.SimpleRegression) BinMethod(uk.ac.sussex.gdsc.core.ij.HistogramPlot.BinMethod) HistogramPlot(uk.ac.sussex.gdsc.core.ij.HistogramPlot) ImageJUtils(uk.ac.sussex.gdsc.core.ij.ImageJUtils) TableColumnAdjuster(uk.ac.sussex.gdsc.smlm.ij.gui.TableColumnAdjuster) IJ(ij.IJ) BitSet(java.util.BitSet) DiagonalMatrix(org.apache.commons.math3.linear.DiagonalMatrix) Collections(java.util.Collections) LocalList(uk.ac.sussex.gdsc.core.utils.LocalList) UniformRandomProviders(uk.ac.sussex.gdsc.core.utils.rng.UniformRandomProviders) LocalList(uk.ac.sussex.gdsc.core.utils.LocalList) CompletableFuture(java.util.concurrent.CompletableFuture) DoubleDoubleBiPredicate(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.DoubleDoubleBiPredicate) MixtureMultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) Ticker(uk.ac.sussex.gdsc.core.logging.Ticker) MultivariateGaussianMixtureExpectationMaximization(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization)

Example 22 with Covariance

use of org.apache.commons.math3.stat.correlation.Covariance in project GDSC-SMLM by aherbert.

the class MultivariateGaussianMixtureExpectationMaximizationTest method getColumnMeans.

/**
 * Gets the column means. This is done using the same method as the means in the Apache Commons
 * Math Covariance class.
 *
 * @param data the data
 * @return the column means
 */
private static double[] getColumnMeans(double[][] data) {
    final Array2DRowRealMatrix m = new Array2DRowRealMatrix(data);
    final Mean mean = new Mean();
    return IntStream.range(0, data[0].length).mapToDouble(i -> mean.evaluate(m.getColumn(i))).toArray();
}
Also used : IntStream(java.util.stream.IntStream) RandomUtils(uk.ac.sussex.gdsc.core.utils.rng.RandomUtils) Arrays(java.util.Arrays) MultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution) BaseTimingTask(uk.ac.sussex.gdsc.test.utils.BaseTimingTask) RngUtils(uk.ac.sussex.gdsc.test.rng.RngUtils) Covariance(org.apache.commons.math3.stat.correlation.Covariance) ArrayList(java.util.ArrayList) Level(java.util.logging.Level) MultivariateNormalMixtureExpectationMaximization(org.apache.commons.math3.distribution.fitting.MultivariateNormalMixtureExpectationMaximization) AfterAll(org.junit.jupiter.api.AfterAll) Mean(org.apache.commons.math3.stat.descriptive.moment.Mean) TimingService(uk.ac.sussex.gdsc.test.utils.TimingService) BeforeAll(org.junit.jupiter.api.BeforeAll) ContinuousUniformSampler(org.apache.commons.rng.sampling.distribution.ContinuousUniformSampler) MultivariateNormalDistribution(org.apache.commons.math3.distribution.MultivariateNormalDistribution) TestComplexity(uk.ac.sussex.gdsc.test.utils.TestComplexity) MixtureMultivariateNormalDistribution(org.apache.commons.math3.distribution.MixtureMultivariateNormalDistribution) MathUtils(uk.ac.sussex.gdsc.core.utils.MathUtils) TestAssertions(uk.ac.sussex.gdsc.test.api.TestAssertions) RandomSeed(uk.ac.sussex.gdsc.test.junit5.RandomSeed) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix) UniformRandomProvider(org.apache.commons.rng.UniformRandomProvider) SpeedTag(uk.ac.sussex.gdsc.test.junit5.SpeedTag) Pair(org.apache.commons.math3.util.Pair) DoubleDoubleBiPredicate(uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate) RandomGeneratorAdapter(uk.ac.sussex.gdsc.core.utils.rng.RandomGeneratorAdapter) Logger(java.util.logging.Logger) SamplerUtils(uk.ac.sussex.gdsc.core.utils.rng.SamplerUtils) SeededTest(uk.ac.sussex.gdsc.test.junit5.SeededTest) Test(org.junit.jupiter.api.Test) List(java.util.List) Assumptions(org.junit.jupiter.api.Assumptions) TestSettings(uk.ac.sussex.gdsc.test.utils.TestSettings) SimpleArrayUtils(uk.ac.sussex.gdsc.core.utils.SimpleArrayUtils) SharedStateContinuousSampler(org.apache.commons.rng.sampling.distribution.SharedStateContinuousSampler) Assertions(org.junit.jupiter.api.Assertions) TestHelper(uk.ac.sussex.gdsc.test.api.TestHelper) MixtureMultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution) NormalizedGaussianSampler(org.apache.commons.rng.sampling.distribution.NormalizedGaussianSampler) LocalList(uk.ac.sussex.gdsc.core.utils.LocalList) Mean(org.apache.commons.math3.stat.descriptive.moment.Mean) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix)

Example 23 with Covariance

use of org.apache.commons.math3.stat.correlation.Covariance in project GDSC-SMLM by aherbert.

the class MultivariateGaussianMixtureExpectationMaximizationTest method createData2d.

/**
 * Creates the data from a mixture of n 2D Gaussian distributions. The length of the weights array
 * (and all other arrays) is the number of mixture components.
 *
 * @param sampleSize the sample size
 * @param rng the random generator
 * @param weights the weights for each component
 * @param means the means for the x and y dimensions
 * @param stdDevs the std devs for the x and y dimensions
 * @param correlations the correlations between the x and y dimensions
 * @return the double[][]
 */
private static double[][] createData2d(int sampleSize, UniformRandomProvider rng, double[] weights, double[][] means, double[][] stdDevs, double[] correlations) {
    // Use Commons Math for sampling
    final ArrayList<Pair<Double, MultivariateNormalDistribution>> components = new ArrayList<>();
    for (int i = 0; i < weights.length; i++) {
        // Create covariance matrix
        final double sx = stdDevs[i][0];
        final double sy = stdDevs[i][1];
        final double sxsy = correlations[i] * sx * sy;
        final double[][] covar = new double[][] { { sx * sx, sxsy }, { sxsy, sy * sy } };
        components.add(new Pair<>(weights[i], new MultivariateNormalDistribution(means[i], covar)));
    }
    final MixtureMultivariateNormalDistribution dist = new MixtureMultivariateNormalDistribution(new RandomGeneratorAdapter(rng), components);
    return dist.sample(sampleSize);
}
Also used : MixtureMultivariateNormalDistribution(org.apache.commons.math3.distribution.MixtureMultivariateNormalDistribution) RandomGeneratorAdapter(uk.ac.sussex.gdsc.core.utils.rng.RandomGeneratorAdapter) ArrayList(java.util.ArrayList) MultivariateNormalDistribution(org.apache.commons.math3.distribution.MultivariateNormalDistribution) MixtureMultivariateNormalDistribution(org.apache.commons.math3.distribution.MixtureMultivariateNormalDistribution) Pair(org.apache.commons.math3.util.Pair)

Example 24 with Covariance

use of org.apache.commons.math3.stat.correlation.Covariance in project IR_Base by Linda-sunshine.

the class TUIR method update_SigmaP.

// variational inference for p(P|\nu,\Sigma) for each user
private void update_SigmaP(_User user) {
    _User4ETBIR u = (_User4ETBIR) user;
    int idx = m_usersIndex.get(u.getUserID());
    ArrayList<Integer> Iu = new ArrayList<>();
    if (m_mapByUser.containsKey(idx))
        // all the items reviewed by this user
        Iu = m_mapByUser.get(idx);
    RealMatrix eta_stat_sigma = MatrixUtils.createRealIdentityMatrix(number_of_topics).scalarMultiply(m_sigma);
    for (Integer itemIdx : Iu) {
        _Product4ETBIR item = (_Product4ETBIR) m_items.get(itemIdx);
        RealMatrix eta_vec = MatrixUtils.createColumnRealMatrix(item.m_eta);
        double eta0 = Utils.sumOfArray(item.m_eta);
        RealMatrix eta_stat_i = MatrixUtils.createRealDiagonalMatrix(item.m_eta).add(eta_vec.multiply(eta_vec.transpose()));
        eta_stat_sigma = eta_stat_sigma.add(eta_stat_i.scalarMultiply(m_rho / (eta0 * (eta0 + 1.0))));
    }
    eta_stat_sigma = new LUDecomposition(eta_stat_sigma).getSolver().getInverse();
    for (int k = 0; k < number_of_topics; k++) // all topics share the same covariance
    u.m_SigmaP[k] = eta_stat_sigma.getData();
}
Also used : RealMatrix(org.apache.commons.math3.linear.RealMatrix) LUDecomposition(org.apache.commons.math3.linear.LUDecomposition)

Example 25 with Covariance

use of org.apache.commons.math3.stat.correlation.Covariance in project lucene-solr by apache.

the class CovarianceEvaluator method evaluate.

public Number evaluate(Tuple tuple) throws IOException {
    StreamEvaluator colEval1 = subEvaluators.get(0);
    StreamEvaluator colEval2 = subEvaluators.get(1);
    List<Number> numbers1 = (List<Number>) colEval1.evaluate(tuple);
    List<Number> numbers2 = (List<Number>) colEval2.evaluate(tuple);
    double[] column1 = new double[numbers1.size()];
    double[] column2 = new double[numbers2.size()];
    for (int i = 0; i < numbers1.size(); i++) {
        column1[i] = numbers1.get(i).doubleValue();
    }
    for (int i = 0; i < numbers2.size(); i++) {
        column2[i] = numbers2.get(i).doubleValue();
    }
    Covariance covariance = new Covariance();
    return covariance.covariance(column1, column2);
}
Also used : Covariance(org.apache.commons.math3.stat.correlation.Covariance) List(java.util.List)

Aggregations

RealMatrix (org.apache.commons.math3.linear.RealMatrix)27 Array2DRowRealMatrix (org.apache.commons.math3.linear.Array2DRowRealMatrix)10 IntStream (java.util.stream.IntStream)9 ArrayList (java.util.ArrayList)8 List (java.util.List)7 Nonnull (javax.annotation.Nonnull)7 Collectors (java.util.stream.Collectors)6 Stream (java.util.stream.Stream)5 Covariance (org.apache.commons.math3.stat.correlation.Covariance)5 java.util (java.util)4 Supplier (java.util.function.Supplier)4 Nullable (javax.annotation.Nullable)4 ImmutablePair (org.apache.commons.lang3.tuple.ImmutablePair)4 TooManyEvaluationsException (org.apache.commons.math3.exception.TooManyEvaluationsException)4 EigenDecomposition (org.apache.commons.math3.linear.EigenDecomposition)4 Logger (org.apache.logging.log4j.Logger)4 UserException (org.broadinstitute.hellbender.exceptions.UserException)4 Nd4jIOUtils (org.broadinstitute.hellbender.tools.coveragemodel.nd4jutils.Nd4jIOUtils)4 Utils (org.broadinstitute.hellbender.utils.Utils)4 ParamUtils (org.broadinstitute.hellbender.utils.param.ParamUtils)4