Search in sources :

Example 1 with SharedTreeModelBuilder

use of hex.gbm.SharedTreeModelBuilder in project h2o-2 by h2oai.

the class Runner method main.

// Do the Work
static void main(OptArgs ARGS) {
    // Finish building the cluster
    TestUtil.stall_till_cloudsize(ARGS.clusterSize);
    // Sanity check basic args
    if (ARGS.ntrees <= 0 || ARGS.ntrees > 100000)
        throw new RuntimeException("ntrees " + ARGS.ntrees + " out of bounds");
    if (ARGS.sample < 0 || ARGS.sample > 1.0f)
        throw new RuntimeException("sample " + ARGS.sample + " out of bounds");
    if (ARGS.learn < 0 || ARGS.learn > 1.0f)
        throw new RuntimeException("learn " + ARGS.learn + " out of bounds");
    if (ARGS.nbins < 2 || ARGS.nbins > 100000)
        throw new RuntimeException("nbins " + ARGS.nbins + " out of bounds");
    if (ARGS.depth <= 0)
        throw new RuntimeException("depth " + ARGS.depth + " out of bounds");
    if (ARGS.splitTestTrain < 0 || ARGS.splitTestTrain > 1.0f)
        throw new RuntimeException("splitTestTrain " + ARGS.splitTestTrain + " out of bounds");
    // If trainFile is NOT set, you are doing the default file and cannot set testFile.
    if ((ARGS.trainFile == OptArgs.defaultTrainFile) && (ARGS.testFile != OptArgs.defaultTestFile))
        throw new RuntimeException("Cannot set test file unless also setting train file");
    // If testFile is set, cannot set splitTestTrain
    if ((ARGS.testFile != OptArgs.defaultTestFile) && !Float.isNaN(ARGS.splitTestTrain))
        throw new RuntimeException("Cannot have both testFile and splitTestTrain");
    Sys sys = ARGS.gbm ? Sys.GBM__ : Sys.DRF__;
    String[] cs = (ARGS.cols + "," + ARGS.response).split("[,\t]");
    // Set mtries
    if (ARGS.mtries == 0)
        ARGS.mtries = (int) Math.sqrt(cs.length);
    if (ARGS.mtries <= 0 || ARGS.mtries > cs.length)
        throw new RuntimeException("mtries " + ARGS.mtries + " out of bounds");
    // Load data
    Timer t_load = new Timer();
    Key trainkey = Key.make("train.hex");
    Key testkey = Key.make("test.hex");
    Frame train = TestUtil.parseFrame(trainkey, ARGS.trainFile);
    Frame test = null;
    if (!Float.isNaN(ARGS.splitTestTrain)) {
        water.exec.Exec2.exec("r=runif(train.hex,-1); test.hex=train.hex[r>=0.7,]; train.hex=train.hex[r<0.7,]").remove_and_unlock();
        train = UKV.get(trainkey);
        test = UKV.get(testkey);
    } else if (ARGS.testFile.length() != 0) {
        test = TestUtil.parseFrame(testkey, ARGS.testFile);
    }
    Log.info(sys, "Data loaded in " + t_load);
    // Pull out the response vector from the train data
    Vec response = train.subframe(new String[] { ARGS.response }).vecs()[0];
    // Build a Frame with just the requested columns.
    train = train.subframe(cs);
    if (test != null)
        test = test.subframe(cs);
    Vec[] vs = train.vecs();
    // Do rollups
    for (Vec v : vs) v.min();
    for (int i = 0; i < train.numCols(); i++) Log.info(sys, train._names[i] + ", " + vs[i].min() + " - " + vs[i].max() + (vs[i].naCnt() == 0 ? "" : (", missing=" + vs[i].naCnt())));
    Log.info(sys, "Arguments used:\n" + ARGS.toString());
    Timer t_model = new Timer();
    SharedTreeModelBuilder stmb = ARGS.gbm ? new GBM() : new DRF();
    stmb.source = train;
    stmb.validation = test;
    stmb.classification = !ARGS.regression;
    stmb.response = response;
    stmb.ntrees = ARGS.ntrees;
    stmb.max_depth = ARGS.depth;
    stmb.min_rows = ARGS.min_rows;
    stmb.destination_key = Key.make("DRF_Model_" + ARGS.trainFile);
    if (ARGS.gbm) {
        GBM gbm = (GBM) stmb;
        gbm.learn_rate = ARGS.learn;
    } else {
        DRF drf = (DRF) stmb;
        drf.mtries = ARGS.mtries;
        drf.sample_rate = ARGS.sample;
        drf.seed = ARGS.seed;
    }
    // Invoke DRF and block till the end
    stmb.invoke();
    Log.info(sys, "Model trained in " + t_model);
}
Also used : Frame(water.fvec.Frame) Sys(water.util.Log.Tag.Sys) GBM(hex.gbm.GBM) SharedTreeModelBuilder(hex.gbm.SharedTreeModelBuilder) Vec(water.fvec.Vec)

Aggregations

GBM (hex.gbm.GBM)1 SharedTreeModelBuilder (hex.gbm.SharedTreeModelBuilder)1 Frame (water.fvec.Frame)1 Vec (water.fvec.Vec)1 Sys (water.util.Log.Tag.Sys)1