Search in sources :

Example 56 with DoubleDoubleBiPredicate

use of uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate in project GDSC-SMLM by aherbert.

the class MultivariateGaussianMixtureExpectationMaximizationTest method canFit.

@SeededTest
void canFit(RandomSeed seed) {
    // Test verses the Commons Math estimation
    final UniformRandomProvider rng = RngUtils.create(seed.getSeed());
    final DoubleDoubleBiPredicate test = TestHelper.doublesAreClose(1e-5, 1e-16);
    final int sampleSize = 1000;
    // Number of components
    for (int n = 2; n <= 3; n++) {
        final double[] sampleWeights = createWeights(n, rng);
        final double[][] sampleMeans = create(n, 2, rng, -5, 5);
        final double[][] sampleStdDevs = create(n, 2, rng, 1, 10);
        final double[] sampleCorrelations = create(n, rng, -0.9, 0.9);
        final double[][] data = createData2d(sampleSize, rng, sampleWeights, sampleMeans, sampleStdDevs, sampleCorrelations);
        final MixtureMultivariateGaussianDistribution initialModel1 = MultivariateGaussianMixtureExpectationMaximization.estimate(data, n);
        final MultivariateGaussianMixtureExpectationMaximization fitter1 = new MultivariateGaussianMixtureExpectationMaximization(data);
        Assertions.assertTrue(fitter1.fit(initialModel1));
        final MultivariateNormalMixtureExpectationMaximization fitter2 = new MultivariateNormalMixtureExpectationMaximization(data);
        fitter2.fit(MultivariateNormalMixtureExpectationMaximization.estimate(data, n));
        final double ll1 = fitter1.getLogLikelihood() / sampleSize;
        Assertions.assertNotEquals(0, ll1);
        final double ll2 = fitter2.getLogLikelihood();
        TestAssertions.assertTest(ll2, ll1, test);
        final MixtureMultivariateGaussianDistribution model1 = fitter1.getFittedModel();
        Assertions.assertNotNull(model1);
        final MixtureMultivariateNormalDistribution model2 = fitter2.getFittedModel();
        // Check fitted models are the same
        final List<Pair<Double, MultivariateNormalDistribution>> comp = model2.getComponents();
        final double[] weights = model1.getWeights();
        final MultivariateGaussianDistribution[] distributions = model1.getDistributions();
        Assertions.assertEquals(n, comp.size());
        Assertions.assertEquals(n, weights.length);
        Assertions.assertEquals(n, distributions.length);
        for (int i = 0; i < n; i++) {
            TestAssertions.assertTest(comp.get(i).getFirst(), weights[i], test, "weight");
            final MultivariateNormalDistribution d = comp.get(i).getSecond();
            TestAssertions.assertArrayTest(d.getMeans(), distributions[i].getMeans(), test, "means");
            TestAssertions.assertArrayTest(d.getCovariances().getData(), distributions[i].getCovariances(), test, "covariances");
        }
        final int iterations = fitter1.getIterations();
        Assertions.assertNotEquals(0, iterations);
        // Test without convergence
        if (iterations > 2) {
            Assertions.assertFalse(fitter1.fit(initialModel1, 2, DEFAULT_CONVERGENCE_CHECKER));
        }
    }
}
Also used : DoubleDoubleBiPredicate(uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate) MultivariateNormalDistribution(org.apache.commons.math3.distribution.MultivariateNormalDistribution) MixtureMultivariateNormalDistribution(org.apache.commons.math3.distribution.MixtureMultivariateNormalDistribution) MixtureMultivariateNormalDistribution(org.apache.commons.math3.distribution.MixtureMultivariateNormalDistribution) MixtureMultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution) MultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution) MixtureMultivariateGaussianDistribution(uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution) UniformRandomProvider(org.apache.commons.rng.UniformRandomProvider) MultivariateNormalMixtureExpectationMaximization(org.apache.commons.math3.distribution.fitting.MultivariateNormalMixtureExpectationMaximization) Pair(org.apache.commons.math3.util.Pair) SeededTest(uk.ac.sussex.gdsc.test.junit5.SeededTest)

Example 57 with DoubleDoubleBiPredicate

use of uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate in project GDSC-SMLM by aherbert.

