use of hex.deeplearning.DeepLearningModel.DeepLearningParameters in project h2o-3 by h2oai.
the class DeepLearningTest method testConvergenceAUC_ModifiedHuber.
@Ignore
public void testConvergenceAUC_ModifiedHuber() {
Frame tfr = null;
DeepLearningModel dl = null;
DeepLearningModel dl2 = null;
try {
tfr = parse_test_file("./smalldata/logreg/prostate.csv");
for (String s : new String[] { "CAPSULE" }) {
Vec resp = tfr.vec(s).toCategoricalVec();
tfr.remove(s).remove();
tfr.add(s, resp);
DKV.put(tfr);
}
DeepLearningParameters parms = new DeepLearningParameters();
parms._train = tfr._key;
parms._epochs = 1000000;
parms._response_column = "CAPSULE";
parms._reproducible = true;
parms._hidden = new int[] { 2, 2 };
parms._seed = 0xdecaf;
parms._variable_importances = true;
parms._distribution = modified_huber;
parms._score_duty_cycle = 1.0;
parms._score_interval = 0;
//don't stop based on absolute classification error
parms._classification_stop = -1;
//don't stop based on absolute classification error
parms._stopping_rounds = 2;
//don't stop based on absolute classification error
parms._stopping_metric = ScoreKeeper.StoppingMetric.AUC;
parms._stopping_tolerance = 0.0;
dl = new DeepLearning(parms).trainModel().get();
Assert.assertTrue(dl.epoch_counter < parms._epochs);
} finally {
if (tfr != null)
tfr.delete();
if (dl != null)
dl.delete();
if (dl2 != null)
dl2.delete();
}
}
use of hex.deeplearning.DeepLearningModel.DeepLearningParameters in project h2o-3 by h2oai.
the class DeepLearningTest method testInitialWeightsAndBiases.
@Test
public void testInitialWeightsAndBiases() {
Frame tfr = null;
DeepLearningModel dl1 = null;
DeepLearningModel dl2 = null;
try {
tfr = parse_test_file("./smalldata/gbm_test/BostonHousing.csv");
// train DL model from scratch
{
DeepLearningParameters parms = new DeepLearningParameters();
parms._train = tfr._key;
parms._response_column = tfr.lastVecName();
parms._activation = DeepLearningParameters.Activation.Tanh;
parms._reproducible = true;
parms._hidden = new int[] { 20, 20 };
parms._seed = 0xdecad;
parms._export_weights_and_biases = true;
dl1 = new DeepLearning(parms).trainModel().get();
}
// train DL model starting from weights/biases from first model
{
DeepLearningParameters parms = new DeepLearningParameters();
parms._train = tfr._key;
parms._response_column = tfr.lastVecName();
parms._activation = DeepLearningParameters.Activation.Tanh;
parms._reproducible = true;
parms._hidden = new int[] { 20, 20 };
parms._seed = 0xdecad;
parms._initial_weights = dl1._output.weights;
parms._initial_biases = dl1._output.biases;
parms._epochs = 0;
dl2 = new DeepLearning(parms).trainModel().get();
}
Log.info("dl1 : MSE=" + dl1._output._training_metrics.mse());
Log.info("dl2 : MSE=" + dl2._output._training_metrics.mse());
Assert.assertTrue(Math.abs(dl1._output._training_metrics.mse() - dl2._output._training_metrics.mse()) < 1e-6);
} finally {
if (tfr != null)
tfr.delete();
if (dl1 != null)
dl1.delete();
if (dl2 != null)
dl2.delete();
for (Key f : dl1._output.weights) f.remove();
for (Key f : dl1._output.biases) f.remove();
}
}
use of hex.deeplearning.DeepLearningModel.DeepLearningParameters in project h2o-3 by h2oai.
the class DeepLearningTest method testGaussian.
@Test
public void testGaussian() {
Frame tfr = null;
DeepLearningModel dl = null;
try {
tfr = parse_test_file("./smalldata/gbm_test/BostonHousing.csv");
DeepLearningParameters parms = new DeepLearningParameters();
parms._train = tfr._key;
parms._response_column = tfr.lastVecName();
parms._reproducible = true;
parms._hidden = new int[] { 20, 20 };
parms._seed = 0xdecaf;
parms._distribution = gaussian;
dl = new DeepLearning(parms).trainModel().get();
Assert.assertEquals(12.93808, /*MSE*/
((ModelMetricsRegression) dl._output._training_metrics)._mean_residual_deviance, 1e-5);
Assert.assertEquals(12.93808, /*MSE*/
((ModelMetricsRegression) dl._output._training_metrics)._MSE, 1e-5);
} finally {
if (tfr != null)
tfr.delete();
if (dl != null)
dl.deleteCrossValidationModels();
if (dl != null)
dl.delete();
}
}
use of hex.deeplearning.DeepLearningModel.DeepLearningParameters in project h2o-3 by h2oai.
the class DeepLearningTest method testCheckpointBackwards.
@Test
public void testCheckpointBackwards() {
Frame tfr = null;
DeepLearningModel dl = null;
DeepLearningModel dl2 = null;
try {
tfr = parse_test_file("./smalldata/iris/iris.csv");
DeepLearningParameters parms = new DeepLearningParameters();
parms._train = tfr._key;
parms._epochs = 10;
parms._response_column = "C5";
parms._reproducible = true;
parms._hidden = new int[] { 2, 2 };
parms._seed = 0xdecaf;
parms._variable_importances = true;
dl = new DeepLearning(parms).trainModel().get();
DeepLearningParameters parms2 = (DeepLearningParameters) parms.clone();
parms2._epochs = 9;
parms2._checkpoint = dl._key;
try {
dl2 = new DeepLearning(parms2).trainModel().get();
Assert.fail("Should toss exception instead of reaching here");
} catch (H2OIllegalArgumentException ex) {
}
} finally {
if (tfr != null)
tfr.delete();
if (dl != null)
dl.delete();
if (dl2 != null)
dl2.delete();
}
}
use of hex.deeplearning.DeepLearningModel.DeepLearningParameters in project h2o-3 by h2oai.
the class DeepLearningTest method testRegression.
@Test
public void testRegression() {
Frame train = null;
Frame preds = null;
DeepLearningModel model = null;
Scope.enter();
try {
train = parse_test_file("./smalldata/junit/titanic_alt.csv");
DeepLearningParameters parms = new DeepLearningParameters();
parms._train = train._key;
parms._response_column = "age";
parms._reproducible = true;
parms._hidden = new int[] { 20, 20 };
parms._distribution = laplace;
parms._seed = 0xdecaf;
model = new DeepLearning(parms).trainModel().get();
preds = model.score(train);
//actual
Vec targets = train.vec("age");
ModelMetricsRegression mm = ModelMetricsRegression.make(preds.vec(0), targets, parms._distribution);
Log.info(mm.toString());
mm = ModelMetricsRegression.make(preds.vec(0), targets, gaussian);
Log.info(mm.toString());
mm = ModelMetricsRegression.make(preds.vec(0), targets, poisson);
Log.info(mm.toString());
} catch (Throwable t) {
t.printStackTrace();
throw t;
} finally {
if (model != null)
model.delete();
if (preds != null)
preds.remove();
if (train != null)
train.remove();
Scope.exit();
}
}
Aggregations