Search in sources :

Example 6 with MixtureMultivariateGaussianDistribution

use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution in project GDSC-SMLM by aherbert.

the class TrackPopulationAnalysis method run.

@Override
public void run(String arg) {
    SmlmUsageTracker.recordPlugin(this.getClass(), arg);
    if (MemoryPeakResults.isMemoryEmpty()) {
        IJ.error(TITLE, "No localisations in memory");
        return;
    }
    settings = Settings.load();
    // Saved by reference so just save now
    settings.save();
    // Read in multiple traced datasets
    // All datasets must have the same pixel pitch and exposure time
    // Get parameters
    // Convert datasets to tracks
    // For each track compute the 4 local track features using the configured window
    // 
    // Optional:
    // Fit a multi-variate Gaussian mixture model to the data
    // (using the configured number of components/populations)
    // Assign each point in the track using the model.
    // Smooth the assignments.
    // 
    // The alternative is to use the localisation category to assign populations.
    // 
    // Plot histograms of each track parameter, coloured by component
    final List<MemoryPeakResults> combinedResults = new LocalList<>();
    if (!showInputDialog(combinedResults)) {
        return;
    }
    final boolean hasCategory = showHasCategoryDialog(combinedResults);
    if (!showDialog(hasCategory)) {
        return;
    }
    ImageJUtils.log(TITLE + "...");
    final List<Trace> tracks = getTracks(combinedResults, settings.window, settings.minTrackLength);
    if (tracks.isEmpty()) {
        IJ.error(TITLE, "No tracks. Please check the input data and min track length setting.");
        return;
    }
    final Calibration cal = combinedResults.get(0).getCalibration();
    final CalibrationReader cr = new CalibrationReader(cal);
    // Use micrometer / second
    final TypeConverter<DistanceUnit> distanceConverter = cr.getDistanceConverter(DistanceUnit.UM);
    final double exposureTime = cr.getExposureTime() / 1000.0;
    final Pair<int[], double[][]> trackData = extractTrackData(tracks, distanceConverter, exposureTime, hasCategory);
    final double[][] data = trackData.getValue();
    // Histogram the raw data.
    final Array2DRowRealMatrix raw = new Array2DRowRealMatrix(data, false);
    final WindowOrganiser wo = new WindowOrganiser();
    // Store the histogram data for plotting the components
    final double[][] columns = new double[FEATURE_NAMES.length][];
    final double[][] limits = new double[FEATURE_NAMES.length][];
    // Get column data
    for (int i = 0; i < FEATURE_NAMES.length; i++) {
        columns[i] = raw.getColumn(i);
        if (i == FEATURE_D) {
            // Plot using a logarithmic scale
            SimpleArrayUtils.apply(columns[i], Math::log10);
        }
        limits[i] = MathUtils.limits(columns[i]);
    }
    // Compute histogram bins
    final int[] bins = new int[FEATURE_NAMES.length];
    if (settings.histogramBins > 0) {
        Arrays.fill(bins, settings.histogramBins);
    } else {
        for (int i = 0; i < FEATURE_NAMES.length; i++) {
            bins[i] = HistogramPlot.getBins(StoredData.create(columns[i]), BinMethod.FD);
        }
        // Use the maximum so all histograms look the same
        Arrays.fill(bins, MathUtils.max(bins));
    }
    // Compute plots
    final Plot[] plots = new Plot[FEATURE_NAMES.length];
    for (int i = 0; i < FEATURE_NAMES.length; i++) {
        final double[][] hist = HistogramPlot.calcHistogram(columns[i], limits[i][0], limits[i][1], bins[i]);
        plots[i] = new Plot(TITLE + " " + FEATURE_NAMES[i], getFeatureLabel(i, i == FEATURE_D), "Frequency");
        plots[i].addPoints(hist[0], hist[1], Plot.BAR);
        ImageJUtils.display(plots[i].getTitle(), plots[i], ImageJUtils.NO_TO_FRONT, wo);
    }
    wo.tile();
    // The component for each data point
    int[] component;
    // The number of components
    int numComponents;
    // Data used to fit the Gaussian mixture model
    double[][] fitData;
    // The fitted model
    MixtureMultivariateGaussianDistribution model;
    if (hasCategory) {
        // Use the category as the component.
        // No fit data and no output model
        fitData = null;
        model = null;
        // The component is stored at the end of the raw track data.
        final int end = data[0].length - 1;
        component = Arrays.stream(data).mapToInt(d -> (int) d[end]).toArray();
        numComponents = MathUtils.max(component) + 1;
        // In the EM algorithm the probability of each data point is computed and normalised to
        // sum to 1. The normalised probabilities are averaged to create the weights.
        // Note the probability of each data point uses the previous weight and the algorithm
        // iterates.
        // This is not a fitted model but the input model so use
        // zero weights to indicate no fitting was performed.
        final double[] weights = new double[numComponents];
        // Remove the trailing component to show the 'model' in a table.
        createModelTable(Arrays.stream(data).map(d -> Arrays.copyOf(d, end)).toArray(double[][]::new), weights, component);
    } else {
        // Multivariate Gaussian mixture EM
        // Provide option to not use the anomalous exponent in the population mix.
        int sortDimension = SORT_DIMENSION;
        if (settings.ignoreAlpha) {
            // Remove index 0. This shifts the sort dimension.
            sortDimension--;
            fitData = Arrays.stream(data).map(d -> Arrays.copyOfRange(d, 1, d.length)).toArray(double[][]::new);
        } else {
            fitData = SimpleArrayUtils.deepCopy(data);
        }
        final MultivariateGaussianMixtureExpectationMaximization mixed = fitGaussianMixture(fitData, sortDimension);
        if (mixed == null) {
            IJ.error(TITLE, "Failed to fit a mixture model");
            return;
        }
        model = sortComponents(mixed.getFittedModel(), sortDimension);
        // For the best model, assign to the most likely population.
        component = assignData(fitData, model);
        // Table of the final model using the original data (i.e. not normalised)
        final double[] weights = model.getWeights();
        numComponents = weights.length;
        createModelTable(data, weights, component);
    }
    // Output coloured histograms of the populations.
    final LUT lut = LutHelper.createLut(settings.lutIndex);
    IntFunction<Color> colourMap;
    if (LutHelper.getColour(lut, 0).equals(Color.BLACK)) {
        colourMap = i -> LutHelper.getNonZeroColour(lut, i, 0, numComponents - 1);
    } else {
        colourMap = i -> LutHelper.getColour(lut, i, 0, numComponents - 1);
    }
    for (int i = 0; i < FEATURE_NAMES.length; i++) {
        // Extract the data for each component
        final double[] col = columns[i];
        final Plot plot = plots[i];
        for (int n = 0; n < numComponents; n++) {
            final StoredData feature = new StoredData();
            for (int j = 0; j < component.length; j++) {
                if (component[j] == n) {
                    feature.add(col[j]);
                }
            }
            if (feature.size() == 0) {
                continue;
            }
            final double[][] hist = HistogramPlot.calcHistogram(feature.values(), limits[i][0], limits[i][1], bins[i]);
            // Colour the points
            plot.setColor(colourMap.apply(n));
            plot.addPoints(hist[0], hist[1], Plot.BAR);
        }
        plot.updateImage();
    }
    createTrackDataTable(tracks, trackData, fitData, model, component, cal, colourMap);
// Analysis.
// Assign the original localisations to their track component.
// Q. What about the start/end not covered by the window?
// Save tracks as a dataset labelled with the sub-track ID?
// Output for the bound component and free components track parameters.
// Compute dwell times.
// Other ...
// Track analysis plugin:
// Extract all continuous segments of the same component.
// Produce MSD plot with error bars.
// Fit using FBM model.
}
Also used : LocalList(uk.ac.sussex.gdsc.core.utils.LocalList) Array2DRowRealMatrix(org.apache.commons.math3.linear.Array2DRowRealMatrix) MixtureMultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution) MemoryPeakResults(uk.ac.sussex.gdsc.smlm.results.MemoryPeakResults) DistanceUnit(uk.ac.sussex.gdsc.smlm.data.config.UnitProtos.DistanceUnit) Plot(ij.gui.Plot) HistogramPlot(uk.ac.sussex.gdsc.core.ij.HistogramPlot) Color(java.awt.Color) LUT(ij.process.LUT) Calibration(uk.ac.sussex.gdsc.smlm.data.config.CalibrationProtos.Calibration) WindowOrganiser(uk.ac.sussex.gdsc.core.ij.plugin.WindowOrganiser) CalibrationReader(uk.ac.sussex.gdsc.smlm.data.config.CalibrationReader) Trace(uk.ac.sussex.gdsc.smlm.results.Trace) MultivariateGaussianMixtureExpectationMaximization(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization) StoredData(uk.ac.sussex.gdsc.core.utils.StoredData)

