use of gdsc.smlm.function.Gradient2Function in project GDSC-SMLM by aherbert.
the class FastMLEGradient2ProcedureTest method gradientProcedureLinearIsFasterThanGradientProcedure.
private void gradientProcedureLinearIsFasterThanGradientProcedure(final int nparams) {
org.junit.Assume.assumeTrue(speedTests || TestSettings.RUN_SPEED_TESTS);
final int iter = 100;
rdg = new RandomDataGenerator(new Well19937c(30051977));
final ArrayList<double[]> paramsList = new ArrayList<double[]>(iter);
final ArrayList<double[]> yList = new ArrayList<double[]>(iter);
createData(1, iter, paramsList, yList);
// Remove the timing of the function call by creating a dummy function
final Gradient2Function func = new FakeGradientFunction(blockWidth, nparams);
for (int i = 0; i < paramsList.size(); i++) {
FastMLEGradient2Procedure p1 = new FastMLEGradient2Procedure(yList.get(i), func);
p1.computeSecondDerivative(paramsList.get(i));
p1.computeSecondDerivative(paramsList.get(i));
FastMLEGradient2Procedure p2 = FastMLEGradient2ProcedureFactory.createUnrolled(yList.get(i), func);
p2.computeSecondDerivative(paramsList.get(i));
p2.computeSecondDerivative(paramsList.get(i));
// Check they are the same
Assert.assertArrayEquals("D1 " + i, p1.d1, p2.d1, 0);
Assert.assertArrayEquals("D2 " + i, p1.d2, p2.d2, 0);
}
// Realistic loops for an optimisation
final int loops = 15;
// Run till stable timing
Timer t1 = new Timer() {
@Override
void run() {
for (int i = 0, k = 0; i < paramsList.size(); i++) {
FastMLEGradient2Procedure p1 = new FastMLEGradient2Procedure(yList.get(i), func);
for (int j = loops; j-- > 0; ) p1.computeSecondDerivative(paramsList.get(k++ % iter));
}
}
};
long time1 = t1.getTime();
Timer t2 = new Timer(t1.loops) {
@Override
void run() {
for (int i = 0, k = 0; i < paramsList.size(); i++) {
FastMLEGradient2Procedure p2 = FastMLEGradient2ProcedureFactory.createUnrolled(yList.get(i), func);
for (int j = loops; j-- > 0; ) p2.computeSecondDerivative(paramsList.get(k++ % iter));
}
}
};
long time2 = t2.getTime();
log("Standard = %d : Unrolled %d = %d : %fx\n", time1, nparams, time2, (1.0 * time1) / time2);
Assert.assertTrue(time2 < time1 * 1.5);
}
Aggregations