Search in sources :

Example 6 with FastLog

use of uk.ac.sussex.gdsc.smlm.function.FastLog in project GDSC-SMLM by aherbert.

the class LvmGradientProcedureTest method gradientProcedureUnrolledComputesSameAsGradientProcedure.

private void gradientProcedureUnrolledComputesSameAsGradientProcedure(RandomSeed seed, int nparams, Type type, boolean precomputed) {
    final int iter = 10;
    final ArrayList<double[]> paramsList = new ArrayList<>(iter);
    final ArrayList<double[]> yList = new ArrayList<>(iter);
    createFakeData(RngUtils.create(seed.getSeed()), nparams, iter, paramsList, yList);
    Gradient1Function func = new FakeGradientFunction(blockWidth, nparams);
    if (precomputed) {
        final double[] b = SimpleArrayUtils.newArray(func.size(), 0.1, 1.3);
        func = OffsetGradient1Function.wrapGradient1Function(func, b);
    }
    final FastLog fastLog = type == Type.FAST_LOG_MLE ? getFastLog() : null;
    final String name = String.format("[%d] %b", nparams, type);
    // Create messages
    final IndexSupplier msgR = new IndexSupplier(1, name + "Result: Not same ", null);
    final IndexSupplier msgOb = new IndexSupplier(1, name + "Observations: Not same beta ", null);
    final IndexSupplier msgOal = new IndexSupplier(1, name + "Observations: Not same alpha linear ", null);
    final IndexSupplier msgOam = new IndexSupplier(1, name + "Observations: Not same alpha matrix ", null);
    for (int i = 0; i < paramsList.size(); i++) {
        final LvmGradientProcedure p1 = createProcedure(type, yList.get(i), func, fastLog);
        p1.gradient(paramsList.get(i));
        final LvmGradientProcedure p2 = LvmGradientProcedureUtils.create(yList.get(i), func, type, fastLog);
        p2.gradient(paramsList.get(i));
        // Exactly the same ...
        Assertions.assertEquals(p1.value, p2.value, msgR.set(0, i));
        Assertions.assertArrayEquals(p1.beta, p2.beta, msgOb.set(0, i));
        Assertions.assertArrayEquals(p1.getAlphaLinear(), p2.getAlphaLinear(), msgOal.set(0, i));
        final double[][] am1 = p1.getAlphaMatrix();
        final double[][] am2 = p2.getAlphaMatrix();
        Assertions.assertArrayEquals(am1, am2, msgOam.set(0, i));
    }
}
Also used : Gradient1Function(uk.ac.sussex.gdsc.smlm.function.Gradient1Function) OffsetGradient1Function(uk.ac.sussex.gdsc.smlm.function.OffsetGradient1Function) IndexSupplier(uk.ac.sussex.gdsc.test.utils.functions.IndexSupplier) ArrayList(java.util.ArrayList) FakeGradientFunction(uk.ac.sussex.gdsc.smlm.function.FakeGradientFunction) FastLog(uk.ac.sussex.gdsc.smlm.function.FastLog)

Example 7 with FastLog

use of uk.ac.sussex.gdsc.smlm.function.FastLog in project GDSC-SMLM by aherbert.

the class LvmGradientProcedureTest method gradientProcedureFactoryCreatesOptimisedProcedures.