Example 7 with MixtureMultivariateGaussianDistribution

use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution in project GDSC-SMLM by aherbert.

the class MultivariateGaussianMixtureExpectationMaximization method fit.

/**
 * Fit a mixture model to the data supplied to the constructor.
 *
 * <p>The quality of the fit depends on the concavity of the data provided to the constructor and
 * the initial mixture provided to this function. If the data has many local optima, multiple runs
 * of the fitting function with different initial mixtures may be required to find the optimal
 * solution. If a SingularMatrixException is encountered, it is possible that another
 * initialisation would work.
 *
 * @param initialMixture model containing initial values of weights and multivariate normals
 * @param maxIterations maximum iterations allowed for fit
 * @param convergencePredicate convergence predicated used to test the logLikelihoods between
 *        successive iterations
 * @return true if converged within the threshold
 * @throws IllegalArgumentException if maxIterations is less than one or if initialMixture mean
 *         vector and data number of columns are not equal
 * @throws SingularMatrixException if any component's covariance matrix is singular during
 *         fitting.
 * @throws NonPositiveDefiniteMatrixException if any component's covariance matrix is not positive
 *         definite during fitting.
 */
public boolean fit(MixtureMultivariateGaussianDistribution initialMixture, int maxIterations, DoubleDoubleBiPredicate convergencePredicate) {
    ValidationUtils.checkStrictlyPositive(maxIterations, "maxIterations");
    ValidationUtils.checkNotNull(convergencePredicate, "convergencePredicate");
    ValidationUtils.checkNotNull(initialMixture, "initialMixture");
    final int n = data.length;
    final int k = initialMixture.weights.length;
    // Number of data columns.
    final int numCols = data[0].length;
    final int numMeanColumns = initialMixture.distributions[0].means.length;
    ValidationUtils.checkArgument(numCols == numMeanColumns, "Mixture model dimension mismatch with data columns: %d != %d", numCols, numMeanColumns);
    logLikelihood = -Double.MAX_VALUE;
    iterations = 0;
    // Initialize model to fit to initial mixture.
    fittedModel = initialMixture;
    while (iterations++ <= maxIterations) {
        final double previousLogLikelihood = logLikelihood;
        double sumLogLikelihood = 0;
        // Weight and distribution of each component
        final double[] weights = fittedModel.weights;
        final MultivariateGaussianDistribution[] mvns = fittedModel.distributions;
        // E-step: compute the data dependent parameters of the expectation function.
        // The percentage of row's total density between a row and a component
        final double[][] gamma = new double[n][k];
        // Sum of gamma for each component
        final double[] gammaSums = new double[k];
        // Sum of gamma times its row for each each component
        final double[][] gammaDataProdSums = new double[k][numCols];
        // Cache for the weight multiplied by the distribution density
        final double[] mvnDensity = new double[k];
        for (int i = 0; i < n; i++) {
            final double[] point = data[i];
            // Compute densities for each component and the row density
            double rowDensity = 0;
            for (int j = 0; j < k; j++) {
                final double d = weights[j] * mvns[j].density(point);
                mvnDensity[j] = d;
                rowDensity += d;
            }
            sumLogLikelihood += Math.log(rowDensity);
            for (int j = 0; j < k; j++) {
                gamma[i][j] = mvnDensity[j] / rowDensity;
                gammaSums[j] += gamma[i][j];
                for (int col = 0; col < numCols; col++) {
                    gammaDataProdSums[j][col] += gamma[i][j] * point[col];
                }
            }
        }
        logLikelihood = sumLogLikelihood;
        // M-step: compute the new parameters based on the expectation function.
        final double[] newWeights = new double[k];
        final double[][] newMeans = new double[k][numCols];
        for (int j = 0; j < k; j++) {
            newWeights[j] = gammaSums[j] / n;
            for (int col = 0; col < numCols; col++) {
                newMeans[j][col] = gammaDataProdSums[j][col] / gammaSums[j];
            }
        }
        // Compute new covariance matrices.
        // These are symmetric so we compute the triangular half.
        final double[][][] newCovMats = new double[k][numCols][numCols];
        final double[] vec = new double[numCols];
        for (int i = 0; i < n; i++) {
            final double[] point = data[i];
            for (int j = 0; j < k; j++) {
                subtract(point, newMeans[j], vec);
                final double g = gamma[i][j];
                // covariance = vec * vecT
                // covariance[ii][jj] = vec[ii] * vec[jj] * gamma[i][j]
                final double[][] covar = newCovMats[j];
                for (int ii = 0; ii < numCols; ii++) {
                    // pre-compute
                    final double vig = vec[ii] * g;
                    final double[] covari = covar[ii];
                    for (int jj = 0; jj <= ii; jj++) {
                        covari[jj] += vig * vec[jj];
                    }
                }
            }
        }
        // Converting to arrays for use by fitted model
        final MultivariateGaussianDistribution[] distributions = new MultivariateGaussianDistribution[k];
        for (int j = 0; j < k; j++) {
            // Make symmetric and normalise by gamma sum
            final double norm = 1.0 / gammaSums[j];
            final double[][] covar = newCovMats[j];
            for (int ii = 0; ii < numCols; ii++) {
                // diagonal
                covar[ii][ii] *= norm;
                // elements
                for (int jj = 0; jj < ii; jj++) {
                    final double tmp = covar[ii][jj] * norm;
                    covar[ii][jj] = tmp;
                    covar[jj][ii] = tmp;
                }
            }
            distributions[j] = new MultivariateGaussianDistribution(newMeans[j], covar);
        }
        // Update current model
        fittedModel = MixtureMultivariateGaussianDistribution.create(newWeights, distributions);
        // Check convergence
        if (convergencePredicate.test(previousLogLikelihood, logLikelihood)) {
            return true;
        }
    }
    // No convergence
    return false;
}
Also used : MultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution)

