use of hex.gbm.GBM.GBMModel in project h2o-2 by h2oai.
the class ModelSerializationTest method testGBMModelBinomial.
@Test
public void testGBMModelBinomial() throws IOException {
GBMModel model = null, loadedModel = null;
try {
model = prepareGBMModel("smalldata/logreg/prostate.csv", ari(0), 1, true, 5);
loadedModel = saveAndLoad(model);
// And compare
assertTreeModelEquals(model, loadedModel);
assertModelBinaryEquals(model, loadedModel);
} finally {
if (model != null)
model.delete();
if (loadedModel != null)
loadedModel.delete();
}
}
use of hex.gbm.GBM.GBMModel in project h2o-2 by h2oai.
the class GBMTest method testGBMTrainTest.
// Test-on-Train. Slow test, needed to build a good model.
@Test
public void testGBMTrainTest() {
File file1 = TestUtil.find_test_file("smalldata/gbm_test/ecology_model.csv");
// Silently ignore if file not found
if (file1 == null)
return;
Key fkey1 = NFSFileVec.make(file1);
Key dest1 = Key.make("train.hex");
File file2 = TestUtil.find_test_file("smalldata/gbm_test/ecology_eval.csv");
Key fkey2 = NFSFileVec.make(file2);
Key dest2 = Key.make("test.hex");
// The Builder
GBM gbm = new GBM();
// The Model
GBM.GBMModel gbmmodel = null;
Frame ftest = null, fpreds = null;
try {
Frame fr = ParseDataset2.parse(dest1, new Key[] { fkey1 });
// Remove unique ID; too predictive
UKV.remove(fr.remove("Site")._key);
// Train on the outcome
gbm.response = fr.vecs()[fr.find("Angaus")];
gbm.source = fr;
gbm.ntrees = 5;
gbm.max_depth = 10;
gbm.learn_rate = 0.2f;
gbm.min_rows = 10;
gbm.nbins = 100;
gbm.invoke();
gbmmodel = UKV.get(gbm.dest());
testHTML(gbmmodel);
//HEX-1817
Assert.assertTrue(gbmmodel.get_params().state == Job.JobState.DONE);
// Test on the train data
ftest = ParseDataset2.parse(dest2, new Key[] { fkey2 });
fpreds = gbm.score(ftest);
// Build a confusion matrix
ConfusionMatrix CM = new ConfusionMatrix();
CM.actual = ftest;
CM.vactual = ftest.vecs()[ftest.find("Angaus")];
CM.predict = fpreds;
CM.vpredict = fpreds.vecs()[fpreds.find("predict")];
// Start it, do it
CM.invoke();
StringBuilder sb = new StringBuilder();
CM.toASCII(sb);
System.out.println(sb);
} finally {
// Remove the original hex frame key
gbm.source.delete();
if (ftest != null)
ftest.delete();
if (fpreds != null)
fpreds.delete();
// Remove the model
if (gbmmodel != null)
gbmmodel.delete();
UKV.remove(gbm.response._key);
// Remove GBM Job
gbm.remove();
}
}
use of hex.gbm.GBM.GBMModel in project h2o-2 by h2oai.
the class GBMTest method testModelAdapt.
// Adapt a trained model to a test dataset with different enums
@Test
public void testModelAdapt() {
File file1 = TestUtil.find_test_file("./smalldata/kaggle/KDDTrain.arff.gz");
Key fkey1 = NFSFileVec.make(file1);
Key dest1 = Key.make("KDDTrain.hex");
File file2 = TestUtil.find_test_file("./smalldata/kaggle/KDDTest.arff.gz");
Key fkey2 = NFSFileVec.make(file2);
Key dest2 = Key.make("KDDTest.hex");
GBM gbm = new GBM();
// The Model
GBM.GBMModel gbmmodel = null;
try {
gbm.source = ParseDataset2.parse(dest1, new Key[] { fkey1 });
// Response is col 41
gbm.response = gbm.source.vecs()[41];
gbm.ntrees = 2;
gbm.max_depth = 8;
gbm.learn_rate = 0.2f;
gbm.min_rows = 10;
gbm.nbins = 50;
gbm.invoke();
gbmmodel = UKV.get(gbm.dest());
testHTML(gbmmodel);
//HEX-1817
Assert.assertTrue(gbmmodel.get_params().state == Job.JobState.DONE);
// The test data set has a few more enums than the train
Frame ftest = ParseDataset2.parse(dest2, new Key[] { fkey2 });
Frame preds = gbm.score(ftest);
ftest.delete();
preds.delete();
} finally {
// Remove the model
if (gbmmodel != null)
gbmmodel.delete();
// Remove original hex frame key
gbm.source.delete();
UKV.remove(gbm.response._key);
// Remove GBM Job
gbm.remove();
}
}
use of hex.gbm.GBM.GBMModel in project h2o-2 by h2oai.
the class GBMTest method basicGBM.
public GBMModel basicGBM(String fname, String hexname, PrepData prep, boolean validation, Family family) {
File file = TestUtil.find_test_file(fname);
// Silently abort test if the file is missing
if (file == null)
return null;
Key fkey = NFSFileVec.make(file);
Key dest = Key.make(hexname);
// The Builder
GBM gbm = new GBM();
// The Model
GBM.GBMModel gbmmodel = null;
try {
Frame fr = gbm.source = ParseDataset2.parse(dest, new Key[] { fkey });
UKV.remove(fkey);
int idx = prep.prep(fr);
if (idx < 0) {
gbm.classification = false;
idx = ~idx;
}
gbm.response = fr.vecs()[idx];
gbm.family = family;
assert gbm.family != Family.bernoulli || gbm.classification;
gbm.ntrees = 4;
gbm.max_depth = 4;
gbm.min_rows = 1;
gbm.nbins = 50;
gbm.cols = new int[fr.numCols()];
for (int i = 0; i < gbm.cols.length; i++) gbm.cols[i] = i;
gbm.validation = validation ? new Frame(gbm.source) : null;
gbm.learn_rate = .2f;
gbm.score_each_iteration = true;
gbm.invoke();
gbmmodel = UKV.get(gbm.dest());
testHTML(gbmmodel);
//HEX-1817
Assert.assertTrue(gbmmodel.get_params().state == Job.JobState.DONE);
//System.out.println(gbmmodel.toJava());
Frame preds = gbm.score(gbm.source);
preds.delete();
return gbmmodel;
} finally {
// Remove original hex frame key
gbm.source.delete();
// Remove validation dataset if specified
if (gbm.validation != null)
gbm.validation.delete();
// Remove the model
if (gbmmodel != null)
gbmmodel.delete();
// Remove GBM Job
gbm.remove();
}
}
use of hex.gbm.GBM.GBMModel in project h2o-2 by h2oai.
the class GBMTest method testGBMRegression.
@Test
public void testGBMRegression() {
File file = TestUtil.find_test_file("./smalldata/gbm_test/Mfgdata_gaussian_GBM_testing.csv");
Key fkey = NFSFileVec.make(file);
Key dest = Key.make("mfg.hex");
// The Builder
GBM gbm = new GBM();
// The Model
GBM.GBMModel gbmmodel = null;
try {
Frame fr = gbm.source = ParseDataset2.parse(dest, new Key[] { fkey });
UKV.remove(fkey);
// Regression
gbm.classification = false;
gbm.family = GBM.Family.AUTO;
// Row in col 0, dependent in col 1, predictor in col 2
gbm.response = fr.vecs()[1];
gbm.ntrees = 1;
gbm.max_depth = 1;
gbm.min_rows = 1;
gbm.nbins = 20;
// Just column 2
gbm.cols = new int[] { 2 };
gbm.validation = null;
gbm.learn_rate = 1.0f;
gbm.score_each_iteration = true;
gbm.invoke();
gbmmodel = UKV.get(gbm.dest());
//HEX-1817
Assert.assertTrue(gbmmodel.get_params().state == Job.JobState.DONE);
Frame preds = gbm.score(gbm.source);
double sq_err = new CompErr().doAll(gbm.response, preds.vecs()[0])._sum;
double mse = sq_err / preds.numRows();
assertEquals(79152.1233, mse, 0.1);
preds.delete();
} finally {
// Remove original hex frame key
gbm.source.delete();
// Remove validation dataset if specified
if (gbm.validation != null)
gbm.validation.delete();
// Remove the model
if (gbmmodel != null)
gbmmodel.delete();
// Remove GBM Job
gbm.remove();
}
}