the class PsfModelGradient1FunctionTest method canComputeValueAndGradient.

@Test
void canComputeValueAndGradient() {
    // Use a reasonable z-depth function from the Smith, et al (2010) paper (page 377)
    final double sx = 1.08;
    final double sy = 1.01;
    final double gamma = 0.389;
    final double d = 0.531;
    final double Ax = -0.0708;
    final double Bx = -0.073;
    final double Ay = 0.164;
    final double By = 0.0417;
    final AstigmatismZModel zModel = HoltzerAstigmatismZModel.create(sx, sy, gamma, d, Ax, Bx, Ay, By);
    // Small size ensure the PSF model covers the entire image
    final int maxx = 11;
    final int maxy = 11;
    final double[] ve = new double[maxx * maxy];
    final double[] vo = new double[maxx * maxy];
    final double[][] ge = new double[maxx * maxy][];
    final double[][] go = new double[maxx * maxy][];
    final PsfModelGradient1Function psf = new PsfModelGradient1Function(new GaussianPsfModel(zModel), maxx, maxy);
    final ErfGaussian2DFunction f = new SingleAstigmatismErfGaussian2DFunction(maxx, maxy, zModel);
    f.setErfFunction(ErfFunction.COMMONS_MATH);
    final double[] a2 = new double[Gaussian2DFunction.PARAMETERS_PER_PEAK + 1];
    final DoubleDoubleBiPredicate equality = TestHelper.doublesAreClose(1e-8, 0);
    final double c = maxx * 0.5;
    for (int i = -1; i <= 1; i++) {
        final double x0 = c + i * 0.33;
        for (int j = -1; j <= 1; j++) {
            final double x1 = c + j * 0.33;
            for (int k = -1; k <= 1; k++) {
                final double x2 = k * 0.33;
                for (final double in : new double[] { 23.2, 405.67 }) {
                    // Background is constant for gradients so just use 1 value
                    final double[] a = new double[] { 2.2, in, x0, x1, x2 };
                    psf.initialise1(a);
                    psf.forEach(new Gradient1Procedure() {

                        int index = 0;

                        @Override
                        public void execute(double value, double[] dyDa) {
                            vo[index] = value;
                            go[index] = dyDa.clone();
                            index++;
                        }
                    });
                    a2[Gaussian2DFunction.BACKGROUND] = a[0];
                    a2[Gaussian2DFunction.SIGNAL] = a[1];
                    a2[Gaussian2DFunction.X_POSITION] = a[2] - 0.5;
                    a2[Gaussian2DFunction.Y_POSITION] = a[3] - 0.5;
                    a2[Gaussian2DFunction.Z_POSITION] = a[4];
                    f.initialise1(a2);
                    f.forEach(new Gradient1Procedure() {

                        int index = 0;

                        @Override
                        public void execute(double value, double[] dyDa) {
                            ve[index] = value;
                            ge[index] = dyDa.clone();
                            index++;
                        }
                    });
                    for (int ii = 0; ii < ve.length; ii++) {
                        TestAssertions.assertTest(ve[ii], vo[ii], equality);
                        TestAssertions.assertArrayTest(ge[ii], go[ii], equality);
                    }
                }
            }
        }
    }
}
Also used : SingleAstigmatismErfGaussian2DFunction(uk.ac.sussex.gdsc.smlm.function.gaussian.erf.SingleAstigmatismErfGaussian2DFunction) ErfGaussian2DFunction(uk.ac.sussex.gdsc.smlm.function.gaussian.erf.ErfGaussian2DFunction) DoubleDoubleBiPredicate(uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate) SingleAstigmatismErfGaussian2DFunction(uk.ac.sussex.gdsc.smlm.function.gaussian.erf.SingleAstigmatismErfGaussian2DFunction) AstigmatismZModel(uk.ac.sussex.gdsc.smlm.function.gaussian.AstigmatismZModel) HoltzerAstigmatismZModel(uk.ac.sussex.gdsc.smlm.function.gaussian.HoltzerAstigmatismZModel) Gradient1Procedure(uk.ac.sussex.gdsc.smlm.function.Gradient1Procedure) Test(org.junit.jupiter.api.Test)

Example 58 with DoubleDoubleBiPredicate

use of uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate in project GDSC-SMLM by aherbert.