Example 8 with MixtureMultivariateGaussianDistribution

use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution in project GDSC-SMLM by aherbert.

the class MultivariateGaussianMixtureExpectationMaximization method estimate.

/**
 * Helper method to create a multivariate Gaussian mixture model which can be used to initialize
 * {@link #fit(MixtureMultivariateGaussianDistribution)}.
 *
 * <p>The function is used to extract a value for ranking all points. These are then split
 * uniformly into the specified number of components. This is equivalent to projection of the
 * points onto a line and partitioning the points along the line uniformly; this cuts the points
 * into sets using hyper-planes defined by the vector of the line. A good ranking metric is a dot
 * product with a random unit vector uniformly sampled from the surface of an n-dimension
 * hypersphere.
 *
 * <pre>
 * double[] vec = ...;
 * ToDoubleFunction&lt;double[]&gt; rankingMetric =
 *   values -&gt; {
 *     double d = 0;
 *     for (int i = 0; i &lt; vec.length; i++) {
 *       d += vec[i] * values[i];
 *     }
 *     return d;
 *   };
 * </pre>
 *
 * <p>This method can be used with the data supplied to the instance constructor to try to
 * determine a good mixture model at which to start the fit, but it is not guaranteed to supply a
 * model which will find the optimal solution or even converge.
 *
 * @param data data to estimate distribution
 * @param numComponents number of components for estimated mixture
 * @param rankingMetric the function to generate the ranking metric
 * @return Multivariate Gaussian mixture model estimated from the data
 * @throws IllegalArgumentException if data has less than 2 rows, if {@code numComponents < 2}, or
 *         if {@code numComponents} is greater than the number of data rows.
 */