@SeededTest
void gradientProcedureFactoryCreatesOptimisedProcedures() {
    final DummyGradientFunction[] f = new DummyGradientFunction[7];
    for (int i = 1; i < f.length; i++) {
        f[i] = new DummyGradientFunction(i);
    }
    final LvmGradientProcedureUtils.Type mle = LvmGradientProcedureUtils.Type.MLE;
    final LvmGradientProcedureUtils.Type wlsq = LvmGradientProcedureUtils.Type.WLSQ;
    final LvmGradientProcedureUtils.Type lsq = LvmGradientProcedureUtils.Type.LSQ;
    final LvmGradientProcedureUtils.Type fmle = LvmGradientProcedureUtils.Type.FAST_LOG_MLE;
    final FastLog fl = getFastLog();
    // @formatter:off
    // Generic factory
    final double[] y0 = new double[1];
    final double[] y1 = new double[] { 1 };
    Assertions.assertEquals(LvmGradientProcedureUtils.create(y0, f[6], lsq, fl).getClass(), LsqLvmGradientProcedure6.class);
    Assertions.assertEquals(LvmGradientProcedureUtils.create(y0, f[5], lsq, fl).getClass(), LsqLvmGradientProcedure5.class);
    Assertions.assertEquals(LvmGradientProcedureUtils.create(y0, f[4], lsq, fl).getClass(), LsqLvmGradientProcedure4.class);
    Assertions.assertEquals(LvmGradientProcedureUtils.create(y0, f[1], lsq, fl).getClass(), LsqLvmGradientProcedure.class);
    Assertions.assertEquals(LvmGradientProcedureUtils.create(y0, f[6], mle, fl).getClass(), MleLvmGradientProcedure6.class);
    Assertions.assertEquals(LvmGradientProcedureUtils.create(y0, f[5], mle, fl).getClass(), MleLvmGradientProcedure5.class);
    Assertions.assertEquals(LvmGradientProcedureUtils.create(y0, f[4], mle, fl).getClass(), MleLvmGradientProcedure4.class);
    Assertions.assertEquals(LvmGradientProcedureUtils.create(y0, f[1], mle, fl).getClass(), MleLvmGradientProcedure.class);
    Assertions.assertEquals(LvmGradientProcedureUtils.create(y1, f[6], mle, fl).getClass(), MleLvmGradientProcedureX6.class);
    Assertions.assertEquals(LvmGradientProcedureUtils.create(y1, f[5], mle, fl).getClass(), MleLvmGradientProcedureX5.class);
    Assertions.assertEquals(LvmGradientProcedureUtils.create(y1, f[4], mle, fl).getClass(), MleLvmGradientProcedureX4.class);
    Assertions.assertEquals(LvmGradientProcedureUtils.create(y1, f[1], mle, fl).getClass(), MleLvmGradientProcedureX.class);
    Assertions.assertEquals(LvmGradientProcedureUtils.create(y0, f[6], wlsq, fl).getClass(), WLsqLvmGradientProcedure6.class);
    Assertions.assertEquals(LvmGradientProcedureUtils.create(y0, f[5], wlsq, fl).getClass(), WLsqLvmGradientProcedure5.class);
    Assertions.assertEquals(LvmGradientProcedureUtils.create(y0, f[4], wlsq, fl).getClass(), WLsqLvmGradientProcedure4.class);
    Assertions.assertEquals(LvmGradientProcedureUtils.create(y0, f[1], wlsq, fl).getClass(), WLsqLvmGradientProcedure.class);
    Assertions.assertEquals(LvmGradientProcedureUtils.create(y0, f[6], fmle, fl).getClass(), FastLogMleLvmGradientProcedure6.class);
    Assertions.assertEquals(LvmGradientProcedureUtils.create(y0, f[5], fmle, fl).getClass(), FastLogMleLvmGradientProcedure5.class);
    Assertions.assertEquals(LvmGradientProcedureUtils.create(y0, f[4], fmle, fl).getClass(), FastLogMleLvmGradientProcedure4.class);
    Assertions.assertEquals(LvmGradientProcedureUtils.create(y0, f[1], fmle, fl).getClass(), FastLogMleLvmGradientProcedure.class);
    Assertions.assertEquals(LvmGradientProcedureUtils.create(y1, f[6], fmle, fl).getClass(), FastLogMleLvmGradientProcedureX6.class);
    Assertions.assertEquals(LvmGradientProcedureUtils.create(y1, f[5], fmle, fl).getClass(), FastLogMleLvmGradientProcedureX5.class);
    Assertions.assertEquals(LvmGradientProcedureUtils.create(y1, f[4], fmle, fl).getClass(), FastLogMleLvmGradientProcedureX4.class);
    Assertions.assertEquals(LvmGradientProcedureUtils.create(y1, f[1], fmle, fl).getClass(), FastLogMleLvmGradientProcedureX.class);
    // Dedicated factories
    Assertions.assertEquals(LsqLvmGradientProcedureUtils.create(y0, f[6]).getClass(), LsqLvmGradientProcedure6.class);
    Assertions.assertEquals(LsqLvmGradientProcedureUtils.create(y0, f[5]).getClass(), LsqLvmGradientProcedure5.class);
    Assertions.assertEquals(LsqLvmGradientProcedureUtils.create(y0, f[4]).getClass(), LsqLvmGradientProcedure4.class);
    Assertions.assertEquals(LsqLvmGradientProcedureUtils.create(y0, f[1]).getClass(), LsqLvmGradientProcedure.class);
    Assertions.assertEquals(MleLvmGradientProcedureUtils.create(y0, f[6]).getClass(), MleLvmGradientProcedure6.class);
    Assertions.assertEquals(MleLvmGradientProcedureUtils.create(y0, f[5]).getClass(), MleLvmGradientProcedure5.class);
    Assertions.assertEquals(MleLvmGradientProcedureUtils.create(y0, f[4]).getClass(), MleLvmGradientProcedure4.class);
    Assertions.assertEquals(MleLvmGradientProcedureUtils.create(y0, f[1]).getClass(), MleLvmGradientProcedure.class);
    Assertions.assertEquals(MleLvmGradientProcedureUtils.create(y1, f[6]).getClass(), MleLvmGradientProcedureX6.class);
    Assertions.assertEquals(MleLvmGradientProcedureUtils.create(y1, f[5]).getClass(), MleLvmGradientProcedureX5.class);
    Assertions.assertEquals(MleLvmGradientProcedureUtils.create(y1, f[4]).getClass(), MleLvmGradientProcedureX4.class);
    Assertions.assertEquals(MleLvmGradientProcedureUtils.create(y1, f[1]).getClass(), MleLvmGradientProcedureX.class);
    Assertions.assertEquals(WLsqLvmGradientProcedureUtils.create(y0, null, f[6]).getClass(), WLsqLvmGradientProcedure6.class);
    Assertions.assertEquals(WLsqLvmGradientProcedureUtils.create(y0, null, f[5]).getClass(), WLsqLvmGradientProcedure5.class);
    Assertions.assertEquals(WLsqLvmGradientProcedureUtils.create(y0, null, f[4]).getClass(), WLsqLvmGradientProcedure4.class);
    Assertions.assertEquals(WLsqLvmGradientProcedureUtils.create(y0, null, f[1]).getClass(), WLsqLvmGradientProcedure.class);
    Assertions.assertEquals(MleLvmGradientProcedureUtils.create(y0, f[6], fl).getClass(), FastLogMleLvmGradientProcedure6.class);
    Assertions.assertEquals(MleLvmGradientProcedureUtils.create(y0, f[5], fl).getClass(), FastLogMleLvmGradientProcedure5.class);
    Assertions.assertEquals(MleLvmGradientProcedureUtils.create(y0, f[4], fl).getClass(), FastLogMleLvmGradientProcedure4.class);
    Assertions.assertEquals(MleLvmGradientProcedureUtils.create(y0, f[1], fl).getClass(), FastLogMleLvmGradientProcedure.class);
    Assertions.assertEquals(MleLvmGradientProcedureUtils.create(y1, f[6], fl).getClass(), FastLogMleLvmGradientProcedureX6.class);
    Assertions.assertEquals(MleLvmGradientProcedureUtils.create(y1, f[5], fl).getClass(), FastLogMleLvmGradientProcedureX5.class);
    Assertions.assertEquals(MleLvmGradientProcedureUtils.create(y1, f[4], fl).getClass(), FastLogMleLvmGradientProcedureX4.class);
    Assertions.assertEquals(MleLvmGradientProcedureUtils.create(y1, f[1], fl).getClass(), FastLogMleLvmGradientProcedureX.class);
// @formatter:on
}
Also used : Type(uk.ac.sussex.gdsc.smlm.fitting.nonlinear.gradient.LvmGradientProcedureUtils.Type) FastLog(uk.ac.sussex.gdsc.smlm.function.FastLog) DummyGradientFunction(uk.ac.sussex.gdsc.smlm.function.DummyGradientFunction) SeededTest(uk.ac.sussex.gdsc.test.junit5.SeededTest)

