Search in sources :

Example 6 with Grid

use of hex.grid.Grid in project h2o-3 by h2oai.

the class DRFGridTest method testDuplicatesCarsGrid.

//@Ignore("PUBDEV-1643")
@Test
public void testDuplicatesCarsGrid() {
    Grid grid = null;
    Frame fr = null;
    Vec old = null;
    try {
        fr = parse_test_file("smalldata/junit/cars_20mpg.csv");
        // Remove unique id
        fr.remove("name").remove();
        old = fr.remove("economy");
        // response to last column
        fr.add("economy", old);
        DKV.put(fr);
        // Setup random hyperparameter search space
        HashMap<String, Object[]> hyperParms = new HashMap<String, Object[]>() {

            {
                put("_ntrees", new Integer[] { 5, 5 });
                put("_max_depth", new Integer[] { 2, 2 });
                put("_mtries", new Integer[] { -1, -1 });
                put("_sample_rate", new Double[] { .1, .1 });
            }
        };
        // Fire off a grid search
        DRFModel.DRFParameters params = new DRFModel.DRFParameters();
        params._train = fr._key;
        params._response_column = "economy";
        // Get the Grid for this modeling class and frame
        Job<Grid> gs = GridSearch.startGridSearch(null, params, hyperParms);
        grid = gs.get();
        // Check that duplicate model have not been constructed
        Model[] models = grid.getModels();
        assertTrue("Number of returned models has to be > 0", models.length > 0);
        // But all off them should be same
        Key<Model> modelKey = models[0]._key;
        for (Model m : models) {
            assertTrue("Number of constructed models has to be equal to 1", modelKey == m._key);
        }
    } finally {
        if (old != null) {
            old.remove();
        }
        if (fr != null) {
            fr.remove();
        }
        if (grid != null) {
            grid.remove();
        }
    }
}
Also used : Frame(water.fvec.Frame) HashMap(java.util.HashMap) Grid(hex.grid.Grid) Vec(water.fvec.Vec) Model(hex.Model) Test(org.junit.Test)

Example 7 with Grid

use of hex.grid.Grid in project h2o-3 by h2oai.

the class TestCase method execute.

