Search in sources :

Example 1 with Gradient2Function

use of gdsc.smlm.function.Gradient2Function in project GDSC-SMLM by aherbert.

the class FastMLEGradient2ProcedureTest method gradientProcedureLinearIsFasterThanGradientProcedure.

private void gradientProcedureLinearIsFasterThanGradientProcedure(final int nparams) {
    org.junit.Assume.assumeTrue(speedTests || TestSettings.RUN_SPEED_TESTS);
    final int iter = 100;
    rdg = new RandomDataGenerator(new Well19937c(30051977));
    final ArrayList<double[]> paramsList = new ArrayList<double[]>(iter);
    final ArrayList<double[]> yList = new ArrayList<double[]>(iter);
    createData(1, iter, paramsList, yList);
    // Remove the timing of the function call by creating a dummy function
    final Gradient2Function func = new FakeGradientFunction(blockWidth, nparams);
    for (int i = 0; i < paramsList.size(); i++) {
        FastMLEGradient2Procedure p1 = new FastMLEGradient2Procedure(yList.get(i), func);
        p1.computeSecondDerivative(paramsList.get(i));
        p1.computeSecondDerivative(paramsList.get(i));
        FastMLEGradient2Procedure p2 = FastMLEGradient2ProcedureFactory.createUnrolled(yList.get(i), func);
        p2.computeSecondDerivative(paramsList.get(i));
        p2.computeSecondDerivative(paramsList.get(i));
        // Check they are the same
        Assert.assertArrayEquals("D1 " + i, p1.d1, p2.d1, 0);
        Assert.assertArrayEquals("D2 " + i, p1.d2, p2.d2, 0);
    }
    // Realistic loops for an optimisation
    final int loops = 15;
    // Run till stable timing
    Timer t1 = new Timer() {

        @Override
        void run() {
            for (int i = 0, k = 0; i < paramsList.size(); i++) {
                FastMLEGradient2Procedure p1 = new FastMLEGradient2Procedure(yList.get(i), func);
                for (int j = loops; j-- > 0; ) p1.computeSecondDerivative(paramsList.get(k++ % iter));
            }
        }
    };
    long time1 = t1.getTime();
    Timer t2 = new Timer(t1.loops) {

        @Override
        void run() {
            for (int i = 0, k = 0; i < paramsList.size(); i++) {
                FastMLEGradient2Procedure p2 = FastMLEGradient2ProcedureFactory.createUnrolled(yList.get(i), func);
                for (int j = loops; j-- > 0; ) p2.computeSecondDerivative(paramsList.get(k++ % iter));
            }
        }
    };
    long time2 = t2.getTime();
    log("Standard = %d : Unrolled %d = %d : %fx\n", time1, nparams, time2, (1.0 * time1) / time2);
    Assert.assertTrue(time2 < time1 * 1.5);
}
Also used : RandomDataGenerator(org.apache.commons.math3.random.RandomDataGenerator) Gradient2Function(gdsc.smlm.function.Gradient2Function) PrecomputedGradient2Function(gdsc.smlm.function.PrecomputedGradient2Function) ArrayList(java.util.ArrayList) Well19937c(org.apache.commons.math3.random.Well19937c) FakeGradientFunction(gdsc.smlm.function.FakeGradientFunction)

Aggregations

FakeGradientFunction (gdsc.smlm.function.FakeGradientFunction)1 Gradient2Function (gdsc.smlm.function.Gradient2Function)1 PrecomputedGradient2Function (gdsc.smlm.function.PrecomputedGradient2Function)1 ArrayList (java.util.ArrayList)1 RandomDataGenerator (org.apache.commons.math3.random.RandomDataGenerator)1 Well19937c (org.apache.commons.math3.random.Well19937c)1