use of hex.gbm.GBM.GBMModel in project h2o-2 by h2oai.
the class GBMTest2 method testBalanceWithCrossValidation.
@Override
protected void testBalanceWithCrossValidation(String dataset, int response, int[] ignored_cols, int ntrees, int nfolds) {
Frame f = parseFrame(dataset);
GBMModel model = null;
GBM gbm = new GBM();
try {
Vec respVec = f.vec(response);
// Build a model
gbm.source = f;
gbm.response = respVec;
gbm.ignored_cols = ignored_cols;
gbm.classification = true;
gbm.ntrees = ntrees;
gbm.balance_classes = true;
gbm.n_folds = nfolds;
gbm.keep_cross_validation_splits = false;
gbm.invoke();
Assert.assertEquals("Number of cross validation model is wrond!", nfolds, gbm.xval_models.length);
model = UKV.get(gbm.dest());
//HEX-1817
Assert.assertTrue(model.get_params().state == Job.JobState.DONE);
} finally {
if (f != null)
f.delete();
if (model != null) {
if (gbm.xval_models != null) {
for (Key k : gbm.xval_models) {
Model m = UKV.get(k);
m.delete();
}
}
model.delete();
}
}
}
use of hex.gbm.GBM.GBMModel in project h2o-2 by h2oai.
the class GBMCheckpointTest method testCheckPointReconstruction.
private void testCheckPointReconstruction(String dataset, int response, boolean classification, int ntreesInPriorModel, int ntreesInANewModel) {
Frame f = parseFrame(dataset);
GBMModel model = null;
GBMModel modelFromCheckpoint = null;
GBMModel modelFinal = null;
try {
Vec respVec = f.vec(response);
// Build a model
GBMWithHooks gbm = new GBMWithHooks();
gbm.source = f;
gbm.response = respVec;
gbm.classification = classification;
gbm.ntrees = ntreesInPriorModel;
gbm.collectPoint = WhereToCollect.AFTER_BUILD;
gbm.score_each_iteration = true;
gbm.invoke();
model = UKV.get(gbm.dest());
// Build a checkpointed model
GBMWithHooks gbmFromCheckpoint = new GBMWithHooks();
gbmFromCheckpoint.source = f;
gbmFromCheckpoint.response = respVec;
gbmFromCheckpoint.classification = classification;
gbmFromCheckpoint.ntrees = ntreesInANewModel;
gbmFromCheckpoint.collectPoint = WhereToCollect.AFTER_RECONSTRUCTION;
gbmFromCheckpoint.checkpoint = gbm.dest();
gbmFromCheckpoint.score_each_iteration = true;
gbmFromCheckpoint.invoke();
modelFromCheckpoint = UKV.get(gbmFromCheckpoint.dest());
// Check if reconstructed frame computation data are same
assertArrayEquals("Tree data produced by drf run and reconstructed from a model do not match!", gbm.treesCols, gbmFromCheckpoint.treesCols);
// Build a model which contains old+new trees and compare prediction results
GBM gbmFinal = new GBM();
gbmFinal.source = f;
gbmFinal.response = respVec;
gbmFinal.classification = classification;
gbmFinal.ntrees = ntreesInANewModel + ntreesInPriorModel;
gbmFinal.score_each_iteration = true;
gbmFinal.invoke();
modelFinal = UKV.get(gbmFinal.dest());
assertTreeModelEquals(modelFinal, modelFromCheckpoint);
} finally {
if (f != null)
f.delete();
if (model != null)
model.delete();
if (modelFromCheckpoint != null)
modelFromCheckpoint.delete();
if (modelFinal != null)
modelFinal.delete();
}
}
use of hex.gbm.GBM.GBMModel in project h2o-2 by h2oai.
the class GBMTest method testReproducibility.
@Test
public void testReproducibility() {
Frame tfr = null;
final int N = 5;
double[] mses = new double[N];
Scope.enter();
try {
// Load data, hack frames
tfr = parseFrame(Key.make("air.hex"), "./smalldata/covtype/covtype.20k.data");
// rebalance to 256 chunks
Key dest = Key.make("df.rebalanced.hex");
RebalanceDataSet rb = new RebalanceDataSet(tfr, dest, 256);
H2O.submitTask(rb);
rb.join();
tfr.delete();
tfr = DKV.get(dest).get();
for (int i = 0; i < N; ++i) {
GBM parms = new GBM();
parms.source = tfr;
parms.response = tfr.lastVec();
parms.nbins = 1000;
parms.ntrees = 1;
parms.max_depth = 8;
parms.learn_rate = 0.1;
parms.min_rows = 10;
parms.family = Family.AUTO;
parms.classification = false;
// Build a first model; all remaining models should be equal
GBMModel gbm = parms.fork().get();
mses[i] = gbm.mse();
gbm.delete();
}
} finally {
if (tfr != null)
tfr.delete();
}
Scope.exit();
for (int i = 0; i < mses.length; ++i) {
Log.info("trial: " + i + " -> mse: " + mses[i]);
}
for (int i = 0; i < mses.length; ++i) {
assertEquals(mses[i], mses[0], 1e-15);
}
}
use of hex.gbm.GBM.GBMModel in project h2o-2 by h2oai.
the class ModelSerializationTest method testGBMModelMultinomial.
@Test
public void testGBMModelMultinomial() throws IOException {
GBMModel model = null, loadedModel = null;
try {
model = prepareGBMModel("smalldata/iris/iris.csv", EIA, 4, true, 5);
loadedModel = saveAndLoad(model);
// And compare
assertTreeModelEquals(model, loadedModel);
assertModelBinaryEquals(model, loadedModel);
} finally {
if (model != null)
model.delete();
if (loadedModel != null)
loadedModel.delete();
}
}