public TestCaseResult execute() throws Exception, AssertionError {
    loadTestCaseDataSets();
    makeModelParameters();
    double startTime = 0, stopTime = 0;
    if (!grid) {
        Model.Output modelOutput = null;
        DRF drfJob;
        DRFModel drfModel = null;
        GLM glmJob;
        GLMModel glmModel = null;
        GBM gbmJob;
        GBMModel gbmModel = null;
        DeepLearning dlJob;
        DeepLearningModel dlModel = null;
        String bestModelJson = null;
        try {
            switch(algo) {
                case "drf":
                    drfJob = new DRF((DRFModel.DRFParameters) params);
                    AccuracyTestingSuite.summaryLog.println("Training DRF model.");
                    startTime = System.currentTimeMillis();
                    drfModel = drfJob.trainModel().get();
                    stopTime = System.currentTimeMillis();
                    modelOutput = drfModel._output;
                    bestModelJson = drfModel._parms.toJsonString();
                    break;
                case "glm":
                    glmJob = new GLM((GLMModel.GLMParameters) params, Key.<GLMModel>make("GLMModel"));
                    AccuracyTestingSuite.summaryLog.println("Training GLM model.");
                    startTime = System.currentTimeMillis();
                    glmModel = glmJob.trainModel().get();
                    stopTime = System.currentTimeMillis();
                    modelOutput = glmModel._output;
                    bestModelJson = glmModel._parms.toJsonString();
                    break;
                case "gbm":
                    gbmJob = new GBM((GBMModel.GBMParameters) params);
                    AccuracyTestingSuite.summaryLog.println("Training GBM model.");
                    startTime = System.currentTimeMillis();
                    gbmModel = gbmJob.trainModel().get();
                    stopTime = System.currentTimeMillis();
                    modelOutput = gbmModel._output;
                    bestModelJson = gbmModel._parms.toJsonString();
                    break;
                case "dl":
                    dlJob = new DeepLearning((DeepLearningModel.DeepLearningParameters) params);
                    AccuracyTestingSuite.summaryLog.println("Training DL model.");
                    startTime = System.currentTimeMillis();
                    dlModel = dlJob.trainModel().get();
                    stopTime = System.currentTimeMillis();
                    modelOutput = dlModel._output;
                    bestModelJson = dlModel._parms.toJsonString();
                    break;
            }
        } catch (Exception e) {
            throw new Exception(e);
        } finally {
            if (drfModel != null) {
                drfModel.delete();
            }
            if (glmModel != null) {
                glmModel.delete();
            }
            if (gbmModel != null) {
                gbmModel.delete();
            }
            if (dlModel != null) {
                dlModel.delete();
            }
        }
        removeTestCaseDataSetFrames();
        //Add check if cv is used
        if (params._nfolds > 0) {
            return new TestCaseResult(testCaseId, getMetrics(modelOutput._training_metrics), getMetrics(modelOutput._cross_validation_metrics), stopTime - startTime, bestModelJson, this, trainingDataSet, testingDataSet);
        } else {
            return new TestCaseResult(testCaseId, getMetrics(modelOutput._training_metrics), getMetrics(modelOutput._validation_metrics), stopTime - startTime, bestModelJson, this, trainingDataSet, testingDataSet);
        }
    } else {
        assert !modelSelectionCriteria.equals("");
        makeGridParameters();
        makeSearchCriteria();
        Grid grid = null;
        Model bestModel = null;
        String bestModelJson = null;
        try {
            SchemaServer.registerAllSchemasIfNecessary();
            switch(// TODO: Hack for PUBDEV-2812
            algo) {
                case "drf":
                    if (!drfRegistered) {
                        new DRF(true);
                        new DRFParametersV3();
                        drfRegistered = true;
                    }
                    break;
                case "glm":
                    if (!glmRegistered) {
                        new GLM(true);
                        new GLMParametersV3();
                        glmRegistered = true;
                    }
                    break;
                case "gbm":
                    if (!gbmRegistered) {
                        new GBM(true);
                        new GBMParametersV3();
                        gbmRegistered = true;
                    }
                    break;
                case "dl":
                    if (!dlRegistered) {
                        new DeepLearning(true);
                        new DeepLearningParametersV3();
                        dlRegistered = true;
                    }
                    break;
            }
            startTime = System.currentTimeMillis();
            // TODO: ModelParametersBuilderFactory parameter must be instantiated properly
            Job<Grid> gs = GridSearch.startGridSearch(null, params, hyperParms, new GridSearch.SimpleParametersBuilderFactory<>(), searchCriteria);
            grid = gs.get();
            stopTime = System.currentTimeMillis();
            boolean higherIsBetter = higherIsBetter(modelSelectionCriteria);
            double bestScore = higherIsBetter ? -Double.MAX_VALUE : Double.MAX_VALUE;
            for (Model m : grid.getModels()) {
                double validationMetricScore = getMetrics(m._output._validation_metrics).get(modelSelectionCriteria);
                AccuracyTestingSuite.summaryLog.println(modelSelectionCriteria + " for model " + m._key.toString() + " is " + validationMetricScore);
                if (higherIsBetter ? validationMetricScore > bestScore : validationMetricScore < bestScore) {
                    bestScore = validationMetricScore;
                    bestModel = m;
                    bestModelJson = bestModel._parms.toJsonString();
                }
            }
            AccuracyTestingSuite.summaryLog.println("Best model: " + bestModel._key.toString());
            AccuracyTestingSuite.summaryLog.println("Best model parameters: " + bestModelJson);
        } catch (Exception e) {
            throw new Exception(e);
        } finally {
            if (grid != null) {
                grid.delete();
            }
        }
        removeTestCaseDataSetFrames();
        //Add check if cv is used
        if (params._nfolds > 0) {
            return new TestCaseResult(testCaseId, getMetrics(bestModel._output._training_metrics), getMetrics(bestModel._output._cross_validation_metrics), stopTime - startTime, bestModelJson, this, trainingDataSet, testingDataSet);
        } else {
            return new TestCaseResult(testCaseId, getMetrics(bestModel._output._training_metrics), getMetrics(bestModel._output._validation_metrics), stopTime - startTime, bestModelJson, this, trainingDataSet, testingDataSet);
        }
    }
}
Also used : Grid(hex.grid.Grid) GLM(hex.glm.GLM) DeepLearning(hex.deeplearning.DeepLearning) GBMParametersV3(hex.schemas.GBMV3.GBMParametersV3) GBM(hex.tree.gbm.GBM) GBMModel(hex.tree.gbm.GBMModel) DRFModel(hex.tree.drf.DRFModel) GLMModel(hex.glm.GLMModel) IOException(java.io.IOException) GridSearch(hex.grid.GridSearch) DeepLearningParametersV3(hex.schemas.DeepLearningV3.DeepLearningParametersV3) GLMModel(hex.glm.GLMModel) DeepLearningModel(hex.deeplearning.DeepLearningModel) SharedTreeModel(hex.tree.SharedTreeModel) GBMModel(hex.tree.gbm.GBMModel) DRFModel(hex.tree.drf.DRFModel) DRF(hex.tree.drf.DRF) GLMParametersV3(hex.schemas.GLMV3.GLMParametersV3) DeepLearningModel(hex.deeplearning.DeepLearningModel) DRFParametersV3(hex.schemas.DRFV3.DRFParametersV3)

