use of hex.tree.drf.DRFModel in project h2o-3 by h2oai.
the class TestCase method execute.
public TestCaseResult execute() throws Exception, AssertionError {
loadTestCaseDataSets();
makeModelParameters();
double startTime = 0, stopTime = 0;
if (!grid) {
Model.Output modelOutput = null;
DRF drfJob;
DRFModel drfModel = null;
GLM glmJob;
GLMModel glmModel = null;
GBM gbmJob;
GBMModel gbmModel = null;
DeepLearning dlJob;
DeepLearningModel dlModel = null;
String bestModelJson = null;
try {
switch(algo) {
case "drf":
drfJob = new DRF((DRFModel.DRFParameters) params);
AccuracyTestingSuite.summaryLog.println("Training DRF model.");
startTime = System.currentTimeMillis();
drfModel = drfJob.trainModel().get();
stopTime = System.currentTimeMillis();
modelOutput = drfModel._output;
bestModelJson = drfModel._parms.toJsonString();
break;
case "glm":
glmJob = new GLM((GLMModel.GLMParameters) params, Key.<GLMModel>make("GLMModel"));
AccuracyTestingSuite.summaryLog.println("Training GLM model.");
startTime = System.currentTimeMillis();
glmModel = glmJob.trainModel().get();
stopTime = System.currentTimeMillis();
modelOutput = glmModel._output;
bestModelJson = glmModel._parms.toJsonString();
break;
case "gbm":
gbmJob = new GBM((GBMModel.GBMParameters) params);
AccuracyTestingSuite.summaryLog.println("Training GBM model.");
startTime = System.currentTimeMillis();
gbmModel = gbmJob.trainModel().get();
stopTime = System.currentTimeMillis();
modelOutput = gbmModel._output;
bestModelJson = gbmModel._parms.toJsonString();
break;
case "dl":
dlJob = new DeepLearning((DeepLearningModel.DeepLearningParameters) params);
AccuracyTestingSuite.summaryLog.println("Training DL model.");
startTime = System.currentTimeMillis();
dlModel = dlJob.trainModel().get();
stopTime = System.currentTimeMillis();
modelOutput = dlModel._output;
bestModelJson = dlModel._parms.toJsonString();
break;
}
} catch (Exception e) {
throw new Exception(e);
} finally {
if (drfModel != null) {
drfModel.delete();
}
if (glmModel != null) {
glmModel.delete();
}
if (gbmModel != null) {
gbmModel.delete();
}
if (dlModel != null) {
dlModel.delete();
}
}
removeTestCaseDataSetFrames();
//Add check if cv is used
if (params._nfolds > 0) {
return new TestCaseResult(testCaseId, getMetrics(modelOutput._training_metrics), getMetrics(modelOutput._cross_validation_metrics), stopTime - startTime, bestModelJson, this, trainingDataSet, testingDataSet);
} else {
return new TestCaseResult(testCaseId, getMetrics(modelOutput._training_metrics), getMetrics(modelOutput._validation_metrics), stopTime - startTime, bestModelJson, this, trainingDataSet, testingDataSet);
}
} else {
assert !modelSelectionCriteria.equals("");
makeGridParameters();
makeSearchCriteria();
Grid grid = null;
Model bestModel = null;
String bestModelJson = null;
try {
SchemaServer.registerAllSchemasIfNecessary();
switch(// TODO: Hack for PUBDEV-2812
algo) {
case "drf":
if (!drfRegistered) {
new DRF(true);
new DRFParametersV3();
drfRegistered = true;
}
break;
case "glm":
if (!glmRegistered) {
new GLM(true);
new GLMParametersV3();
glmRegistered = true;
}
break;
case "gbm":
if (!gbmRegistered) {
new GBM(true);
new GBMParametersV3();
gbmRegistered = true;
}
break;
case "dl":
if (!dlRegistered) {
new DeepLearning(true);
new DeepLearningParametersV3();
dlRegistered = true;
}
break;
}
startTime = System.currentTimeMillis();
// TODO: ModelParametersBuilderFactory parameter must be instantiated properly
Job<Grid> gs = GridSearch.startGridSearch(null, params, hyperParms, new GridSearch.SimpleParametersBuilderFactory<>(), searchCriteria);
grid = gs.get();
stopTime = System.currentTimeMillis();
boolean higherIsBetter = higherIsBetter(modelSelectionCriteria);
double bestScore = higherIsBetter ? -Double.MAX_VALUE : Double.MAX_VALUE;
for (Model m : grid.getModels()) {
double validationMetricScore = getMetrics(m._output._validation_metrics).get(modelSelectionCriteria);
AccuracyTestingSuite.summaryLog.println(modelSelectionCriteria + " for model " + m._key.toString() + " is " + validationMetricScore);
if (higherIsBetter ? validationMetricScore > bestScore : validationMetricScore < bestScore) {
bestScore = validationMetricScore;
bestModel = m;
bestModelJson = bestModel._parms.toJsonString();
}
}
AccuracyTestingSuite.summaryLog.println("Best model: " + bestModel._key.toString());
AccuracyTestingSuite.summaryLog.println("Best model parameters: " + bestModelJson);
} catch (Exception e) {
throw new Exception(e);
} finally {
if (grid != null) {
grid.delete();
}
}
removeTestCaseDataSetFrames();
//Add check if cv is used
if (params._nfolds > 0) {
return new TestCaseResult(testCaseId, getMetrics(bestModel._output._training_metrics), getMetrics(bestModel._output._cross_validation_metrics), stopTime - startTime, bestModelJson, this, trainingDataSet, testingDataSet);
} else {
return new TestCaseResult(testCaseId, getMetrics(bestModel._output._training_metrics), getMetrics(bestModel._output._validation_metrics), stopTime - startTime, bestModelJson, this, trainingDataSet, testingDataSet);
}
}
}
use of hex.tree.drf.DRFModel in project h2o-3 by h2oai.
the class XValPredictionsCheck method checkModel.
void checkModel(Model m, Vec foldId, int nclass) {
if (// DRF does out of back instead of true training, nobs might be different
!(m instanceof DRFModel))
assertEquals(m._output._training_metrics._nobs, m._output._cross_validation_metrics._nobs);
m.delete();
m.deleteCrossValidationModels();
Key[] xvalKeys = m._output._cross_validation_predictions;
Key xvalKey = m._output._cross_validation_holdout_predictions_frame_id;
final int[] id = new int[1];
for (Key k : xvalKeys) {
Frame preds = DKV.getGet(k);
assert preds.numRows() == foldId.length();
Vec[] vecs = new Vec[nclass + 1];
vecs[0] = foldId;
if (nclass == 1)
vecs[1] = preds.anyVec();
else
System.arraycopy(preds.vecs(ArrayUtils.range(1, nclass)), 0, vecs, 1, nclass);
new MRTask() {
@Override
public void map(Chunk[] cs) {
Chunk foldId = cs[0];
for (int r = 0; r < cs[0]._len; ++r) if (foldId.at8(r) != id[0])
for (int i = 1; i < cs.length; ++i) // no prediction for this row!
assert cs[i].atd(r) == 0;
}
}.doAll(vecs);
id[0]++;
preds.delete();
}
xvalKey.remove();
}
use of hex.tree.drf.DRFModel in project h2o-3 by h2oai.
the class XValPredictionsCheck method testXValPredictions.
@Test
public void testXValPredictions() {
final int nfolds = 3;
Frame tfr = null;
try {
// Load data, hack frames
tfr = parse_test_file("smalldata/iris/iris_wheader.csv");
Frame foldId = new Frame(new String[] { "foldId" }, new Vec[] { AstKFold.kfoldColumn(tfr.vec("class").makeZero(), nfolds, 543216789) });
tfr.add(foldId);
DKV.put(tfr);
// GBM
GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
parms._train = tfr._key;
parms._response_column = "class";
parms._ntrees = 1;
parms._max_depth = 1;
parms._fold_column = "foldId";
parms._distribution = DistributionFamily.multinomial;
parms._keep_cross_validation_predictions = true;
GBM job = new GBM(parms);
GBMModel gbm = job.trainModel().get();
checkModel(gbm, foldId.anyVec(), 3);
// DRF
DRFModel.DRFParameters parmsDRF = new DRFModel.DRFParameters();
parmsDRF._train = tfr._key;
parmsDRF._response_column = "class";
parmsDRF._ntrees = 1;
parmsDRF._max_depth = 1;
parmsDRF._fold_column = "foldId";
parmsDRF._distribution = DistributionFamily.multinomial;
parmsDRF._keep_cross_validation_predictions = true;
DRF drfJob = new DRF(parmsDRF);
DRFModel drf = drfJob.trainModel().get();
checkModel(drf, foldId.anyVec(), 3);
// GLM
GLMModel.GLMParameters parmsGLM = new GLMModel.GLMParameters();
parmsGLM._train = tfr._key;
parmsGLM._response_column = "sepal_len";
parmsGLM._fold_column = "foldId";
parmsGLM._keep_cross_validation_predictions = true;
GLM glmJob = new GLM(parmsGLM);
GLMModel glm = glmJob.trainModel().get();
checkModel(glm, foldId.anyVec(), 1);
// DL
DeepLearningModel.DeepLearningParameters parmsDL = new DeepLearningModel.DeepLearningParameters();
parmsDL._train = tfr._key;
parmsDL._response_column = "class";
parmsDL._hidden = new int[] { 1 };
parmsDL._epochs = 1;
parmsDL._fold_column = "foldId";
parmsDL._keep_cross_validation_predictions = true;
DeepLearning dlJob = new DeepLearning(parmsDL);
DeepLearningModel dl = dlJob.trainModel().get();
checkModel(dl, foldId.anyVec(), 3);
} finally {
if (tfr != null)
tfr.remove();
}
}
Aggregations