the class Gaussian2DPeakResultHelperTest method canComputeCumulative2DAndInverse.

@Test
void canComputeCumulative2DAndInverse() {
    Assertions.assertEquals(0, Gaussian2DPeakResultHelper.cumulative2D(0));
    Assertions.assertTrue(1 == Gaussian2DPeakResultHelper.cumulative2D(Double.POSITIVE_INFINITY));
    Assertions.assertEquals(0, Gaussian2DPeakResultHelper.inverseCumulative2D(0));
    Assertions.assertTrue(Double.POSITIVE_INFINITY == Gaussian2DPeakResultHelper.inverseCumulative2D(1));
    final DoubleDoubleBiPredicate predicate = TestHelper.doublesAreClose(1e-8, 0);
    for (int i = 1; i <= 10; i++) {
        final double r = i / 10.0;
        final double p = Gaussian2DPeakResultHelper.cumulative2D(r);
        final double r2 = Gaussian2DPeakResultHelper.inverseCumulative2D(p);
        TestAssertions.assertTest(r, r2, predicate);
    }
}
Also used : DoubleDoubleBiPredicate(uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate) SeededTest(uk.ac.sussex.gdsc.test.junit5.SeededTest) Test(org.junit.jupiter.api.Test)

Example 59 with DoubleDoubleBiPredicate

use of uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate in project GDSC-SMLM by aherbert.

the class ConvolutionTest method canComputeConvolution.

@SeededTest
void canComputeConvolution(RandomSeed seed) {
    final UniformRandomProvider random = RngUtils.create(seed.getSeed());
    int size = 10;
    final DoubleDoubleBiPredicate predicate = TestHelper.doublesAreClose(1e-6, 0);
    for (int i = 0; i < sizeLoops; i++) {
        double sd = 0.5;
        for (int j = 0; j < sdLoops; j++) {
            final double[] data = randomData(random, size);
            final double[] kernel = createKernel(sd);
            final double[] r1 = Convolution.convolve(data, kernel);
            final double[] r1b = Convolution.convolve(kernel, data);
            final double[] r2 = Convolution.convolveFft(data, kernel);
            final double[] r2b = Convolution.convolveFft(kernel, data);
            Assertions.assertEquals(r1.length, r1b.length);
            Assertions.assertEquals(r1.length, r2.length);
            Assertions.assertEquals(r1.length, r2b.length);
            TestAssertions.assertArrayTest(r1, r1b, predicate, "Spatial convolution doesn't match");
            TestAssertions.assertArrayTest(r2, r2b, predicate, "FFT convolution doesn't match");
            sd *= 2;
        }
        size *= 2;
    }
}
Also used : DoubleDoubleBiPredicate(uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate) UniformRandomProvider(org.apache.commons.rng.UniformRandomProvider) SeededTest(uk.ac.sussex.gdsc.test.junit5.SeededTest)

Example 60 with DoubleDoubleBiPredicate

use of uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate in project GDSC-SMLM by aherbert.

the class PeakResultsReaderTest method checkEqual.

