Search in sources :

Example 6 with GBMModel

use of hex.gbm.GBM.GBMModel in project h2o-2 by h2oai.

the class GBMTest2 method testBalanceWithCrossValidation.

@Override
protected void testBalanceWithCrossValidation(String dataset, int response, int[] ignored_cols, int ntrees, int nfolds) {
    Frame f = parseFrame(dataset);
    GBMModel model = null;
    GBM gbm = new GBM();
    try {
        Vec respVec = f.vec(response);
        // Build a model
        gbm.source = f;
        gbm.response = respVec;
        gbm.ignored_cols = ignored_cols;
        gbm.classification = true;
        gbm.ntrees = ntrees;
        gbm.balance_classes = true;
        gbm.n_folds = nfolds;
        gbm.keep_cross_validation_splits = false;
        gbm.invoke();
        Assert.assertEquals("Number of cross validation model is wrond!", nfolds, gbm.xval_models.length);
        model = UKV.get(gbm.dest());
        //HEX-1817
        Assert.assertTrue(model.get_params().state == Job.JobState.DONE);
    } finally {
        if (f != null)
            f.delete();
        if (model != null) {
            if (gbm.xval_models != null) {
                for (Key k : gbm.xval_models) {
                    Model m = UKV.get(k);
                    m.delete();
                }
            }
            model.delete();
        }
    }
}
Also used : Frame(water.fvec.Frame) GBMModel(hex.gbm.GBM.GBMModel) Vec(water.fvec.Vec) GBMModel(hex.gbm.GBM.GBMModel)

Example 7 with GBMModel

use of hex.gbm.GBM.GBMModel in project h2o-2 by h2oai.

the class GBMCheckpointTest method testCheckPointReconstruction.

private void testCheckPointReconstruction(String dataset, int response, boolean classification, int ntreesInPriorModel, int ntreesInANewModel) {
    Frame f = parseFrame(dataset);
    GBMModel model = null;
    GBMModel modelFromCheckpoint = null;
    GBMModel modelFinal = null;
    try {
        Vec respVec = f.vec(response);
        // Build a model
        GBMWithHooks gbm = new GBMWithHooks();
        gbm.source = f;
        gbm.response = respVec;
        gbm.classification = classification;
        gbm.ntrees = ntreesInPriorModel;
        gbm.collectPoint = WhereToCollect.AFTER_BUILD;
        gbm.score_each_iteration = true;
        gbm.invoke();
        model = UKV.get(gbm.dest());
        // Build a checkpointed model
        GBMWithHooks gbmFromCheckpoint = new GBMWithHooks();
        gbmFromCheckpoint.source = f;
        gbmFromCheckpoint.response = respVec;
        gbmFromCheckpoint.classification = classification;
        gbmFromCheckpoint.ntrees = ntreesInANewModel;
        gbmFromCheckpoint.collectPoint = WhereToCollect.AFTER_RECONSTRUCTION;
        gbmFromCheckpoint.checkpoint = gbm.dest();
        gbmFromCheckpoint.score_each_iteration = true;
        gbmFromCheckpoint.invoke();
        modelFromCheckpoint = UKV.get(gbmFromCheckpoint.dest());
        // Check if reconstructed frame computation data are same
        assertArrayEquals("Tree data produced by drf run and reconstructed from a model do not match!", gbm.treesCols, gbmFromCheckpoint.treesCols);
        // Build a model which contains old+new trees and compare prediction results
        GBM gbmFinal = new GBM();
        gbmFinal.source = f;
        gbmFinal.response = respVec;
        gbmFinal.classification = classification;
        gbmFinal.ntrees = ntreesInANewModel + ntreesInPriorModel;
        gbmFinal.score_each_iteration = true;
        gbmFinal.invoke();
        modelFinal = UKV.get(gbmFinal.dest());
        assertTreeModelEquals(modelFinal, modelFromCheckpoint);
    } finally {
        if (f != null)
            f.delete();
        if (model != null)
            model.delete();
        if (modelFromCheckpoint != null)
            modelFromCheckpoint.delete();
        if (modelFinal != null)
            modelFinal.delete();
    }
}
Also used : Frame(water.fvec.Frame) GBMModel(hex.gbm.GBM.GBMModel) Vec(water.fvec.Vec)

Example 8 with GBMModel

use of hex.gbm.GBM.GBMModel in project h2o-2 by h2oai.

the class GBMTest method testReproducibility.

@Test
public void testReproducibility() {
    Frame tfr = null;
    final int N = 5;
    double[] mses = new double[N];
    Scope.enter();
    try {
        // Load data, hack frames
        tfr = parseFrame(Key.make("air.hex"), "./smalldata/covtype/covtype.20k.data");
        // rebalance to 256 chunks
        Key dest = Key.make("df.rebalanced.hex");
        RebalanceDataSet rb = new RebalanceDataSet(tfr, dest, 256);
        H2O.submitTask(rb);
        rb.join();
        tfr.delete();
        tfr = DKV.get(dest).get();
        for (int i = 0; i < N; ++i) {
            GBM parms = new GBM();
            parms.source = tfr;
            parms.response = tfr.lastVec();
            parms.nbins = 1000;
            parms.ntrees = 1;
            parms.max_depth = 8;
            parms.learn_rate = 0.1;
            parms.min_rows = 10;
            parms.family = Family.AUTO;
            parms.classification = false;
            // Build a first model; all remaining models should be equal
            GBMModel gbm = parms.fork().get();
            mses[i] = gbm.mse();
            gbm.delete();
        }
    } finally {
        if (tfr != null)
            tfr.delete();
    }
    Scope.exit();
    for (int i = 0; i < mses.length; ++i) {
        Log.info("trial: " + i + " -> mse: " + mses[i]);
    }
    for (int i = 0; i < mses.length; ++i) {
        assertEquals(mses[i], mses[0], 1e-15);
    }
}
Also used : GBMModel(hex.gbm.GBM.GBMModel) Test(org.junit.Test)

Example 9 with GBMModel

use of hex.gbm.GBM.GBMModel in project h2o-2 by h2oai.

the class ModelSerializationTest method testGBMModelMultinomial.

@Test
public void testGBMModelMultinomial() throws IOException {
    GBMModel model = null, loadedModel = null;
    try {
        model = prepareGBMModel("smalldata/iris/iris.csv", EIA, 4, true, 5);
        loadedModel = saveAndLoad(model);
        // And compare
        assertTreeModelEquals(model, loadedModel);
        assertModelBinaryEquals(model, loadedModel);
    } finally {
        if (model != null)
            model.delete();
        if (loadedModel != null)
            loadedModel.delete();
    }
}
Also used : GBMModel(hex.gbm.GBM.GBMModel) Test(org.junit.Test)

Aggregations

GBMModel (hex.gbm.GBM.GBMModel)9 Test (org.junit.Test)6 File (java.io.File)4 Frame (water.fvec.Frame)2 Vec (water.fvec.Vec)2 ConfusionMatrix (water.api.ConfusionMatrix)1