Search in sources :

Example 1 with DRF

use of hex.tree.drf.DRF in project h2o-3 by h2oai.

the class ModelSerializationTest method prepareDRFModel.

private DRFModel prepareDRFModel(String dataset, String[] ignoredColumns, String response, boolean classification, int ntrees) {
    Frame f = parse_test_file(dataset);
    try {
        if (classification && !f.vec(response).isCategorical()) {
            f.replace(f.find(response), f.vec(response).toCategoricalVec()).remove();
            DKV.put(f._key, f);
        }
        DRFModel.DRFParameters drfParams = new DRFModel.DRFParameters();
        drfParams._train = f._key;
        drfParams._ignored_columns = ignoredColumns;
        drfParams._response_column = response;
        drfParams._ntrees = ntrees;
        drfParams._score_each_iteration = true;
        return new DRF(drfParams).trainModel().get();
    } finally {
        if (f != null)
            f.delete();
    }
}
Also used : Frame(water.fvec.Frame) DRFModel(hex.tree.drf.DRFModel) DRF(hex.tree.drf.DRF)

Example 2 with DRF

use of hex.tree.drf.DRF in project h2o-3 by h2oai.

the class TestCase method execute.

public TestCaseResult execute() throws Exception, AssertionError {
    loadTestCaseDataSets();
    makeModelParameters();
    double startTime = 0, stopTime = 0;
    if (!grid) {
        Model.Output modelOutput = null;
        DRF drfJob;
        DRFModel drfModel = null;
        GLM glmJob;
        GLMModel glmModel = null;
        GBM gbmJob;
        GBMModel gbmModel = null;
        DeepLearning dlJob;
        DeepLearningModel dlModel = null;
        String bestModelJson = null;
        try {
            switch(algo) {
                case "drf":
                    drfJob = new DRF((DRFModel.DRFParameters) params);
                    AccuracyTestingSuite.summaryLog.println("Training DRF model.");
                    startTime = System.currentTimeMillis();
                    drfModel = drfJob.trainModel().get();
                    stopTime = System.currentTimeMillis();
                    modelOutput = drfModel._output;
                    bestModelJson = drfModel._parms.toJsonString();
                    break;
                case "glm":
                    glmJob = new GLM((GLMModel.GLMParameters) params, Key.<GLMModel>make("GLMModel"));
                    AccuracyTestingSuite.summaryLog.println("Training GLM model.");
                    startTime = System.currentTimeMillis();
                    glmModel = glmJob.trainModel().get();
                    stopTime = System.currentTimeMillis();
                    modelOutput = glmModel._output;
                    bestModelJson = glmModel._parms.toJsonString();
                    break;
                case "gbm":
                    gbmJob = new GBM((GBMModel.GBMParameters) params);
                    AccuracyTestingSuite.summaryLog.println("Training GBM model.");
                    startTime = System.currentTimeMillis();
                    gbmModel = gbmJob.trainModel().get();
                    stopTime = System.currentTimeMillis();
                    modelOutput = gbmModel._output;
                    bestModelJson = gbmModel._parms.toJsonString();
                    break;
                case "dl":
                    dlJob = new DeepLearning((DeepLearningModel.DeepLearningParameters) params);
                    AccuracyTestingSuite.summaryLog.println("Training DL model.");
                    startTime = System.currentTimeMillis();
                    dlModel = dlJob.trainModel().get();
                    stopTime = System.currentTimeMillis();
                    modelOutput = dlModel._output;
                    bestModelJson = dlModel._parms.toJsonString();
                    break;
            }
        } catch (Exception e) {
            throw new Exception(e);
        } finally {
            if (drfModel != null) {
                drfModel.delete();
            }
            if (glmModel != null) {
                glmModel.delete();
            }
            if (gbmModel != null) {
                gbmModel.delete();
            }
            if (dlModel != null) {
                dlModel.delete();
            }
        }
        removeTestCaseDataSetFrames();
        //Add check if cv is used
        if (params._nfolds > 0) {
            return new TestCaseResult(testCaseId, getMetrics(modelOutput._training_metrics), getMetrics(modelOutput._cross_validation_metrics), stopTime - startTime, bestModelJson, this, trainingDataSet, testingDataSet);
        } else {
            return new TestCaseResult(testCaseId, getMetrics(modelOutput._training_metrics), getMetrics(modelOutput._validation_metrics), stopTime - startTime, bestModelJson, this, trainingDataSet, testingDataSet);
        }
    } else {
        assert !modelSelectionCriteria.equals("");
        makeGridParameters();
        makeSearchCriteria();
        Grid grid = null;
        Model bestModel = null;
        String bestModelJson = null;
        try {
            SchemaServer.registerAllSchemasIfNecessary();
            switch(// TODO: Hack for PUBDEV-2812
            algo) {
                case "drf":
                    if (!drfRegistered) {
                        new DRF(true);
                        new DRFParametersV3();
                        drfRegistered = true;
                    }
                    break;
                case "glm":
                    if (!glmRegistered) {
                        new GLM(true);
                        new GLMParametersV3();
                        glmRegistered = true;
                    }
                    break;
                case "gbm":
                    if (!gbmRegistered) {
                        new GBM(true);
                        new GBMParametersV3();
                        gbmRegistered = true;
                    }
                    break;
                case "dl":
                    if (!dlRegistered) {
                        new DeepLearning(true);
                        new DeepLearningParametersV3();
                        dlRegistered = true;
                    }
                    break;
            }
            startTime = System.currentTimeMillis();
            // TODO: ModelParametersBuilderFactory parameter must be instantiated properly
            Job<Grid> gs = GridSearch.startGridSearch(null, params, hyperParms, new GridSearch.SimpleParametersBuilderFactory<>(), searchCriteria);
            grid = gs.get();
            stopTime = System.currentTimeMillis();
            boolean higherIsBetter = higherIsBetter(modelSelectionCriteria);
            double bestScore = higherIsBetter ? -Double.MAX_VALUE : Double.MAX_VALUE;
            for (Model m : grid.getModels()) {
                double validationMetricScore = getMetrics(m._output._validation_metrics).get(modelSelectionCriteria);
                AccuracyTestingSuite.summaryLog.println(modelSelectionCriteria + " for model " + m._key.toString() + " is " + validationMetricScore);
                if (higherIsBetter ? validationMetricScore > bestScore : validationMetricScore < bestScore) {
                    bestScore = validationMetricScore;
                    bestModel = m;
                    bestModelJson = bestModel._parms.toJsonString();
                }
            }
            AccuracyTestingSuite.summaryLog.println("Best model: " + bestModel._key.toString());
            AccuracyTestingSuite.summaryLog.println("Best model parameters: " + bestModelJson);
        } catch (Exception e) {
            throw new Exception(e);
        } finally {
            if (grid != null) {
                grid.delete();
            }
        }
        removeTestCaseDataSetFrames();
        //Add check if cv is used
        if (params._nfolds > 0) {
            return new TestCaseResult(testCaseId, getMetrics(bestModel._output._training_metrics), getMetrics(bestModel._output._cross_validation_metrics), stopTime - startTime, bestModelJson, this, trainingDataSet, testingDataSet);
        } else {
            return new TestCaseResult(testCaseId, getMetrics(bestModel._output._training_metrics), getMetrics(bestModel._output._validation_metrics), stopTime - startTime, bestModelJson, this, trainingDataSet, testingDataSet);
        }
    }
}
Also used : Grid(hex.grid.Grid) GLM(hex.glm.GLM) DeepLearning(hex.deeplearning.DeepLearning) GBMParametersV3(hex.schemas.GBMV3.GBMParametersV3) GBM(hex.tree.gbm.GBM) GBMModel(hex.tree.gbm.GBMModel) DRFModel(hex.tree.drf.DRFModel) GLMModel(hex.glm.GLMModel) IOException(java.io.IOException) GridSearch(hex.grid.GridSearch) DeepLearningParametersV3(hex.schemas.DeepLearningV3.DeepLearningParametersV3) GLMModel(hex.glm.GLMModel) DeepLearningModel(hex.deeplearning.DeepLearningModel) SharedTreeModel(hex.tree.SharedTreeModel) GBMModel(hex.tree.gbm.GBMModel) DRFModel(hex.tree.drf.DRFModel) DRF(hex.tree.drf.DRF) GLMParametersV3(hex.schemas.GLMV3.GLMParametersV3) DeepLearningModel(hex.deeplearning.DeepLearningModel) DRFParametersV3(hex.schemas.DRFV3.DRFParametersV3)

