use of hex.tree.gbm.GBM in project h2o-3 by h2oai.
the class ModelSerializationTest method prepareGBMModel.
private GBMModel prepareGBMModel(String dataset, String[] ignoredColumns, String response, boolean classification, int ntrees) {
Frame f = parse_test_file(dataset);
try {
if (classification && !f.vec(response).isCategorical()) {
f.replace(f.find(response), f.vec(response).toCategoricalVec()).remove();
DKV.put(f._key, f);
}
GBMModel.GBMParameters gbmParams = new GBMModel.GBMParameters();
gbmParams._train = f._key;
gbmParams._ignored_columns = ignoredColumns;
gbmParams._response_column = response;
gbmParams._ntrees = ntrees;
gbmParams._score_each_iteration = true;
return new GBM(gbmParams).trainModel().get();
} finally {
if (f != null)
f.delete();
}
}
use of hex.tree.gbm.GBM in project h2o-3 by h2oai.
the class SSLEncryptionTest method testGBMRegressionGaussian.
private static void testGBMRegressionGaussian() {
GBMModel gbm = null;
Frame fr = null, fr2 = null;
try {
Date start = new Date();
fr = parse_test_file("./smalldata/gbm_test/Mfgdata_gaussian_GBM_testing.csv");
GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
parms._train = fr._key;
parms._distribution = gaussian;
// Row in col 0, dependent in col 1, predictor in col 2
parms._response_column = fr._names[1];
parms._ntrees = 1;
parms._max_depth = 1;
parms._min_rows = 1;
parms._nbins = 20;
// Drop ColV2 0 (row), keep 1 (response), keep col 2 (only predictor), drop remaining cols
String[] xcols = parms._ignored_columns = new String[fr.numCols() - 2];
xcols[0] = fr._names[0];
System.arraycopy(fr._names, 3, xcols, 1, fr.numCols() - 3);
parms._learn_rate = 1.0f;
parms._score_each_iteration = true;
GBM job = new GBM(parms);
gbm = job.trainModel().get();
Log.info(">>> GBM parsing and training took: " + (new Date().getTime() - start.getTime()) + " ms.");
//HEX-1817
Assert.assertTrue(job.isStopped());
// Done building model; produce a score column with predictions
Date scoringStart = new Date();
fr2 = gbm.score(fr);
Log.info(">>> GBM scoring took: " + (new Date().getTime() - scoringStart.getTime()) + " ms.");
} finally {
if (fr != null)
fr.remove();
if (fr2 != null)
fr2.remove();
if (gbm != null)
gbm.remove();
}
}
use of hex.tree.gbm.GBM in project h2o-3 by h2oai.
the class PartialDependenceTest method prostateBinary.
@Test
public void prostateBinary() {
Frame fr = null;
GBMModel model = null;
PartialDependence partialDependence = null;
try {
// Frame
fr = parse_test_file("smalldata/prostate/prostate.csv");
for (String s : new String[] { "RACE", "GLEASON", "DPROS", "DCAPS", "CAPSULE" }) {
Vec v = fr.remove(s);
fr.add(s, v.toCategoricalVec());
v.remove();
}
DKV.put(fr);
// Model
GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
parms._train = fr._key;
parms._ignored_columns = new String[] { "ID" };
parms._response_column = "CAPSULE";
model = new GBM(parms).trainModel().get();
// PartialDependence
partialDependence = new PartialDependence(Key.<PartialDependence>make());
// partialDependence._cols = model._output._names;
partialDependence._nbins = 10;
partialDependence._model_id = (Key) model._key;
partialDependence._frame_id = fr._key;
partialDependence.execImpl().get();
for (TwoDimTable t : partialDependence._partial_dependence_data) Log.info(t);
} finally {
if (fr != null)
fr.remove();
if (model != null)
model.remove();
if (partialDependence != null)
partialDependence.remove();
}
}
use of hex.tree.gbm.GBM in project h2o-3 by h2oai.
the class PartialDependenceTest method prostateRegression.
@Test
public void prostateRegression() {
Frame fr = null;
GBMModel model = null;
PartialDependence partialDependence = null;
try {
// Frame
fr = parse_test_file("smalldata/prostate/prostate.csv");
for (String s : new String[] { "RACE", "GLEASON", "DPROS", "DCAPS", "CAPSULE" }) {
Vec v = fr.remove(s);
fr.add(s, v.toCategoricalVec());
v.remove();
}
DKV.put(fr);
// Model
GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
parms._train = fr._key;
parms._ignored_columns = new String[] { "ID" };
parms._response_column = "AGE";
model = new GBM(parms).trainModel().get();
// PartialDependence
partialDependence = new PartialDependence(Key.<PartialDependence>make());
partialDependence._nbins = 10;
partialDependence._model_id = (Key) model._key;
partialDependence._frame_id = fr._key;
partialDependence.execImpl().get();
for (TwoDimTable t : partialDependence._partial_dependence_data) Log.info(t);
} finally {
if (fr != null)
fr.remove();
if (model != null)
model.remove();
if (partialDependence != null)
partialDependence.remove();
}
}
use of hex.tree.gbm.GBM in project h2o-3 by h2oai.
the class TestCase method execute.
public TestCaseResult execute() throws Exception, AssertionError {
loadTestCaseDataSets();
makeModelParameters();
double startTime = 0, stopTime = 0;
if (!grid) {
Model.Output modelOutput = null;
DRF drfJob;
DRFModel drfModel = null;
GLM glmJob;
GLMModel glmModel = null;
GBM gbmJob;
GBMModel gbmModel = null;
DeepLearning dlJob;
DeepLearningModel dlModel = null;
String bestModelJson = null;
try {
switch(algo) {
case "drf":
drfJob = new DRF((DRFModel.DRFParameters) params);
AccuracyTestingSuite.summaryLog.println("Training DRF model.");
startTime = System.currentTimeMillis();
drfModel = drfJob.trainModel().get();
stopTime = System.currentTimeMillis();
modelOutput = drfModel._output;
bestModelJson = drfModel._parms.toJsonString();
break;
case "glm":
glmJob = new GLM((GLMModel.GLMParameters) params, Key.<GLMModel>make("GLMModel"));
AccuracyTestingSuite.summaryLog.println("Training GLM model.");
startTime = System.currentTimeMillis();
glmModel = glmJob.trainModel().get();
stopTime = System.currentTimeMillis();
modelOutput = glmModel._output;
bestModelJson = glmModel._parms.toJsonString();
break;
case "gbm":
gbmJob = new GBM((GBMModel.GBMParameters) params);
AccuracyTestingSuite.summaryLog.println("Training GBM model.");
startTime = System.currentTimeMillis();
gbmModel = gbmJob.trainModel().get();
stopTime = System.currentTimeMillis();
modelOutput = gbmModel._output;
bestModelJson = gbmModel._parms.toJsonString();
break;
case "dl":
dlJob = new DeepLearning((DeepLearningModel.DeepLearningParameters) params);
AccuracyTestingSuite.summaryLog.println("Training DL model.");
startTime = System.currentTimeMillis();
dlModel = dlJob.trainModel().get();
stopTime = System.currentTimeMillis();
modelOutput = dlModel._output;
bestModelJson = dlModel._parms.toJsonString();
break;
}
} catch (Exception e) {
throw new Exception(e);
} finally {
if (drfModel != null) {
drfModel.delete();
}
if (glmModel != null) {
glmModel.delete();
}
if (gbmModel != null) {
gbmModel.delete();
}
if (dlModel != null) {
dlModel.delete();
}
}
removeTestCaseDataSetFrames();
//Add check if cv is used
if (params._nfolds > 0) {
return new TestCaseResult(testCaseId, getMetrics(modelOutput._training_metrics), getMetrics(modelOutput._cross_validation_metrics), stopTime - startTime, bestModelJson, this, trainingDataSet, testingDataSet);
} else {
return new TestCaseResult(testCaseId, getMetrics(modelOutput._training_metrics), getMetrics(modelOutput._validation_metrics), stopTime - startTime, bestModelJson, this, trainingDataSet, testingDataSet);
}
} else {
assert !modelSelectionCriteria.equals("");
makeGridParameters();
makeSearchCriteria();
Grid grid = null;
Model bestModel = null;
String bestModelJson = null;
try {
SchemaServer.registerAllSchemasIfNecessary();
switch(// TODO: Hack for PUBDEV-2812
algo) {
case "drf":
if (!drfRegistered) {
new DRF(true);
new DRFParametersV3();
drfRegistered = true;
}
break;
case "glm":
if (!glmRegistered) {
new GLM(true);
new GLMParametersV3();
glmRegistered = true;
}
break;
case "gbm":
if (!gbmRegistered) {
new GBM(true);
new GBMParametersV3();
gbmRegistered = true;
}
break;
case "dl":
if (!dlRegistered) {
new DeepLearning(true);
new DeepLearningParametersV3();
dlRegistered = true;
}
break;
}
startTime = System.currentTimeMillis();
// TODO: ModelParametersBuilderFactory parameter must be instantiated properly
Job<Grid> gs = GridSearch.startGridSearch(null, params, hyperParms, new GridSearch.SimpleParametersBuilderFactory<>(), searchCriteria);
grid = gs.get();
stopTime = System.currentTimeMillis();
boolean higherIsBetter = higherIsBetter(modelSelectionCriteria);
double bestScore = higherIsBetter ? -Double.MAX_VALUE : Double.MAX_VALUE;
for (Model m : grid.getModels()) {
double validationMetricScore = getMetrics(m._output._validation_metrics).get(modelSelectionCriteria);
AccuracyTestingSuite.summaryLog.println(modelSelectionCriteria + " for model " + m._key.toString() + " is " + validationMetricScore);
if (higherIsBetter ? validationMetricScore > bestScore : validationMetricScore < bestScore) {
bestScore = validationMetricScore;
bestModel = m;
bestModelJson = bestModel._parms.toJsonString();
}
}
AccuracyTestingSuite.summaryLog.println("Best model: " + bestModel._key.toString());
AccuracyTestingSuite.summaryLog.println("Best model parameters: " + bestModelJson);
} catch (Exception e) {
throw new Exception(e);
} finally {
if (grid != null) {
grid.delete();
}
}
removeTestCaseDataSetFrames();
//Add check if cv is used
if (params._nfolds > 0) {
return new TestCaseResult(testCaseId, getMetrics(bestModel._output._training_metrics), getMetrics(bestModel._output._cross_validation_metrics), stopTime - startTime, bestModelJson, this, trainingDataSet, testingDataSet);
} else {
return new TestCaseResult(testCaseId, getMetrics(bestModel._output._training_metrics), getMetrics(bestModel._output._validation_metrics), stopTime - startTime, bestModelJson, this, trainingDataSet, testingDataSet);
}
}
}
Aggregations