public static MixtureMultivariateGaussianDistribution estimate(double[][] data, int numComponents, ToDoubleFunction<double[]> rankingMetric) {
    ValidationUtils.checkArgument(data.length >= 2, "Estimation requires at least 2 data points: %d", data.length);
    ValidationUtils.checkArgument(numComponents >= 2, "Multivariate Gaussian mixture requires at least 2 components: %d", numComponents);
    ValidationUtils.checkArgument(numComponents <= data.length, "Number of components %d greater than data length %d", numComponents, data.length);
    ValidationUtils.checkNotNull(rankingMetric, "rankingMetric");
    final int numRows = data.length;
    final int numCols = data[0].length;
    ValidationUtils.checkArgument(numCols >= 2, "Multivariate Gaussian requires at least 2 data columns: %d", numCols);
    // Sort the data
    final ProjectedData[] sortedData = new ProjectedData[numRows];
    for (int i = 0; i < numRows; i++) {
        sortedData[i] = new ProjectedData(data[i], rankingMetric.applyAsDouble(data[i]));
    }
    Arrays.sort(sortedData, (o1, o2) -> Double.compare(o1.value, o2.value));
    // components of mixture model to be created
    final MultivariateGaussianDistribution[] distributions = new MultivariateGaussianDistribution[numComponents];
    // create a component based on data in each bin
    for (int binIndex = 0; binIndex < numComponents; binIndex++) {
        // minimum index (inclusive) from sorted data for this bin
        final int minIndex = (binIndex * numRows) / numComponents;
        // maximum index (exclusive) from sorted data for this bin
        final int maxIndex = ((binIndex + 1) * numRows) / numComponents;
        // number of data records that will be in this bin
        final int numBinRows = maxIndex - minIndex;
        // data for this bin
        final double[][] binData = new double[numBinRows][];
        // mean of each column for the data in the this bin
        final double[] columnMeans = new double[numCols];
        // populate bin and create component
        for (int i = minIndex, iBin = 0; i < maxIndex; i++, iBin++) {
            final double[] values = sortedData[i].data;
            binData[iBin] = values;
            for (int j = 0; j < numCols; j++) {
                columnMeans[j] += values[j];
            }
        }
        SimpleArrayUtils.multiply(columnMeans, 1.0 / numBinRows);
        distributions[binIndex] = new MultivariateGaussianDistribution(columnMeans, covariance(columnMeans, binData));
    }
    // uniform weight for each bin
    final double[] weights = SimpleArrayUtils.newDoubleArray(numComponents, 1.0 / numComponents);
    return new MixtureMultivariateGaussianDistribution(weights, distributions);
}
Also used : MultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution)

