use of hex.deeplearning.DeepLearningModel.DeepLearningParameters in project h2o-3 by h2oai.
the class DeepLearningTest method testInitialWeightsAndBiasesPartial.
@Test
public void testInitialWeightsAndBiasesPartial() {
Frame tfr = null;
DeepLearningModel dl1 = null;
DeepLearningModel dl2 = null;
try {
tfr = parse_test_file("./smalldata/gbm_test/BostonHousing.csv");
// train DL model from scratch
{
DeepLearningParameters parms = new DeepLearningParameters();
parms._train = tfr._key;
parms._response_column = tfr.lastVecName();
parms._activation = DeepLearningParameters.Activation.Tanh;
parms._reproducible = true;
parms._hidden = new int[] { 20, 20 };
parms._seed = 0xdecad;
parms._export_weights_and_biases = true;
dl1 = new DeepLearning(parms).trainModel().get();
}
// train DL model starting from weights/biases from first model
{
DeepLearningParameters parms = new DeepLearningParameters();
parms._train = tfr._key;
parms._response_column = tfr.lastVecName();
parms._activation = DeepLearningParameters.Activation.Tanh;
parms._reproducible = true;
parms._hidden = new int[] { 20, 20 };
parms._seed = 0xdecad;
parms._initial_weights = dl1._output.weights;
parms._initial_biases = dl1._output.biases;
parms._initial_weights[1].remove();
parms._initial_weights[1] = null;
parms._initial_biases[0].remove();
parms._initial_biases[0] = null;
parms._epochs = 10;
dl2 = new DeepLearning(parms).trainModel().get();
}
Log.info("dl1 : MSE=" + dl1._output._training_metrics.mse());
Log.info("dl2 : MSE=" + dl2._output._training_metrics.mse());
// the second model is better since it got warm-started at least partially
Assert.assertTrue(dl1._output._training_metrics.mse() > dl2._output._training_metrics.mse());
} finally {
if (tfr != null)
tfr.delete();
if (dl1 != null)
dl1.delete();
if (dl2 != null)
dl2.delete();
for (Key f : dl1._output.weights) if (f != null)
f.remove();
for (Key f : dl1._output.biases) if (f != null)
f.remove();
}
}
use of hex.deeplearning.DeepLearningModel.DeepLearningParameters in project h2o-3 by h2oai.
the class DeepLearningTest method testNoRowWeights.
@Test
public void testNoRowWeights() {
Frame tfr = null, vfr = null, pred = null, fr2 = null;
Scope.enter();
try {
tfr = parse_test_file("smalldata/junit/no_weights.csv");
DKV.put(tfr);
DeepLearningParameters parms = new DeepLearningParameters();
parms._train = tfr._key;
parms._response_column = "response";
parms._reproducible = true;
parms._seed = 0xdecaf;
parms._l1 = 0.1;
parms._epochs = 1;
parms._hidden = new int[] { 1 };
parms._classification_stop = -1;
// Build a first model; all remaining models should be equal
DeepLearningModel dl = new DeepLearning(parms).trainModel().get();
pred = dl.score(parms.train());
hex.ModelMetricsBinomial mm = hex.ModelMetricsBinomial.getFromDKV(dl, parms.train());
assertEquals(0.7592592592592592, mm.auc_obj()._auc, 1e-8);
double mse = dl._output._training_metrics.mse();
assertEquals(0.314813341867078, mse, 1e-8);
assertTrue(dl.testJavaScoring(tfr, fr2 = dl.score(tfr), 1e-5));
dl.delete();
} finally {
if (tfr != null)
tfr.remove();
if (vfr != null)
vfr.remove();
if (pred != null)
pred.remove();
if (fr2 != null)
fr2.remove();
}
Scope.exit();
}
use of hex.deeplearning.DeepLearningModel.DeepLearningParameters in project h2o-3 by h2oai.
the class DeepLearningTest method testMiniBatch1.
@Test
public void testMiniBatch1() {
Frame tfr = null;
DeepLearningModel dl = null;
try {
tfr = parse_test_file("./smalldata/gbm_test/BostonHousing.csv");
DeepLearningParameters parms = new DeepLearningParameters();
parms._train = tfr._key;
parms._response_column = tfr.lastVecName();
parms._reproducible = true;
parms._hidden = new int[] { 20, 20 };
parms._seed = 0xdecaf;
parms._mini_batch_size = 1;
dl = new DeepLearning(parms).trainModel().get();
Assert.assertEquals(12.938076268040659, dl._output._training_metrics._MSE, 1e-6);
} finally {
if (tfr != null)
tfr.delete();
if (dl != null)
dl.deleteCrossValidationModels();
if (dl != null)
dl.delete();
}
}
use of hex.deeplearning.DeepLearningModel.DeepLearningParameters in project h2o-3 by h2oai.
the class DeepLearningTest method testHuber.
@Test
public void testHuber() {
Frame tfr = null;
DeepLearningModel dl = null;
try {
tfr = parse_test_file("./smalldata/gbm_test/BostonHousing.csv");
DeepLearningParameters parms = new DeepLearningParameters();
parms._train = tfr._key;
parms._response_column = tfr.lastVecName();
parms._reproducible = true;
parms._hidden = new int[] { 20, 20 };
parms._seed = 0xdecaf;
parms._distribution = huber;
dl = new DeepLearning(parms).trainModel().get();
Assert.assertEquals(6.4964976811, ((ModelMetricsRegression) dl._output._training_metrics)._mean_residual_deviance, 1e-5);
} finally {
if (tfr != null)
tfr.delete();
if (dl != null)
dl.deleteCrossValidationModels();
if (dl != null)
dl.delete();
}
}
use of hex.deeplearning.DeepLearningModel.DeepLearningParameters in project h2o-3 by h2oai.
the class DeepLearningTest method testCheckpointOverwriteWithBestModel2.
// Check that the restarted model honors the previous model as a best model so far
@Test
public void testCheckpointOverwriteWithBestModel2() {
Frame tfr = null;
DeepLearningModel dl = null;
DeepLearningModel dl2 = null;
Frame train = null, valid = null;
try {
tfr = parse_test_file("./smalldata/iris/iris.csv");
FrameSplitter fs = new FrameSplitter(tfr, new double[] { 0.8 }, new Key[] { Key.make("train"), Key.make("valid") }, null);
fs.compute2();
train = fs.getResult()[0];
valid = fs.getResult()[1];
DeepLearningParameters parms = new DeepLearningParameters();
parms._train = train._key;
parms._valid = valid._key;
parms._epochs = 10;
parms._response_column = "C5";
parms._reproducible = true;
parms._hidden = new int[] { 50, 50 };
parms._seed = 0xdecaf;
parms._train_samples_per_iteration = 0;
parms._score_duty_cycle = 1;
parms._score_interval = 0;
parms._stopping_rounds = 0;
parms._overwrite_with_best_model = true;
dl = new DeepLearning(parms).trainModel().get();
double ll1 = ((ModelMetricsMultinomial) dl._output._validation_metrics).logloss();
DeepLearningParameters parms2 = (DeepLearningParameters) parms.clone();
parms2._epochs = 20;
parms2._checkpoint = dl._key;
dl2 = new DeepLearning(parms2).trainModel().get();
double ll2 = ((ModelMetricsMultinomial) dl2._output._validation_metrics).logloss();
Assert.assertTrue(ll2 <= ll1);
} finally {
if (tfr != null)
tfr.delete();
if (dl != null)
dl.delete();
if (dl2 != null)
dl2.delete();
if (train != null)
train.delete();
if (valid != null)
valid.delete();
}
}
Aggregations