use of hex.gbm.GBM in project h2o-2 by h2oai.
the class Runner method main.
// Do the Work
static void main(OptArgs ARGS) {
// Finish building the cluster
TestUtil.stall_till_cloudsize(ARGS.clusterSize);
// Sanity check basic args
if (ARGS.ntrees <= 0 || ARGS.ntrees > 100000)
throw new RuntimeException("ntrees " + ARGS.ntrees + " out of bounds");
if (ARGS.sample < 0 || ARGS.sample > 1.0f)
throw new RuntimeException("sample " + ARGS.sample + " out of bounds");
if (ARGS.learn < 0 || ARGS.learn > 1.0f)
throw new RuntimeException("learn " + ARGS.learn + " out of bounds");
if (ARGS.nbins < 2 || ARGS.nbins > 100000)
throw new RuntimeException("nbins " + ARGS.nbins + " out of bounds");
if (ARGS.depth <= 0)
throw new RuntimeException("depth " + ARGS.depth + " out of bounds");
if (ARGS.splitTestTrain < 0 || ARGS.splitTestTrain > 1.0f)
throw new RuntimeException("splitTestTrain " + ARGS.splitTestTrain + " out of bounds");
// If trainFile is NOT set, you are doing the default file and cannot set testFile.
if ((ARGS.trainFile == OptArgs.defaultTrainFile) && (ARGS.testFile != OptArgs.defaultTestFile))
throw new RuntimeException("Cannot set test file unless also setting train file");
// If testFile is set, cannot set splitTestTrain
if ((ARGS.testFile != OptArgs.defaultTestFile) && !Float.isNaN(ARGS.splitTestTrain))
throw new RuntimeException("Cannot have both testFile and splitTestTrain");
Sys sys = ARGS.gbm ? Sys.GBM__ : Sys.DRF__;
String[] cs = (ARGS.cols + "," + ARGS.response).split("[,\t]");
// Set mtries
if (ARGS.mtries == 0)
ARGS.mtries = (int) Math.sqrt(cs.length);
if (ARGS.mtries <= 0 || ARGS.mtries > cs.length)
throw new RuntimeException("mtries " + ARGS.mtries + " out of bounds");
// Load data
Timer t_load = new Timer();
Key trainkey = Key.make("train.hex");
Key testkey = Key.make("test.hex");
Frame train = TestUtil.parseFrame(trainkey, ARGS.trainFile);
Frame test = null;
if (!Float.isNaN(ARGS.splitTestTrain)) {
water.exec.Exec2.exec("r=runif(train.hex,-1); test.hex=train.hex[r>=0.7,]; train.hex=train.hex[r<0.7,]").remove_and_unlock();
train = UKV.get(trainkey);
test = UKV.get(testkey);
} else if (ARGS.testFile.length() != 0) {
test = TestUtil.parseFrame(testkey, ARGS.testFile);
}
Log.info(sys, "Data loaded in " + t_load);
// Pull out the response vector from the train data
Vec response = train.subframe(new String[] { ARGS.response }).vecs()[0];
// Build a Frame with just the requested columns.
train = train.subframe(cs);
if (test != null)
test = test.subframe(cs);
Vec[] vs = train.vecs();
// Do rollups
for (Vec v : vs) v.min();
for (int i = 0; i < train.numCols(); i++) Log.info(sys, train._names[i] + ", " + vs[i].min() + " - " + vs[i].max() + (vs[i].naCnt() == 0 ? "" : (", missing=" + vs[i].naCnt())));
Log.info(sys, "Arguments used:\n" + ARGS.toString());
Timer t_model = new Timer();
SharedTreeModelBuilder stmb = ARGS.gbm ? new GBM() : new DRF();
stmb.source = train;
stmb.validation = test;
stmb.classification = !ARGS.regression;
stmb.response = response;
stmb.ntrees = ARGS.ntrees;
stmb.max_depth = ARGS.depth;
stmb.min_rows = ARGS.min_rows;
stmb.destination_key = Key.make("DRF_Model_" + ARGS.trainFile);
if (ARGS.gbm) {
GBM gbm = (GBM) stmb;
gbm.learn_rate = ARGS.learn;
} else {
DRF drf = (DRF) stmb;
drf.mtries = ARGS.mtries;
drf.sample_rate = ARGS.sample;
drf.seed = ARGS.seed;
}
// Invoke DRF and block till the end
stmb.invoke();
Log.info(sys, "Model trained in " + t_model);
}
use of hex.gbm.GBM in project h2o-2 by h2oai.
the class ModelSerializationTest method prepareGBMModel.
private GBMModel prepareGBMModel(String dataset, int[] ignores, int response, boolean classification, int ntrees) {
Frame f = parseFrame(dataset);
try {
GBM gbm = new GBM();
Vec respVec = f.vec(response);
gbm.source = f;
gbm.response = respVec;
gbm.classification = classification;
gbm.ntrees = ntrees;
gbm.score_each_iteration = true;
gbm.invoke();
return UKV.get(gbm.dest());
} finally {
if (f != null)
f.delete();
}
}
Aggregations