Aggregations

FastLog (uk.ac.sussex.gdsc.smlm.function.FastLog)7 ArrayList (java.util.ArrayList)6 FakeGradientFunction (uk.ac.sussex.gdsc.smlm.function.FakeGradientFunction)4 DoubleEquality (uk.ac.sussex.gdsc.core.utils.DoubleEquality)2 Gradient1Function (uk.ac.sussex.gdsc.smlm.function.Gradient1Function)2 OffsetGradient1Function (uk.ac.sussex.gdsc.smlm.function.OffsetGradient1Function)2 TestCounter (uk.ac.sussex.gdsc.test.utils.TestCounter)2 TimingResult (uk.ac.sussex.gdsc.test.utils.TimingResult)2 IndexSupplier (uk.ac.sussex.gdsc.test.utils.functions.IndexSupplier)2 UniformRandomProvider (org.apache.commons.rng.UniformRandomProvider)1 SharedStateContinuousSampler (org.apache.commons.rng.sampling.distribution.SharedStateContinuousSampler)1 DenseMatrix64F (org.ejml.data.DenseMatrix64F)1 Type (uk.ac.sussex.gdsc.smlm.fitting.nonlinear.gradient.LvmGradientProcedureUtils.Type)1 DummyGradientFunction (uk.ac.sussex.gdsc.smlm.function.DummyGradientFunction)1 ValueProcedure (uk.ac.sussex.gdsc.smlm.function.ValueProcedure)1 Gaussian2DFunction (uk.ac.sussex.gdsc.smlm.function.gaussian.Gaussian2DFunction)1 ErfGaussian2DFunction (uk.ac.sussex.gdsc.smlm.function.gaussian.erf.ErfGaussian2DFunction)1 SingleFreeCircularErfGaussian2DFunction (uk.ac.sussex.gdsc.smlm.function.gaussian.erf.SingleFreeCircularErfGaussian2DFunction)1 DoubleDoubleBiPredicate (uk.ac.sussex.gdsc.test.api.function.DoubleDoubleBiPredicate)1 SeededTest (uk.ac.sussex.gdsc.test.junit5.SeededTest)1