Example 1 with GLM

use of hex.glm.GLM in project h2o-3 by h2oai.

the class TestCase method execute.

public TestCaseResult execute() throws Exception, AssertionError {
    double startTime = 0, stopTime = 0;
    if (!grid) {
        Model.Output modelOutput = null;
        DRF drfJob;
        DRFModel drfModel = null;
        GLM glmJob;
        GLMModel glmModel = null;
        GBM gbmJob;
        GBMModel gbmModel = null;
        DeepLearning dlJob;
        DeepLearningModel dlModel = null;
        String bestModelJson = null;
        try {
            switch(algo) {
                case "drf":
                    drfJob = new DRF((DRFModel.DRFParameters) params);
                    AccuracyTestingSuite.summaryLog.println("Training DRF model.");
                    startTime = System.currentTimeMillis();
                    drfModel = drfJob.trainModel().get();
                    stopTime = System.currentTimeMillis();
                    modelOutput = drfModel._output;
                    bestModelJson = drfModel._parms.toJsonString();
                case "glm":
                    glmJob = new GLM((GLMModel.GLMParameters) params, Key.<GLMModel>make("GLMModel"));
                    AccuracyTestingSuite.summaryLog.println("Training GLM model.");
                    startTime = System.currentTimeMillis();
                    glmModel = glmJob.trainModel().get();
                    stopTime = System.currentTimeMillis();
                    modelOutput = glmModel._output;
                    bestModelJson = glmModel._parms.toJsonString();
                case "gbm":
                    gbmJob = new GBM((GBMModel.GBMParameters) params);
                    AccuracyTestingSuite.summaryLog.println("Training GBM model.");
                    startTime = System.currentTimeMillis();
                    gbmModel = gbmJob.trainModel().get();
                    stopTime = System.currentTimeMillis();
                    modelOutput = gbmModel._output;
                    bestModelJson = gbmModel._parms.toJsonString();
                case "dl":
                    dlJob = new DeepLearning((DeepLearningModel.DeepLearningParameters) params);
                    AccuracyTestingSuite.summaryLog.println("Training DL model.");
                    startTime = System.currentTimeMillis();
                    dlModel = dlJob.trainModel().get();
                    stopTime = System.currentTimeMillis();
                    modelOutput = dlModel._output;
                    bestModelJson = dlModel._parms.toJsonString();
        } catch (Exception e) {
            throw new Exception(e);
        } finally {
            if (drfModel != null) {
            if (glmModel != null) {
            if (gbmModel != null) {
            if (dlModel != null) {
        //Add check if cv is used
        if (params._nfolds > 0) {
            return new TestCaseResult(testCaseId, getMetrics(modelOutput._training_metrics), getMetrics(modelOutput._cross_validation_metrics), stopTime - startTime, bestModelJson, this, trainingDataSet, testingDataSet);
        } else {
            return new TestCaseResult(testCaseId, getMetrics(modelOutput._training_metrics), getMetrics(modelOutput._validation_metrics), stopTime - startTime, bestModelJson, this, trainingDataSet, testingDataSet);
    } else {
        assert !modelSelectionCriteria.equals("");
        Grid grid = null;
        Model bestModel = null;
        String bestModelJson = null;
        try {
            switch(// TODO: Hack for PUBDEV-2812
            algo) {
                case "drf":
                    if (!drfRegistered) {
                        new DRF(true);
                        new DRFParametersV3();
                        drfRegistered = true;
                case "glm":
                    if (!glmRegistered) {
                        new GLM(true);
                        new GLMParametersV3();
                        glmRegistered = true;
                case "gbm":
                    if (!gbmRegistered) {
                        new GBM(true);
                        new GBMParametersV3();
                        gbmRegistered = true;
                case "dl":
                    if (!dlRegistered) {
                        new DeepLearning(true);
                        new DeepLearningParametersV3();
                        dlRegistered = true;
            startTime = System.currentTimeMillis();
            // TODO: ModelParametersBuilderFactory parameter must be instantiated properly
            Job<Grid> gs = GridSearch.startGridSearch(null, params, hyperParms, new GridSearch.SimpleParametersBuilderFactory<>(), searchCriteria);
            grid = gs.get();
            stopTime = System.currentTimeMillis();
            boolean higherIsBetter = higherIsBetter(modelSelectionCriteria);
            double bestScore = higherIsBetter ? -Double.MAX_VALUE : Double.MAX_VALUE;
            for (Model m : grid.getModels()) {
                double validationMetricScore = getMetrics(m._output._validation_metrics).get(modelSelectionCriteria);
                AccuracyTestingSuite.summaryLog.println(modelSelectionCriteria + " for model " + m._key.toString() + " is " + validationMetricScore);
                if (higherIsBetter ? validationMetricScore > bestScore : validationMetricScore < bestScore) {
                    bestScore = validationMetricScore;
                    bestModel = m;
                    bestModelJson = bestModel._parms.toJsonString();
            AccuracyTestingSuite.summaryLog.println("Best model: " + bestModel._key.toString());
            AccuracyTestingSuite.summaryLog.println("Best model parameters: " + bestModelJson);
        } catch (Exception e) {
            throw new Exception(e);
        } finally {
            if (grid != null) {
        //Add check if cv is used
        if (params._nfolds > 0) {
            return new TestCaseResult(testCaseId, getMetrics(bestModel._output._training_metrics), getMetrics(bestModel._output._cross_validation_metrics), stopTime - startTime, bestModelJson, this, trainingDataSet, testingDataSet);
        } else {
            return new TestCaseResult(testCaseId, getMetrics(bestModel._output._training_metrics), getMetrics(bestModel._output._validation_metrics), stopTime - startTime, bestModelJson, this, trainingDataSet, testingDataSet);
Example 2 with GLM

use of hex.glm.GLM in project h2o-3 by h2oai.

the class WorkFlowTest method testWorkFlow.

// End-to-end workflow test:
// 1- load set of files, train, test, holdout
// 2- light data munging
// 3- build model on train; using test as validation
// 4- score on holdout set
// If files are missing, silently fail - as the files are big and this is not
// yet a junit test
private void testWorkFlow(String[] files) {
    try {
        // 1- Load datasets
        Frame data = load_files("data.hex", files);
        if (data == null)
        // -------------------------------------------------
        // 2- light data munging
        // Convert start time to: Day since the Epoch
        Vec startime = data.vec("starttime");
        data.add(new TimeSplit().doIt(startime));
        // Now do a monster Group-By.  Count bike starts per-station per-day
        Vec days = data.vec("Days");
        long start = System.currentTimeMillis();
        Frame bph = new CountBikes(days).doAll(days, data.vec("start station name")).makeFrame(Key.make("bph.hex"));
        System.out.println("Groupby took " + (System.currentTimeMillis() - start));
        System.out.println(bph.toString(10000, 20));
        QuantileModel.QuantileParameters quantile_parms = new QuantileModel.QuantileParameters();
        quantile_parms._train = bph._key;
        Job<QuantileModel> job2 = new Quantile(quantile_parms).trainModel();
        QuantileModel quantile = job2.get();
        // Split into train, test and holdout sets
        Key[] keys = new Key[] { Key.make("train.hex"), Key.make("test.hex"), Key.make("hold.hex") };
        double[] ratios = new double[] { 0.6, 0.3, 0.1 };
        Frame[] frs = ShuffleSplitFrame.shuffleSplitFrame(bph, keys, ratios, 1234567689L);
        Frame train = frs[0];
        Frame test = frs[1];
        Frame hold = frs[2];
        // -------------------------------------------------
        // 3- build model on train; using test as validation
        // ---
        // Gradient Boosting Machine
        GBMModel.GBMParameters gbm_parms = new GBMModel.GBMParameters();
        // base Model.Parameters
        gbm_parms._train = train._key;
        gbm_parms._valid = test._key;
        // default is false
        gbm_parms._score_each_iteration = false;
        // SupervisedModel.Parameters
        gbm_parms._response_column = "bikes";
        // SharedTreeModel.Parameters
        // default is 50, 1000 is 0.90, 10000 is 0.91
        gbm_parms._ntrees = 500;
        // default is 5
        gbm_parms._max_depth = 6;
        // default
        gbm_parms._min_rows = 10;
        // default
        gbm_parms._nbins = 20;
        // GBMModel.Parameters
        // default
        gbm_parms._distribution = DistributionFamily.gaussian;
        // default
        gbm_parms._learn_rate = 0.1f;
        // Train model; block for results
        Job<GBMModel> job = new GBM(gbm_parms).trainModel();
        GBMModel gbm = job.get();
        // ---
        // Build a GLM model also
        GLMModel.GLMParameters glm_parms = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.gaussian);
        // base Model.Parameters
        glm_parms._train = train._key;
        glm_parms._valid = test._key;
        // default is false
        glm_parms._score_each_iteration = false;
        // SupervisedModel.Parameters
        glm_parms._response_column = "bikes";
        // GLMModel.Parameters
        glm_parms._use_all_factor_levels = true;
        // Train model; block for results
        Job<GLMModel> glm_job = new GLM(glm_parms).trainModel();
        GLMModel glm = glm_job.get();
        // -------------------------------------------------
        // 4- Score on holdout set & report
        // Cleanup
    } finally {
Example 3 with GLM

use of hex.glm.GLM in project h2o-3 by h2oai.

the class ModelSerializationTest method prepareGLMModel.

private GLMModel prepareGLMModel(String dataset, String[] ignoredColumns, String response, GLMModel.GLMParameters.Family family) {
    Frame f = parse_test_file(dataset);
    try {
        GLMModel.GLMParameters params = new GLMModel.GLMParameters();
        params._train = f._key;
        params._ignored_columns = ignoredColumns;
        params._response_column = response;
        params._family = family;
        return new GLM(params).trainModel().get();
    } finally {
        if (f != null)
Example 4 with GLM

use of hex.glm.GLM in project h2o-3 by h2oai.

the class XValPredictionsCheck method testXValPredictions.

public void testXValPredictions() {
    final int nfolds = 3;
    Frame tfr = null;
    try {
        // Load data, hack frames
        tfr = parse_test_file("smalldata/iris/iris_wheader.csv");
        Frame foldId = new Frame(new String[] { "foldId" }, new Vec[] { AstKFold.kfoldColumn(tfr.vec("class").makeZero(), nfolds, 543216789) });
        // GBM
        GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
        parms._train = tfr._key;
        parms._response_column = "class";
        parms._ntrees = 1;
        parms._max_depth = 1;
        parms._fold_column = "foldId";
        parms._distribution = DistributionFamily.multinomial;
        parms._keep_cross_validation_predictions = true;
        GBM job = new GBM(parms);
        GBMModel gbm = job.trainModel().get();
        checkModel(gbm, foldId.anyVec(), 3);
        // DRF
        DRFModel.DRFParameters parmsDRF = new DRFModel.DRFParameters();
        parmsDRF._train = tfr._key;
        parmsDRF._response_column = "class";
        parmsDRF._ntrees = 1;
        parmsDRF._max_depth = 1;
        parmsDRF._fold_column = "foldId";
        parmsDRF._distribution = DistributionFamily.multinomial;
        parmsDRF._keep_cross_validation_predictions = true;
        DRF drfJob = new DRF(parmsDRF);
        DRFModel drf = drfJob.trainModel().get();
        checkModel(drf, foldId.anyVec(), 3);
        // GLM
        GLMModel.GLMParameters parmsGLM = new GLMModel.GLMParameters();
        parmsGLM._train = tfr._key;
        parmsGLM._response_column = "sepal_len";
        parmsGLM._fold_column = "foldId";
        parmsGLM._keep_cross_validation_predictions = true;
        GLM glmJob = new GLM(parmsGLM);
        GLMModel glm = glmJob.trainModel().get();
        checkModel(glm, foldId.anyVec(), 1);
        // DL
        DeepLearningModel.DeepLearningParameters parmsDL = new DeepLearningModel.DeepLearningParameters();
        parmsDL._train = tfr._key;
        parmsDL._response_column = "class";
        parmsDL._hidden = new int[] { 1 };
        parmsDL._epochs = 1;
        parmsDL._fold_column = "foldId";
        parmsDL._keep_cross_validation_predictions = true;
        DeepLearning dlJob = new DeepLearning(parmsDL);
        DeepLearningModel dl = dlJob.trainModel().get();
        checkModel(dl, foldId.anyVec(), 3);
    } finally {
        if (tfr != null)