Example 8 with Grid

use of hex.grid.Grid in project h2o-3 by h2oai.

the class GridSearchHandler method handle.

// Invoke the handler with parameters.  Can throw any exception the called handler can throw.
// TODO: why does this do its own params filling?
// TODO: why does this do its own sub-dispatch?
@Override
S handle(int version, water.api.Route route, Properties parms, String postBody) throws Exception {
    // Only here for train or validate-parms
    if (!route._handler_method.getName().equals("train"))
        throw water.H2O.unimpl();
    // Peek out the desired algo from the URL
    String[] ss = route._url.split("/");
    // {}/{99}/{Grid}/{gbm}/
    String algoURLName = ss[3];
    // gbm -> GBM; deeplearning -> DeepLearning
    String algoName = ModelBuilder.algoName(algoURLName);
    String schemaDir = ModelBuilder.schemaDirectory(algoURLName);
    // Get the latest version of this algo: /99/Grid/gbm  ==> GBMV3
    // String algoSchemaName = SchemaServer.schemaClass(version, algoName).getSimpleName(); // GBMV3
    // int algoVersion = Integer.valueOf(algoSchemaName.substring(algoSchemaName.lastIndexOf("V")+1)); // '3'
    // Ok, i'm replacing one hack with another hack here, because SchemaServer.schema*() calls are getting eliminated.
    // There probably shouldn't be any reference to algoVersion here at all... TODO: unhack all of this
    int algoVersion = 3;
    if (algoName.equals("SVD") || algoName.equals("Aggregator") || algoName.equals("StackedEnsemble"))
        algoVersion = 99;
    // TODO: this is a horrible hack which is going to cause maintenance problems:
    String paramSchemaName = schemaDir + algoName + "V" + algoVersion + "$" + ModelBuilder.paramName(algoURLName) + "V" + algoVersion;
    // Build the Grid Search schema, and fill it from the parameters
    S gss = (S) new GridSearchSchema();
    gss.init_meta();
    gss.parameters = (P) TypeMap.newFreezable(paramSchemaName);
    gss.parameters.init_meta();
    gss.hyper_parameters = new IcedHashMap<>();
    // Get default parameters, then overlay the passed-in values
    // Default parameter settings
    ModelBuilder builder = ModelBuilder.make(algoURLName, null, null);
    // Defaults for this builder into schema
    gss.parameters.fillFromImpl(builder._parms);
    // Override defaults from user parms
    gss.fillFromParms(parms);
    // Verify list of hyper parameters
    // Right now only names, no types
    // note: still use _validation_frame and and _training_frame at this point.
    // Do not change those names yet.
    validateHyperParams((P) gss.parameters, gss.hyper_parameters);
    // Get actual parameters
    MP params = (MP) gss.parameters.createAndFillImpl();
    Map<String, Object[]> sortedMap = new TreeMap<>(gss.hyper_parameters);
    // training_fame are no longer valid names.
    if (sortedMap.containsKey("validation_frame")) {
        sortedMap.put("valid", sortedMap.get("validation_frame"));
        sortedMap.remove("validation_frame");
    }
    // Get/create a grid for given frame
    // FIXME: Grid ID is not pass to grid search builder!
    Key<Grid> destKey = gss.grid_id != null ? gss.grid_id.key() : null;
    // Create target grid search object (keep it private for now)
    // Start grid search and return the schema back with job key
    Job<Grid> gsJob = GridSearch.startGridSearch(destKey, params, sortedMap, new DefaultModelParametersBuilderFactory<MP, P>(), (HyperSpaceSearchCriteria) gss.search_criteria.createAndFillImpl());
    // Fill schema with job parameters
    // FIXME: right now we have to remove grid parameters which we sent back
    gss.hyper_parameters = null;
    // TODO: looks like it's currently always 0
    gss.total_models = gsJob._result.get().getModelCount();
    gss.job = new JobV3(gsJob);
    return gss;
}
Also used : Grid(hex.grid.Grid) ModelBuilder(hex.ModelBuilder) JobV3(water.api.schemas3.JobV3)

