Search in sources :

Example 1 with GLM

use of hex.glm.GLM in project h2o-3 by h2oai.

the class TestCase method execute.

public TestCaseResult execute() throws Exception, AssertionError {
    loadTestCaseDataSets();
    makeModelParameters();
    double startTime = 0, stopTime = 0;
    if (!grid) {
        Model.Output modelOutput = null;
        DRF drfJob;
        DRFModel drfModel = null;
        GLM glmJob;
        GLMModel glmModel = null;
        GBM gbmJob;
        GBMModel gbmModel = null;
        DeepLearning dlJob;
        DeepLearningModel dlModel = null;
        String bestModelJson = null;
        try {
            switch(algo) {
                case "drf":
                    drfJob = new DRF((DRFModel.DRFParameters) params);
                    AccuracyTestingSuite.summaryLog.println("Training DRF model.");
                    startTime = System.currentTimeMillis();
                    drfModel = drfJob.trainModel().get();
                    stopTime = System.currentTimeMillis();
                    modelOutput = drfModel._output;
                    bestModelJson = drfModel._parms.toJsonString();
                    break;
                case "glm":
                    glmJob = new GLM((GLMModel.GLMParameters) params, Key.<GLMModel>make("GLMModel"));
                    AccuracyTestingSuite.summaryLog.println("Training GLM model.");
                    startTime = System.currentTimeMillis();
                    glmModel = glmJob.trainModel().get();
                    stopTime = System.currentTimeMillis();
                    modelOutput = glmModel._output;
                    bestModelJson = glmModel._parms.toJsonString();
                    break;
                case "gbm":
                    gbmJob = new GBM((GBMModel.GBMParameters) params);
                    AccuracyTestingSuite.summaryLog.println("Training GBM model.");
                    startTime = System.currentTimeMillis();
                    gbmModel = gbmJob.trainModel().get();
                    stopTime = System.currentTimeMillis();
                    modelOutput = gbmModel._output;
                    bestModelJson = gbmModel._parms.toJsonString();
                    break;
                case "dl":
                    dlJob = new DeepLearning((DeepLearningModel.DeepLearningParameters) params);
                    AccuracyTestingSuite.summaryLog.println("Training DL model.");
                    startTime = System.currentTimeMillis();
                    dlModel = dlJob.trainModel().get();
                    stopTime = System.currentTimeMillis();
                    modelOutput = dlModel._output;
                    bestModelJson = dlModel._parms.toJsonString();
                    break;
            }
        } catch (Exception e) {
            throw new Exception(e);
        } finally {
            if (drfModel != null) {
                drfModel.delete();
            }
            if (glmModel != null) {
                glmModel.delete();
            }
            if (gbmModel != null) {
                gbmModel.delete();
            }
            if (dlModel != null) {
                dlModel.delete();
            }
        }
        removeTestCaseDataSetFrames();
        //Add check if cv is used
        if (params._nfolds > 0) {
            return new TestCaseResult(testCaseId, getMetrics(modelOutput._training_metrics), getMetrics(modelOutput._cross_validation_metrics), stopTime - startTime, bestModelJson, this, trainingDataSet, testingDataSet);
        } else {
            return new TestCaseResult(testCaseId, getMetrics(modelOutput._training_metrics), getMetrics(modelOutput._validation_metrics), stopTime - startTime, bestModelJson, this, trainingDataSet, testingDataSet);
        }
    } else {
        assert !modelSelectionCriteria.equals("");
        makeGridParameters();
        makeSearchCriteria();
        Grid grid = null;
        Model bestModel = null;
        String bestModelJson = null;
        try {
            SchemaServer.registerAllSchemasIfNecessary();
            switch(// TODO: Hack for PUBDEV-2812
            algo) {
                case "drf":
                    if (!drfRegistered) {
                        new DRF(true);
                        new DRFParametersV3();
                        drfRegistered = true;
                    }
                    break;
                case "glm":
                    if (!glmRegistered) {
                        new GLM(true);
                        new GLMParametersV3();
                        glmRegistered = true;
                    }
                    break;
                case "gbm":
                    if (!gbmRegistered) {
                        new GBM(true);
                        new GBMParametersV3();
                        gbmRegistered = true;
                    }
                    break;
                case "dl":
                    if (!dlRegistered) {
                        new DeepLearning(true);
                        new DeepLearningParametersV3();
                        dlRegistered = true;
                    }
                    break;
            }
            startTime = System.currentTimeMillis();
            // TODO: ModelParametersBuilderFactory parameter must be instantiated properly
            Job<Grid> gs = GridSearch.startGridSearch(null, params, hyperParms, new GridSearch.SimpleParametersBuilderFactory<>(), searchCriteria);
            grid = gs.get();
            stopTime = System.currentTimeMillis();
            boolean higherIsBetter = higherIsBetter(modelSelectionCriteria);
            double bestScore = higherIsBetter ? -Double.MAX_VALUE : Double.MAX_VALUE;
            for (Model m : grid.getModels()) {
                double validationMetricScore = getMetrics(m._output._validation_metrics).get(modelSelectionCriteria);
                AccuracyTestingSuite.summaryLog.println(modelSelectionCriteria + " for model " + m._key.toString() + " is " + validationMetricScore);
                if (higherIsBetter ? validationMetricScore > bestScore : validationMetricScore < bestScore) {
                    bestScore = validationMetricScore;
                    bestModel = m;
                    bestModelJson = bestModel._parms.toJsonString();
                }
            }
            AccuracyTestingSuite.summaryLog.println("Best model: " + bestModel._key.toString());
            AccuracyTestingSuite.summaryLog.println("Best model parameters: " + bestModelJson);
        } catch (Exception e) {
            throw new Exception(e);
        } finally {
            if (grid != null) {
                grid.delete();
            }
        }
        removeTestCaseDataSetFrames();
        //Add check if cv is used
        if (params._nfolds > 0) {
            return new TestCaseResult(testCaseId, getMetrics(bestModel._output._training_metrics), getMetrics(bestModel._output._cross_validation_metrics), stopTime - startTime, bestModelJson, this, trainingDataSet, testingDataSet);
        } else {
            return new TestCaseResult(testCaseId, getMetrics(bestModel._output._training_metrics), getMetrics(bestModel._output._validation_metrics), stopTime - startTime, bestModelJson, this, trainingDataSet, testingDataSet);
        }
    }
}
Also used : Grid(hex.grid.Grid) GLM(hex.glm.GLM) DeepLearning(hex.deeplearning.DeepLearning) GBMParametersV3(hex.schemas.GBMV3.GBMParametersV3) GBM(hex.tree.gbm.GBM) GBMModel(hex.tree.gbm.GBMModel) DRFModel(hex.tree.drf.DRFModel) GLMModel(hex.glm.GLMModel) IOException(java.io.IOException) GridSearch(hex.grid.GridSearch) DeepLearningParametersV3(hex.schemas.DeepLearningV3.DeepLearningParametersV3) GLMModel(hex.glm.GLMModel) DeepLearningModel(hex.deeplearning.DeepLearningModel) SharedTreeModel(hex.tree.SharedTreeModel) GBMModel(hex.tree.gbm.GBMModel) DRFModel(hex.tree.drf.DRFModel) DRF(hex.tree.drf.DRF) GLMParametersV3(hex.schemas.GLMV3.GLMParametersV3) DeepLearningModel(hex.deeplearning.DeepLearningModel) DRFParametersV3(hex.schemas.DRFV3.DRFParametersV3)