private static void checkEqual(ResultsFileFormat fileFormat, boolean showDeviations, boolean showEndFrame, boolean showId, boolean showPrecision, boolean showCategory, boolean sort, MemoryPeakResults expectedResults, MemoryPeakResults actualResults) {
    Assertions.assertNotNull(actualResults, "Input results are null");
    Assertions.assertEquals(expectedResults.size(), actualResults.size(), "Size differ");
    final PeakResult[] expected = expectedResults.toArray();
    final PeakResult[] actual = actualResults.toArray();
    if (sort) {
        // Results should be sorted by time
        Arrays.sort(expected, (o1, o2) -> o1.getFrame() - o2.getFrame());
    }
    // TSF requires the bias be subtracted
    // double bias = expectedResults.getCalibration().getBias();
    final DoubleDoubleBiPredicate deltaD = TestHelper.doublesIsCloseTo(1e-5, 0);
    final FloatFloatBiPredicate deltaF = TestHelper.floatsIsCloseTo(1e-5, 0);
    for (int i = 0; i < actualResults.size(); i++) {
        final PeakResult p1 = expected[i];
        final PeakResult p2 = actual[i];
        final ObjectArrayFormatSupplier msg = new ObjectArrayFormatSupplier("%s @ [" + i + "]", 1);
        Assertions.assertEquals(p1.getFrame(), p2.getFrame(), msg.set(0, "Peak"));
        if (fileFormat == ResultsFileFormat.MALK) {
            TestAssertions.assertTest(p1.getXPosition(), p2.getXPosition(), deltaF, msg.set(0, "X"));
            TestAssertions.assertTest(p1.getYPosition(), p2.getYPosition(), deltaF, msg.set(0, "Y"));
            TestAssertions.assertTest(p1.getIntensity(), p2.getIntensity(), deltaF, msg.set(0, "Intensity"));
            continue;
        }
        Assertions.assertEquals(p1.getOrigX(), p2.getOrigX(), msg.set(0, "Orig X"));
        Assertions.assertEquals(p1.getOrigY(), p2.getOrigY(), msg.set(0, "Orig Y"));
        Assertions.assertNotNull(p2.getParameters(), msg.set(0, "Params is null"));
        if (showEndFrame) {
            Assertions.assertEquals(p1.getEndFrame(), p2.getEndFrame(), msg.set(0, "End frame"));
        }
        if (showId) {
            Assertions.assertEquals(p1.getId(), p2.getId(), msg.set(0, "ID"));
        }
        if (showDeviations) {
            Assertions.assertNotNull(p2.getParameterDeviations(), msg.set(0, "Deviations"));
        }
        if (showCategory) {
            Assertions.assertEquals(p1.getCategory(), p2.getCategory(), msg.set(0, "Category"));
        }
        // Binary should be exact for float numbers
        if (fileFormat == ResultsFileFormat.BINARY) {
            Assertions.assertEquals(p1.getOrigValue(), p2.getOrigValue(), msg.set(0, "Orig value"));
            Assertions.assertEquals(p1.getError(), p2.getError(), msg.set(0, "Error"));
            Assertions.assertEquals(p1.getNoise(), p2.getNoise(), msg.set(0, "Noise"));
            Assertions.assertEquals(p1.getMeanIntensity(), p2.getMeanIntensity(), msg.set(0, "Mean intensity"));
            Assertions.assertArrayEquals(p1.getParameters(), p2.getParameters(), msg.set(0, "Params"));
            if (showDeviations) {
                Assertions.assertArrayEquals(p1.getParameterDeviations(), p2.getParameterDeviations(), msg.set(0, "Params StdDev"));
            }
            if (showPrecision) {
                Assertions.assertEquals(p1.getPrecision(), p2.getPrecision(), msg.set(0, "Precision"));
            }
            continue;
        }
        // Otherwise have an error
        TestAssertions.assertTest(p1.getOrigValue(), p2.getOrigValue(), deltaF, msg.set(0, "Orig value"));
        TestAssertions.assertTest(p1.getError(), p2.getError(), deltaD, msg.set(0, "Error"));
        TestAssertions.assertTest(p1.getNoise(), p2.getNoise(), deltaF, msg.set(0, "Noise"));
        TestAssertions.assertTest(p1.getMeanIntensity(), p2.getMeanIntensity(), deltaF, msg.set(0, "Mean intensity"));
        TestAssertions.assertArrayTest(p1.getParameters(), p2.getParameters(), deltaF, msg.set(0, "Params"));
        if (showDeviations) {
            TestAssertions.assertArrayTest(p1.getParameterDeviations(), p2.getParameterDeviations(), deltaF, msg.set(0, "Params StdDev"));
        }
        if (showPrecision) {
            // Handle NaN precisions
            final double pa = p1.getPrecision();
            final double pb = p2.getPrecision();
            if (!Double.isNaN(pa) || !Double.isNaN(pb)) {
                TestAssertions.assertTest(p1.getPrecision(), p2.getPrecision(), deltaD, msg.set(0, "Precision"));
            }
        }
    }
    // Check the header information
    Assertions.assertEquals(expectedResults.getName(), actualResults.getName(), "Name");
    Assertions.assertEquals(expectedResults.getConfiguration(), actualResults.getConfiguration(), "Configuration");
    final Rectangle r1 = expectedResults.getBounds();
    final Rectangle r2 = actualResults.getBounds();
    if (r1 != null) {
        Assertions.assertNotNull(r2, "Bounds");
        Assertions.assertEquals(r1.x, r2.x, "Bounds x");
        Assertions.assertEquals(r1.y, r2.y, "Bounds y");
        Assertions.assertEquals(r1.width, r2.width, "Bounds width");
        Assertions.assertEquals(r1.height, r2.height, "Bounds height");
    } else {
        Assertions.assertNull(r2, "Bounds");
    }
    final Calibration c1 = expectedResults.getCalibration();
    final Calibration c2 = actualResults.getCalibration();
    if (c1 != null) {
        Assertions.assertNotNull(c2, "Calibration");
        // Be lenient and allow no TimeUnit to match TimeUnit.FRAME
        boolean ok = c1.equals(c2);
        if (!ok && new CalibrationReader(c1).getTimeUnitValue() == TimeUnit.TIME_UNIT_NA_VALUE) {
            switch(fileFormat) {
                case BINARY:
                case MALK:
                case TEXT:
                case TSF:
                    final CalibrationWriter writer = new CalibrationWriter(c1);
                    writer.setTimeUnit(TimeUnit.FRAME);
                    ok = writer.getCalibration().equals(c2);
                    break;
                default:
                    // Do not assume frames for other file formats
                    break;
            }
        }
        Assertions.assertTrue(ok, "Calibration");
    } else {
        Assertions.assertNull(c2, "Calibration");
    }
    final PSF p1 = expectedResults.getPsf();
    final PSF p2 = actualResults.getPsf();
    if (p1 != null) {
        Assertions.assertNotNull(p2, "PSF");
        Assertions.assertTrue(p1.equals(p2), "PSF");
    } else {
        Assertions.assertNull(p2, "PSF");
    }
}
Also used : DoubleDoubleBiPredicate(uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate) PSF(uk.ac.sussex.gdsc.smlm.data.config.PSFProtos.PSF) ObjectArrayFormatSupplier(uk.ac.sussex.gdsc.test.utils.functions.ObjectArrayFormatSupplier) Rectangle(java.awt.Rectangle) Calibration(uk.ac.sussex.gdsc.smlm.data.config.CalibrationProtos.Calibration) CalibrationReader(uk.ac.sussex.gdsc.smlm.data.config.CalibrationReader) FloatFloatBiPredicate(uk.ac.sussex.gdsc.test.api.function.FloatFloatBiPredicate) CalibrationWriter(uk.ac.sussex.gdsc.smlm.data.config.CalibrationWriter)