Example 9 with Grid

use of hex.grid.Grid in project h2o-3 by h2oai.

the class GridSearchSchema method fillFromParms.

@Override
public S fillFromParms(Properties parms) {
    if (parms.containsKey("hyper_parameters")) {
        Map<String, Object> m = water.util.JSONUtils.parse(parms.getProperty("hyper_parameters"));
        // Convert lists and singletons into arrays
        for (Map.Entry<String, Object> e : m.entrySet()) {
            Object o = e.getValue();
            Object[] o2 = o instanceof List ? ((List) o).toArray() : new Object[] { o };
            hyper_parameters.put(e.getKey(), o2);
        }
        parms.remove("hyper_parameters");
    }
    if (parms.containsKey("search_criteria")) {
        Properties p = water.util.JSONUtils.parseToProperties(parms.getProperty("search_criteria"));
        if (!p.containsKey("strategy")) {
            throw new H2OIllegalArgumentException("search_criteria.strategy", "null");
        }
        // TODO: move this into a factory method in HyperSpaceSearchCriteriaV99
        String strategy = (String) p.get("strategy");
        if ("Cartesian".equals(strategy)) {
            search_criteria = new HyperSpaceSearchCriteriaV99.CartesianSearchCriteriaV99();
        } else if ("RandomDiscrete".equals(strategy)) {
            search_criteria = new HyperSpaceSearchCriteriaV99.RandomDiscreteValueSearchCriteriaV99();
            if (p.containsKey("max_runtime_secs") && Double.parseDouble((String) p.get("max_runtime_secs")) < 0) {
                throw new H2OIllegalArgumentException("max_runtime_secs must be >= 0 (0 for unlimited time)", strategy);
            }
            if (p.containsKey("max_models") && Integer.parseInt((String) p.get("max_models")) < 0) {
                throw new H2OIllegalArgumentException("max_models must be >= 0 (0 for all models)", strategy);
            }
        } else {
            throw new H2OIllegalArgumentException("search_criteria.strategy", strategy);
        }
        search_criteria.fillWithDefaults();
        search_criteria.fillFromParms(p);
        parms.remove("search_criteria");
    } else {
        // Fall back to Cartesian if there's no search_criteria specified.
        search_criteria = new HyperSpaceSearchCriteriaV99.CartesianSearchCriteriaV99();
    }
    if (parms.containsKey("grid_id")) {
        grid_id = new KeyV3.GridKeyV3(Key.<Grid>make(parms.getProperty("grid_id")));
        parms.remove("grid_id");
    }
    // Do not check validity of parameters, GridSearch is tolerant of bad
    // parameters (on purpose, many hyper-param points in the grid might be
    // illegal for whatever reason).
    this.parameters.fillFromParms(parms, false);
    return (S) this;
}
Also used : Grid(hex.grid.Grid) H2OIllegalArgumentException(water.exceptions.H2OIllegalArgumentException) Properties(java.util.Properties) KeyV3(water.api.schemas3.KeyV3) List(java.util.List) IcedHashMap(water.util.IcedHashMap) Map(java.util.Map)

Example 10 with Grid

use of hex.grid.Grid in project h2o-3 by h2oai.

the class GBMGridTest method testCarsGrid.