Example 2 with GLM

use of hex.glm.GLM in project h2o-3 by h2oai.

the class WorkFlowTest method testWorkFlow.

// End-to-end workflow test:
// 1- load set of files, train, test, holdout
// 2- light data munging
// 3- build model on train; using test as validation
// 4- score on holdout set
//
// If files are missing, silently fail - as the files are big and this is not
// yet a junit test
private void testWorkFlow(String[] files) {
    try {
        Scope.enter();
        // 1- Load datasets
        Frame data = load_files("data.hex", files);
        if (data == null)
            return;
        // -------------------------------------------------
        // 2- light data munging
        // Convert start time to: Day since the Epoch
        Vec startime = data.vec("starttime");
        data.add(new TimeSplit().doIt(startime));
        // Now do a monster Group-By.  Count bike starts per-station per-day
        Vec days = data.vec("Days");
        long start = System.currentTimeMillis();
        Frame bph = new CountBikes(days).doAll(days, data.vec("start station name")).makeFrame(Key.make("bph.hex"));
        System.out.println("Groupby took " + (System.currentTimeMillis() - start));
        System.out.println(bph);
        System.out.println(bph.toString(10000, 20));
        data.remove();
        QuantileModel.QuantileParameters quantile_parms = new QuantileModel.QuantileParameters();
        quantile_parms._train = bph._key;
        Job<QuantileModel> job2 = new Quantile(quantile_parms).trainModel();
        QuantileModel quantile = job2.get();
        job2.remove();
        System.out.println(Arrays.deepToString(quantile._output._quantiles));
        quantile.remove();
        // Split into train, test and holdout sets
        Key[] keys = new Key[] { Key.make("train.hex"), Key.make("test.hex"), Key.make("hold.hex") };
        double[] ratios = new double[] { 0.6, 0.3, 0.1 };
        Frame[] frs = ShuffleSplitFrame.shuffleSplitFrame(bph, keys, ratios, 1234567689L);
        Frame train = frs[0];
        Frame test = frs[1];
        Frame hold = frs[2];
        bph.remove();
        System.out.println(train);
        System.out.println(test);
        // -------------------------------------------------
        // 3- build model on train; using test as validation
        // ---
        // Gradient Boosting Machine
        GBMModel.GBMParameters gbm_parms = new GBMModel.GBMParameters();
        // base Model.Parameters
        gbm_parms._train = train._key;
        gbm_parms._valid = test._key;
        // default is false
        gbm_parms._score_each_iteration = false;
        // SupervisedModel.Parameters
        gbm_parms._response_column = "bikes";
        // SharedTreeModel.Parameters
        // default is 50, 1000 is 0.90, 10000 is 0.91
        gbm_parms._ntrees = 500;
        // default is 5
        gbm_parms._max_depth = 6;
        // default
        gbm_parms._min_rows = 10;
        // default
        gbm_parms._nbins = 20;
        // GBMModel.Parameters
        // default
        gbm_parms._distribution = DistributionFamily.gaussian;
        // default
        gbm_parms._learn_rate = 0.1f;
        // Train model; block for results
        Job<GBMModel> job = new GBM(gbm_parms).trainModel();
        GBMModel gbm = job.get();
        job.remove();
        // ---
        // Build a GLM model also
        GLMModel.GLMParameters glm_parms = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.gaussian);
        // base Model.Parameters
        glm_parms._train = train._key;
        glm_parms._valid = test._key;
        // default is false
        glm_parms._score_each_iteration = false;
        // SupervisedModel.Parameters
        glm_parms._response_column = "bikes";
        // GLMModel.Parameters
        glm_parms._use_all_factor_levels = true;
        // Train model; block for results
        Job<GLMModel> glm_job = new GLM(glm_parms).trainModel();
        GLMModel glm = glm_job.get();
        glm_job.remove();
        // -------------------------------------------------
        // 4- Score on holdout set & report
        gbm.score(train).remove();
        glm.score(train).remove();
        // Cleanup
        train.remove();
        test.remove();
        hold.remove();
    } finally {
        Scope.exit();
    }
}
Also used : ShuffleSplitFrame(hex.splitframe.ShuffleSplitFrame) GLMModel(hex.glm.GLMModel) GLM(hex.glm.GLM) QuantileModel(hex.quantile.QuantileModel) GBMModel(hex.tree.gbm.GBMModel) GBM(hex.tree.gbm.GBM) Quantile(hex.quantile.Quantile)

