use of hex.glm.GLM.GLMGradientSolver in project h2o-3 by h2oai.
the class L_BFGS_Test method testArcene.
// Test LSM on arcene - wide dataset with ~10k columns
// test warm start and max #iteratoions
@Test
public void testArcene() {
Key parsedKey = Key.make("arcene_parsed");
DataInfo dinfo = null;
try {
Frame source = parse_test_file(parsedKey, "smalldata/glm_test/arcene.csv");
Frame valid = new Frame(source._names.clone(), source.vecs().clone());
GLMParameters glmp = new GLMParameters(Family.gaussian);
glmp._lambda = new double[] { 1e-5 };
glmp._alpha = new double[] { 0 };
glmp._obj_reg = 0.01;
dinfo = new DataInfo(source, valid, 1, false, DataInfo.TransformType.STANDARDIZE, DataInfo.TransformType.NONE, true, false, false, /* weights */
false, /* offset */
false, /* fold */
false);
DKV.put(dinfo._key, dinfo);
GradientSolver solver = new GLMGradientSolver(null, glmp, dinfo, 1e-5, null);
L_BFGS lbfgs = new L_BFGS().setMaxIter(20);
double[] beta = MemoryManager.malloc8d(dinfo.fullN() + 1);
beta[beta.length - 1] = new GLMWeightsFun(glmp).link(source.lastVec().mean());
L_BFGS.Result r1 = lbfgs.solve(solver, beta.clone(), solver.getGradient(beta), new L_BFGS.ProgressMonitor() {
int _i = 0;
public boolean progress(double[] beta, GradientInfo ginfo) {
System.out.println(++_i + ":" + ginfo._objVal);
return true;
}
});
lbfgs.setMaxIter(50);
final int iter = r1.iter;
L_BFGS.Result r2 = lbfgs.solve(solver, r1.coefs, r1.ginfo, new L_BFGS.ProgressMonitor() {
int _i = 0;
public boolean progress(double[] beta, GradientInfo ginfo) {
System.out.println(iter + " + " + ++_i + ":" + ginfo._objVal);
return true;
}
});
System.out.println();
lbfgs = new L_BFGS().setMaxIter(100);
L_BFGS.Result r3 = lbfgs.solve(solver, beta.clone(), solver.getGradient(beta), new L_BFGS.ProgressMonitor() {
int _i = 0;
public boolean progress(double[] beta, GradientInfo ginfo) {
System.out.println(++_i + ":" + ginfo._objVal + ", " + ArrayUtils.l2norm2(ginfo._gradient, false));
return true;
}
});
assertEquals(r1.iter, 20);
// assertEquals (r1.iter + r2.iter,r3.iter); // should be equal? got mismatch by 2
assertEquals(r2.ginfo._objVal, r3.ginfo._objVal, 1e-8);
assertEquals(.5 * glmp._lambda[0] * ArrayUtils.l2norm(r3.coefs, true) + r3.ginfo._objVal, 1e-4, 5e-4);
assertTrue("iter# expected < 100, got " + r3.iter, r3.iter < 100);
} finally {
if (dinfo != null)
DKV.remove(dinfo._key);
Value v = DKV.get(parsedKey);
if (v != null) {
v.<Frame>get().delete();
}
}
}
Aggregations