Search in sources :

Example 1 with GBMModel

use of hex.gbm.GBM.GBMModel in project h2o-2 by h2oai.

the class ModelSerializationTest method testGBMModelBinomial.

@Test
public void testGBMModelBinomial() throws IOException {
    GBMModel model = null, loadedModel = null;
    try {
        model = prepareGBMModel("smalldata/logreg/prostate.csv", ari(0), 1, true, 5);
        loadedModel = saveAndLoad(model);
        // And compare
        assertTreeModelEquals(model, loadedModel);
        assertModelBinaryEquals(model, loadedModel);
    } finally {
        if (model != null)
            model.delete();
        if (loadedModel != null)
            loadedModel.delete();
    }
}
Also used : GBMModel(hex.gbm.GBM.GBMModel) Test(org.junit.Test)

Example 2 with GBMModel

use of hex.gbm.GBM.GBMModel in project h2o-2 by h2oai.

the class GBMTest method testGBMTrainTest.

// Test-on-Train.  Slow test, needed to build a good model.
@Test
public void testGBMTrainTest() {
    File file1 = TestUtil.find_test_file("smalldata/gbm_test/ecology_model.csv");
    // Silently ignore if file not found
    if (file1 == null)
        return;
    Key fkey1 = NFSFileVec.make(file1);
    Key dest1 = Key.make("train.hex");
    File file2 = TestUtil.find_test_file("smalldata/gbm_test/ecology_eval.csv");
    Key fkey2 = NFSFileVec.make(file2);
    Key dest2 = Key.make("test.hex");
    // The Builder
    GBM gbm = new GBM();
    // The Model
    GBM.GBMModel gbmmodel = null;
    Frame ftest = null, fpreds = null;
    try {
        Frame fr = ParseDataset2.parse(dest1, new Key[] { fkey1 });
        // Remove unique ID; too predictive
        UKV.remove(fr.remove("Site")._key);
        // Train on the outcome
        gbm.response = fr.vecs()[fr.find("Angaus")];
        gbm.source = fr;
        gbm.ntrees = 5;
        gbm.max_depth = 10;
        gbm.learn_rate = 0.2f;
        gbm.min_rows = 10;
        gbm.nbins = 100;
        gbm.invoke();
        gbmmodel = UKV.get(gbm.dest());
        testHTML(gbmmodel);
        //HEX-1817
        Assert.assertTrue(gbmmodel.get_params().state == Job.JobState.DONE);
        // Test on the train data
        ftest = ParseDataset2.parse(dest2, new Key[] { fkey2 });
        fpreds = gbm.score(ftest);
        // Build a confusion matrix
        ConfusionMatrix CM = new ConfusionMatrix();
        CM.actual = ftest;
        CM.vactual = ftest.vecs()[ftest.find("Angaus")];
        CM.predict = fpreds;
        CM.vpredict = fpreds.vecs()[fpreds.find("predict")];
        // Start it, do it
        CM.invoke();
        StringBuilder sb = new StringBuilder();
        CM.toASCII(sb);
        System.out.println(sb);
    } finally {
        // Remove the original hex frame key
        gbm.source.delete();
        if (ftest != null)
            ftest.delete();
        if (fpreds != null)
            fpreds.delete();
        // Remove the model
        if (gbmmodel != null)
            gbmmodel.delete();
        UKV.remove(gbm.response._key);
        // Remove GBM Job
        gbm.remove();
    }
}
Also used : GBMModel(hex.gbm.GBM.GBMModel) ConfusionMatrix(water.api.ConfusionMatrix) File(java.io.File) Test(org.junit.Test)

Example 3 with GBMModel

use of hex.gbm.GBM.GBMModel in project h2o-2 by h2oai.

the class GBMTest method testModelAdapt.

// Adapt a trained model to a test dataset with different enums
@Test
public void testModelAdapt() {
    File file1 = TestUtil.find_test_file("./smalldata/kaggle/KDDTrain.arff.gz");
    Key fkey1 = NFSFileVec.make(file1);
    Key dest1 = Key.make("KDDTrain.hex");
    File file2 = TestUtil.find_test_file("./smalldata/kaggle/KDDTest.arff.gz");
    Key fkey2 = NFSFileVec.make(file2);
    Key dest2 = Key.make("KDDTest.hex");
    GBM gbm = new GBM();
    // The Model
    GBM.GBMModel gbmmodel = null;
    try {
        gbm.source = ParseDataset2.parse(dest1, new Key[] { fkey1 });
        // Response is col 41
        gbm.response = gbm.source.vecs()[41];
        gbm.ntrees = 2;
        gbm.max_depth = 8;
        gbm.learn_rate = 0.2f;
        gbm.min_rows = 10;
        gbm.nbins = 50;
        gbm.invoke();
        gbmmodel = UKV.get(gbm.dest());
        testHTML(gbmmodel);
        //HEX-1817
        Assert.assertTrue(gbmmodel.get_params().state == Job.JobState.DONE);
        // The test data set has a few more enums than the train
        Frame ftest = ParseDataset2.parse(dest2, new Key[] { fkey2 });
        Frame preds = gbm.score(ftest);
        ftest.delete();
        preds.delete();
    } finally {
        // Remove the model
        if (gbmmodel != null)
            gbmmodel.delete();
        // Remove original hex frame key
        gbm.source.delete();
        UKV.remove(gbm.response._key);
        // Remove GBM Job
        gbm.remove();
    }
}
Also used : GBMModel(hex.gbm.GBM.GBMModel) File(java.io.File) Test(org.junit.Test)

