Search in sources :

Example 6 with TestCounter

use of uk.ac.sussex.gdsc.test.utils.TestCounter in project GDSC-SMLM by aherbert.

the class LvmGradientProcedureTest method gradientProcedureComputesGradient.

@SuppressWarnings("null")
private void gradientProcedureComputesGradient(RandomSeed seed, ErfGaussian2DFunction func, Type type, boolean precomputed) {
    final int nparams = func.getNumberOfGradients();
    final int[] indices = func.gradientIndices();
    final int iter = 100;
    final ArrayList<double[]> paramsList = new ArrayList<>(iter);
    final ArrayList<double[]> yList = new ArrayList<>(iter);
    createData(RngUtils.create(seed.getSeed()), 1, iter, paramsList, yList, true);
    // for the gradients
    final double delta = 1e-4;
    final DoubleEquality eq = new DoubleEquality(5e-2, 1e-16);
    final double[] b = (precomputed) ? new double[func.size()] : null;
    final FastLog fastLog = type == Type.FAST_LOG_MLE ? getFastLog() : null;
    // Must compute most of the time
    final int failureLimit = TestCounter.computeFailureLimit(iter, 0.1);
    final TestCounter failCounter = new TestCounter(failureLimit, nparams);
    for (int i = 0; i < paramsList.size(); i++) {
        final int ii = i;
        final double[] y = yList.get(i);
        final double[] a = paramsList.get(i);
        final double[] a2 = a.clone();
        LvmGradientProcedure gp;
        if (precomputed) {
            // Mock fitting part of the function already
            for (int j = 0; j < b.length; j++) {
                b[j] = y[j] * 0.5;
            }
            gp = LvmGradientProcedureUtils.create(y, OffsetGradient1Function.wrapGradient1Function(func, b), type, fastLog);
        } else {
            gp = LvmGradientProcedureUtils.create(y, func, type, fastLog);
        }
        gp.gradient(a);
        // double s = p.value;
        final double[] beta = gp.beta.clone();
        for (int j = 0; j < nparams; j++) {
            final int jj = j;
            final int k = indices[j];
            // double d = Precision.representableDelta(a[k], (a[k] == 0) ? 1e-3 : a[k] * delta);
            final double d = Precision.representableDelta(a[k], delta);
            a2[k] = a[k] + d;
            gp.value(a2);
            final double s1 = gp.value;
            a2[k] = a[k] - d;
            gp.value(a2);
            final double s2 = gp.value;
            a2[k] = a[k];
            // Apply a factor of -2 to compute the actual gradients:
            // See Numerical Recipes in C++, 2nd Ed. Equation 15.5.6 for Nonlinear Models
            beta[j] *= -2;
            final double gradient = (s1 - s2) / (2 * d);
            // logger.fine(FunctionUtils.getSupplier("[%d,%d] %f (%s %f+/-%f) %f ?= %f", i, k, s,
            // Gaussian2DFunction.getName(k),
            // a[k], d, beta[j], gradient);
            failCounter.run(j, () -> eq.almostEqualRelativeOrAbsolute(beta[jj], gradient), () -> {
                Assertions.fail(() -> String.format("Not same gradient @ %d,%d: %s != %s (error=%s)", ii, jj, beta[jj], gradient, DoubleEquality.relativeError(beta[jj], gradient)));
            });
        }
    }
}
Also used : TestCounter(uk.ac.sussex.gdsc.test.utils.TestCounter) ArrayList(java.util.ArrayList) DoubleEquality(uk.ac.sussex.gdsc.core.utils.DoubleEquality) FastLog(uk.ac.sussex.gdsc.smlm.function.FastLog)

Example 7 with TestCounter

use of uk.ac.sussex.gdsc.test.utils.TestCounter in project GDSC-SMLM by aherbert.

the class SolverSpeedTest method solveLinearAndGaussJordanReturnSameSolutionResult.

