use of hex.tree.gbm.GBMModel in project h2o-3 by h2oai.
the class TestCase method execute.
public TestCaseResult execute() throws Exception, AssertionError {
loadTestCaseDataSets();
makeModelParameters();
double startTime = 0, stopTime = 0;
if (!grid) {
Model.Output modelOutput = null;
DRF drfJob;
DRFModel drfModel = null;
GLM glmJob;
GLMModel glmModel = null;
GBM gbmJob;
GBMModel gbmModel = null;
DeepLearning dlJob;
DeepLearningModel dlModel = null;
String bestModelJson = null;
try {
switch(algo) {
case "drf":
drfJob = new DRF((DRFModel.DRFParameters) params);
AccuracyTestingSuite.summaryLog.println("Training DRF model.");
startTime = System.currentTimeMillis();
drfModel = drfJob.trainModel().get();
stopTime = System.currentTimeMillis();
modelOutput = drfModel._output;
bestModelJson = drfModel._parms.toJsonString();
break;
case "glm":
glmJob = new GLM((GLMModel.GLMParameters) params, Key.<GLMModel>make("GLMModel"));
AccuracyTestingSuite.summaryLog.println("Training GLM model.");
startTime = System.currentTimeMillis();
glmModel = glmJob.trainModel().get();
stopTime = System.currentTimeMillis();
modelOutput = glmModel._output;
bestModelJson = glmModel._parms.toJsonString();
break;
case "gbm":
gbmJob = new GBM((GBMModel.GBMParameters) params);
AccuracyTestingSuite.summaryLog.println("Training GBM model.");
startTime = System.currentTimeMillis();
gbmModel = gbmJob.trainModel().get();
stopTime = System.currentTimeMillis();
modelOutput = gbmModel._output;
bestModelJson = gbmModel._parms.toJsonString();
break;
case "dl":
dlJob = new DeepLearning((DeepLearningModel.DeepLearningParameters) params);
AccuracyTestingSuite.summaryLog.println("Training DL model.");
startTime = System.currentTimeMillis();
dlModel = dlJob.trainModel().get();
stopTime = System.currentTimeMillis();
modelOutput = dlModel._output;
bestModelJson = dlModel._parms.toJsonString();
break;
}
} catch (Exception e) {
throw new Exception(e);
} finally {
if (drfModel != null) {
drfModel.delete();
}
if (glmModel != null) {
glmModel.delete();
}
if (gbmModel != null) {
gbmModel.delete();
}
if (dlModel != null) {
dlModel.delete();
}
}
removeTestCaseDataSetFrames();
//Add check if cv is used
if (params._nfolds > 0) {
return new TestCaseResult(testCaseId, getMetrics(modelOutput._training_metrics), getMetrics(modelOutput._cross_validation_metrics), stopTime - startTime, bestModelJson, this, trainingDataSet, testingDataSet);
} else {
return new TestCaseResult(testCaseId, getMetrics(modelOutput._training_metrics), getMetrics(modelOutput._validation_metrics), stopTime - startTime, bestModelJson, this, trainingDataSet, testingDataSet);
}
} else {
assert !modelSelectionCriteria.equals("");
makeGridParameters();
makeSearchCriteria();
Grid grid = null;
Model bestModel = null;
String bestModelJson = null;
try {
SchemaServer.registerAllSchemasIfNecessary();
switch(// TODO: Hack for PUBDEV-2812
algo) {
case "drf":
if (!drfRegistered) {
new DRF(true);
new DRFParametersV3();
drfRegistered = true;
}
break;
case "glm":
if (!glmRegistered) {
new GLM(true);
new GLMParametersV3();
glmRegistered = true;
}
break;
case "gbm":
if (!gbmRegistered) {
new GBM(true);
new GBMParametersV3();
gbmRegistered = true;
}
break;
case "dl":
if (!dlRegistered) {
new DeepLearning(true);
new DeepLearningParametersV3();
dlRegistered = true;
}
break;
}
startTime = System.currentTimeMillis();
// TODO: ModelParametersBuilderFactory parameter must be instantiated properly
Job<Grid> gs = GridSearch.startGridSearch(null, params, hyperParms, new GridSearch.SimpleParametersBuilderFactory<>(), searchCriteria);
grid = gs.get();
stopTime = System.currentTimeMillis();
boolean higherIsBetter = higherIsBetter(modelSelectionCriteria);
double bestScore = higherIsBetter ? -Double.MAX_VALUE : Double.MAX_VALUE;
for (Model m : grid.getModels()) {
double validationMetricScore = getMetrics(m._output._validation_metrics).get(modelSelectionCriteria);
AccuracyTestingSuite.summaryLog.println(modelSelectionCriteria + " for model " + m._key.toString() + " is " + validationMetricScore);
if (higherIsBetter ? validationMetricScore > bestScore : validationMetricScore < bestScore) {
bestScore = validationMetricScore;
bestModel = m;
bestModelJson = bestModel._parms.toJsonString();
}
}
AccuracyTestingSuite.summaryLog.println("Best model: " + bestModel._key.toString());
AccuracyTestingSuite.summaryLog.println("Best model parameters: " + bestModelJson);
} catch (Exception e) {
throw new Exception(e);
} finally {
if (grid != null) {
grid.delete();
}
}
removeTestCaseDataSetFrames();
//Add check if cv is used
if (params._nfolds > 0) {
return new TestCaseResult(testCaseId, getMetrics(bestModel._output._training_metrics), getMetrics(bestModel._output._cross_validation_metrics), stopTime - startTime, bestModelJson, this, trainingDataSet, testingDataSet);
} else {
return new TestCaseResult(testCaseId, getMetrics(bestModel._output._training_metrics), getMetrics(bestModel._output._validation_metrics), stopTime - startTime, bestModelJson, this, trainingDataSet, testingDataSet);
}
}
}
use of hex.tree.gbm.GBMModel in project h2o-3 by h2oai.
the class WorkFlowTest method testWorkFlow.
// End-to-end workflow test:
// 1- load set of files, train, test, holdout
// 2- light data munging
// 3- build model on train; using test as validation
// 4- score on holdout set
//
// If files are missing, silently fail - as the files are big and this is not
// yet a junit test
private void testWorkFlow(String[] files) {
try {
Scope.enter();
// 1- Load datasets
Frame data = load_files("data.hex", files);
if (data == null)
return;
// -------------------------------------------------
// 2- light data munging
// Convert start time to: Day since the Epoch
Vec startime = data.vec("starttime");
data.add(new TimeSplit().doIt(startime));
// Now do a monster Group-By. Count bike starts per-station per-day
Vec days = data.vec("Days");
long start = System.currentTimeMillis();
Frame bph = new CountBikes(days).doAll(days, data.vec("start station name")).makeFrame(Key.make("bph.hex"));
System.out.println("Groupby took " + (System.currentTimeMillis() - start));
System.out.println(bph);
System.out.println(bph.toString(10000, 20));
data.remove();
QuantileModel.QuantileParameters quantile_parms = new QuantileModel.QuantileParameters();
quantile_parms._train = bph._key;
Job<QuantileModel> job2 = new Quantile(quantile_parms).trainModel();
QuantileModel quantile = job2.get();
job2.remove();
System.out.println(Arrays.deepToString(quantile._output._quantiles));
quantile.remove();
// Split into train, test and holdout sets
Key[] keys = new Key[] { Key.make("train.hex"), Key.make("test.hex"), Key.make("hold.hex") };
double[] ratios = new double[] { 0.6, 0.3, 0.1 };
Frame[] frs = ShuffleSplitFrame.shuffleSplitFrame(bph, keys, ratios, 1234567689L);
Frame train = frs[0];
Frame test = frs[1];
Frame hold = frs[2];
bph.remove();
System.out.println(train);
System.out.println(test);
// -------------------------------------------------
// 3- build model on train; using test as validation
// ---
// Gradient Boosting Machine
GBMModel.GBMParameters gbm_parms = new GBMModel.GBMParameters();
// base Model.Parameters
gbm_parms._train = train._key;
gbm_parms._valid = test._key;
// default is false
gbm_parms._score_each_iteration = false;
// SupervisedModel.Parameters
gbm_parms._response_column = "bikes";
// SharedTreeModel.Parameters
// default is 50, 1000 is 0.90, 10000 is 0.91
gbm_parms._ntrees = 500;
// default is 5
gbm_parms._max_depth = 6;
// default
gbm_parms._min_rows = 10;
// default
gbm_parms._nbins = 20;
// GBMModel.Parameters
// default
gbm_parms._distribution = DistributionFamily.gaussian;
// default
gbm_parms._learn_rate = 0.1f;
// Train model; block for results
Job<GBMModel> job = new GBM(gbm_parms).trainModel();
GBMModel gbm = job.get();
job.remove();
// ---
// Build a GLM model also
GLMModel.GLMParameters glm_parms = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.gaussian);
// base Model.Parameters
glm_parms._train = train._key;
glm_parms._valid = test._key;
// default is false
glm_parms._score_each_iteration = false;
// SupervisedModel.Parameters
glm_parms._response_column = "bikes";
// GLMModel.Parameters
glm_parms._use_all_factor_levels = true;
// Train model; block for results
Job<GLMModel> glm_job = new GLM(glm_parms).trainModel();
GLMModel glm = glm_job.get();
glm_job.remove();
// -------------------------------------------------
// 4- Score on holdout set & report
gbm.score(train).remove();
glm.score(train).remove();
// Cleanup
train.remove();
test.remove();
hold.remove();
} finally {
Scope.exit();
}
}
use of hex.tree.gbm.GBMModel in project h2o-3 by h2oai.
the class ModelSerializationTest method testGBMModelMultinomial.
@Test
public void testGBMModelMultinomial() throws IOException {
GBMModel model, loadedModel = null;
try {
model = prepareGBMModel("smalldata/iris/iris.csv", ESA, "C5", true, 5);
CompressedTree[][] trees = getTrees(model);
loadedModel = saveAndLoad(model);
// And compare
assertModelBinaryEquals(model, loadedModel);
CompressedTree[][] loadedTrees = getTrees(loadedModel);
assertTreeEquals("Trees have to be binary same", trees, loadedTrees);
} finally {
if (loadedModel != null)
loadedModel.delete();
}
}
use of hex.tree.gbm.GBMModel in project h2o-3 by h2oai.
the class XValPredictionsCheck method testXValPredictions.
@Test
public void testXValPredictions() {
final int nfolds = 3;
Frame tfr = null;
try {
// Load data, hack frames
tfr = parse_test_file("smalldata/iris/iris_wheader.csv");
Frame foldId = new Frame(new String[] { "foldId" }, new Vec[] { AstKFold.kfoldColumn(tfr.vec("class").makeZero(), nfolds, 543216789) });
tfr.add(foldId);
DKV.put(tfr);
// GBM
GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
parms._train = tfr._key;
parms._response_column = "class";
parms._ntrees = 1;
parms._max_depth = 1;
parms._fold_column = "foldId";
parms._distribution = DistributionFamily.multinomial;
parms._keep_cross_validation_predictions = true;
GBM job = new GBM(parms);
GBMModel gbm = job.trainModel().get();
checkModel(gbm, foldId.anyVec(), 3);
// DRF
DRFModel.DRFParameters parmsDRF = new DRFModel.DRFParameters();
parmsDRF._train = tfr._key;
parmsDRF._response_column = "class";
parmsDRF._ntrees = 1;
parmsDRF._max_depth = 1;
parmsDRF._fold_column = "foldId";
parmsDRF._distribution = DistributionFamily.multinomial;
parmsDRF._keep_cross_validation_predictions = true;
DRF drfJob = new DRF(parmsDRF);
DRFModel drf = drfJob.trainModel().get();
checkModel(drf, foldId.anyVec(), 3);
// GLM
GLMModel.GLMParameters parmsGLM = new GLMModel.GLMParameters();
parmsGLM._train = tfr._key;
parmsGLM._response_column = "sepal_len";
parmsGLM._fold_column = "foldId";
parmsGLM._keep_cross_validation_predictions = true;
GLM glmJob = new GLM(parmsGLM);
GLMModel glm = glmJob.trainModel().get();
checkModel(glm, foldId.anyVec(), 1);
// DL
DeepLearningModel.DeepLearningParameters parmsDL = new DeepLearningModel.DeepLearningParameters();
parmsDL._train = tfr._key;
parmsDL._response_column = "class";
parmsDL._hidden = new int[] { 1 };
parmsDL._epochs = 1;
parmsDL._fold_column = "foldId";
parmsDL._keep_cross_validation_predictions = true;
DeepLearning dlJob = new DeepLearning(parmsDL);
DeepLearningModel dl = dlJob.trainModel().get();
checkModel(dl, foldId.anyVec(), 3);
} finally {
if (tfr != null)
tfr.remove();
}
}
use of hex.tree.gbm.GBMModel in project h2o-3 by h2oai.
the class PartialDependenceTest method weatherBinary.
@Test
public void weatherBinary() {
Frame fr = null;
GBMModel model = null;
PartialDependence partialDependence = null;
try {
// Frame
fr = parse_test_file("smalldata/junit/weather.csv");
// Model
GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
parms._train = fr._key;
parms._ignored_columns = new String[] { "Date", "RISK_MM", "EvapMM" };
parms._response_column = "RainTomorrow";
model = new GBM(parms).trainModel().get();
// PartialDependence
partialDependence = new PartialDependence(Key.<PartialDependence>make());
partialDependence._nbins = 33;
partialDependence._cols = new String[] { "Sunshine", "MaxWindPeriod", "WindSpeed9am" };
partialDependence._model_id = (Key) model._key;
partialDependence._frame_id = fr._key;
partialDependence.execImpl().get();
for (TwoDimTable t : partialDependence._partial_dependence_data) Log.info(t);
} finally {
if (fr != null)
fr.remove();
if (model != null)
model.remove();
if (partialDependence != null)
partialDependence.remove();
}
}
Aggregations