Example 4 with GBMModel

use of hex.gbm.GBM.GBMModel in project h2o-2 by h2oai.

the class GBMTest method basicGBM.

public GBMModel basicGBM(String fname, String hexname, PrepData prep, boolean validation, Family family) {
    File file = TestUtil.find_test_file(fname);
    // Silently abort test if the file is missing
    if (file == null)
        return null;
    Key fkey = NFSFileVec.make(file);
    Key dest = Key.make(hexname);
    // The Builder
    GBM gbm = new GBM();
    // The Model
    GBM.GBMModel gbmmodel = null;
    try {
        Frame fr = gbm.source = ParseDataset2.parse(dest, new Key[] { fkey });
        UKV.remove(fkey);
        int idx = prep.prep(fr);
        if (idx < 0) {
            gbm.classification = false;
            idx = ~idx;
        }
        gbm.response = fr.vecs()[idx];
        gbm.family = family;
        assert gbm.family != Family.bernoulli || gbm.classification;
        gbm.ntrees = 4;
        gbm.max_depth = 4;
        gbm.min_rows = 1;
        gbm.nbins = 50;
        gbm.cols = new int[fr.numCols()];
        for (int i = 0; i < gbm.cols.length; i++) gbm.cols[i] = i;
        gbm.validation = validation ? new Frame(gbm.source) : null;
        gbm.learn_rate = .2f;
        gbm.score_each_iteration = true;
        gbm.invoke();
        gbmmodel = UKV.get(gbm.dest());
        testHTML(gbmmodel);
        //HEX-1817
        Assert.assertTrue(gbmmodel.get_params().state == Job.JobState.DONE);
        //System.out.println(gbmmodel.toJava());
        Frame preds = gbm.score(gbm.source);
        preds.delete();
        return gbmmodel;
    } finally {
        // Remove original hex frame key
        gbm.source.delete();
        // Remove validation dataset if specified
        if (gbm.validation != null)
            gbm.validation.delete();
        // Remove the model
        if (gbmmodel != null)
            gbmmodel.delete();
        // Remove GBM Job
        gbm.remove();
    }
}
Also used : GBMModel(hex.gbm.GBM.GBMModel) File(java.io.File)

Example 5 with GBMModel

use of hex.gbm.GBM.GBMModel in project h2o-2 by h2oai.

the class GBMTest method testGBMRegression.

@Test
public void testGBMRegression() {
    File file = TestUtil.find_test_file("./smalldata/gbm_test/Mfgdata_gaussian_GBM_testing.csv");
    Key fkey = NFSFileVec.make(file);
    Key dest = Key.make("mfg.hex");
    // The Builder
    GBM gbm = new GBM();
    // The Model
    GBM.GBMModel gbmmodel = null;
    try {
        Frame fr = gbm.source = ParseDataset2.parse(dest, new Key[] { fkey });
        UKV.remove(fkey);
        // Regression
        gbm.classification = false;
        gbm.family = GBM.Family.AUTO;
        // Row in col 0, dependent in col 1, predictor in col 2
        gbm.response = fr.vecs()[1];
        gbm.ntrees = 1;
        gbm.max_depth = 1;
        gbm.min_rows = 1;
        gbm.nbins = 20;
        // Just column 2
        gbm.cols = new int[] { 2 };
        gbm.validation = null;
        gbm.learn_rate = 1.0f;
        gbm.score_each_iteration = true;
        gbm.invoke();
        gbmmodel = UKV.get(gbm.dest());
        //HEX-1817
        Assert.assertTrue(gbmmodel.get_params().state == Job.JobState.DONE);
        Frame preds = gbm.score(gbm.source);
        double sq_err = new CompErr().doAll(gbm.response, preds.vecs()[0])._sum;
        double mse = sq_err / preds.numRows();
        assertEquals(79152.1233, mse, 0.1);
        preds.delete();
    } finally {
        // Remove original hex frame key
        gbm.source.delete();
        // Remove validation dataset if specified
        if (gbm.validation != null)
            gbm.validation.delete();
        // Remove the model
        if (gbmmodel != null)
            gbmmodel.delete();
        // Remove GBM Job
        gbm.remove();
    }
}
Also used : GBMModel(hex.gbm.GBM.GBMModel) File(java.io.File) Test(org.junit.Test)

Aggregations

GBMModel (hex.gbm.GBM.GBMModel)9 Test (org.junit.Test)6 File (java.io.File)4 Frame (water.fvec.Frame)2 Vec (water.fvec.Vec)2 ConfusionMatrix (water.api.ConfusionMatrix)1