use of hex.grid.Grid in project h2o-3 by h2oai.
the class DRFGridTest method testDuplicatesCarsGrid.
//@Ignore("PUBDEV-1643")
@Test
public void testDuplicatesCarsGrid() {
Grid grid = null;
Frame fr = null;
Vec old = null;
try {
fr = parse_test_file("smalldata/junit/cars_20mpg.csv");
// Remove unique id
fr.remove("name").remove();
old = fr.remove("economy");
// response to last column
fr.add("economy", old);
DKV.put(fr);
// Setup random hyperparameter search space
HashMap<String, Object[]> hyperParms = new HashMap<String, Object[]>() {
{
put("_ntrees", new Integer[] { 5, 5 });
put("_max_depth", new Integer[] { 2, 2 });
put("_mtries", new Integer[] { -1, -1 });
put("_sample_rate", new Double[] { .1, .1 });
}
};
// Fire off a grid search
DRFModel.DRFParameters params = new DRFModel.DRFParameters();
params._train = fr._key;
params._response_column = "economy";
// Get the Grid for this modeling class and frame
Job<Grid> gs = GridSearch.startGridSearch(null, params, hyperParms);
grid = gs.get();
// Check that duplicate model have not been constructed
Model[] models = grid.getModels();
assertTrue("Number of returned models has to be > 0", models.length > 0);
// But all off them should be same
Key<Model> modelKey = models[0]._key;
for (Model m : models) {
assertTrue("Number of constructed models has to be equal to 1", modelKey == m._key);
}
} finally {
if (old != null) {
old.remove();
}
if (fr != null) {
fr.remove();
}
if (grid != null) {
grid.remove();
}
}
}
use of hex.grid.Grid in project h2o-3 by h2oai.
the class TestCase method execute.
public TestCaseResult execute() throws Exception, AssertionError {
loadTestCaseDataSets();
makeModelParameters();
double startTime = 0, stopTime = 0;
if (!grid) {
Model.Output modelOutput = null;
DRF drfJob;
DRFModel drfModel = null;
GLM glmJob;
GLMModel glmModel = null;
GBM gbmJob;
GBMModel gbmModel = null;
DeepLearning dlJob;
DeepLearningModel dlModel = null;
String bestModelJson = null;
try {
switch(algo) {
case "drf":
drfJob = new DRF((DRFModel.DRFParameters) params);
AccuracyTestingSuite.summaryLog.println("Training DRF model.");
startTime = System.currentTimeMillis();
drfModel = drfJob.trainModel().get();
stopTime = System.currentTimeMillis();
modelOutput = drfModel._output;
bestModelJson = drfModel._parms.toJsonString();
break;
case "glm":
glmJob = new GLM((GLMModel.GLMParameters) params, Key.<GLMModel>make("GLMModel"));
AccuracyTestingSuite.summaryLog.println("Training GLM model.");
startTime = System.currentTimeMillis();
glmModel = glmJob.trainModel().get();
stopTime = System.currentTimeMillis();
modelOutput = glmModel._output;
bestModelJson = glmModel._parms.toJsonString();
break;
case "gbm":
gbmJob = new GBM((GBMModel.GBMParameters) params);
AccuracyTestingSuite.summaryLog.println("Training GBM model.");
startTime = System.currentTimeMillis();
gbmModel = gbmJob.trainModel().get();
stopTime = System.currentTimeMillis();
modelOutput = gbmModel._output;
bestModelJson = gbmModel._parms.toJsonString();
break;
case "dl":
dlJob = new DeepLearning((DeepLearningModel.DeepLearningParameters) params);
AccuracyTestingSuite.summaryLog.println("Training DL model.");
startTime = System.currentTimeMillis();
dlModel = dlJob.trainModel().get();
stopTime = System.currentTimeMillis();
modelOutput = dlModel._output;
bestModelJson = dlModel._parms.toJsonString();
break;
}
} catch (Exception e) {
throw new Exception(e);
} finally {
if (drfModel != null) {
drfModel.delete();
}
if (glmModel != null) {
glmModel.delete();
}
if (gbmModel != null) {
gbmModel.delete();
}
if (dlModel != null) {
dlModel.delete();
}
}
removeTestCaseDataSetFrames();
//Add check if cv is used
if (params._nfolds > 0) {
return new TestCaseResult(testCaseId, getMetrics(modelOutput._training_metrics), getMetrics(modelOutput._cross_validation_metrics), stopTime - startTime, bestModelJson, this, trainingDataSet, testingDataSet);
} else {
return new TestCaseResult(testCaseId, getMetrics(modelOutput._training_metrics), getMetrics(modelOutput._validation_metrics), stopTime - startTime, bestModelJson, this, trainingDataSet, testingDataSet);
}
} else {
assert !modelSelectionCriteria.equals("");
makeGridParameters();
makeSearchCriteria();
Grid grid = null;
Model bestModel = null;
String bestModelJson = null;
try {
SchemaServer.registerAllSchemasIfNecessary();
switch(// TODO: Hack for PUBDEV-2812
algo) {
case "drf":
if (!drfRegistered) {
new DRF(true);
new DRFParametersV3();
drfRegistered = true;
}
break;
case "glm":
if (!glmRegistered) {
new GLM(true);
new GLMParametersV3();
glmRegistered = true;
}
break;
case "gbm":
if (!gbmRegistered) {
new GBM(true);
new GBMParametersV3();
gbmRegistered = true;
}
break;
case "dl":
if (!dlRegistered) {
new DeepLearning(true);
new DeepLearningParametersV3();
dlRegistered = true;
}
break;
}
startTime = System.currentTimeMillis();
// TODO: ModelParametersBuilderFactory parameter must be instantiated properly
Job<Grid> gs = GridSearch.startGridSearch(null, params, hyperParms, new GridSearch.SimpleParametersBuilderFactory<>(), searchCriteria);
grid = gs.get();
stopTime = System.currentTimeMillis();
boolean higherIsBetter = higherIsBetter(modelSelectionCriteria);
double bestScore = higherIsBetter ? -Double.MAX_VALUE : Double.MAX_VALUE;
for (Model m : grid.getModels()) {
double validationMetricScore = getMetrics(m._output._validation_metrics).get(modelSelectionCriteria);
AccuracyTestingSuite.summaryLog.println(modelSelectionCriteria + " for model " + m._key.toString() + " is " + validationMetricScore);
if (higherIsBetter ? validationMetricScore > bestScore : validationMetricScore < bestScore) {
bestScore = validationMetricScore;
bestModel = m;
bestModelJson = bestModel._parms.toJsonString();
}
}
AccuracyTestingSuite.summaryLog.println("Best model: " + bestModel._key.toString());
AccuracyTestingSuite.summaryLog.println("Best model parameters: " + bestModelJson);
} catch (Exception e) {
throw new Exception(e);
} finally {
if (grid != null) {
grid.delete();
}
}
removeTestCaseDataSetFrames();
//Add check if cv is used
if (params._nfolds > 0) {
return new TestCaseResult(testCaseId, getMetrics(bestModel._output._training_metrics), getMetrics(bestModel._output._cross_validation_metrics), stopTime - startTime, bestModelJson, this, trainingDataSet, testingDataSet);
} else {
return new TestCaseResult(testCaseId, getMetrics(bestModel._output._training_metrics), getMetrics(bestModel._output._validation_metrics), stopTime - startTime, bestModelJson, this, trainingDataSet, testingDataSet);
}
}
}
use of hex.grid.Grid in project h2o-3 by h2oai.
the class GridSearchHandler method handle.
// Invoke the handler with parameters. Can throw any exception the called handler can throw.
// TODO: why does this do its own params filling?
// TODO: why does this do its own sub-dispatch?
@Override
S handle(int version, water.api.Route route, Properties parms, String postBody) throws Exception {
// Only here for train or validate-parms
if (!route._handler_method.getName().equals("train"))
throw water.H2O.unimpl();
// Peek out the desired algo from the URL
String[] ss = route._url.split("/");
// {}/{99}/{Grid}/{gbm}/
String algoURLName = ss[3];
// gbm -> GBM; deeplearning -> DeepLearning
String algoName = ModelBuilder.algoName(algoURLName);
String schemaDir = ModelBuilder.schemaDirectory(algoURLName);
// Get the latest version of this algo: /99/Grid/gbm ==> GBMV3
// String algoSchemaName = SchemaServer.schemaClass(version, algoName).getSimpleName(); // GBMV3
// int algoVersion = Integer.valueOf(algoSchemaName.substring(algoSchemaName.lastIndexOf("V")+1)); // '3'
// Ok, i'm replacing one hack with another hack here, because SchemaServer.schema*() calls are getting eliminated.
// There probably shouldn't be any reference to algoVersion here at all... TODO: unhack all of this
int algoVersion = 3;
if (algoName.equals("SVD") || algoName.equals("Aggregator") || algoName.equals("StackedEnsemble"))
algoVersion = 99;
// TODO: this is a horrible hack which is going to cause maintenance problems:
String paramSchemaName = schemaDir + algoName + "V" + algoVersion + "$" + ModelBuilder.paramName(algoURLName) + "V" + algoVersion;
// Build the Grid Search schema, and fill it from the parameters
S gss = (S) new GridSearchSchema();
gss.init_meta();
gss.parameters = (P) TypeMap.newFreezable(paramSchemaName);
gss.parameters.init_meta();
gss.hyper_parameters = new IcedHashMap<>();
// Get default parameters, then overlay the passed-in values
// Default parameter settings
ModelBuilder builder = ModelBuilder.make(algoURLName, null, null);
// Defaults for this builder into schema
gss.parameters.fillFromImpl(builder._parms);
// Override defaults from user parms
gss.fillFromParms(parms);
// Verify list of hyper parameters
// Right now only names, no types
// note: still use _validation_frame and and _training_frame at this point.
// Do not change those names yet.
validateHyperParams((P) gss.parameters, gss.hyper_parameters);
// Get actual parameters
MP params = (MP) gss.parameters.createAndFillImpl();
Map<String, Object[]> sortedMap = new TreeMap<>(gss.hyper_parameters);
// training_fame are no longer valid names.
if (sortedMap.containsKey("validation_frame")) {
sortedMap.put("valid", sortedMap.get("validation_frame"));
sortedMap.remove("validation_frame");
}
// Get/create a grid for given frame
// FIXME: Grid ID is not pass to grid search builder!
Key<Grid> destKey = gss.grid_id != null ? gss.grid_id.key() : null;
// Create target grid search object (keep it private for now)
// Start grid search and return the schema back with job key
Job<Grid> gsJob = GridSearch.startGridSearch(destKey, params, sortedMap, new DefaultModelParametersBuilderFactory<MP, P>(), (HyperSpaceSearchCriteria) gss.search_criteria.createAndFillImpl());
// Fill schema with job parameters
// FIXME: right now we have to remove grid parameters which we sent back
gss.hyper_parameters = null;
// TODO: looks like it's currently always 0
gss.total_models = gsJob._result.get().getModelCount();
gss.job = new JobV3(gsJob);
return gss;
}
use of hex.grid.Grid in project h2o-3 by h2oai.
the class GridSearchSchema method fillFromParms.
@Override
public S fillFromParms(Properties parms) {
if (parms.containsKey("hyper_parameters")) {
Map<String, Object> m = water.util.JSONUtils.parse(parms.getProperty("hyper_parameters"));
// Convert lists and singletons into arrays
for (Map.Entry<String, Object> e : m.entrySet()) {
Object o = e.getValue();
Object[] o2 = o instanceof List ? ((List) o).toArray() : new Object[] { o };
hyper_parameters.put(e.getKey(), o2);
}
parms.remove("hyper_parameters");
}
if (parms.containsKey("search_criteria")) {
Properties p = water.util.JSONUtils.parseToProperties(parms.getProperty("search_criteria"));
if (!p.containsKey("strategy")) {
throw new H2OIllegalArgumentException("search_criteria.strategy", "null");
}
// TODO: move this into a factory method in HyperSpaceSearchCriteriaV99
String strategy = (String) p.get("strategy");
if ("Cartesian".equals(strategy)) {
search_criteria = new HyperSpaceSearchCriteriaV99.CartesianSearchCriteriaV99();
} else if ("RandomDiscrete".equals(strategy)) {
search_criteria = new HyperSpaceSearchCriteriaV99.RandomDiscreteValueSearchCriteriaV99();
if (p.containsKey("max_runtime_secs") && Double.parseDouble((String) p.get("max_runtime_secs")) < 0) {
throw new H2OIllegalArgumentException("max_runtime_secs must be >= 0 (0 for unlimited time)", strategy);
}
if (p.containsKey("max_models") && Integer.parseInt((String) p.get("max_models")) < 0) {
throw new H2OIllegalArgumentException("max_models must be >= 0 (0 for all models)", strategy);
}
} else {
throw new H2OIllegalArgumentException("search_criteria.strategy", strategy);
}
search_criteria.fillWithDefaults();
search_criteria.fillFromParms(p);
parms.remove("search_criteria");
} else {
// Fall back to Cartesian if there's no search_criteria specified.
search_criteria = new HyperSpaceSearchCriteriaV99.CartesianSearchCriteriaV99();
}
if (parms.containsKey("grid_id")) {
grid_id = new KeyV3.GridKeyV3(Key.<Grid>make(parms.getProperty("grid_id")));
parms.remove("grid_id");
}
// Do not check validity of parameters, GridSearch is tolerant of bad
// parameters (on purpose, many hyper-param points in the grid might be
// illegal for whatever reason).
this.parameters.fillFromParms(parms, false);
return (S) this;
}
use of hex.grid.Grid in project h2o-3 by h2oai.
the class GBMGridTest method testCarsGrid.
@Test
public void testCarsGrid() {
Grid<GBMModel.GBMParameters> grid = null;
Frame fr = null;
Vec old = null;
try {
fr = parse_test_file("smalldata/junit/cars.csv");
// Remove unique id
fr.remove("name").remove();
old = fr.remove("cylinders");
// response to last column
fr.add("cylinders", old.toCategoricalVec());
DKV.put(fr);
// Setup hyperparameter search space
final Double[] legalLearnRateOpts = new Double[] { 0.01, 0.1, 0.3 };
final Double[] illegalLearnRateOpts = new Double[] { -1.0 };
HashMap<String, Object[]> hyperParms = new HashMap<String, Object[]>() {
{
put("_ntrees", new Integer[] { 1, 2 });
put("_distribution", new DistributionFamily[] { DistributionFamily.multinomial });
put("_max_depth", new Integer[] { 1, 2, 5 });
put("_learn_rate", ArrayUtils.join(legalLearnRateOpts, illegalLearnRateOpts));
}
};
// Name of used hyper parameters
String[] hyperParamNames = hyperParms.keySet().toArray(new String[hyperParms.size()]);
Arrays.sort(hyperParamNames);
int hyperSpaceSize = ArrayUtils.crossProductSize(hyperParms);
// Fire off a grid search
GBMModel.GBMParameters params = new GBMModel.GBMParameters();
params._train = fr._key;
params._response_column = "cylinders";
// Get the Grid for this modeling class and frame
Job<Grid> gs = GridSearch.startGridSearch(null, params, hyperParms);
grid = (Grid<GBMModel.GBMParameters>) gs.get();
// Make sure number of produced models match size of specified hyper space
Assert.assertEquals("Size of grid (models+failures) should match to size of hyper space", hyperSpaceSize, grid.getModelCount() + grid.getFailureCount());
//
// Make sure that names of used parameters match
//
String[] gridHyperNames = grid.getHyperNames();
Arrays.sort(gridHyperNames);
Assert.assertArrayEquals("Hyper parameters names should match!", hyperParamNames, gridHyperNames);
//
// Make sure that values of used parameters match as well to the specified values
//
Key<Model>[] mKeys = grid.getModelKeys();
Map<String, Set<Object>> usedHyperParams = GridTestUtils.initMap(hyperParamNames);
for (Key<Model> mKey : mKeys) {
GBMModel gbm = (GBMModel) mKey.get();
System.out.println(gbm._output._scored_train[gbm._output._ntrees]._mse + " " + Arrays.deepToString(ArrayUtils.zip(grid.getHyperNames(), grid.getHyperValues(gbm._parms))));
GridTestUtils.extractParams(usedHyperParams, gbm._parms, hyperParamNames);
}
// Remove illegal options
hyperParms.put("_learn_rate", legalLearnRateOpts);
GridTestUtils.assertParamsEqual("Grid models parameters have to cover specified hyper space", hyperParms, usedHyperParams);
// Verify model failure
Map<String, Set<Object>> failedHyperParams = GridTestUtils.initMap(hyperParamNames);
;
for (Model.Parameters failedParams : grid.getFailedParameters()) {
GridTestUtils.extractParams(failedHyperParams, failedParams, hyperParamNames);
}
hyperParms.put("_learn_rate", illegalLearnRateOpts);
GridTestUtils.assertParamsEqual("Failed model parameters have to correspond to specified hyper space", hyperParms, failedHyperParams);
} finally {
if (old != null) {
old.remove();
}
if (fr != null) {
fr.remove();
}
if (grid != null) {
grid.remove();
}
}
}
Aggregations