use of hex.glm.GLMModel.GLMParameters.Solver in project h2o-3 by h2oai.
the class GLM method defaultSolver.
private Solver defaultSolver() {
Solver s = Solver.IRLSM;
int max_active = 0;
if (_parms._family == Family.multinomial)
for (int c = 0; c < _nclass; ++c) max_active = Math.max(_state.activeDataMultinomial(c).fullN(), max_active);
else
max_active = _state.activeData().fullN();
if (// cutoff has to be somewhere
max_active >= 5000)
s = Solver.L_BFGS;
else if (_parms._lambda_search) {
// lambda search prefers coordinate descent
// l1 lambda search is better with coordinate descent!
s = Solver.COORDINATE_DESCENT;
} else if (_state.activeBC().hasBounds() && !_state.activeBC().hasProximalPenalty()) {
s = Solver.COORDINATE_DESCENT;
} else if (_parms._family == Family.multinomial && _parms._alpha[0] == 0)
// multinomial does better with lbfgs
s = Solver.L_BFGS;
else
Log.info(LogMsg("picked solver " + s));
if (s != Solver.L_BFGS && _parms._max_active_predictors == -1)
_parms._max_active_predictors = 5000;
_parms._solver = s;
return s;
}
use of hex.glm.GLMModel.GLMParameters.Solver in project h2o-3 by h2oai.
the class GLMBasicTestBinomial method testNoInterceptWithOffset.
@Test
public void testNoInterceptWithOffset() {
GLMModel model = null;
double[] offset_train = new double[] { -0.39771185, +1.20479170, -0.16374109, -0.97885903, -1.42996530, +0.83474893, +0.83474893, -0.74488827, +0.83474893, +0.86851236, +1.41589611, +1.41589611, -1.42996530, -0.39771185, -2.01111248, -0.39771185, -0.16374109, +0.62364452, -0.39771185, +0.60262749, -0.06143251, -1.42996530, -0.06143251, -0.06143251, +0.14967191, -0.06143251, -0.39771185, +0.14967191, +1.20479170, -0.39771185, -0.16374109, -0.06143251, -0.06143251, -1.42996530, -0.39771185, -0.39771185, -0.64257969, +1.65774729, -0.97885903, -0.39771185, -0.39771185, -0.39771185, -1.42996530, +1.41589611, -0.06143251, -0.06143251, -0.39771185, -0.06143251, -0.06143251, -0.39771185, -0.06143251, +0.14967191, -0.39771185, -1.42996530, -0.39771185, -0.64257969, -0.39771185, -0.06143251, -0.06143251, -0.06143251, -1.42996530, -2.01111248, -0.06143251, -0.39771185, -0.39771185, -1.42996530, -0.39771185, -1.42996530, -0.06143251, +1.41589611, +0.14967191, -1.42996530, -1.42996530, -0.06143251, -1.42996530, -1.42996530, -0.06143251, -1.42996530, -0.06143251, -0.39771185, -0.06143251, -1.42996530, -0.06143251, -0.39771185, -1.42996530, -0.06143251, -0.06143251, -0.06143251, -1.42996530, -0.39771185, -1.42996530, -0.43147527, -0.39771185, -0.39771185, -0.39771185, -1.42996530, -1.42996530, -0.43147527, -0.39771185, -0.39771185, -0.39771185, -0.39771185, -1.42996530, -1.42996530, -1.42996530, -0.39771185, +0.14967191, +1.41589611, -1.42996530, +1.41589611, -1.42996530, +1.41589611, -0.06143251, +0.14967191, -0.39771185, -0.97885903, -1.42996530, -0.39771185, -0.39771185, -0.39771185, -0.39771185, -1.42996530, -0.39771185, -0.97885903, -0.06143251, -0.06143251, +0.86851236, -0.39771185, -0.39771185, -0.06143251, -0.39771185, -0.39771185, -0.06143251, +0.14967191, -1.42996530, -1.42996530, -0.39771185, +1.20479170, -1.42996530, -0.39771185, -0.06143251, -1.42996530, -0.97885903, +0.14967191, +0.14967191, -1.42996530, -1.42996530, -0.39771185, -0.06143251, -0.43147527, -0.06143251, -0.39771185, -1.42996530, -0.06143251, -0.39771185, -0.39771185, -1.42996530, -0.39771185, -0.39771185, -0.06143251, -0.39771185, -0.39771185, +0.14967191, -0.06143251, +1.41589611, -0.06143251, -0.39771185, -0.39771185, -0.06143251, -1.42996530, -0.06143251, -1.42996530, -0.39771185, -0.64257969, -0.06143251, +1.20479170, -0.43147527, -0.97885903, -0.39771185, -0.39771185, -0.39771185, +0.14967191, -2.01111248, -1.42996530, -0.06143251, +0.83474893, -1.42996530, -1.42996530, -2.01111248, -1.42996530, -0.06143251, +0.86851236, +0.05524374, -0.39771185, -0.39771185, -0.39771185, +1.41589611, -1.42996530, -0.39771185, -1.42996530, -0.39771185, -0.39771185, -0.06143251, +0.14967191, -1.42996530, -0.39771185, -1.42996530, -1.42996530, -0.39771185, -0.39771185, -0.06143251, -1.42996530, -0.97885903, -1.42996530, -0.39771185, -0.06143251, -0.39771185, -0.06143251, -1.42996530, -1.42996530, -0.06143251, -1.42996530, -0.39771185, +0.14967191, -0.06143251, -1.42996530, -1.42996530, +0.14967191, -0.39771185, -0.39771185, -1.42996530, -0.06143251, -0.06143251, -1.42996530, -0.06143251, -1.42996530, +0.14967191, +1.20479170, -1.42996530, -0.06143251, -0.39771185, -0.39771185, -0.06143251, +0.14967191, -0.06143251, -1.42996530, -1.42996530, -1.42996530, -0.39771185, -0.39771185, -0.39771185, +0.86851236, -0.06143251, -0.97885903, -0.06143251, -0.64257969, +0.14967191, +0.86851236, -0.39771185, -0.39771185, -0.39771185, -0.64257969, -1.42996530, -0.06143251, -0.39771185, -0.39771185, -1.42996530, -1.42996530, -0.06143251, +0.14967191, -0.06143251, +0.86851236, -0.97885903, -1.42996530, -1.42996530, -1.42996530, -1.42996530, +0.86851236, +0.14967191, -1.42996530, -0.97885903, -1.42996530, -1.42996530, -0.06143251, +0.14967191, -1.42996530, -0.64257969, -2.01111248, -0.97885903, -0.39771185 };
double[] offset_test = new double[] { +1.65774729, -0.97700971, -0.97700971, -0.97700971, +0.05524374, +0.05524374, +0.05524374, +0.05524374, +0.39152308, +0.39152308, +0.39152308, +0.05524374, +0.05524374, +0.05524374, +0.39152308, -0.97700971, +0.05524374, +1.32146795, +0.39152308, +1.65774729, -0.97700971, +1.65774729, +0.39152308, +0.39152308, +1.65774729, +0.60262749, +0.05524374, +0.05524374, +0.05524374, +0.60262749, +0.05524374, -0.97700971, -0.97885903, +0.05524374, -2.01111248, -0.97700971, +0.05524374, +0.39152308, +0.05524374, +0.60262749, +0.60262749, +0.39152308, +0.60262749, -0.97700971, +0.39152308, +1.65774729, +0.39152308, +0.39152308, +0.05524374, +1.86885170, +0.05524374, -0.97700971, +0.60262749, -0.97700971, +0.60262749, -0.97700971, +0.39152308, -0.97700971, -0.43147527, +1.32146795, +0.05524374, +0.05524374, +0.39152308, +0.39152308, +0.05524374, +0.39152308, -0.97700971, +0.05524374, +0.39152308, +0.05524374, +0.60262749, +1.86885170, +0.05524374, +0.05524374, +1.86885170, +0.60262749, -0.64257969, -0.97700971, +0.60262749, +0.39152308, -0.97700971, -0.97700971, +0.05524374, -0.97700971, -0.97700971, +0.05524374, +0.05524374, +0.60262749, +0.05524374, +0.05524374 };
double[] pred_test = new double[] { +0.88475366, +0.23100271, +0.40966315, +0.08957188, +0.47333302, +0.44622513, +0.56450046, +0.74271010, +0.45129280, +0.72359111, +0.67918401, +0.19882802, +0.42330391, +0.62734862, +0.38055506, +0.47286476, +0.40180469, +0.97907526, +0.61428344, +0.97109299, +0.30489181, +0.81303545, +0.36130639, +0.65434899, +0.98863675, +0.58301866, +0.37950467, +0.53679205, +0.30636941, +0.70320372, +0.45303278, +0.35011042, +0.78165074, +0.44915160, +0.09008065, +0.16789833, +0.45748862, +0.59328118, +0.75002334, +0.35170410, +0.57550279, +0.42038237, +0.76349569, +0.28883753, +0.84824847, +0.72396381, +0.56782477, +0.54078190, +0.51169047, +0.80828547, +0.52001699, +0.26202346, +0.81014557, +0.29986016, +0.62011569, +0.33034872, +0.62284802, +0.28303618, +0.38470707, +0.96444405, +0.36155179, +0.46368503, +0.65192144, +0.43597041, +0.30906461, +0.69259415, +0.21819579, +0.49998652, +0.57162728, +0.44255738, +0.80820564, +0.90616782, +0.49377901, +0.34235025, +0.99621673, +0.65768252, +0.43909050, +0.23205826, +0.71124897, +0.42908417, +0.47880901, +0.29185818, +0.42648317, +0.01247279, +0.18372518, +0.27281535, +0.63807876, +0.44563524, +0.32821696, +0.43636099 };
Vec offsetVecTrain = _prostateTrain.anyVec().makeZero();
try (Vec.Writer vw = offsetVecTrain.open()) {
for (int i = 0; i < offset_train.length; ++i) vw.set(i, offset_train[i]);
}
Vec offsetVecTest = _prostateTest.anyVec().makeZero();
try (Vec.Writer vw = offsetVecTest.open()) {
for (int i = 0; i < offset_test.length; ++i) vw.set(i, offset_test[i]);
}
Key fKeyTrain = Key.make("prostate_with_offset_train");
Key fKeyTest = Key.make("prostate_with_offset_test");
Frame fTrain = new Frame(fKeyTrain, new String[] { "offset" }, new Vec[] { offsetVecTrain });
fTrain.add(_prostateTrain.names(), _prostateTrain.vecs());
DKV.put(fKeyTrain, fTrain);
Frame fTest = new Frame(fKeyTest, new String[] { "offset" }, new Vec[] { offsetVecTest });
fTest.add(_prostateTest.names(), _prostateTest.vecs());
DKV.put(fKeyTest, fTest);
// Call: glm(formula = CAPSULE ~ . - ID - RACE - DCAPS - DPROS - 1, family = binomial,
// data = train, offset = offset_train)
//
// Coefficients:
// AGE PSA VOL GLEASON
// -0.054102 0.027517 -0.008937 0.516363
//
// Degrees of Freedom: 290 Total (i.e. Null); 286 Residual
// Null Deviance: 355.7
// Residual Deviance: 313 AIC: 321
String[] cfs1 = new String[] { "Intercept", "AGE", "PSA", "VOL", "GLEASON" };
double[] vals = new double[] { 0, -0.054102, 0.027517, -0.008937, 0.516363 };
GLMParameters params = new GLMParameters(Family.binomial);
params._response_column = "CAPSULE";
params._ignored_columns = new String[] { "ID", "RACE", "DPROS", "DCAPS" };
params._train = fKeyTrain;
params._valid = fKeyTest;
params._offset_column = "offset";
params._lambda = new double[] { 0 };
params._alpha = new double[] { 0 };
params._standardize = false;
params._objective_epsilon = 0;
params._gradient_epsilon = 1e-6;
// not expected to reach max iterations here
params._max_iterations = 100;
params._intercept = false;
params._beta_epsilon = 1e-6;
params._missing_values_handling = MissingValuesHandling.Skip;
try {
for (Solver s : new Solver[] { Solver.AUTO, Solver.IRLSM, Solver.L_BFGS, Solver.COORDINATE_DESCENT }) {
Frame scoreTrain = null, scoreTest = null;
try {
params._solver = s;
System.out.println("SOLVER = " + s);
model = new GLM(params).trainModel().get();
HashMap<String, Double> coefs = model.coefficients();
System.out.println("coefs = " + coefs);
boolean CD = s == Solver.COORDINATE_DESCENT;
for (int i = 0; i < cfs1.length; ++i) assertEquals(vals[i], coefs.get(cfs1[i]), CD ? 1e-2 : 1e-4);
assertEquals(355.7, GLMTest.nullDeviance(model), 1e-1);
assertEquals(313.0, GLMTest.residualDeviance(model), 1e-1);
assertEquals(290, GLMTest.nullDOF(model), 0);
assertEquals(286, GLMTest.resDOF(model), 0);
assertEquals(321, GLMTest.aic(model), 1e-1);
assertEquals(88.72363, GLMTest.residualDevianceTest(model), CD ? 1e-2 : 1e-4);
// test scoring
try {
scoreTrain = model.score(_prostateTrain);
assertTrue("shoul've thrown IAE", false);
} catch (IllegalArgumentException iae) {
assertTrue(iae.getMessage().contains("Test/Validation dataset is missing offset column"));
}
hex.ModelMetricsBinomialGLM mmTrain = (ModelMetricsBinomialGLM) hex.ModelMetricsBinomial.getFromDKV(model, fTrain);
hex.AUC2 adata = mmTrain._auc;
assertEquals(model._output._training_metrics.auc_obj()._auc, adata._auc, 1e-8);
assertEquals(model._output._training_metrics._MSE, mmTrain._MSE, 1e-8);
assertEquals(((ModelMetricsBinomialGLM) model._output._training_metrics)._resDev, mmTrain._resDev, 1e-8);
scoreTrain = model.score(fTrain);
mmTrain = (ModelMetricsBinomialGLM) hex.ModelMetricsBinomial.getFromDKV(model, fTrain);
adata = mmTrain._auc;
assertEquals(model._output._training_metrics.auc_obj()._auc, adata._auc, 1e-8);
assertEquals(model._output._training_metrics._MSE, mmTrain._MSE, 1e-8);
assertEquals(((ModelMetricsBinomialGLM) model._output._training_metrics)._resDev, mmTrain._resDev, 1e-8);
scoreTest = model.score(fTest);
ModelMetricsBinomialGLM mmTest = (ModelMetricsBinomialGLM) hex.ModelMetricsBinomial.getFromDKV(model, fTest);
adata = mmTest._auc;
assertEquals(model._output._validation_metrics.auc_obj()._auc, adata._auc, 1e-8);
assertEquals(model._output._validation_metrics._MSE, mmTest._MSE, 1e-8);
assertEquals(((ModelMetricsBinomialGLM) model._output._validation_metrics)._resDev, mmTest._resDev, 1e-8);
GLMTest.testScoring(model, fTest);
// test the actual predictions
Vec.Reader preds = scoreTest.vec("p1").new Reader();
for (int i = 0; i < pred_test.length; ++i) // s == Solver.COORDINATE_DESCENT_NAIVE
assertEquals(pred_test[i], preds.at(i), CD ? 1e-3 : 1e-6);
} finally {
if (model != null)
model.delete();
if (scoreTrain != null)
scoreTrain.delete();
if (scoreTest != null)
scoreTest.delete();
}
}
} finally {
if (fTrain != null) {
fTrain.remove("offset").remove();
DKV.remove(fTrain._key);
}
if (fTest != null) {
fTest.remove("offset").remove();
DKV.remove(fTest._key);
}
}
}
use of hex.glm.GLMModel.GLMParameters.Solver in project h2o-3 by h2oai.
the class GLMBasicTestRegression method testPoissonWithOffset.
@Test
public void testPoissonWithOffset() {
GLMModel model = null;
Frame scoreTrain = null;
// Call: glm(formula = formula, family = poisson, data = D)
//
// Coefficients:
// (Intercept) Merit1 Merit2 Merit3 Class2 Class3 Class4 Class5
// -2.0357 -0.1378 -0.2207 -0.4930 0.2998 0.4691 0.5259 0.2156
//
// Degrees of Freedom: 19 Total (i.e. Null); 12 Residual
// Null Deviance: 33850
// Residual Deviance: 579.5 AIC: 805.9
String[] cfs1 = new String[] { "Intercept", "Merit.1", "Merit.2", "Merit.3", "Class.2", "Class.3", "Class.4", "Class.5" };
double[] vals = new double[] { -2.0357, -0.1378, -0.2207, -0.4930, 0.2998, 0.4691, 0.5259, 0.2156 };
GLMParameters parms = new GLMParameters(Family.poisson);
parms._train = _canCarTrain._key;
parms._ignored_columns = new String[] { "Insured", "Premium", "Cost" };
// "response_column":"Claims","offset_column":"logInsured"
parms._response_column = "Claims";
parms._offset_column = "logInsured";
parms._standardize = false;
parms._lambda = new double[] { 0 };
parms._alpha = new double[] { 0 };
parms._objective_epsilon = 0;
parms._beta_epsilon = 1e-6;
parms._gradient_epsilon = 1e-10;
parms._max_iterations = 1000;
for (Solver s : GLMParameters.Solver.values()) {
// skip for now, does not handle zero columns (introduced by extra missing bucket with no missing in the dataset)
if (s == Solver.COORDINATE_DESCENT_NAIVE)
continue;
try {
parms._solver = s;
model = new GLM(parms).trainModel().get();
HashMap<String, Double> coefs = model.coefficients();
System.out.println("coefs = " + coefs);
for (int i = 0; i < cfs1.length; ++i) assertEquals(vals[i], coefs.get(cfs1[i]), 1e-4);
assertEquals(33850, GLMTest.nullDeviance(model), 5);
assertEquals(579.5, GLMTest.residualDeviance(model), 1e-4 * 579.5);
assertEquals(19, GLMTest.nullDOF(model), 0);
assertEquals(12, GLMTest.resDOF(model), 0);
assertEquals(805.9, GLMTest.aic(model), 1e-4 * 805.9);
// test scoring
try {
Frame fr = new Frame(_canCarTrain.names(), _canCarTrain.vecs());
fr.remove(parms._offset_column);
scoreTrain = model.score(fr);
assertTrue("shoul've thrown IAE", false);
} catch (IllegalArgumentException iae) {
assertTrue(iae.getMessage().contains("Test/Validation dataset is missing offset column"));
}
scoreTrain = model.score(_canCarTrain);
hex.ModelMetricsRegressionGLM mmTrain = (ModelMetricsRegressionGLM) hex.ModelMetricsRegression.getFromDKV(model, _canCarTrain);
assertEquals(model._output._training_metrics._MSE, mmTrain._MSE, 1e-8);
assertEquals(GLMTest.residualDeviance(model), mmTrain._resDev, 1e-8);
assertEquals(GLMTest.nullDeviance(model), mmTrain._nullDev, 1e-8);
} finally {
if (model != null)
model.delete();
if (scoreTrain != null)
scoreTrain.delete();
}
}
}
use of hex.glm.GLMModel.GLMParameters.Solver in project h2o-3 by h2oai.
the class GLMBasicTestRegression method testTweedie.
@Test
public void testTweedie() {
GLMModel model = null;
Frame scoreTrain = null;
// -------------------------------------- R examples output ----------------------------------------------------------------
// Call: glm(formula = Infections ~ ., family = tweedie(0), data = D)
//
// Coefficients:
// (Intercept) SwimmerOccas LocationNonBeach Age20-24 Age25-29 SexMale
// 0.8910 0.8221 0.7266 -0.5033 -0.2679 -0.1056
//
// Degrees of Freedom: 286 Total (i.e. Null); 281 Residual
// Null Deviance: 1564
// Residual Deviance: 1469 AIC: NA
// Call: glm(formula = Infections ~ ., family = tweedie(1), data = D)
//
// Coefficients:
// (Intercept) SwimmerOccas LocationNonBeach Age20-24 Age25-29 SexMale
// -0.12261 0.61149 0.53454 -0.37442 -0.18973 -0.08985
//
// Degrees of Freedom: 286 Total (i.e. Null); 281 Residual
// Null Deviance: 824.5
// Residual Deviance: 755.4 AIC: NA
// Call: glm(formula = Infections ~ ., family = tweedie(1.25), data = D)
//
// Coefficients:
// (Intercept) SwimmerOccas LocationNonBeach Age20-24 Age25-29 SexMale
// 1.02964 -0.14079 -0.12200 0.08502 0.04269 0.02105
//
// Degrees of Freedom: 286 Total (i.e. Null); 281 Residual
// Null Deviance: 834.2
// Residual Deviance: 770.8 AIC: NA
// Call: glm(formula = Infections ~ ., family = tweedie(1.5), data = D)
//
// Coefficients:
// (Intercept) SwimmerOccas LocationNonBeach Age20-24 Age25-29 SexMale
// 1.05665 -0.25891 -0.22185 0.15325 0.07624 0.03908
//
// Degrees of Freedom: 286 Total (i.e. Null); 281 Residual
// Null Deviance: 967
// Residual Deviance: 908.9 AIC: NA
// Call: glm(formula = Infections ~ ., family = tweedie(1.75), data = D)
//
// Coefficients:
// (Intercept) SwimmerOccas LocationNonBeach Age20-24 Age25-29 SexMale
// 1.08076 -0.35690 -0.30154 0.20556 0.10122 0.05375
//
// Degrees of Freedom: 286 Total (i.e. Null); 281 Residual
// Null Deviance: 1518
// Residual Deviance: 1465 AIC: NA
// Call: glm(formula = Infections ~ ., family = tweedie(2), data = D)
//
// Coefficients:
// (Intercept) SwimmerOccas LocationNonBeach Age20-24 Age25-29 SexMale
// 1.10230 -0.43751 -0.36337 0.24318 0.11830 0.06467
//
// Degrees of Freedom: 286 Total (i.e. Null); 281 Residual
// Null Deviance: 964.4
// Residual Deviance: 915.7 AIC: NA
// ---------------------------------------------------------------------------------------------------------------------------
String[] cfs1 = new String[] { "Intercept", "Swimmer.Occas", "Location.NonBeach", "Age.20-24", "Age.25-29", "Sex.Male" };
double[][] vals = new double[][] { { 0.89100, 0.82210, 0.72660, -0.50330, -0.26790, -0.10560 }, { -0.12261, 0.61149, 0.53454, -0.37442, -0.18973, -0.08985 }, { 1.02964, -0.14079, -0.12200, 0.08502, 0.04269, 0.02105 }, { 1.05665, -0.25891, -0.22185, 0.15325, 0.07624, 0.03908 }, { 1.08076, -0.35690, -0.30154, 0.20556, 0.10122, 0.05375 }, { 1.10230, -0.43751, -0.36337, 0.24318, 0.11830, 0.06467 } };
int dof = 286, res_dof = 281;
double[] nullDev = new double[] { 1564, 824.5, 834.2, 967.0, 1518, 964.4 };
double[] resDev = new double[] { 1469, 755.4, 770.8, 908.9, 1465, 915.7 };
double[] varPow = new double[] { 0, 1.0, 1.25, 1.5, 1.75, 2.0 };
GLMParameters parms = new GLMParameters(Family.tweedie);
parms._train = _earinf._key;
parms._ignored_columns = new String[] {};
// "response_column":"Claims","offset_column":"logInsured"
parms._response_column = "Infections";
parms._standardize = false;
parms._lambda = new double[] { 0 };
parms._alpha = new double[] { 0 };
parms._gradient_epsilon = 1e-10;
parms._max_iterations = 1000;
parms._objective_epsilon = 0;
parms._beta_epsilon = 1e-6;
for (int x = 0; x < varPow.length; ++x) {
double p = varPow[x];
parms._tweedie_variance_power = p;
parms._tweedie_link_power = 1 - p;
for (Solver s : /*new Solver[]{Solver.IRLSM}*/
GLMParameters.Solver.values()) {
// ignore for now, has trouble with zero columns
if (s == Solver.COORDINATE_DESCENT_NAIVE)
continue;
try {
parms._solver = s;
model = new GLM(parms).trainModel().get();
HashMap<String, Double> coefs = model.coefficients();
System.out.println("coefs = " + coefs);
for (int i = 0; i < cfs1.length; ++i) assertEquals(vals[x][i], coefs.get(cfs1[i]), 1e-4);
assertEquals(nullDev[x], (GLMTest.nullDeviance(model)), 5e-4 * nullDev[x]);
assertEquals(resDev[x], (GLMTest.residualDeviance(model)), 5e-4 * resDev[x]);
assertEquals(dof, GLMTest.nullDOF(model), 0);
assertEquals(res_dof, GLMTest.resDOF(model), 0);
// test scoring
scoreTrain = model.score(_earinf);
assertTrue(model.testJavaScoring(_earinf, scoreTrain, 1e-8));
hex.ModelMetricsRegressionGLM mmTrain = (ModelMetricsRegressionGLM) hex.ModelMetricsRegression.getFromDKV(model, _earinf);
assertEquals(model._output._training_metrics._MSE, mmTrain._MSE, 1e-8);
assertEquals(GLMTest.residualDeviance(model), mmTrain._resDev, 1e-8);
assertEquals(GLMTest.nullDeviance(model), mmTrain._nullDev, 1e-8);
} finally {
if (model != null)
model.delete();
if (scoreTrain != null)
scoreTrain.delete();
}
}
}
}
use of hex.glm.GLMModel.GLMParameters.Solver in project h2o-3 by h2oai.
the class GLMTest method testArcene.
/**
* Test strong rules on arcene datasets (10k predictors, 100 rows).
* Should be able to obtain good model (~100 predictors, ~1 explained deviance) with up to 250 active predictors.
* Scaled down (higher lambda min, fewer lambdas) to run at reasonable speed (whole test takes 20s on my laptop).
*
* Test runs glm with gaussian on arcene dataset and verifies it gets all lambda while limiting maximum actove predictors to reasonably small number.
* Compares the objective value to expected one.
*/
@Test
public void testArcene() throws InterruptedException, ExecutionException {
Key parsed = Key.make("arcene_parsed");
Key<GLMModel> modelKey = Key.make("arcene_model");
GLMModel model = null;
Frame fr = parse_test_file(parsed, "smalldata/glm_test/arcene.csv");
try {
Scope.enter();
// test LBFGS with l1 pen
GLMParameters params = new GLMParameters(Family.gaussian);
// params._response = 0;
params._lambda = null;
params._response_column = fr._names[0];
params._train = parsed;
params._lambda_search = true;
params._nlambdas = 35;
params._lambda_min_ratio = 0.18;
params._max_iterations = 100000;
params._max_active_predictors = 10000;
params._alpha = new double[] { 1 };
for (Solver s : new Solver[] { Solver.IRLSM, Solver.COORDINATE_DESCENT }) {
//Solver.COORDINATE_DESCENT,}) { // LBFGS lambda-search is too slow now
params._solver = s;
GLM glm = new GLM(params, modelKey);
glm.trainModel().get();
model = DKV.get(modelKey).get();
System.out.println(model._output._model_summary);
// assert on that we got all submodels (if strong rules work, we should be able to get the results with this many active predictors)
assertEquals(params._nlambdas, model._output._submodels.length);
System.out.println(model._output._training_metrics);
// assert on the quality of the result, technically should compare objective value, but this should be good enough for now
}
model.delete();
params._solver = Solver.COORDINATE_DESCENT;
params._max_active_predictors = 100;
params._lambda_min_ratio = 1e-2;
params._nlambdas = 100;
GLM glm = new GLM(params, modelKey);
glm.trainModel().get();
model = DKV.get(modelKey).get();
assertTrue(model._output.rank() <= params._max_active_predictors);
// System.out.println("============================================================================================================");
System.out.println(model._output._model_summary);
// assert on that we got all submodels (if strong rules work, we should be able to get the results with this many active predictors)
System.out.println(model._output._training_metrics);
System.out.println("============================================================================================================");
model.delete();
params._max_active_predictors = 250;
params._lambda = null;
params._lambda_search = false;
glm = new GLM(params, modelKey);
glm.trainModel().get();
model = DKV.get(modelKey).get();
assertTrue(model._output.rank() <= params._max_active_predictors);
// System.out.println("============================================================================================================");
System.out.println(model._output._model_summary);
// assert on that we got all submodels (if strong rules work, we should be able to get the results with this many active predictors)
System.out.println(model._output._training_metrics);
System.out.println("============================================================================================================");
model.delete();
} finally {
fr.delete();
if (model != null)
model.delete();
Scope.exit();
}
}
Aggregations