@SeededTest
void solveLinearAndGaussJordanReturnSameSolutionResult(RandomSeed seed) {
    final int iter = 100;
    final SolverSpeedTestData data = ensureData(seed, iter);
    final ArrayList<double[][]> adata = copyAdouble(data.adata, iter);
    final ArrayList<double[]> bdata = copyBdouble(data.bdata, iter);
    final ArrayList<double[][]> adata2 = copyAdouble(data.adata, iter);
    final ArrayList<double[]> bdata2 = copyBdouble(data.bdata, iter);
    final GaussJordan solver = new GaussJordan();
    final EjmlLinearSolver solver2 = new EjmlLinearSolver();
    final int failureLimit = TestCounter.computeFailureLimit(iter, 0.1);
    final TestCounter failCounter = new TestCounter(failureLimit);
    final DoubleDoubleBiPredicate predicate = TestHelper.doublesAreClose(1e-2, 0);
    int fail = 0;
    for (int i = 0; i < iter; i++) {
        final double[][] a1 = adata.get(i);
        final double[] b1 = bdata.get(i);
        final double[][] a2 = adata2.get(i);
        final double[] b2 = bdata2.get(i);
        final boolean r1 = solver.solve(a1, b1);
        final boolean r2 = solver2.solve(a2, b2);
        // Assertions.assertTrue("Different solve result @ " + i, r1 == r2);
        if (r1 && r2) {
            failCounter.run(() -> {
                TestAssertions.assertArrayTest(b1, b2, predicate, "Different b result");
            });
        } else {
            fail++;
        }
    }
    if (fail > iter / 2) {
        Assertions.fail(String.format("Failed to solve %d / %d", fail, iter));
    }
}
Also used : TestCounter(uk.ac.sussex.gdsc.test.utils.TestCounter) DoubleDoubleBiPredicate(uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate) SeededTest(uk.ac.sussex.gdsc.test.junit5.SeededTest)

Example 8 with TestCounter

use of uk.ac.sussex.gdsc.test.utils.TestCounter in project GDSC-SMLM by aherbert.

the class SolverSpeedTest method gaussJordanFloatAndDoubleReturnSameSolutionAndInversionResult.

@SeededTest
void gaussJordanFloatAndDoubleReturnSameSolutionAndInversionResult(RandomSeed seed) {
    final int iter = 100;
    final SolverSpeedTestData data = ensureData(seed, iter);
    final ArrayList<float[][]> adata = copyAfloat(data.adata, iter);
    final ArrayList<float[]> bdata = copyBfloat(data.bdata, iter);
    final ArrayList<double[][]> adata2 = copyAdouble(data.adata, iter);
    final ArrayList<double[]> bdata2 = copyBdouble(data.bdata, iter);
    final GaussJordan solver = new GaussJordan();
    final int failureLimit = TestCounter.computeFailureLimit(iter, 0.1);
    final TestCounter failCounter = new TestCounter(failureLimit, 2);
    final DoubleDoubleBiPredicate predicate = TestHelper.doublesAreClose(1e-2, 0);
    int fail = 0;
    for (int i = 0; i < adata.size(); i++) {
        final float[][] a1 = adata.get(i);
        final float[] b1 = bdata.get(i);
        final double[][] a2 = adata2.get(i);
        final double[] b2 = bdata2.get(i);
        final boolean r1 = solver.solve(a1, b1);
        final boolean r2 = solver.solve(a2, b2);
        // Assertions.assertTrue("Different solve result @ " + i, r1 == r2);
        if (r1 && r2) {
            final double[] db1 = SimpleArrayUtils.toDouble(b1);
            final double[][] da1 = new double[a1.length][];
            for (int j = a1.length; j-- > 0; ) {
                da1[j] = SimpleArrayUtils.toDouble(a1[j]);
            }
            failCounter.run(0, () -> {
                TestAssertions.assertArrayTest(db1, b2, predicate, "Different b result");
            });
            failCounter.run(1, () -> {
                TestAssertions.assertArrayTest(da1, a2, predicate, "Different a result");
            });
        } else {
            fail++;
        }
    }
    if (fail > iter / 2) {
        Assertions.fail(String.format("Failed to solve %d / %d", fail, iter));
    }
}
Also used : DoubleDoubleBiPredicate(uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate) TestCounter(uk.ac.sussex.gdsc.test.utils.TestCounter) SeededTest(uk.ac.sussex.gdsc.test.junit5.SeededTest)

Example 9 with TestCounter

use of uk.ac.sussex.gdsc.test.utils.TestCounter in project GDSC-SMLM by aherbert.

the class FhtFilterTest method canFilter.

