use of uk.ac.sussex.gdsc.smlm.function.FastLog in project GDSC-SMLM by aherbert.
the class LvmGradientProcedureTest method gradientProcedureUnrolledComputesSameAsGradientProcedure.
private void gradientProcedureUnrolledComputesSameAsGradientProcedure(RandomSeed seed, int nparams, Type type, boolean precomputed) {
final int iter = 10;
final ArrayList<double[]> paramsList = new ArrayList<>(iter);
final ArrayList<double[]> yList = new ArrayList<>(iter);
createFakeData(RngUtils.create(seed.getSeed()), nparams, iter, paramsList, yList);
Gradient1Function func = new FakeGradientFunction(blockWidth, nparams);
if (precomputed) {
final double[] b = SimpleArrayUtils.newArray(func.size(), 0.1, 1.3);
func = OffsetGradient1Function.wrapGradient1Function(func, b);
}
final FastLog fastLog = type == Type.FAST_LOG_MLE ? getFastLog() : null;
final String name = String.format("[%d] %b", nparams, type);
// Create messages
final IndexSupplier msgR = new IndexSupplier(1, name + "Result: Not same ", null);
final IndexSupplier msgOb = new IndexSupplier(1, name + "Observations: Not same beta ", null);
final IndexSupplier msgOal = new IndexSupplier(1, name + "Observations: Not same alpha linear ", null);
final IndexSupplier msgOam = new IndexSupplier(1, name + "Observations: Not same alpha matrix ", null);
for (int i = 0; i < paramsList.size(); i++) {
final LvmGradientProcedure p1 = createProcedure(type, yList.get(i), func, fastLog);
p1.gradient(paramsList.get(i));
final LvmGradientProcedure p2 = LvmGradientProcedureUtils.create(yList.get(i), func, type, fastLog);
p2.gradient(paramsList.get(i));
// Exactly the same ...
Assertions.assertEquals(p1.value, p2.value, msgR.set(0, i));
Assertions.assertArrayEquals(p1.beta, p2.beta, msgOb.set(0, i));
Assertions.assertArrayEquals(p1.getAlphaLinear(), p2.getAlphaLinear(), msgOal.set(0, i));
final double[][] am1 = p1.getAlphaMatrix();
final double[][] am2 = p2.getAlphaMatrix();
Assertions.assertArrayEquals(am1, am2, msgOam.set(0, i));
}
}
use of uk.ac.sussex.gdsc.smlm.function.FastLog in project GDSC-SMLM by aherbert.
the class LvmGradientProcedureTest method gradientProcedureFactoryCreatesOptimisedProcedures.
@SeededTest
void gradientProcedureFactoryCreatesOptimisedProcedures() {
final DummyGradientFunction[] f = new DummyGradientFunction[7];
for (int i = 1; i < f.length; i++) {
f[i] = new DummyGradientFunction(i);
}
final LvmGradientProcedureUtils.Type mle = LvmGradientProcedureUtils.Type.MLE;
final LvmGradientProcedureUtils.Type wlsq = LvmGradientProcedureUtils.Type.WLSQ;
final LvmGradientProcedureUtils.Type lsq = LvmGradientProcedureUtils.Type.LSQ;
final LvmGradientProcedureUtils.Type fmle = LvmGradientProcedureUtils.Type.FAST_LOG_MLE;
final FastLog fl = getFastLog();
// @formatter:off
// Generic factory
final double[] y0 = new double[1];
final double[] y1 = new double[] { 1 };
Assertions.assertEquals(LvmGradientProcedureUtils.create(y0, f[6], lsq, fl).getClass(), LsqLvmGradientProcedure6.class);
Assertions.assertEquals(LvmGradientProcedureUtils.create(y0, f[5], lsq, fl).getClass(), LsqLvmGradientProcedure5.class);
Assertions.assertEquals(LvmGradientProcedureUtils.create(y0, f[4], lsq, fl).getClass(), LsqLvmGradientProcedure4.class);
Assertions.assertEquals(LvmGradientProcedureUtils.create(y0, f[1], lsq, fl).getClass(), LsqLvmGradientProcedure.class);
Assertions.assertEquals(LvmGradientProcedureUtils.create(y0, f[6], mle, fl).getClass(), MleLvmGradientProcedure6.class);
Assertions.assertEquals(LvmGradientProcedureUtils.create(y0, f[5], mle, fl).getClass(), MleLvmGradientProcedure5.class);
Assertions.assertEquals(LvmGradientProcedureUtils.create(y0, f[4], mle, fl).getClass(), MleLvmGradientProcedure4.class);
Assertions.assertEquals(LvmGradientProcedureUtils.create(y0, f[1], mle, fl).getClass(), MleLvmGradientProcedure.class);
Assertions.assertEquals(LvmGradientProcedureUtils.create(y1, f[6], mle, fl).getClass(), MleLvmGradientProcedureX6.class);
Assertions.assertEquals(LvmGradientProcedureUtils.create(y1, f[5], mle, fl).getClass(), MleLvmGradientProcedureX5.class);
Assertions.assertEquals(LvmGradientProcedureUtils.create(y1, f[4], mle, fl).getClass(), MleLvmGradientProcedureX4.class);
Assertions.assertEquals(LvmGradientProcedureUtils.create(y1, f[1], mle, fl).getClass(), MleLvmGradientProcedureX.class);
Assertions.assertEquals(LvmGradientProcedureUtils.create(y0, f[6], wlsq, fl).getClass(), WLsqLvmGradientProcedure6.class);
Assertions.assertEquals(LvmGradientProcedureUtils.create(y0, f[5], wlsq, fl).getClass(), WLsqLvmGradientProcedure5.class);
Assertions.assertEquals(LvmGradientProcedureUtils.create(y0, f[4], wlsq, fl).getClass(), WLsqLvmGradientProcedure4.class);
Assertions.assertEquals(LvmGradientProcedureUtils.create(y0, f[1], wlsq, fl).getClass(), WLsqLvmGradientProcedure.class);
Assertions.assertEquals(LvmGradientProcedureUtils.create(y0, f[6], fmle, fl).getClass(), FastLogMleLvmGradientProcedure6.class);
Assertions.assertEquals(LvmGradientProcedureUtils.create(y0, f[5], fmle, fl).getClass(), FastLogMleLvmGradientProcedure5.class);
Assertions.assertEquals(LvmGradientProcedureUtils.create(y0, f[4], fmle, fl).getClass(), FastLogMleLvmGradientProcedure4.class);
Assertions.assertEquals(LvmGradientProcedureUtils.create(y0, f[1], fmle, fl).getClass(), FastLogMleLvmGradientProcedure.class);
Assertions.assertEquals(LvmGradientProcedureUtils.create(y1, f[6], fmle, fl).getClass(), FastLogMleLvmGradientProcedureX6.class);
Assertions.assertEquals(LvmGradientProcedureUtils.create(y1, f[5], fmle, fl).getClass(), FastLogMleLvmGradientProcedureX5.class);
Assertions.assertEquals(LvmGradientProcedureUtils.create(y1, f[4], fmle, fl).getClass(), FastLogMleLvmGradientProcedureX4.class);
Assertions.assertEquals(LvmGradientProcedureUtils.create(y1, f[1], fmle, fl).getClass(), FastLogMleLvmGradientProcedureX.class);
// Dedicated factories
Assertions.assertEquals(LsqLvmGradientProcedureUtils.create(y0, f[6]).getClass(), LsqLvmGradientProcedure6.class);
Assertions.assertEquals(LsqLvmGradientProcedureUtils.create(y0, f[5]).getClass(), LsqLvmGradientProcedure5.class);
Assertions.assertEquals(LsqLvmGradientProcedureUtils.create(y0, f[4]).getClass(), LsqLvmGradientProcedure4.class);
Assertions.assertEquals(LsqLvmGradientProcedureUtils.create(y0, f[1]).getClass(), LsqLvmGradientProcedure.class);
Assertions.assertEquals(MleLvmGradientProcedureUtils.create(y0, f[6]).getClass(), MleLvmGradientProcedure6.class);
Assertions.assertEquals(MleLvmGradientProcedureUtils.create(y0, f[5]).getClass(), MleLvmGradientProcedure5.class);
Assertions.assertEquals(MleLvmGradientProcedureUtils.create(y0, f[4]).getClass(), MleLvmGradientProcedure4.class);
Assertions.assertEquals(MleLvmGradientProcedureUtils.create(y0, f[1]).getClass(), MleLvmGradientProcedure.class);
Assertions.assertEquals(MleLvmGradientProcedureUtils.create(y1, f[6]).getClass(), MleLvmGradientProcedureX6.class);
Assertions.assertEquals(MleLvmGradientProcedureUtils.create(y1, f[5]).getClass(), MleLvmGradientProcedureX5.class);
Assertions.assertEquals(MleLvmGradientProcedureUtils.create(y1, f[4]).getClass(), MleLvmGradientProcedureX4.class);
Assertions.assertEquals(MleLvmGradientProcedureUtils.create(y1, f[1]).getClass(), MleLvmGradientProcedureX.class);
Assertions.assertEquals(WLsqLvmGradientProcedureUtils.create(y0, null, f[6]).getClass(), WLsqLvmGradientProcedure6.class);
Assertions.assertEquals(WLsqLvmGradientProcedureUtils.create(y0, null, f[5]).getClass(), WLsqLvmGradientProcedure5.class);
Assertions.assertEquals(WLsqLvmGradientProcedureUtils.create(y0, null, f[4]).getClass(), WLsqLvmGradientProcedure4.class);
Assertions.assertEquals(WLsqLvmGradientProcedureUtils.create(y0, null, f[1]).getClass(), WLsqLvmGradientProcedure.class);
Assertions.assertEquals(MleLvmGradientProcedureUtils.create(y0, f[6], fl).getClass(), FastLogMleLvmGradientProcedure6.class);
Assertions.assertEquals(MleLvmGradientProcedureUtils.create(y0, f[5], fl).getClass(), FastLogMleLvmGradientProcedure5.class);
Assertions.assertEquals(MleLvmGradientProcedureUtils.create(y0, f[4], fl).getClass(), FastLogMleLvmGradientProcedure4.class);
Assertions.assertEquals(MleLvmGradientProcedureUtils.create(y0, f[1], fl).getClass(), FastLogMleLvmGradientProcedure.class);
Assertions.assertEquals(MleLvmGradientProcedureUtils.create(y1, f[6], fl).getClass(), FastLogMleLvmGradientProcedureX6.class);
Assertions.assertEquals(MleLvmGradientProcedureUtils.create(y1, f[5], fl).getClass(), FastLogMleLvmGradientProcedureX5.class);
Assertions.assertEquals(MleLvmGradientProcedureUtils.create(y1, f[4], fl).getClass(), FastLogMleLvmGradientProcedureX4.class);
Assertions.assertEquals(MleLvmGradientProcedureUtils.create(y1, f[1], fl).getClass(), FastLogMleLvmGradientProcedureX.class);
// @formatter:on
}
Aggregations