use of water.fvec.Frame in project h2o-3 by h2oai.
the class ModelSerializationTest method prepareGBMModel.
private GBMModel prepareGBMModel(String dataset, String[] ignoredColumns, String response, boolean classification, int ntrees) {
Frame f = parse_test_file(dataset);
try {
if (classification && !f.vec(response).isCategorical()) {
f.replace(f.find(response), f.vec(response).toCategoricalVec()).remove();
DKV.put(f._key, f);
}
GBMModel.GBMParameters gbmParams = new GBMModel.GBMParameters();
gbmParams._train = f._key;
gbmParams._ignored_columns = ignoredColumns;
gbmParams._response_column = response;
gbmParams._ntrees = ntrees;
gbmParams._score_each_iteration = true;
return new GBM(gbmParams).trainModel().get();
} finally {
if (f != null)
f.delete();
}
}
use of water.fvec.Frame in project h2o-3 by h2oai.
the class SSLEncryptionTest method testGBMRegressionGaussian.
private static void testGBMRegressionGaussian() {
GBMModel gbm = null;
Frame fr = null, fr2 = null;
try {
Date start = new Date();
fr = parse_test_file("./smalldata/gbm_test/Mfgdata_gaussian_GBM_testing.csv");
GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
parms._train = fr._key;
parms._distribution = gaussian;
// Row in col 0, dependent in col 1, predictor in col 2
parms._response_column = fr._names[1];
parms._ntrees = 1;
parms._max_depth = 1;
parms._min_rows = 1;
parms._nbins = 20;
// Drop ColV2 0 (row), keep 1 (response), keep col 2 (only predictor), drop remaining cols
String[] xcols = parms._ignored_columns = new String[fr.numCols() - 2];
xcols[0] = fr._names[0];
System.arraycopy(fr._names, 3, xcols, 1, fr.numCols() - 3);
parms._learn_rate = 1.0f;
parms._score_each_iteration = true;
GBM job = new GBM(parms);
gbm = job.trainModel().get();
Log.info(">>> GBM parsing and training took: " + (new Date().getTime() - start.getTime()) + " ms.");
//HEX-1817
Assert.assertTrue(job.isStopped());
// Done building model; produce a score column with predictions
Date scoringStart = new Date();
fr2 = gbm.score(fr);
Log.info(">>> GBM scoring took: " + (new Date().getTime() - scoringStart.getTime()) + " ms.");
} finally {
if (fr != null)
fr.remove();
if (fr2 != null)
fr2.remove();
if (gbm != null)
gbm.remove();
}
}
use of water.fvec.Frame in project h2o-3 by h2oai.
the class GBMTest method testModifiedHuberStability.
@Ignore
public void testModifiedHuberStability() {
String xy = "A,Y\nB,N\nA,N\nB,N\nA,Y\nA,Y";
Key tr = Key.make("train");
Frame df = ParseDataset.parse(tr, makeByteVec(Key.make("xy"), xy));
String test = "A,Y\nB,N\nA,N\nB,N\nA,Y\nA,Y";
Key te = Key.make("test");
Frame df2 = ParseDataset.parse(te, makeByteVec(Key.make("te"), test));
GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
parms._train = tr;
parms._response_column = "C2";
parms._min_rows = 1;
parms._learn_rate = 1;
parms._distribution = DistributionFamily.modified_huber;
parms._ntrees = 1;
GBM job = new GBM(parms);
GBMModel gbm = job.trainModel().get();
//AdaptTestTrain leaks when it does inplace Vec adaptation, need a Scope to catch that stuff
Scope.enter();
Frame preds = gbm.score(df);
Frame preds2 = gbm.score(df2);
Log.info(df);
Log.info(preds);
Log.info(df2);
Log.info(preds2);
Assert.assertTrue(gbm.testJavaScoring(df, preds, 1e-15));
Assert.assertTrue(gbm.testJavaScoring(df2, preds2, 1e-15));
// Assert.assertTrue(Math.abs(preds.vec(0).at(0) - -2.5) < 1e-6);
// Assert.assertTrue(Math.abs(preds.vec(0).at(1) - 1) < 1e-6);
// Assert.assertTrue(Math.abs(preds.vec(0).at(2) - -2.5) < 1e-6);
// Assert.assertTrue(Math.abs(preds.vec(0).at(3) - 1) < 1e-6);
// Assert.assertTrue(Math.abs(preds.vec(0).at(4) - 0) < 1e-6);
// Assert.assertTrue(Math.abs(preds.vec(0).at(5) - 1) < 1e-6);
preds.remove();
preds2.remove();
gbm.remove();
df.remove();
df2.remove();
Scope.exit();
}
use of water.fvec.Frame in project h2o-3 by h2oai.
the class GBMTest method testModelLock.
// A test of locking the input dataset during model building.
@Test
public void testModelLock() {
GBM gbm = null;
Frame fr = null;
Scope.enter();
try {
GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
fr = parse_test_file("smalldata/gbm_test/ecology_model.csv");
// Remove unique ID
fr.remove("Site").remove();
int ci = fr.find("Angaus");
// Convert response 'Angaus' to categorical
Scope.track(fr.replace(ci, fr.vecs()[ci].toCategoricalVec()));
// Update after hacking
DKV.put(fr);
parms._train = fr._key;
// Train on the outcome
parms._response_column = "Angaus";
parms._ntrees = 10;
parms._max_depth = 10;
parms._min_rows = 1;
parms._nbins = 20;
parms._learn_rate = .2f;
parms._distribution = DistributionFamily.multinomial;
gbm = new GBM(parms);
gbm.trainModel();
try {
Thread.sleep(100);
} catch (Exception ignore) {
}
try {
Log.info("Trying illegal frame delete.");
// Attempted delete while model-build is active
fr.delete();
Assert.fail("Should toss IAE instead of reaching here");
} catch (IllegalArgumentException ignore) {
} catch (RuntimeException re) {
assertTrue(re.getCause() instanceof IllegalArgumentException);
}
Log.info("Getting model");
GBMModel model = gbm.get();
//HEX-1817
Assert.assertTrue(gbm.isStopped());
if (model != null)
model.delete();
} finally {
if (fr != null)
fr.remove();
Scope.exit();
}
}
use of water.fvec.Frame in project h2o-3 by h2oai.
the class GBMTest method testNfoldsColumn.
@Test
public void testNfoldsColumn() {
Frame tfr = null;
GBMModel gbm1 = null;
try {
tfr = parse_test_file("smalldata/junit/cars_20mpg.csv");
// Remove unique id
tfr.remove("name").remove();
DKV.put(tfr);
GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
parms._train = tfr._key;
parms._response_column = "economy_20mpg";
parms._fold_column = "cylinders";
Vec old = tfr.remove("cylinders");
tfr.add("cylinders", old.toCategoricalVec());
DKV.put(tfr);
parms._ntrees = 10;
parms._keep_cross_validation_fold_assignment = true;
GBM job1 = new GBM(parms);
gbm1 = job1.trainModel().get();
Assert.assertTrue(gbm1._output._cross_validation_models.length == 5);
old.remove();
} finally {
if (tfr != null)
tfr.remove();
if (gbm1 != null) {
gbm1.deleteCrossValidationModels();
gbm1.delete();
gbm1._output._cross_validation_fold_assignment_frame_id.remove();
}
}
}
Aggregations