private static void canFilter(RandomSeed seed, Operation operation) {
    final int size = 16;
    final int ex = 5;
    final int ey = 7;
    final int ox = 1;
    final int oy = 2;
    final UniformRandomProvider r = RngUtils.create(seed.getSeed());
    final FloatProcessor fp1 = createProcessor(size, ex, ey, 4, 4, r);
    // This is offset from the centre
    final FloatProcessor fp2 = createProcessor(size, size / 2 + ox, size / 2 + oy, 4, 4, r);
    final float[] input1 = ((float[]) fp1.getPixels()).clone();
    final float[] input2 = ((float[]) fp2.getPixels()).clone();
    final FHT fht1 = new FHT(fp1);
    fht1.transform();
    final FHT fht2 = new FHT(fp2);
    fht2.transform();
    FHT fhtE;
    switch(operation) {
        case CONVOLUTION:
            fhtE = fht1.multiply(fht2);
            break;
        case CORRELATION:
            fhtE = fht1.conjugateMultiply(fht2);
            break;
        case DECONVOLUTION:
            fhtE = fht1.divide(fht2);
            break;
        default:
            throw new RuntimeException();
    }
    fhtE.inverseTransform();
    fhtE.swapQuadrants();
    final float[] e = (float[]) fhtE.getPixels();
    if (operation == Operation.CORRELATION) {
        // Test the max correlation position
        final int max = SimpleArrayUtils.findMaxIndex(e);
        final int x = max % 16;
        final int y = max / 16;
        Assertions.assertEquals(ex, x + ox);
        Assertions.assertEquals(ey, y + oy);
    }
    // Test verses a spatial domain filter in the middle of the image
    if (operation != Operation.DECONVOLUTION) {
        double sum = 0;
        float[] i2 = input2;
        if (operation == Operation.CONVOLUTION) {
            i2 = i2.clone();
            KernelFilter.rotate180(i2);
        }
        for (int i = 0; i < input1.length; i++) {
            sum += input1[i] * i2[i];
        }
        // double exp = e[size / 2 * size + size / 2];
        // logger.fine(() -> String.format("Sum = %f vs [%d] %f", sum, size / 2 * size + size / 2,
        // exp);
        Assertions.assertEquals(sum, sum, 1e-3);
    }
    // Test the FHT filter
    final FhtFilter ff = new FhtFilter(input2, size, size);
    ff.setOperation(operation);
    ff.filter(input1, size, size);
    // There may be differences due to the use of the JTransforms library
    final double error = (operation == Operation.DECONVOLUTION) ? 5e-2 : 1e-4;
    final FloatFloatBiPredicate predicate = TestHelper.floatsAreClose(error, 0);
    // This tests everything and can fail easily depending on the random generator
    // due to edge artifacts.
    // TestAssertions.assertArrayTest(e, input1, TestHelper.almostEqualFloats(error, 0));
    // This tests the centre to ignore edge differences
    final int min = size / 4;
    final int max = size - min;
    int repeats = 0;
    for (int y = min; y < max; y++) {
        for (int x = min; x < max; x++) {
            repeats++;
        }
    }
    // Use a fail counter for a 'soft' test that detects major problems
    final int failureLimit = TestCounter.computeFailureLimit(repeats, 0.1);
    final TestCounter failCounter = new TestCounter(failureLimit);
    final IndexSupplier msg = new IndexSupplier(2);
    for (int y = min; y < max; y++) {
        msg.set(1, y);
        for (int x = min; x < max; x++) {
            final int xx = x;
            final int i = y * size + x;
            failCounter.run(() -> {
                TestAssertions.assertTest(e[i], input1[i], predicate, msg.set(0, xx));
            });
        }
    }
}
Also used : TestCounter(uk.ac.sussex.gdsc.test.utils.TestCounter) FHT(ij.process.FHT) FloatProcessor(ij.process.FloatProcessor) IndexSupplier(uk.ac.sussex.gdsc.test.utils.functions.IndexSupplier) FloatFloatBiPredicate(uk.ac.sussex.gdsc.test.api.function.FloatFloatBiPredicate) UniformRandomProvider(org.apache.commons.rng.UniformRandomProvider) FhtFilter(uk.ac.sussex.gdsc.smlm.ij.filters.FhtFilter)

Aggregations

TestCounter (uk.ac.sussex.gdsc.test.utils.TestCounter)9 ArrayList (java.util.ArrayList)5 DoubleEquality (uk.ac.sussex.gdsc.core.utils.DoubleEquality)5 DoubleDoubleBiPredicate (uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate)3 SeededTest (uk.ac.sussex.gdsc.test.junit5.SeededTest)3 UniformRandomProvider (org.apache.commons.rng.UniformRandomProvider)2 FastLog (uk.ac.sussex.gdsc.smlm.function.FastLog)2 FHT (ij.process.FHT)1 FloatProcessor (ij.process.FloatProcessor)1 SharedStateContinuousSampler (org.apache.commons.rng.sampling.distribution.SharedStateContinuousSampler)1 DenseMatrix64F (org.ejml.data.DenseMatrix64F)1 ValueProcedure (uk.ac.sussex.gdsc.smlm.function.ValueProcedure)1 Gaussian2DFunction (uk.ac.sussex.gdsc.smlm.function.gaussian.Gaussian2DFunction)1 ErfGaussian2DFunction (uk.ac.sussex.gdsc.smlm.function.gaussian.erf.ErfGaussian2DFunction)1 SingleFreeCircularErfGaussian2DFunction (uk.ac.sussex.gdsc.smlm.function.gaussian.erf.SingleFreeCircularErfGaussian2DFunction)1 FhtFilter (uk.ac.sussex.gdsc.smlm.ij.filters.FhtFilter)1 FloatFloatBiPredicate (uk.ac.sussex.gdsc.test.api.function.FloatFloatBiPredicate)1 IndexSupplier (uk.ac.sussex.gdsc.test.utils.functions.IndexSupplier)1