Aggregations

DoubleDoubleBiPredicate (uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate)63 SeededTest (uk.ac.sussex.gdsc.test.junit5.SeededTest)27 Test (org.junit.jupiter.api.Test)22 UniformRandomProvider (org.apache.commons.rng.UniformRandomProvider)12 PoissonDistribution (org.apache.commons.math3.distribution.PoissonDistribution)6 TDoubleArrayList (gnu.trove.list.array.TDoubleArrayList)4 ArrayList (java.util.ArrayList)4 MixtureMultivariateGaussianDistribution (uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution)4 MultivariateGaussianDistribution (uk.ac.sussex.gdsc.smlm.math3.distribution.fitting.MultivariateGaussianMixtureExpectationMaximization.MixtureMultivariateGaussianDistribution.MultivariateGaussianDistribution)4 UnivariateFunction (org.apache.commons.math3.analysis.UnivariateFunction)3 SimpsonIntegrator (org.apache.commons.math3.analysis.integration.SimpsonIntegrator)3 ArrayRealVector (org.apache.commons.math3.linear.ArrayRealVector)3 RealMatrix (org.apache.commons.math3.linear.RealMatrix)3 RealVector (org.apache.commons.math3.linear.RealVector)3 ParameterizedTest (org.junit.jupiter.params.ParameterizedTest)3 BigDecimal (java.math.BigDecimal)2 MathContext (java.math.MathContext)2 MixtureMultivariateNormalDistribution (org.apache.commons.math3.distribution.MixtureMultivariateNormalDistribution)2 MultivariateNormalDistribution (org.apache.commons.math3.distribution.MultivariateNormalDistribution)2 MultivariateJacobianFunction (org.apache.commons.math3.fitting.leastsquares.MultivariateJacobianFunction)2