@Test
public void testCarsGrid() {
    Grid<GBMModel.GBMParameters> grid = null;
    Frame fr = null;
    Vec old = null;
    try {
        fr = parse_test_file("smalldata/junit/cars.csv");
        // Remove unique id
        fr.remove("name").remove();
        old = fr.remove("cylinders");
        // response to last column
        fr.add("cylinders", old.toCategoricalVec());
        DKV.put(fr);
        // Setup hyperparameter search space
        final Double[] legalLearnRateOpts = new Double[] { 0.01, 0.1, 0.3 };
        final Double[] illegalLearnRateOpts = new Double[] { -1.0 };
        HashMap<String, Object[]> hyperParms = new HashMap<String, Object[]>() {

            {
                put("_ntrees", new Integer[] { 1, 2 });
                put("_distribution", new DistributionFamily[] { DistributionFamily.multinomial });
                put("_max_depth", new Integer[] { 1, 2, 5 });
                put("_learn_rate", ArrayUtils.join(legalLearnRateOpts, illegalLearnRateOpts));
            }
        };
        // Name of used hyper parameters
        String[] hyperParamNames = hyperParms.keySet().toArray(new String[hyperParms.size()]);
        Arrays.sort(hyperParamNames);
        int hyperSpaceSize = ArrayUtils.crossProductSize(hyperParms);
        // Fire off a grid search
        GBMModel.GBMParameters params = new GBMModel.GBMParameters();
        params._train = fr._key;
        params._response_column = "cylinders";
        // Get the Grid for this modeling class and frame
        Job<Grid> gs = GridSearch.startGridSearch(null, params, hyperParms);
        grid = (Grid<GBMModel.GBMParameters>) gs.get();
        // Make sure number of produced models match size of specified hyper space
        Assert.assertEquals("Size of grid (models+failures) should match to size of hyper space", hyperSpaceSize, grid.getModelCount() + grid.getFailureCount());
        //
        // Make sure that names of used parameters match
        //
        String[] gridHyperNames = grid.getHyperNames();
        Arrays.sort(gridHyperNames);
        Assert.assertArrayEquals("Hyper parameters names should match!", hyperParamNames, gridHyperNames);
        //
        // Make sure that values of used parameters match as well to the specified values
        //
        Key<Model>[] mKeys = grid.getModelKeys();
        Map<String, Set<Object>> usedHyperParams = GridTestUtils.initMap(hyperParamNames);
        for (Key<Model> mKey : mKeys) {
            GBMModel gbm = (GBMModel) mKey.get();
            System.out.println(gbm._output._scored_train[gbm._output._ntrees]._mse + " " + Arrays.deepToString(ArrayUtils.zip(grid.getHyperNames(), grid.getHyperValues(gbm._parms))));
            GridTestUtils.extractParams(usedHyperParams, gbm._parms, hyperParamNames);
        }
        // Remove illegal options
        hyperParms.put("_learn_rate", legalLearnRateOpts);
        GridTestUtils.assertParamsEqual("Grid models parameters have to cover specified hyper space", hyperParms, usedHyperParams);
        // Verify model failure
        Map<String, Set<Object>> failedHyperParams = GridTestUtils.initMap(hyperParamNames);
        ;
        for (Model.Parameters failedParams : grid.getFailedParameters()) {
            GridTestUtils.extractParams(failedHyperParams, failedParams, hyperParamNames);
        }
        hyperParms.put("_learn_rate", illegalLearnRateOpts);
        GridTestUtils.assertParamsEqual("Failed model parameters have to correspond to specified hyper space", hyperParms, failedHyperParams);
    } finally {
        if (old != null) {
            old.remove();
        }
        if (fr != null) {
            fr.remove();
        }
        if (grid != null) {
            grid.remove();
        }
    }
}
Also used : Frame(water.fvec.Frame) Set(java.util.Set) HashMap(java.util.HashMap) Grid(hex.grid.Grid) Vec(water.fvec.Vec) Model(hex.Model) Key(water.Key) Test(org.junit.Test)

Aggregations

Grid (hex.grid.Grid)15 HashMap (java.util.HashMap)12 Frame (water.fvec.Frame)12 Model (hex.Model)11 Test (org.junit.Test)11 Vec (water.fvec.Vec)6 ArrayList (java.util.ArrayList)3 Random (java.util.Random)3 Set (java.util.Set)3 DataInfo (hex.DataInfo)2 Key (water.Key)2 ModelBuilder (hex.ModelBuilder)1 DeepLearning (hex.deeplearning.DeepLearning)1 DeepLearningModel (hex.deeplearning.DeepLearningModel)1 GLM (hex.glm.GLM)1 GLMModel (hex.glm.GLMModel)1 GridSearch (hex.grid.GridSearch)1 DRFParametersV3 (hex.schemas.DRFV3.DRFParametersV3)1 DeepLearningParametersV3 (hex.schemas.DeepLearningV3.DeepLearningParametersV3)1 GBMParametersV3 (hex.schemas.GBMV3.GBMParametersV3)1