Example 9 with MixtureMultivariateGaussianDistribution

use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution in project GDSC-SMLM by aherbert.

the class MultivariateGaussianMixtureExpectationMaximization method createMixed.

/**
 * Helper method to create a multivariate Gaussian mixture model which can be used to initialize
 * {@link #fit(MixtureMultivariateGaussianDistribution)}.
 *
 * <p>This method can be used with the data supplied to the instance constructor to try to
 * determine a good mixture model at which to start the fit, but it is not guaranteed to supply a
 * model which will find the optimal solution or even converge.
 *
 * <p>The weights for each component will be uniform.
 *
 * @param data data for the distribution
 * @param component the component for each data point
 * @return Multivariate Gaussian mixture model estimated from the data
 * @throws IllegalArgumentException if data has less than 2 rows, if there is a size mismatch
 *         between the data and components length, or if the number of components is less than 2.
 */
public static MixtureMultivariateGaussianDistribution createMixed(double[][] data, int[] component) {
    ValidationUtils.checkArgument(data.length >= 2, "Estimation requires at least 2 data points: %d", data.length);
    ValidationUtils.checkArgument(data.length == component.length, "Data and component size mismatch: %d != %d", data.length, component.length);
    final int numRows = data.length;
    final int numCols = data[0].length;
    ValidationUtils.checkArgument(numCols >= 2, "Multivariate Gaussian requires at least 2 data columns: %d", numCols);
    // Sort the data
    final ClassifiedData[] sortedData = new ClassifiedData[data.length];
    for (int i = 0; i < numRows; i++) {
        sortedData[i] = new ClassifiedData(data[i], component[i]);
    }
    Arrays.sort(sortedData, (o1, o2) -> Integer.compare(o1.value, o2.value));
    ValidationUtils.checkArgument(sortedData[0].value != sortedData[sortedData.length - 1].value, "Mixture model requires at least 2 data components");
    // components of mixture model to be created
    final LocalList<MultivariateGaussianDistribution> distributions = new LocalList<>();
    int from = 0;
    while (from < sortedData.length) {
        // Find the end
        int to = from + 1;
        final int comp = sortedData[from].value;
        while (to < sortedData.length && sortedData[to].value == comp) {
            to++;
        }
        // number of data records that will be in this component
        final int numCompRows = to - from;
        // data for this component
        final double[][] compData = new double[numCompRows][];
        // mean of each column for the data
        final double[] columnMeans = new double[numCols];
        // populate and create component
        int count = 0;
        for (int i = from; i < to; i++) {
            final double[] values = sortedData[i].data;
            compData[count++] = values;
            for (int j = 0; j < numCols; j++) {
                columnMeans[j] += values[j];
            }
        }
        SimpleArrayUtils.multiply(columnMeans, 1.0 / numCompRows);
        distributions.add(new MultivariateGaussianDistribution(columnMeans, covariance(columnMeans, compData)));
        from = to;
    }
    // uniform weight for each bin
    final int numComponents = distributions.size();
    final double[] weights = SimpleArrayUtils.newDoubleArray(numComponents, 1.0 / numComponents);
    return new MixtureMultivariateGaussianDistribution(weights, distributions.toArray(new MultivariateGaussianDistribution[0]));
}
Also used : LocalList(uk.ac.sussex.gdsc.core.utils.LocalList) MultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution)

