use of hex.glm.GLM.GLMGradientSolver in project h2o-3 by h2oai.
the class ComputationState method setLambda.
public void setLambda(double lambda) {
adjustToNewLambda(0, _lambda);
// strong rules are to be applied on the gradient with no l2 penalty
// NOTE: we start with lambdaOld being 0, not lambda_max
// non-recursive strong rules should use lambdaMax instead of _lambda
// However, it seems tobe working nicely to use 0 instead and be more aggressive on the predictor pruning
// (shoudl be safe as we check the KKTs anyways)
applyStrongRules(lambda, _lambda);
adjustToNewLambda(lambda, 0);
_lambda = lambda;
_gslvr = new GLMGradientSolver(_job, _parms, _activeData, l2pen(), _activeBC);
}
use of hex.glm.GLM.GLMGradientSolver in project h2o-3 by h2oai.
the class ComputationState method removeCols.
public int[] removeCols(int[] cols) {
int[] activeCols = ArrayUtils.removeIds(_activeData.activeCols(), cols);
if (_beta != null)
_beta = ArrayUtils.removeIds(_beta, cols);
if (_u != null)
_u = ArrayUtils.removeIds(_u, cols);
if (_ginfo != null && _ginfo._gradient != null)
_ginfo._gradient = ArrayUtils.removeIds(_ginfo._gradient, cols);
_activeData = _dinfo.filterExpandedColumns(activeCols);
_activeBC = _bc.filterExpandedColumns(activeCols);
_gslvr = new GLMGradientSolver(_job, _parms, _activeData, (1 - _alpha) * _lambda, _activeBC);
return activeCols;
}
use of hex.glm.GLM.GLMGradientSolver in project h2o-3 by h2oai.
the class L_BFGS_Test method logistic.
@Test
public void logistic() {
Key parsedKey = Key.make("prostate");
DataInfo dinfo = null;
try {
GLMParameters glmp = new GLMParameters(Family.binomial, Family.binomial.defaultLink);
glmp._alpha = new double[] { 0 };
glmp._lambda = new double[] { 1e-5 };
Frame source = parse_test_file(parsedKey, "smalldata/glm_test/prostate_cat_replaced.csv");
source.add("CAPSULE", source.remove("CAPSULE"));
source.remove("ID").remove();
Frame valid = new Frame(source._names.clone(), source.vecs().clone());
dinfo = new DataInfo(source, valid, 1, false, DataInfo.TransformType.STANDARDIZE, DataInfo.TransformType.NONE, true, false, false, /* weights */
false, /* offset */
false, /* fold */
false);
DKV.put(dinfo._key, dinfo);
glmp._obj_reg = 1 / 380.0;
GLMGradientSolver solver = new GLMGradientSolver(null, glmp, dinfo, 1e-5, null);
L_BFGS lbfgs = new L_BFGS().setGradEps(1e-8);
double[] beta = MemoryManager.malloc8d(dinfo.fullN() + 1);
beta[beta.length - 1] = new GLMWeightsFun(glmp).link(source.vec("CAPSULE").mean());
L_BFGS.Result r = lbfgs.solve(solver, beta, solver.getGradient(beta), new L_BFGS.ProgressMonitor() {
int _i = 0;
public boolean progress(double[] beta, GradientInfo ginfo) {
System.out.println(++_i + ":" + ginfo._objVal + ", " + ArrayUtils.l2norm2(ginfo._gradient, false));
return true;
}
});
assertEquals(378.34, 2 * r.ginfo._objVal * source.numRows(), 1e-1);
} finally {
if (dinfo != null)
DKV.remove(dinfo._key);
Value v = DKV.get(parsedKey);
if (v != null) {
v.<Frame>get().delete();
}
}
}
use of hex.glm.GLM.GLMGradientSolver in project h2o-3 by h2oai.
the class L_BFGS_Test method testArcene.
// Test LSM on arcene - wide dataset with ~10k columns
// test warm start and max #iteratoions
@Test
public void testArcene() {
Key parsedKey = Key.make("arcene_parsed");
DataInfo dinfo = null;
try {
Frame source = parse_test_file(parsedKey, "smalldata/glm_test/arcene.csv");
Frame valid = new Frame(source._names.clone(), source.vecs().clone());
GLMParameters glmp = new GLMParameters(Family.gaussian);
glmp._lambda = new double[] { 1e-5 };
glmp._alpha = new double[] { 0 };
glmp._obj_reg = 0.01;
dinfo = new DataInfo(source, valid, 1, false, DataInfo.TransformType.STANDARDIZE, DataInfo.TransformType.NONE, true, false, false, /* weights */
false, /* offset */
false, /* fold */
false);
DKV.put(dinfo._key, dinfo);
GradientSolver solver = new GLMGradientSolver(null, glmp, dinfo, 1e-5, null);
L_BFGS lbfgs = new L_BFGS().setMaxIter(20);
double[] beta = MemoryManager.malloc8d(dinfo.fullN() + 1);
beta[beta.length - 1] = new GLMWeightsFun(glmp).link(source.lastVec().mean());
L_BFGS.Result r1 = lbfgs.solve(solver, beta.clone(), solver.getGradient(beta), new L_BFGS.ProgressMonitor() {
int _i = 0;
public boolean progress(double[] beta, GradientInfo ginfo) {
System.out.println(++_i + ":" + ginfo._objVal);
return true;
}
});
lbfgs.setMaxIter(50);
final int iter = r1.iter;
L_BFGS.Result r2 = lbfgs.solve(solver, r1.coefs, r1.ginfo, new L_BFGS.ProgressMonitor() {
int _i = 0;
public boolean progress(double[] beta, GradientInfo ginfo) {
System.out.println(iter + " + " + ++_i + ":" + ginfo._objVal);
return true;
}
});
System.out.println();
lbfgs = new L_BFGS().setMaxIter(100);
L_BFGS.Result r3 = lbfgs.solve(solver, beta.clone(), solver.getGradient(beta), new L_BFGS.ProgressMonitor() {
int _i = 0;
public boolean progress(double[] beta, GradientInfo ginfo) {
System.out.println(++_i + ":" + ginfo._objVal + ", " + ArrayUtils.l2norm2(ginfo._gradient, false));
return true;
}
});
assertEquals(r1.iter, 20);
// assertEquals (r1.iter + r2.iter,r3.iter); // should be equal? got mismatch by 2
assertEquals(r2.ginfo._objVal, r3.ginfo._objVal, 1e-8);
assertEquals(.5 * glmp._lambda[0] * ArrayUtils.l2norm(r3.coefs, true) + r3.ginfo._objVal, 1e-4, 5e-4);
assertTrue("iter# expected < 100, got " + r3.iter, r3.iter < 100);
} finally {
if (dinfo != null)
DKV.remove(dinfo._key);
Value v = DKV.get(parsedKey);
if (v != null) {
v.<Frame>get().delete();
}
}
}
use of hex.glm.GLM.GLMGradientSolver in project h2o-3 by h2oai.
the class ComputationState method applyStrongRules.
/**
* Apply strong rules to filter out expected inactive (with zero coefficient) predictors.
*
* @return indices of expected active predictors.
*/
protected void applyStrongRules(double lambdaNew, double lambdaOld) {
lambdaNew = Math.min(_lambdaMax, lambdaNew);
lambdaOld = Math.min(_lambdaMax, lambdaOld);
if (_parms._family == Family.multinomial) /* && _parms._solver != GLMParameters.Solver.L_BFGS */
{
applyStrongRulesMultinomial(lambdaNew, lambdaOld);
return;
}
int P = _dinfo.fullN();
_activeBC = _bc;
_activeData = _activeData != null ? _activeData : _dinfo;
_allIn = _allIn || _parms._alpha[0] * lambdaNew == 0 || _activeBC.hasBounds();
if (!_allIn) {
int newlySelected = 0;
final double rhs = Math.max(0, _alpha * (2 * lambdaNew - lambdaOld));
int[] newCols = MemoryManager.malloc4(P);
int j = 0;
int[] oldActiveCols = _activeData._activeCols == null ? new int[] { P } : _activeData.activeCols();
for (int i = 0; i < P; ++i) {
if (j < oldActiveCols.length && oldActiveCols[j] == i)
j++;
else if (_ginfo._gradient[i] > rhs || -_ginfo._gradient[i] > rhs)
newCols[newlySelected++] = i;
}
if (_parms._max_active_predictors != -1 && (oldActiveCols.length + newlySelected - 1) > _parms._max_active_predictors) {
Integer[] bigInts = ArrayUtils.toIntegers(newCols, 0, newlySelected);
Arrays.sort(bigInts, new Comparator<Integer>() {
@Override
public int compare(Integer o1, Integer o2) {
return (int) Math.signum(_ginfo._gradient[o2.intValue()] * _ginfo._gradient[o2.intValue()] - _ginfo._gradient[o1.intValue()] * _ginfo._gradient[o1.intValue()]);
}
});
newCols = ArrayUtils.toInt(bigInts, 0, _parms._max_active_predictors - oldActiveCols.length + 1);
Arrays.sort(newCols);
} else
newCols = Arrays.copyOf(newCols, newlySelected);
newCols = ArrayUtils.sortedMerge(oldActiveCols, newCols);
// merge already active columns in
int active = newCols.length;
_allIn = active == P;
if (!_allIn) {
int[] cols = newCols;
// intercept is always selected, even if it is false (it's gonna be dropped later, it is needed for other stuff too)
assert cols[active - 1] == P;
_beta = ArrayUtils.select(_beta, cols);
if (_u != null)
_u = ArrayUtils.select(_u, cols);
_activeData = _dinfo.filterExpandedColumns(cols);
assert _activeData.activeCols().length == _beta.length;
assert _u == null || _activeData.activeCols().length == _u.length;
_ginfo = new GLMGradientInfo(_ginfo._likelihood, _ginfo._objVal, ArrayUtils.select(_ginfo._gradient, cols));
_activeBC = _bc.filterExpandedColumns(_activeData.activeCols());
_gslvr = new GLMGradientSolver(_job, _parms, _activeData, (1 - _alpha) * _lambda, _bc);
assert _beta.length == cols.length;
return;
}
}
_activeData = _dinfo;
}
Aggregations