Example 3 with DRF

use of hex.tree.drf.DRF in project h2o-3 by h2oai.

the class XValPredictionsCheck method testXValPredictions.

@Test
public void testXValPredictions() {
    final int nfolds = 3;
    Frame tfr = null;
    try {
        // Load data, hack frames
        tfr = parse_test_file("smalldata/iris/iris_wheader.csv");
        Frame foldId = new Frame(new String[] { "foldId" }, new Vec[] { AstKFold.kfoldColumn(tfr.vec("class").makeZero(), nfolds, 543216789) });
        tfr.add(foldId);
        DKV.put(tfr);
        // GBM
        GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
        parms._train = tfr._key;
        parms._response_column = "class";
        parms._ntrees = 1;
        parms._max_depth = 1;
        parms._fold_column = "foldId";
        parms._distribution = DistributionFamily.multinomial;
        parms._keep_cross_validation_predictions = true;
        GBM job = new GBM(parms);
        GBMModel gbm = job.trainModel().get();
        checkModel(gbm, foldId.anyVec(), 3);
        // DRF
        DRFModel.DRFParameters parmsDRF = new DRFModel.DRFParameters();
        parmsDRF._train = tfr._key;
        parmsDRF._response_column = "class";
        parmsDRF._ntrees = 1;
        parmsDRF._max_depth = 1;
        parmsDRF._fold_column = "foldId";
        parmsDRF._distribution = DistributionFamily.multinomial;
        parmsDRF._keep_cross_validation_predictions = true;
        DRF drfJob = new DRF(parmsDRF);
        DRFModel drf = drfJob.trainModel().get();
        checkModel(drf, foldId.anyVec(), 3);
        // GLM
        GLMModel.GLMParameters parmsGLM = new GLMModel.GLMParameters();
        parmsGLM._train = tfr._key;
        parmsGLM._response_column = "sepal_len";
        parmsGLM._fold_column = "foldId";
        parmsGLM._keep_cross_validation_predictions = true;
        GLM glmJob = new GLM(parmsGLM);
        GLMModel glm = glmJob.trainModel().get();
        checkModel(glm, foldId.anyVec(), 1);
        // DL
        DeepLearningModel.DeepLearningParameters parmsDL = new DeepLearningModel.DeepLearningParameters();
        parmsDL._train = tfr._key;
        parmsDL._response_column = "class";
        parmsDL._hidden = new int[] { 1 };
        parmsDL._epochs = 1;
        parmsDL._fold_column = "foldId";
        parmsDL._keep_cross_validation_predictions = true;
        DeepLearning dlJob = new DeepLearning(parmsDL);
        DeepLearningModel dl = dlJob.trainModel().get();
        checkModel(dl, foldId.anyVec(), 3);
    } finally {
        if (tfr != null)
            tfr.remove();
    }
}
Also used : Frame(water.fvec.Frame) DRFModel(hex.tree.drf.DRFModel) GLMModel(hex.glm.GLMModel) GLM(hex.glm.GLM) DeepLearning(hex.deeplearning.DeepLearning) GBMModel(hex.tree.gbm.GBMModel) GBM(hex.tree.gbm.GBM) DRF(hex.tree.drf.DRF) DeepLearningModel(hex.deeplearning.DeepLearningModel) Test(org.junit.Test)

Aggregations

DRF (hex.tree.drf.DRF)3 DRFModel (hex.tree.drf.DRFModel)3 DeepLearning (hex.deeplearning.DeepLearning)2 DeepLearningModel (hex.deeplearning.DeepLearningModel)2 GLM (hex.glm.GLM)2 GLMModel (hex.glm.GLMModel)2 GBM (hex.tree.gbm.GBM)2 GBMModel (hex.tree.gbm.GBMModel)2 Frame (water.fvec.Frame)2 Grid (hex.grid.Grid)1 GridSearch (hex.grid.GridSearch)1 DRFParametersV3 (hex.schemas.DRFV3.DRFParametersV3)1 DeepLearningParametersV3 (hex.schemas.DeepLearningV3.DeepLearningParametersV3)1 GBMParametersV3 (hex.schemas.GBMV3.GBMParametersV3)1 GLMParametersV3 (hex.schemas.GLMV3.GLMParametersV3)1 SharedTreeModel (hex.tree.SharedTreeModel)1 IOException (java.io.IOException)1 Test (org.junit.Test)1