Example 10 with MixtureMultivariateGaussianDistribution

use of uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution in project GDSC-SMLM by aherbert.

the class MultivariateGaussianMixtureExpectationMaximizationTest method canEstimateInitialMixture.

@SeededTest
void canEstimateInitialMixture(RandomSeed seed) {
    // Test verses the Commons Math estimation
    final UniformRandomProvider rng = RngUtils.create(seed.getSeed());
    final DoubleDoubleBiPredicate test = TestHelper.doublesAreClose(1e-5, 1e-16);
    // Number of components
    for (int n = 2; n <= 3; n++) {
        final double[] sampleWeights = createWeights(n, rng);
        final double[][] sampleMeans = create(n, 2, rng, -5, 5);
        final double[][] sampleStdDevs = create(n, 2, rng, 1, 10);
        final double[] sampleCorrelations = create(n, rng, -0.9, 0.9);
        final double[][] data = createData2d(1000, rng, sampleWeights, sampleMeans, sampleStdDevs, sampleCorrelations);
        final MixtureMultivariateGaussianDistribution model1 = MultivariateGaussianMixtureExpectationMaximization.estimate(data, n);
        final MixtureMultivariateNormalDistribution model2 = MultivariateNormalMixtureExpectationMaximization.estimate(data, n);
        final List<Pair<Double, MultivariateNormalDistribution>> comp = model2.getComponents();
        final double[] weights = model1.getWeights();
        final MultivariateGaussianDistribution[] distributions = model1.getDistributions();
        Assertions.assertEquals(n, comp.size());
        Assertions.assertEquals(n, weights.length);
        Assertions.assertEquals(n, distributions.length);
        for (int i = 0; i < n; i++) {
            // Must be binary equal for estimated model
            Assertions.assertEquals(comp.get(i).getFirst(), weights[i], "weight");
            final MultivariateNormalDistribution d = comp.get(i).getSecond();
            TestAssertions.assertArrayTest(d.getMeans(), distributions[i].getMeans(), test, "means");
            TestAssertions.assertArrayTest(d.getCovariances().getData(), distributions[i].getCovariances(), test, "covariances");
        }
    }
}
Also used : DoubleDoubleBiPredicate(uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate) MixtureMultivariateNormalDistribution(org.apache.commons.math3.distribution.MixtureMultivariateNormalDistribution) MixtureMultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution) MultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution) MixtureMultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution) UniformRandomProvider(org.apache.commons.rng.UniformRandomProvider) MultivariateNormalDistribution(org.apache.commons.math3.distribution.MultivariateNormalDistribution) MixtureMultivariateNormalDistribution(org.apache.commons.math3.distribution.MixtureMultivariateNormalDistribution) Pair(org.apache.commons.math3.util.Pair) SeededTest(uk.ac.sussex.gdsc.test.junit5.SeededTest)

Aggregations

MultivariateGaussianDistribution (uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution)10 MixtureMultivariateGaussianDistribution (uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution)9 SeededTest (uk.ac.sussex.gdsc.test.junit5.SeededTest)6 LocalList (uk.ac.sussex.gdsc.core.utils.LocalList)4 MixtureMultivariateNormalDistribution (org.apache.commons.math3.distribution.MixtureMultivariateNormalDistribution)3 Test (org.junit.jupiter.api.Test)3 Plot (ij.gui.Plot)2 LUT (ij.process.LUT)2 TextWindow (ij.text.TextWindow)2 Color (java.awt.Color)2 MultivariateNormalDistribution (org.apache.commons.math3.distribution.MultivariateNormalDistribution)2 Pair (org.apache.commons.math3.util.Pair)2 UniformRandomProvider (org.apache.commons.rng.UniformRandomProvider)2 BufferedTextWindow (uk.ac.sussex.gdsc.core.ij.BufferedTextWindow)2 DoubleDoubleBiPredicate (uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate)2 TDoubleList (gnu.trove.list.TDoubleList)1 TDoubleArrayList (gnu.trove.list.array.TDoubleArrayList)1 TFloatArrayList (gnu.trove.list.array.TFloatArrayList)1 TIntArrayList (gnu.trove.list.array.TIntArrayList)1 TIntIntHashMap (gnu.trove.map.hash.TIntIntHashMap)1