Example 3 with GLM

use of hex.glm.GLM in project h2o-3 by h2oai.

the class ModelSerializationTest method prepareGLMModel.

private GLMModel prepareGLMModel(String dataset, String[] ignoredColumns, String response, GLMModel.GLMParameters.Family family) {
    Frame f = parse_test_file(dataset);
    try {
        GLMModel.GLMParameters params = new GLMModel.GLMParameters();
        params._train = f._key;
        params._ignored_columns = ignoredColumns;
        params._response_column = response;
        params._family = family;
        return new GLM(params).trainModel().get();
    } finally {
        if (f != null)
            f.delete();
    }
}
Also used : Frame(water.fvec.Frame) GLMModel(hex.glm.GLMModel) GLM(hex.glm.GLM)

Example 4 with GLM

use of hex.glm.GLM in project h2o-3 by h2oai.

the class XValPredictionsCheck method testXValPredictions.

@Test
public void testXValPredictions() {
    final int nfolds = 3;
    Frame tfr = null;
    try {
        // Load data, hack frames
        tfr = parse_test_file("smalldata/iris/iris_wheader.csv");
        Frame foldId = new Frame(new String[] { "foldId" }, new Vec[] { AstKFold.kfoldColumn(tfr.vec("class").makeZero(), nfolds, 543216789) });
        tfr.add(foldId);
        DKV.put(tfr);
        // GBM
        GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
        parms._train = tfr._key;
        parms._response_column = "class";
        parms._ntrees = 1;
        parms._max_depth = 1;
        parms._fold_column = "foldId";
        parms._distribution = DistributionFamily.multinomial;
        parms._keep_cross_validation_predictions = true;
        GBM job = new GBM(parms);
        GBMModel gbm = job.trainModel().get();
        checkModel(gbm, foldId.anyVec(), 3);
        // DRF
        DRFModel.DRFParameters parmsDRF = new DRFModel.DRFParameters();
        parmsDRF._train = tfr._key;
        parmsDRF._response_column = "class";
        parmsDRF._ntrees = 1;
        parmsDRF._max_depth = 1;
        parmsDRF._fold_column = "foldId";
        parmsDRF._distribution = DistributionFamily.multinomial;
        parmsDRF._keep_cross_validation_predictions = true;
        DRF drfJob = new DRF(parmsDRF);
        DRFModel drf = drfJob.trainModel().get();
        checkModel(drf, foldId.anyVec(), 3);
        // GLM
        GLMModel.GLMParameters parmsGLM = new GLMModel.GLMParameters();
        parmsGLM._train = tfr._key;
        parmsGLM._response_column = "sepal_len";
        parmsGLM._fold_column = "foldId";
        parmsGLM._keep_cross_validation_predictions = true;
        GLM glmJob = new GLM(parmsGLM);
        GLMModel glm = glmJob.trainModel().get();
        checkModel(glm, foldId.anyVec(), 1);
        // DL
        DeepLearningModel.DeepLearningParameters parmsDL = new DeepLearningModel.DeepLearningParameters();
        parmsDL._train = tfr._key;
        parmsDL._response_column = "class";
        parmsDL._hidden = new int[] { 1 };
        parmsDL._epochs = 1;
        parmsDL._fold_column = "foldId";
        parmsDL._keep_cross_validation_predictions = true;
        DeepLearning dlJob = new DeepLearning(parmsDL);
        DeepLearningModel dl = dlJob.trainModel().get();
        checkModel(dl, foldId.anyVec(), 3);
    } finally {
        if (tfr != null)
            tfr.remove();
    }
}
Also used : Frame(water.fvec.Frame) DRFModel(hex.tree.drf.DRFModel) GLMModel(hex.glm.GLMModel) GLM(hex.glm.GLM) DeepLearning(hex.deeplearning.DeepLearning) GBMModel(hex.tree.gbm.GBMModel) GBM(hex.tree.gbm.GBM) DRF(hex.tree.drf.DRF) DeepLearningModel(hex.deeplearning.DeepLearningModel) Test(org.junit.Test)

Aggregations

GLM (hex.glm.GLM)4 GLMModel (hex.glm.GLMModel)4 GBM (hex.tree.gbm.GBM)3 GBMModel (hex.tree.gbm.GBMModel)3 DeepLearning (hex.deeplearning.DeepLearning)2 DeepLearningModel (hex.deeplearning.DeepLearningModel)2 DRF (hex.tree.drf.DRF)2 DRFModel (hex.tree.drf.DRFModel)2 Frame (water.fvec.Frame)2 Grid (hex.grid.Grid)1 GridSearch (hex.grid.GridSearch)1 Quantile (hex.quantile.Quantile)1 QuantileModel (hex.quantile.QuantileModel)1 DRFParametersV3 (hex.schemas.DRFV3.DRFParametersV3)1 DeepLearningParametersV3 (hex.schemas.DeepLearningV3.DeepLearningParametersV3)1 GBMParametersV3 (hex.schemas.GBMV3.GBMParametersV3)1 GLMParametersV3 (hex.schemas.GLMV3.GLMParametersV3)1 ShuffleSplitFrame (hex.splitframe.ShuffleSplitFrame)1 SharedTreeModel (hex.tree.SharedTreeModel)1 IOException (java.io.IOException)1